| """ |
| ============================================================================== |
| API FastAPI RAG/NLP + TinyLlama local pour OVH AI Notebooks |
| ============================================================================== |
| |
| Objectifs de cette version corrigée : |
| 1) Télécharger TinyLlama/TinyLlama-1.1B-Chat-v1.0 dans un dossier local. |
| 2) Charger TinyLlama depuis ce dossier local pour éviter de re-télécharger. |
| 3) Exposer l'API sur le port 8080, compatible avec l'URL publique OVH AI Notebooks. |
| 4) Permettre la consommation depuis ton PC local via l'URL OVH publique. |
| 5) Garder les endpoints RAG existants, mais rendre l'encodeur RAG optionnel. |
| |
| Installation minimale : |
| pip install -U fastapi uvicorn transformers torch huggingface_hub pydantic numpy |
| pip install -U accelerate safetensors sentencepiece protobuf |
| pip install -U faiss-cpu # seulement si tu utilises /index/build ou /search |
| |
| Premier lancement sur OVH AI Notebook : |
| python rag_api_server_local_tinyllama.py \ |
| --host 0.0.0.0 \ |
| --port 8080 \ |
| --download-llm \ |
| --public-url "https://TON-ID-NOTEBOOK.notebook.gra.ai.cloud.ovh.net" |
| |
| Ensuite, depuis ton PC local : |
| export OVH_RAG_URL="https://TON-ID-NOTEBOOK.notebook.gra.ai.cloud.ovh.net" |
| curl -X POST "$OVH_RAG_URL/generate" \ |
| -H "Content-Type: application/json" \ |
| -d '{"prompt":"Explique RAG simplement en français.","max_new_tokens":128}' |
| |
| Notes OVH : |
| - Sur AI Notebooks, l'URL publique mappe le port 8080. Lance donc FastAPI sur 8080. |
| - Si ton notebook est en accès restreint, ajoute ton token dans l'en-tête Authorization |
| selon la configuration OVH. |
| ============================================================================== |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| import os |
| import time |
| from contextlib import asynccontextmanager |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from fastapi import FastAPI, HTTPException, Request |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import JSONResponse |
| from fastapi.staticfiles import StaticFiles |
| from pydantic import BaseModel, Field |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| |
| |
| |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") |
| log = logging.getLogger("rag-tinyllama-api") |
|
|
| TINYLLAMA_REPO_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
| DEFAULT_LLM_DIR = "./models/TinyLlama-1.1B-Chat-v1.0" |
|
|
|
|
| def env_bool(name: str, default: bool = False) -> bool: |
| value = os.environ.get(name) |
| if value is None: |
| return default |
| return value.strip().lower() in {"1", "true", "yes", "y", "on"} |
|
|
|
|
| def clean_public_url(url: str) -> str: |
| return url.strip().rstrip("/") |
|
|
|
|
| def ensure_tinyllama_local(model_id: str, local_dir: str, download: bool = True) -> str: |
| """Return a local model path. Download from Hugging Face Hub if requested/missing.""" |
| target = Path(local_dir).expanduser().resolve() |
| has_config = (target / "config.json").exists() |
| has_tokenizer = any((target / name).exists() for name in ["tokenizer.json", "tokenizer.model"]) |
|
|
| if has_config and has_tokenizer: |
| log.info("TinyLlama déjà disponible localement : %s", target) |
| return str(target) |
|
|
| if not download: |
| raise RuntimeError( |
| f"TinyLlama introuvable dans {target}. Relance avec --download-llm ou mets " |
| f"RAG_LLM_DOWNLOAD=true." |
| ) |
|
|
| log.info("Téléchargement TinyLlama depuis Hugging Face : %s -> %s", model_id, target) |
| try: |
| from huggingface_hub import snapshot_download |
| except ImportError as exc: |
| raise RuntimeError("Installe huggingface_hub : pip install -U huggingface_hub") from exc |
|
|
| target.mkdir(parents=True, exist_ok=True) |
| snapshot_download( |
| repo_id=model_id, |
| local_dir=str(target), |
| local_dir_use_symlinks=False, |
| resume_download=True, |
| ignore_patterns=["*.md", ".gitattributes"], |
| ) |
| log.info("Téléchargement terminé : %s", target) |
| return str(target) |
|
|
|
|
| |
| |
| |
| @dataclass |
| class ModelConfig: |
| vocab_size: int = 32005 |
| hidden_size: int = 384 |
| num_hidden_layers: int = 6 |
| num_attention_heads: int = 6 |
| intermediate_size: int = 1536 |
| max_position_embeddings: int = 256 |
| hidden_dropout_prob: float = 0.0 |
| attention_probs_dropout_prob: float = 0.0 |
| layer_norm_eps: float = 1e-12 |
| embedding_dim: int = 384 |
| use_layer_scale: bool = True |
| layer_scale_init: float = 1e-4 |
|
|
|
|
| class TransformerEncoderBlock(nn.Module): |
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.num_heads = cfg.num_attention_heads |
| self.head_dim = cfg.hidden_size // cfg.num_attention_heads |
| self.ln1 = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps) |
| self.qkv = nn.Linear(cfg.hidden_size, 3 * cfg.hidden_size, bias=True) |
| self.proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True) |
| self.ln2 = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps) |
| self.mlp = nn.Sequential( |
| nn.Linear(cfg.hidden_size, cfg.intermediate_size), |
| nn.GELU(), |
| nn.Linear(cfg.intermediate_size, cfg.hidden_size), |
| nn.Dropout(cfg.hidden_dropout_prob), |
| ) |
| self.resid_drop = nn.Dropout(cfg.hidden_dropout_prob) |
| self.use_ls = cfg.use_layer_scale |
| if cfg.use_layer_scale: |
| self.gamma1 = nn.Parameter(cfg.layer_scale_init * torch.ones(cfg.hidden_size)) |
| self.gamma2 = nn.Parameter(cfg.layer_scale_init * torch.ones(cfg.hidden_size)) |
|
|
| def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: |
| batch, tokens, channels = x.shape |
| h = self.ln1(x) |
| qkv = h.new_empty(batch, tokens, 3, self.num_heads, self.head_dim) |
| qkv.copy_(self.qkv(h).view(batch, tokens, 3, self.num_heads, self.head_dim)) |
| q, k, v = qkv.permute(2, 0, 3, 1, 4) |
|
|
| bool_mask = attn_mask[:, None, None, :].bool() |
| attn_out = F.scaled_dot_product_attention( |
| q, k, v, attn_mask=bool_mask, dropout_p=0.0, is_causal=False |
| ) |
| attn_out = attn_out.transpose(1, 2).contiguous().view(batch, tokens, channels) |
| attn_out = self.resid_drop(self.proj(attn_out)) |
| if self.use_ls: |
| attn_out = attn_out * self.gamma1 |
| x = x + attn_out |
|
|
| mlp_out = self.mlp(self.ln2(x)) |
| if self.use_ls: |
| mlp_out = mlp_out * self.gamma2 |
| return x + mlp_out |
|
|
|
|
| class TextEncoder(nn.Module): |
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.hidden_size, padding_idx=0) |
| self.pos_emb = nn.Embedding(cfg.max_position_embeddings, cfg.hidden_size) |
| self.emb_ln = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps) |
| self.emb_drop = nn.Dropout(cfg.hidden_dropout_prob) |
| self.blocks = nn.ModuleList([TransformerEncoderBlock(cfg) for _ in range(cfg.num_hidden_layers)]) |
| self.ln_f = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps) |
| self.proj_head = nn.Sequential( |
| nn.Linear(cfg.hidden_size, cfg.hidden_size), |
| nn.Tanh(), |
| nn.Linear(cfg.hidden_size, cfg.embedding_dim), |
| ) |
|
|
| def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| batch, tokens = input_ids.shape |
| positions = torch.arange(tokens, device=input_ids.device).unsqueeze(0).expand(batch, tokens) |
| x = self.tok_emb(input_ids) + self.pos_emb(positions) |
| x = self.emb_drop(self.emb_ln(x)) |
| for block in self.blocks: |
| x = block(x, attention_mask) |
| x = self.ln_f(x) |
| mask = attention_mask.unsqueeze(-1).float() |
| pooled = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-6) |
| emb = self.proj_head(pooled) |
| return F.normalize(emb, p=2, dim=-1) |
|
|
|
|
| |
| |
| |
| class RAGService: |
| def __init__( |
| self, |
| ckpt_path: Optional[str], |
| tokenizer_dir: Optional[str], |
| max_len: int = 128, |
| llm_model_id: str = TINYLLAMA_REPO_ID, |
| llm_dir: str = DEFAULT_LLM_DIR, |
| download_llm: bool = True, |
| llm_offline: bool = False, |
| public_url: str = "", |
| load_encoder: bool = True, |
| ): |
| self.ckpt_path = ckpt_path |
| self.tokenizer_dir = tokenizer_dir |
| self.max_len = max_len |
| self.llm_model_id = llm_model_id |
| self.llm_dir = llm_dir |
| self.download_llm = download_llm |
| self.llm_offline = llm_offline |
| self.public_url = clean_public_url(public_url) if public_url else "" |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| if self.device.type == "cuda": |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.set_float32_matmul_precision("high") |
|
|
| self.tokenizer = None |
| self.model = None |
| self.cfg = ModelConfig() |
| self.train_metrics: Dict[str, Any] = {} |
| self.train_epoch = -1 |
|
|
| self.llm_tokenizer = None |
| self.llm_model = None |
| self.llm_path = "" |
|
|
| self.index = None |
| self.corpus_texts: List[str] = [] |
| self.corpus_meta: List[Dict[str, Any]] = [] |
|
|
| if load_encoder: |
| self._try_load_encoder() |
| self._load_llm() |
|
|
| @property |
| def encoder_loaded(self) -> bool: |
| return self.model is not None and self.tokenizer is not None |
|
|
| def _try_load_encoder(self) -> None: |
| if not self.ckpt_path or not self.tokenizer_dir: |
| log.warning("Encodeur RAG désactivé : ckpt/tokenizer-dir non fournis.") |
| return |
|
|
| ckpt_path = Path(self.ckpt_path) |
| tokenizer_dir = Path(self.tokenizer_dir) |
| if not ckpt_path.exists(): |
| log.warning("Encodeur RAG désactivé : checkpoint introuvable : %s", ckpt_path) |
| return |
| if not tokenizer_dir.exists() or not any(tokenizer_dir.glob("tokenizer*")): |
| log.warning("Encodeur RAG désactivé : tokenizer introuvable : %s", tokenizer_dir) |
| return |
|
|
| log.info("Chargement tokenizer RAG depuis %s", tokenizer_dir) |
| self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_dir)) |
|
|
| log.info("Chargement checkpoint RAG %s", ckpt_path) |
| ckpt = torch.load(str(ckpt_path), map_location=self.device) |
| saved_cfg = ckpt.get("config", {}) |
| self.cfg = ModelConfig( |
| vocab_size=saved_cfg.get("vocab_size", self.tokenizer.vocab_size), |
| hidden_size=saved_cfg.get("hidden_size", 384), |
| num_hidden_layers=saved_cfg.get("num_hidden_layers", 6), |
| num_attention_heads=saved_cfg.get("num_attention_heads", 6), |
| intermediate_size=saved_cfg.get("intermediate_size", 1536), |
| max_position_embeddings=saved_cfg.get("max_position_embeddings", 256), |
| embedding_dim=saved_cfg.get("embedding_dim", 384), |
| use_layer_scale=saved_cfg.get("use_layer_scale", True), |
| layer_scale_init=saved_cfg.get("layer_scale_init", 1e-4), |
| ) |
| self.model = TextEncoder(self.cfg).to(self.device) |
| self.model.load_state_dict(ckpt["model_state"], strict=False) |
| self.model.eval() |
| self.train_metrics = ckpt.get("metrics", {}) |
| self.train_epoch = ckpt.get("epoch", -1) |
| log.info("Encodeur RAG chargé sur %s, dim=%s", self.device, self.cfg.embedding_dim) |
|
|
| def _load_llm(self) -> None: |
| self.llm_path = ensure_tinyllama_local( |
| model_id=self.llm_model_id, |
| local_dir=self.llm_dir, |
| download=self.download_llm and not self.llm_offline, |
| ) |
| kwargs = { |
| "trust_remote_code": True, |
| "local_files_only": self.llm_offline, |
| } |
| log.info("Chargement tokenizer TinyLlama depuis %s", self.llm_path) |
| self.llm_tokenizer = AutoTokenizer.from_pretrained(self.llm_path, **kwargs) |
| if self.llm_tokenizer.pad_token_id is None: |
| self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token |
|
|
| dtype = torch.bfloat16 if self.device.type == "cuda" else torch.float32 |
| log.info("Chargement modèle TinyLlama depuis %s", self.llm_path) |
| try: |
| self.llm_model = AutoModelForCausalLM.from_pretrained( |
| self.llm_path, |
| torch_dtype=dtype, |
| device_map="auto" if self.device.type == "cuda" else None, |
| low_cpu_mem_usage=True, |
| **kwargs, |
| ) |
| except Exception as exc: |
| log.warning("device_map='auto' indisponible ou erreur accelerate (%s). Bascule .to(device).", exc) |
| self.llm_model = AutoModelForCausalLM.from_pretrained( |
| self.llm_path, |
| torch_dtype=dtype, |
| low_cpu_mem_usage=True, |
| **kwargs, |
| ).to(self.device) |
| self.llm_model.eval() |
| log.info("TinyLlama chargé : %s sur %s", self.llm_model_id, self.device) |
|
|
| def require_encoder(self) -> None: |
| """Lève HTTPException 503 si l'encodeur RAG n'est pas chargé.""" |
| if not self.encoder_loaded: |
| raise HTTPException( |
| status_code=503, |
| detail=( |
| "Encodeur RAG non chargé. Fournis --ckpt et --tokenizer-dir valides, " |
| "ou utilise seulement /generate et /v1/chat/completions." |
| ), |
| ) |
|
|
| def require_index(self) -> None: |
| """Lève HTTPException 400 si aucun index FAISS n'est chargé en mémoire.""" |
| if self.index is None: |
| raise HTTPException( |
| status_code=400, |
| detail="Aucun index chargé. Appelle POST /index/build ou POST /index/load d'abord.", |
| ) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| prompt: str, |
| max_new_tokens: int = 128, |
| temperature: float = 0.7, |
| top_p: float = 0.9, |
| system: Optional[str] = None, |
| ) -> str: |
| if self.llm_model is None or self.llm_tokenizer is None: |
| raise HTTPException(status_code=503, detail="TinyLlama non chargé.") |
|
|
| messages = [] |
| if system: |
| messages.append({"role": "system", "content": system}) |
| messages.append({"role": "user", "content": prompt}) |
|
|
| if hasattr(self.llm_tokenizer, "apply_chat_template") and self.llm_tokenizer.chat_template: |
| input_ids = self.llm_tokenizer.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| return_tensors="pt", |
| ).to(self.llm_model.device) |
| attention_mask = torch.ones_like(input_ids) |
| model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} |
| else: |
| model_inputs = self.llm_tokenizer( |
| prompt, |
| return_tensors="pt", |
| truncation=True, |
| max_length=2048, |
| ).to(self.llm_model.device) |
|
|
| do_sample = temperature > 0 |
| generated = self.llm_model.generate( |
| **model_inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=max(temperature, 1e-5), |
| top_p=top_p, |
| do_sample=do_sample, |
| pad_token_id=self.llm_tokenizer.pad_token_id, |
| eos_token_id=self.llm_tokenizer.eos_token_id, |
| no_repeat_ngram_size=2, |
| ) |
| prompt_len = model_inputs["input_ids"].shape[-1] |
| answer_ids = generated[0][prompt_len:] |
| return self.llm_tokenizer.decode(answer_ids, skip_special_tokens=True).strip() |
|
|
| @torch.no_grad() |
| def encode(self, texts: List[str], batch_size: int = 64, normalize: bool = True) -> np.ndarray: |
| self.require_encoder() |
| if not texts: |
| return np.zeros((0, self.cfg.embedding_dim), dtype=np.float32) |
|
|
| embs = [] |
| for i in range(0, len(texts), batch_size): |
| chunk = texts[i : i + batch_size] |
| enc = self.tokenizer( |
| chunk, |
| padding=True, |
| truncation=True, |
| max_length=self.max_len, |
| return_tensors="pt", |
| ).to(self.device) |
| with torch.autocast( |
| device_type=self.device.type, |
| dtype=torch.bfloat16, |
| enabled=(self.device.type == "cuda"), |
| ): |
| e = self.model(enc["input_ids"], enc["attention_mask"]) |
| embs.append(e.float().cpu().numpy()) |
|
|
| out = np.concatenate(embs, axis=0) |
| if not normalize: |
| return out.astype(np.float32) |
| norms = np.linalg.norm(out, axis=1, keepdims=True).clip(min=1e-12) |
| return (out / norms).astype(np.float32) |
|
|
| def build_index(self, corpus: List[str], metas: Optional[List[Dict[str, Any]]] = None) -> None: |
| self.require_encoder() |
| try: |
| import faiss |
| except ImportError as exc: |
| raise HTTPException( |
| status_code=501, |
| detail="faiss-cpu non installé. Fais : pip install faiss-cpu", |
| ) from exc |
|
|
| log.info("Indexation de %s documents...", len(corpus)) |
| embs = self.encode(corpus, batch_size=128) |
| index = faiss.IndexFlatIP(embs.shape[1]) |
| index.add(embs) |
| self.index = index |
| self.corpus_texts = list(corpus) |
| self.corpus_meta = metas if metas else [{} for _ in corpus] |
|
|
| def save_index(self, path: str) -> None: |
| try: |
| import faiss |
| except ImportError as exc: |
| raise HTTPException(status_code=501, detail="faiss-cpu non installé.") from exc |
| if self.index is None: |
| raise HTTPException(status_code=400, detail="Aucun index à sauvegarder.") |
| faiss.write_index(self.index, path) |
| with open(path + ".meta.json", "w", encoding="utf-8") as f: |
| json.dump({"texts": self.corpus_texts, "meta": self.corpus_meta}, f, ensure_ascii=False) |
|
|
| def load_index(self, path: str) -> None: |
| try: |
| import faiss |
| except ImportError as exc: |
| raise HTTPException(status_code=501, detail="faiss-cpu non installé.") from exc |
| index_path = Path(path) |
| if not index_path.exists(): |
| raise HTTPException(status_code=404, detail=f"Index introuvable : {path}") |
| self.index = faiss.read_index(str(index_path)) |
| meta_path = str(index_path) + ".meta.json" |
| if os.path.exists(meta_path): |
| with open(meta_path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| self.corpus_texts = data.get("texts", []) |
| self.corpus_meta = data.get("meta", []) |
| log.info("Index chargé : %s vecteurs", self.index.ntotal) |
|
|
| def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: |
| self.require_index() |
| q_emb = self.encode([query]) |
| scores, ids = self.index.search(q_emb, top_k) |
| results = [] |
| for score, idx in zip(scores[0].tolist(), ids[0].tolist()): |
| if idx < 0 or idx >= len(self.corpus_texts): |
| continue |
| results.append( |
| { |
| "score": float(score), |
| "text": self.corpus_texts[idx], |
| "meta": self.corpus_meta[idx] if idx < len(self.corpus_meta) else {}, |
| "doc_id": idx, |
| } |
| ) |
| return results |
|
|
| def index_status(self) -> Dict[str, Any]: |
| return { |
| "loaded": self.index is not None, |
| "n_vectors": int(self.index.ntotal) if self.index is not None else 0, |
| "n_documents": len(self.corpus_texts), |
| "has_metadata": bool(self.corpus_meta), |
| "embedding_dim": self.cfg.embedding_dim if self.encoder_loaded else None, |
| "encoder_loaded": self.encoder_loaded, |
| } |
|
|
| def clear_index(self) -> None: |
| self.index = None |
| self.corpus_texts = [] |
| self.corpus_meta = [] |
|
|
| def add_documents(self, documents: List[str], metas: Optional[List[Dict[str, Any]]] = None) -> None: |
| self.require_encoder() |
| if not documents: |
| raise HTTPException(status_code=400, detail="documents vide") |
| if metas is not None and len(metas) != len(documents): |
| raise HTTPException(status_code=400, detail="metadata doit avoir la même taille que documents") |
| new_texts = self.corpus_texts + list(documents) |
| new_meta = self.corpus_meta + (metas if metas is not None else [{} for _ in documents]) |
| self.build_index(new_texts, new_meta) |
|
|
| def delete_document(self, doc_id: int) -> None: |
| self.require_encoder() |
| if doc_id < 0 or doc_id >= len(self.corpus_texts): |
| raise HTTPException(status_code=404, detail=f"Document introuvable : doc_id={doc_id}") |
| del self.corpus_texts[doc_id] |
| if doc_id < len(self.corpus_meta): |
| del self.corpus_meta[doc_id] |
| if self.corpus_texts: |
| self.build_index(self.corpus_texts, self.corpus_meta) |
| else: |
| self.clear_index() |
|
|
| def get_document(self, doc_id: int) -> Dict[str, Any]: |
| if doc_id < 0 or doc_id >= len(self.corpus_texts): |
| raise HTTPException(status_code=404, detail=f"Document introuvable : doc_id={doc_id}") |
| return { |
| "doc_id": doc_id, |
| "text": self.corpus_texts[doc_id], |
| "meta": self.corpus_meta[doc_id] if doc_id < len(self.corpus_meta) else {}, |
| } |
|
|
| def list_documents(self, offset: int = 0, limit: int = 50) -> Dict[str, Any]: |
| offset = max(0, offset) |
| limit = max(1, min(limit, 500)) |
| items = [] |
| for idx in range(offset, min(offset + limit, len(self.corpus_texts))): |
| text = self.corpus_texts[idx] |
| items.append({ |
| "doc_id": idx, |
| "text": text, |
| "preview": text[:300], |
| "chars": len(text), |
| "meta": self.corpus_meta[idx] if idx < len(self.corpus_meta) else {}, |
| }) |
| return {"offset": offset, "limit": limit, "total": len(self.corpus_texts), "documents": items} |
|
|
| def save_corpus(self, path: str) -> None: |
| target = Path(path).expanduser().resolve() |
| target.parent.mkdir(parents=True, exist_ok=True) |
| with open(target, "w", encoding="utf-8") as f: |
| json.dump({"texts": self.corpus_texts, "meta": self.corpus_meta}, f, ensure_ascii=False, indent=2) |
|
|
| def load_corpus(self, path: str, rebuild_index: bool = True) -> None: |
| source = Path(path).expanduser().resolve() |
| if not source.exists(): |
| raise HTTPException(status_code=404, detail=f"Corpus introuvable : {source}") |
| with open(source, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| texts = data.get("texts") or data.get("documents") or [] |
| metas = data.get("meta") or data.get("metadata") or [{} for _ in texts] |
| if len(metas) != len(texts): |
| raise HTTPException(status_code=422, detail="Corpus invalide : texts/meta tailles différentes") |
| if rebuild_index: |
| self.build_index(texts, metas) |
| else: |
| self.corpus_texts = list(texts) |
| self.corpus_meta = list(metas) |
| self.index = None |
|
|
| def count_llm_tokens(self, text: str) -> int: |
| if self.llm_tokenizer is None: |
| raise HTTPException(status_code=503, detail="Tokenizer TinyLlama non chargé.") |
| return int(len(self.llm_tokenizer(text, add_special_tokens=False)["input_ids"])) |
|
|
| def generate_batch( |
| self, |
| prompts: List[str], |
| max_new_tokens: int = 128, |
| temperature: float = 0.7, |
| top_p: float = 0.9, |
| system: Optional[str] = None, |
| ) -> List[str]: |
| if not prompts: |
| raise HTTPException(status_code=400, detail="prompts vide") |
| return [ |
| self.generate( |
| prompt=p, |
| system=system, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| ) |
| for p in prompts |
| ] |
|
|
| def build_rag_prompt(self, query: str, top_k: int = 3, max_context_chars: int = 2000) -> Dict[str, Any]: |
| self.require_index() |
| results = self.search(query, top_k=top_k) |
| pieces, total = [], 0 |
| for r in results: |
| snippet = r["text"] |
| remaining = max_context_chars - total |
| if remaining <= 0: |
| break |
| if len(snippet) > remaining: |
| snippet = snippet[:remaining] |
| pieces.append(f"[doc#{r['doc_id']}] {snippet}") |
| total += len(snippet) |
| context = "\n\n".join(pieces) |
| prompt = ( |
| "Tu es un assistant qui répond UNIQUEMENT à partir du contexte fourni.\n" |
| "Si la réponse n'est pas dans le contexte, dis : Information non disponible.\n\n" |
| f"Contexte :\n{context}\n\n" |
| f"Question : {query}\n" |
| "Réponse :" |
| ) |
| return {"context": context, "prompt": prompt, "sources": results} |
|
|
| def ask( |
| self, |
| query: str, |
| top_k: int = 3, |
| max_context_chars: int = 2000, |
| max_new_tokens: int = 256, |
| temperature: float = 0.4, |
| top_p: float = 0.9, |
| ) -> Dict[str, Any]: |
| """Répond avec RAG si un index est chargé, sinon génération directe TinyLlama.""" |
| if self.index is not None: |
| rag_data = self.build_rag_prompt(query, top_k=top_k, max_context_chars=max_context_chars) |
| answer = self.generate( |
| prompt=rag_data["prompt"], |
| system="Tu réponds en français, de manière concise et sourcée par le contexte.", |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| ) |
| return {"mode": "rag", "answer": answer, **rag_data} |
| answer = self.generate( |
| prompt=query, |
| system="Tu es un assistant utile qui répond clairement en français.", |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| ) |
| return {"mode": "llm", "answer": answer, "context": "", "prompt": query, "sources": []} |
|
|
|
|
| |
| |
| |
| service: Optional[RAGService] = None |
|
|
|
|
| def mount_model_files(app: FastAPI, model_files_dir: str) -> None: |
| path = Path(model_files_dir).expanduser().resolve() |
| path.mkdir(parents=True, exist_ok=True) |
| app.mount("/files", StaticFiles(directory=str(path)), name="model_files") |
| log.info("Dossier /files monté depuis %s", path) |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| global service |
| ckpt = os.environ.get("RAG_CKPT", "") |
| tok_dir = os.environ.get("RAG_TOKENIZER_DIR", "") |
| llm_model_id = os.environ.get("RAG_LLM_MODEL", TINYLLAMA_REPO_ID) |
| llm_dir = os.environ.get("RAG_LLM_DIR", DEFAULT_LLM_DIR) |
| model_files_dir = os.environ.get("RAG_MODEL_FILES_DIR", "./models") |
| public_url = os.environ.get("OVH_PUBLIC_URL", "") |
| download_llm = env_bool("RAG_LLM_DOWNLOAD", True) |
| offline_llm = env_bool("RAG_LLM_OFFLINE", False) |
| load_encoder = env_bool("RAG_LOAD_ENCODER", True) |
|
|
| mount_model_files(app, model_files_dir) |
| service = RAGService( |
| ckpt_path=ckpt or None, |
| tokenizer_dir=tok_dir or None, |
| llm_model_id=llm_model_id, |
| llm_dir=llm_dir, |
| download_llm=download_llm, |
| llm_offline=offline_llm, |
| public_url=public_url, |
| load_encoder=load_encoder, |
| ) |
| log.info("Service prêt. URL publique configurée : %s", public_url or "non fournie") |
| yield |
| log.info("Arrêt du service.") |
|
|
|
|
| app = FastAPI( |
| title="SecureRAG FR + TinyLlama Local API", |
| description="API RAG optionnelle + génération TinyLlama local pour OVH AI Notebooks", |
| version="3.2.0", |
| lifespan=lifespan, |
| ) |
|
|
| allowed_origins_raw = os.environ.get("CORS_ALLOW_ORIGINS", "*") |
| allowed_origins = [item.strip() for item in allowed_origins_raw.split(",") if item.strip()] |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=allowed_origins or ["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| |
| |
| |
| class EncodeRequest(BaseModel): |
| texts: List[str] = Field(..., description="Liste de textes à encoder") |
| normalize: bool = True |
|
|
|
|
| class EncodeResponse(BaseModel): |
| embeddings: List[List[float]] |
| dim: int |
| elapsed_ms: float |
|
|
|
|
| class BuildIndexRequest(BaseModel): |
| documents: List[str] |
| metadata: Optional[List[Dict[str, Any]]] = None |
| save_path: Optional[str] = Field(None, description="Si fourni, sauvegarde l'index FAISS") |
|
|
|
|
| class LoadIndexRequest(BaseModel): |
| path: str |
|
|
|
|
| class SearchRequest(BaseModel): |
| query: str |
| top_k: int = Field(5, ge=1, le=100) |
|
|
|
|
| class SearchResult(BaseModel): |
| score: float |
| text: str |
| meta: Dict[str, Any] |
| doc_id: int |
|
|
|
|
| class SearchResponse(BaseModel): |
| query: str |
| results: List[SearchResult] |
| elapsed_ms: float |
|
|
|
|
| class RAGRequest(BaseModel): |
| query: str |
| top_k: int = Field(3, ge=1, le=20) |
| max_context_chars: int = Field(2000, ge=100, le=20000) |
| generate_answer: bool = True |
| max_new_tokens: int = Field(256, ge=1, le=2048) |
| temperature: float = Field(0.4, ge=0.0, le=2.0) |
| top_p: float = Field(0.9, ge=0.01, le=1.0) |
|
|
|
|
| class RAGResponse(BaseModel): |
| query: str |
| context: str |
| sources: List[SearchResult] |
| prompt_template: str |
| answer: Optional[str] = None |
|
|
|
|
| class GenerateRequest(BaseModel): |
| prompt: str |
| system: Optional[str] = "Tu es un assistant utile qui répond clairement en français." |
| max_new_tokens: int = Field(128, ge=1, le=2048) |
| temperature: float = Field(0.7, ge=0.0, le=2.0) |
| top_p: float = Field(0.9, ge=0.01, le=1.0) |
|
|
|
|
| class GenerateResponse(BaseModel): |
| prompt: str |
| answer: str |
| elapsed_ms: float |
| model: str |
| model_path: str |
|
|
|
|
| class SimilarityRequest(BaseModel): |
| text_a: str |
| text_b: str |
|
|
|
|
| class SimilarityResponse(BaseModel): |
| cosine_similarity: float |
|
|
|
|
| class ChatMessage(BaseModel): |
| role: str |
| content: str |
|
|
|
|
| class OpenAIChatCompletionRequest(BaseModel): |
| model: Optional[str] = TINYLLAMA_REPO_ID |
| messages: List[ChatMessage] |
| max_tokens: int = Field(256, ge=1, le=2048) |
| temperature: float = Field(0.7, ge=0.0, le=2.0) |
| top_p: float = Field(0.9, ge=0.01, le=1.0) |
| stream: bool = False |
| stop: Optional[Any] = None |
|
|
|
|
| class CompletionRequest(BaseModel): |
| model: Optional[str] = TINYLLAMA_REPO_ID |
| prompt: str |
| max_tokens: int = Field(128, ge=1, le=2048) |
| temperature: float = Field(0.7, ge=0.0, le=2.0) |
| top_p: float = Field(0.9, ge=0.01, le=1.0) |
| stream: bool = False |
| stop: Optional[Any] = None |
|
|
|
|
| class CompletionChoice(BaseModel): |
| text: str |
| index: int |
| logprobs: Optional[Any] = None |
| finish_reason: Optional[str] = None |
|
|
|
|
| class CompletionResponse(BaseModel): |
| id: str |
| object: str |
| created: int |
| model: str |
| choices: List[CompletionChoice] |
| usage: Dict[str, Optional[Any]] |
|
|
|
|
| class BatchGenerateRequest(BaseModel): |
| prompts: List[str] |
| system: Optional[str] = "Tu es un assistant utile qui répond clairement en français." |
| max_new_tokens: int = Field(128, ge=1, le=2048) |
| temperature: float = Field(0.7, ge=0.0, le=2.0) |
| top_p: float = Field(0.9, ge=0.01, le=1.0) |
|
|
|
|
| class BatchGenerateResponse(BaseModel): |
| answers: List[str] |
| elapsed_ms: float |
| model: str |
|
|
|
|
| class TokenCountRequest(BaseModel): |
| text: str |
|
|
|
|
| class TokenCountResponse(BaseModel): |
| tokens: int |
| model: str |
|
|
|
|
| class SaveIndexRequest(BaseModel): |
| path: str |
|
|
|
|
| class AddDocumentsRequest(BaseModel): |
| documents: List[str] |
| metadata: Optional[List[Dict[str, Any]]] = None |
| save_path: Optional[str] = None |
|
|
|
|
| class CorpusSaveRequest(BaseModel): |
| path: str |
|
|
|
|
| class CorpusLoadRequest(BaseModel): |
| path: str |
| rebuild_index: bool = True |
|
|
|
|
| class AskRequest(BaseModel): |
| query: str |
| top_k: int = Field(3, ge=1, le=20) |
| max_context_chars: int = Field(2000, ge=100, le=20000) |
| max_new_tokens: int = Field(256, ge=1, le=2048) |
| temperature: float = Field(0.4, ge=0.0, le=2.0) |
| top_p: float = Field(0.9, ge=0.01, le=1.0) |
|
|
|
|
| class OpenAIEmbeddingsRequest(BaseModel): |
| model: Optional[str] = "rag-encoder-local" |
| input: Any |
| encoding_format: Optional[str] = "float" |
|
|
|
|
| |
| |
| |
| def _check_service() -> RAGService: |
| """Retourne le service ou lève 503.""" |
| if service is None: |
| raise HTTPException(status_code=503, detail="Service non prêt, réessaie dans quelques secondes.") |
| return service |
|
|
|
|
| |
| |
| |
| @app.get("/") |
| def root(): |
| svc = _check_service() |
| base = svc.public_url or "http://127.0.0.1:8080" |
| return { |
| "status": "ok", |
| "name": "SecureRAG FR + TinyLlama Local API", |
| "version": "3.2.0", |
| "device": str(svc.device), |
| "public_url": svc.public_url, |
| "local_test_url": "http://127.0.0.1:8080", |
| "llm_model": svc.llm_model_id, |
| "llm_path": svc.llm_path, |
| "llm_loaded": svc.llm_model is not None, |
| "encoder_loaded": svc.encoder_loaded, |
| "embedding_dim": svc.cfg.embedding_dim if svc.encoder_loaded else None, |
| "index_loaded": svc.index is not None, |
| "index_size": svc.index.ntotal if svc.index is not None else 0, |
| "endpoints": { |
| "health": f"{base}/health", |
| "generate": f"{base}/generate", |
| "generate_batch": f"{base}/generate/batch", |
| "ask": f"{base}/ask", |
| "tokens_count": f"{base}/tokens/count", |
| "encode": f"{base}/encode", |
| "similarity": f"{base}/similarity", |
| "index_build": f"{base}/index/build", |
| "index_load": f"{base}/index/load", |
| "index_save": f"{base}/index/save", |
| "index_add": f"{base}/index/add", |
| "index_status": f"{base}/index/status", |
| "index_clear": f"{base}/index/clear", |
| "search": f"{base}/search", |
| "rag": f"{base}/rag", |
| "documents_list": f"{base}/documents", |
| "document_get": f"{base}/documents/{{doc_id}}", |
| "document_delete": f"{base}/documents/{{doc_id}}", |
| "corpus_save": f"{base}/corpus/save", |
| "corpus_load": f"{base}/corpus/load", |
| "openai_chat": f"{base}/v1/chat/completions", |
| "openai_completions": f"{base}/v1/completions", |
| "openai_embeddings": f"{base}/v1/embeddings", |
| "models": f"{base}/v1/models", |
| "client_snippets": f"{base}/client/snippets", |
| "model_info": f"{base}/model/info", |
| "files_list": f"{base}/files/list", |
| }, |
| } |
|
|
|
|
| @app.get("/health") |
| def health(): |
| svc = _check_service() |
| return { |
| "status": "ok", |
| "gpu": torch.cuda.is_available(), |
| "device": str(svc.device), |
| "llm_loaded": svc.llm_model is not None, |
| "encoder_loaded": svc.encoder_loaded, |
| } |
|
|
|
|
| @app.get("/client/snippets") |
| def client_snippets(): |
| svc = _check_service() |
| base = svc.public_url or "https://TON-ID-NOTEBOOK.notebook.gra.ai.cloud.ovh.net" |
| return { |
| "curl_generate": ( |
| f"curl -X POST '{base}/generate' -H 'Content-Type: application/json' " |
| "-d '{\"prompt\":\"Explique TinyLlama en français.\",\"max_new_tokens\":128}'" |
| ), |
| "python_requests": ( |
| "import requests\n" |
| f"BASE_URL = '{base}'\n" |
| "r = requests.post(BASE_URL + '/generate', json={\n" |
| " 'prompt': 'Explique RAG simplement en français.',\n" |
| " 'max_new_tokens': 128\n" |
| "})\n" |
| "print(r.json()['answer'])" |
| ), |
| "openai_compatible_python": ( |
| "from openai import OpenAI\n" |
| f"client = OpenAI(base_url='{base}/v1', api_key='not-needed')\n" |
| "resp = client.chat.completions.create(\n" |
| " model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',\n" |
| " messages=[{'role': 'user', 'content': 'Bonjour, réponds en français.'}],\n" |
| " max_tokens=128,\n" |
| ")\n" |
| "print(resp.choices[0].message.content)" |
| ), |
| "openai_embeddings_python": ( |
| "from openai import OpenAI\n" |
| f"client = OpenAI(base_url='{base}/v1', api_key='not-needed')\n" |
| "resp = client.embeddings.create(\n" |
| " model='rag-encoder-local',\n" |
| " input=['Bonjour le monde']\n" |
| ")\n" |
| "print(resp.data[0].embedding[:5])" |
| ), |
| } |
|
|
|
|
| @app.get("/files/list") |
| def files_list(): |
| files_dir = Path(os.environ.get("RAG_MODEL_FILES_DIR", "./models")).expanduser().resolve() |
| if not files_dir.exists(): |
| raise HTTPException(status_code=404, detail=f"Dossier introuvable : {files_dir}") |
| files = [str(p.relative_to(files_dir)) for p in sorted(files_dir.rglob("*")) if p.is_file()] |
| return {"model_files_dir": str(files_dir), "files": files} |
|
|
|
|
| @app.get("/model/info") |
| def model_info(): |
| svc = _check_service() |
| return { |
| "llm_model": svc.llm_model_id, |
| "llm_path": svc.llm_path, |
| "device": str(svc.device), |
| "encoder_loaded": svc.encoder_loaded, |
| "index": svc.index_status(), |
| "public_url": svc.public_url, |
| } |
|
|
|
|
| @app.post("/tokens/count", response_model=TokenCountResponse) |
| def count_tokens(req: TokenCountRequest): |
| svc = _check_service() |
| return TokenCountResponse(tokens=svc.count_llm_tokens(req.text), model=svc.llm_model_id) |
|
|
|
|
| @app.post("/generate/batch", response_model=BatchGenerateResponse) |
| def generate_batch(req: BatchGenerateRequest): |
| svc = _check_service() |
| if not req.prompts: |
| raise HTTPException(status_code=400, detail="prompts ne peut pas être vide") |
| t0 = time.time() |
| answers = svc.generate_batch( |
| prompts=req.prompts, |
| system=req.system, |
| max_new_tokens=req.max_new_tokens, |
| temperature=req.temperature, |
| top_p=req.top_p, |
| ) |
| return BatchGenerateResponse(answers=answers, elapsed_ms=(time.time() - t0) * 1000.0, model=svc.llm_model_id) |
|
|
|
|
| @app.get("/index/status") |
| def index_status(): |
| svc = _check_service() |
| return svc.index_status() |
|
|
|
|
| @app.delete("/index/clear") |
| def index_clear(): |
| svc = _check_service() |
| svc.clear_index() |
| return {"status": "ok", "message": "index vidé"} |
|
|
|
|
| @app.post("/index/save") |
| def index_save(req: SaveIndexRequest): |
| svc = _check_service() |
| if svc.index is None: |
| raise HTTPException(status_code=400, detail="Aucun index à sauvegarder.") |
| svc.save_index(req.path) |
| return {"status": "ok", "saved_to": req.path} |
|
|
|
|
| @app.post("/index/add") |
| def index_add(req: AddDocumentsRequest): |
| svc = _check_service() |
| if not req.documents: |
| raise HTTPException(status_code=400, detail="documents ne peut pas être vide") |
| if req.metadata is not None and len(req.metadata) != len(req.documents): |
| raise HTTPException(status_code=422, detail="metadata doit avoir la même taille que documents") |
| t0 = time.time() |
| svc.add_documents(req.documents, req.metadata) |
| saved = None |
| if req.save_path: |
| svc.save_index(req.save_path) |
| saved = req.save_path |
| return { |
| "status": "ok", |
| "added": len(req.documents), |
| "index": svc.index_status(), |
| "saved_to": saved, |
| "elapsed_ms": (time.time() - t0) * 1000.0, |
| } |
|
|
|
|
| @app.get("/documents") |
| def documents_list(offset: int = 0, limit: int = 50): |
| svc = _check_service() |
| return svc.list_documents(offset=offset, limit=limit) |
|
|
|
|
| @app.get("/documents/{doc_id}") |
| def document_get(doc_id: int): |
| svc = _check_service() |
| return svc.get_document(doc_id) |
|
|
|
|
| @app.delete("/documents/{doc_id}") |
| def document_delete(doc_id: int): |
| svc = _check_service() |
| svc.delete_document(doc_id) |
| return {"status": "ok", "deleted_doc_id": doc_id, "index": svc.index_status()} |
|
|
|
|
| @app.post("/corpus/save") |
| def corpus_save(req: CorpusSaveRequest): |
| svc = _check_service() |
| svc.save_corpus(req.path) |
| return {"status": "ok", "saved_to": req.path, "n_documents": len(svc.corpus_texts)} |
|
|
|
|
| @app.post("/corpus/load") |
| def corpus_load(req: CorpusLoadRequest): |
| svc = _check_service() |
| svc.load_corpus(req.path, rebuild_index=req.rebuild_index) |
| return {"status": "ok", "loaded_from": req.path, "index": svc.index_status()} |
|
|
|
|
| @app.post("/ask") |
| def ask(req: AskRequest): |
| svc = _check_service() |
| t0 = time.time() |
| data = svc.ask( |
| query=req.query, |
| top_k=req.top_k, |
| max_context_chars=req.max_context_chars, |
| max_new_tokens=req.max_new_tokens, |
| temperature=req.temperature, |
| top_p=req.top_p, |
| ) |
| data["elapsed_ms"] = (time.time() - t0) * 1000.0 |
| return data |
|
|
|
|
| @app.get("/v1/models") |
| def openai_models(): |
| svc = _check_service() |
| return { |
| "object": "list", |
| "data": [ |
| { |
| "id": svc.llm_model_id, |
| "object": "model", |
| "created": 0, |
| "owned_by": "local", |
| "path": svc.llm_path, |
| } |
| ], |
| } |
|
|
|
|
| @app.get("/v1/models/{model_id:path}") |
| def openai_model_get(model_id: str): |
| svc = _check_service() |
| return { |
| "id": model_id or svc.llm_model_id, |
| "object": "model", |
| "created": 0, |
| "owned_by": "local", |
| "path": svc.llm_path, |
| } |
|
|
|
|
| @app.post("/v1/embeddings") |
| def openai_embeddings(req: OpenAIEmbeddingsRequest): |
| svc = _check_service() |
| if req.encoding_format not in (None, "float"): |
| raise HTTPException(status_code=400, detail="encoding_format supporté uniquement : float") |
|
|
| if isinstance(req.input, str): |
| inputs = [req.input] |
| elif isinstance(req.input, list): |
| if not req.input: |
| raise HTTPException(status_code=400, detail="input vide") |
| if all(isinstance(x, str) for x in req.input): |
| inputs = req.input |
| else: |
| raise HTTPException(status_code=400, detail="input doit être une chaîne ou une liste de chaînes") |
| else: |
| raise HTTPException(status_code=400, detail="input doit être une chaîne ou une liste de chaînes") |
|
|
| embs = svc.encode(inputs, normalize=True) |
| return { |
| "object": "list", |
| "model": req.model or "rag-encoder-local", |
| "data": [ |
| {"object": "embedding", "index": idx, "embedding": emb.tolist()} |
| for idx, emb in enumerate(embs) |
| ], |
| "usage": {"prompt_tokens": None, "total_tokens": None}, |
| } |
|
|
|
|
| @app.post("/v1/chat/completions") |
| def openai_chat_completions(req: OpenAIChatCompletionRequest): |
| svc = _check_service() |
| if req.stream: |
| raise HTTPException(status_code=400, detail="stream=True non supporté dans cette version locale") |
| if not req.messages: |
| raise HTTPException(status_code=400, detail="messages vide") |
|
|
| system_parts = [m.content for m in req.messages if m.role == "system"] |
| user_parts = [f"{m.role}: {m.content}" for m in req.messages if m.role != "system"] |
| prompt = "\n".join(user_parts).strip() |
| system = "\n".join(system_parts).strip() or None |
|
|
| t0 = time.time() |
| answer = svc.generate( |
| prompt=prompt, |
| system=system, |
| max_new_tokens=req.max_tokens, |
| temperature=req.temperature, |
| top_p=req.top_p, |
| ) |
| created = int(time.time()) |
| return { |
| "id": f"chatcmpl-local-{created}", |
| "object": "chat.completion", |
| "created": created, |
| "model": svc.llm_model_id, |
| "choices": [ |
| { |
| "index": 0, |
| "message": {"role": "assistant", "content": answer}, |
| "finish_reason": "stop", |
| } |
| ], |
| "usage": { |
| "prompt_tokens": None, |
| "completion_tokens": None, |
| "total_tokens": None, |
| "elapsed_ms": (time.time() - t0) * 1000.0, |
| }, |
| } |
|
|
|
|
| @app.post("/v1/completions", response_model=CompletionResponse) |
| def openai_completions(req: CompletionRequest): |
| svc = _check_service() |
| if req.stream: |
| raise HTTPException(status_code=400, detail="stream=True non supporté dans cette version locale") |
| t0 = time.time() |
| answer = svc.generate( |
| prompt=req.prompt, |
| system=None, |
| max_new_tokens=req.max_tokens, |
| temperature=req.temperature, |
| top_p=req.top_p, |
| ) |
| created = int(time.time()) |
| return CompletionResponse( |
| id=f"cmpl-local-{created}", |
| object="text_completion", |
| created=created, |
| model=svc.llm_model_id, |
| choices=[CompletionChoice(text=answer, index=0, finish_reason="stop")], |
| usage={ |
| "prompt_tokens": None, |
| "completion_tokens": None, |
| "total_tokens": None, |
| }, |
| ) |
|
|
|
|
| @app.post("/v1/engines/{engine}/completions", response_model=CompletionResponse) |
| def openai_engine_completions(engine: str, req: CompletionRequest): |
| svc = _check_service() |
| if req.stream: |
| raise HTTPException(status_code=400, detail="stream=True non supporté dans cette version locale") |
| t0 = time.time() |
| answer = svc.generate( |
| prompt=req.prompt, |
| system=None, |
| max_new_tokens=req.max_tokens, |
| temperature=req.temperature, |
| top_p=req.top_p, |
| ) |
| created = int(time.time()) |
| return CompletionResponse( |
| id=f"cmpl-local-{created}", |
| object="text_completion", |
| created=created, |
| model=svc.llm_model_id, |
| choices=[CompletionChoice(text=answer, index=0, finish_reason="stop")], |
| usage={ |
| "prompt_tokens": None, |
| "completion_tokens": None, |
| "total_tokens": None, |
| }, |
| ) |
|
|
|
|
| @app.post("/generate", response_model=GenerateResponse) |
| def generate(req: GenerateRequest): |
| svc = _check_service() |
| t0 = time.time() |
| answer = svc.generate( |
| prompt=req.prompt, |
| system=req.system, |
| max_new_tokens=req.max_new_tokens, |
| temperature=req.temperature, |
| top_p=req.top_p, |
| ) |
| return GenerateResponse( |
| prompt=req.prompt, |
| answer=answer, |
| elapsed_ms=(time.time() - t0) * 1000.0, |
| model=svc.llm_model_id, |
| model_path=svc.llm_path, |
| ) |
|
|
|
|
| @app.post("/encode", response_model=EncodeResponse) |
| def encode(req: EncodeRequest): |
| svc = _check_service() |
| if not req.texts: |
| raise HTTPException(status_code=400, detail="texts ne peut pas être vide") |
| t0 = time.time() |
| embs = svc.encode(req.texts, normalize=req.normalize) |
| return EncodeResponse( |
| embeddings=embs.tolist(), |
| dim=int(embs.shape[1]), |
| elapsed_ms=(time.time() - t0) * 1000.0, |
| ) |
|
|
|
|
| @app.post("/index/build") |
| def index_build(req: BuildIndexRequest): |
| svc = _check_service() |
| if not req.documents: |
| raise HTTPException(status_code=400, detail="documents ne peut pas être vide") |
| if req.metadata is not None and len(req.metadata) != len(req.documents): |
| raise HTTPException(status_code=422, detail="metadata doit avoir la même taille que documents") |
|
|
| t0 = time.time() |
| svc.build_index(req.documents, req.metadata) |
| saved = None |
| if req.save_path: |
| svc.save_index(req.save_path) |
| saved = req.save_path |
| return { |
| "status": "ok", |
| "n_docs": len(req.documents), |
| "elapsed_ms": (time.time() - t0) * 1000.0, |
| "saved_to": saved, |
| } |
|
|
|
|
| @app.post("/index/load") |
| def index_load(req: LoadIndexRequest): |
| svc = _check_service() |
| |
| svc.load_index(req.path) |
| return {"status": "ok", "n_vectors": svc.index.ntotal} |
|
|
|
|
| @app.post("/search", response_model=SearchResponse) |
| def search(req: SearchRequest): |
| svc = _check_service() |
| t0 = time.time() |
| results = svc.search(req.query, top_k=req.top_k) |
| return SearchResponse( |
| query=req.query, |
| results=[SearchResult(**r) for r in results], |
| elapsed_ms=(time.time() - t0) * 1000.0, |
| ) |
|
|
|
|
| @app.post("/rag", response_model=RAGResponse) |
| def rag(req: RAGRequest): |
| svc = _check_service() |
| svc.require_index() |
|
|
| results = svc.search(req.query, top_k=req.top_k) |
| pieces, total = [], 0 |
| for r in results: |
| snippet = r["text"] |
| remaining = req.max_context_chars - total |
| if remaining <= 0: |
| break |
| if len(snippet) > remaining: |
| snippet = snippet[:remaining] |
| pieces.append(f"[doc#{r['doc_id']}] {snippet}") |
| total += len(snippet) |
| context = "\n\n".join(pieces) |
|
|
| prompt = ( |
| "Tu es un assistant qui répond UNIQUEMENT à partir du contexte fourni.\n" |
| "Si la réponse n'est pas dans le contexte, dis : Information non disponible.\n\n" |
| f"Contexte :\n{context}\n\n" |
| f"Question : {req.query}\n" |
| "Réponse :" |
| ) |
|
|
| answer = None |
| if req.generate_answer: |
| answer = svc.generate( |
| prompt=prompt, |
| system="Tu réponds en français, de manière concise et sourcée par le contexte.", |
| max_new_tokens=req.max_new_tokens, |
| temperature=req.temperature, |
| top_p=req.top_p, |
| ) |
|
|
| return RAGResponse( |
| query=req.query, |
| context=context, |
| sources=[SearchResult(**r) for r in results], |
| prompt_template=prompt, |
| answer=answer, |
| ) |
|
|
|
|
| @app.post("/similarity", response_model=SimilarityResponse) |
| def similarity(req: SimilarityRequest): |
| svc = _check_service() |
| embs = svc.encode([req.text_a, req.text_b]) |
| sim = float(np.dot(embs[0], embs[1])) |
| return SimilarityResponse(cosine_similarity=sim) |
|
|
|
|
| |
| |
| |
| |
| @app.exception_handler(Exception) |
| async def all_exception_handler(request: Request, exc: Exception): |
| log.exception("Erreur non gérée sur %s", request.url.path) |
| if isinstance(exc, HTTPException): |
| return JSONResponse( |
| status_code=exc.status_code, |
| content={"detail": exc.detail, "path": str(request.url.path)}, |
| ) |
| return JSONResponse( |
| status_code=500, |
| content={"detail": str(exc), "path": str(request.url.path)}, |
| ) |
|
|
|
|
| |
| |
| |
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--host", default="0.0.0.0") |
| parser.add_argument("--port", type=int, default=8080, help="OVH AI Notebooks expose le port 8080") |
| parser.add_argument("--ckpt", default="") |
| parser.add_argument("--tokenizer-dir", default="") |
| parser.add_argument("--no-encoder", action="store_true", help="Désactive les endpoints encode/search/rag") |
| parser.add_argument("--llm-model", default=TINYLLAMA_REPO_ID) |
| parser.add_argument("--llm-dir", default=DEFAULT_LLM_DIR, help="Dossier local de TinyLlama") |
| parser.add_argument("--download-llm", action="store_true", help="Télécharge TinyLlama si le dossier local est vide") |
| parser.add_argument("--offline-llm", action="store_true", help="Force le chargement 100% local") |
| parser.add_argument("--model-files-dir", default="./models", help="Dossier exposé par /files") |
| parser.add_argument("--public-url", default="", help="URL OVH publique, sans slash final") |
| parser.add_argument("--workers", type=int, default=1, help="Garde 1 avec GPU pour éviter de dupliquer la VRAM") |
| args = parser.parse_args() |
|
|
| |
| os.environ["RAG_CKPT"] = args.ckpt |
| os.environ["RAG_TOKENIZER_DIR"] = args.tokenizer_dir |
| os.environ["RAG_LOAD_ENCODER"] = "false" if args.no_encoder else "true" |
| os.environ["RAG_LLM_MODEL"] = args.llm_model |
| os.environ["RAG_LLM_DIR"] = args.llm_dir |
| os.environ["RAG_LLM_DOWNLOAD"] = "true" if args.download_llm else os.environ.get("RAG_LLM_DOWNLOAD", "true") |
| os.environ["RAG_LLM_OFFLINE"] = "true" if args.offline_llm else os.environ.get("RAG_LLM_OFFLINE", "false") |
| os.environ["RAG_MODEL_FILES_DIR"] = args.model_files_dir |
| if args.public_url: |
| os.environ["OVH_PUBLIC_URL"] = clean_public_url(args.public_url) |
|
|
| import uvicorn |
|
|
| if args.workers != 1: |
| log.warning("workers>1 peut dupliquer le modèle en VRAM. Recommandé : --workers 1") |
| uvicorn.run(app, host=args.host, port=args.port, workers=args.workers, log_level="info") |
|
|
|
|
| if __name__ == "__main__": |
| main() |