""" ============================================================================== API FastAPI RAG/NLP + TinyLlama local pour OVH AI Notebooks ============================================================================== Objectifs de cette version corrigée : 1) Télécharger TinyLlama/TinyLlama-1.1B-Chat-v1.0 dans un dossier local. 2) Charger TinyLlama depuis ce dossier local pour éviter de re-télécharger. 3) Exposer l'API sur le port 8080, compatible avec l'URL publique OVH AI Notebooks. 4) Permettre la consommation depuis ton PC local via l'URL OVH publique. 5) Garder les endpoints RAG existants, mais rendre l'encodeur RAG optionnel. Installation minimale : pip install -U fastapi uvicorn transformers torch huggingface_hub pydantic numpy pip install -U accelerate safetensors sentencepiece protobuf pip install -U faiss-cpu # seulement si tu utilises /index/build ou /search Premier lancement sur OVH AI Notebook : python rag_api_server_local_tinyllama.py \ --host 0.0.0.0 \ --port 8080 \ --download-llm \ --public-url "https://TON-ID-NOTEBOOK.notebook.gra.ai.cloud.ovh.net" Ensuite, depuis ton PC local : export OVH_RAG_URL="https://TON-ID-NOTEBOOK.notebook.gra.ai.cloud.ovh.net" curl -X POST "$OVH_RAG_URL/generate" \ -H "Content-Type: application/json" \ -d '{"prompt":"Explique RAG simplement en français.","max_new_tokens":128}' Notes OVH : - Sur AI Notebooks, l'URL publique mappe le port 8080. Lance donc FastAPI sur 8080. - Si ton notebook est en accès restreint, ajoute ton token dans l'en-tête Authorization selon la configuration OVH. ============================================================================== """ from __future__ import annotations import argparse import json import logging import os import time from contextlib import asynccontextmanager from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field from transformers import AutoModelForCausalLM, AutoTokenizer # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") log = logging.getLogger("rag-tinyllama-api") TINYLLAMA_REPO_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" DEFAULT_LLM_DIR = "./models/TinyLlama-1.1B-Chat-v1.0" def env_bool(name: str, default: bool = False) -> bool: value = os.environ.get(name) if value is None: return default return value.strip().lower() in {"1", "true", "yes", "y", "on"} def clean_public_url(url: str) -> str: return url.strip().rstrip("/") def ensure_tinyllama_local(model_id: str, local_dir: str, download: bool = True) -> str: """Return a local model path. Download from Hugging Face Hub if requested/missing.""" target = Path(local_dir).expanduser().resolve() has_config = (target / "config.json").exists() has_tokenizer = any((target / name).exists() for name in ["tokenizer.json", "tokenizer.model"]) if has_config and has_tokenizer: log.info("TinyLlama déjà disponible localement : %s", target) return str(target) if not download: raise RuntimeError( f"TinyLlama introuvable dans {target}. Relance avec --download-llm ou mets " f"RAG_LLM_DOWNLOAD=true." ) log.info("Téléchargement TinyLlama depuis Hugging Face : %s -> %s", model_id, target) try: from huggingface_hub import snapshot_download except ImportError as exc: raise RuntimeError("Installe huggingface_hub : pip install -U huggingface_hub") from exc target.mkdir(parents=True, exist_ok=True) snapshot_download( repo_id=model_id, local_dir=str(target), local_dir_use_symlinks=False, resume_download=True, ignore_patterns=["*.md", ".gitattributes"], ) log.info("Téléchargement terminé : %s", target) return str(target) # ============================================================================= # 1. ARCHITECTURE ENCODEUR RAG 20M - optionnelle # ============================================================================= @dataclass class ModelConfig: vocab_size: int = 32005 hidden_size: int = 384 num_hidden_layers: int = 6 num_attention_heads: int = 6 intermediate_size: int = 1536 max_position_embeddings: int = 256 hidden_dropout_prob: float = 0.0 attention_probs_dropout_prob: float = 0.0 layer_norm_eps: float = 1e-12 embedding_dim: int = 384 use_layer_scale: bool = True layer_scale_init: float = 1e-4 class TransformerEncoderBlock(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() self.num_heads = cfg.num_attention_heads self.head_dim = cfg.hidden_size // cfg.num_attention_heads self.ln1 = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps) self.qkv = nn.Linear(cfg.hidden_size, 3 * cfg.hidden_size, bias=True) self.proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True) self.ln2 = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps) self.mlp = nn.Sequential( nn.Linear(cfg.hidden_size, cfg.intermediate_size), nn.GELU(), nn.Linear(cfg.intermediate_size, cfg.hidden_size), nn.Dropout(cfg.hidden_dropout_prob), ) self.resid_drop = nn.Dropout(cfg.hidden_dropout_prob) self.use_ls = cfg.use_layer_scale if cfg.use_layer_scale: self.gamma1 = nn.Parameter(cfg.layer_scale_init * torch.ones(cfg.hidden_size)) self.gamma2 = nn.Parameter(cfg.layer_scale_init * torch.ones(cfg.hidden_size)) def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: batch, tokens, channels = x.shape h = self.ln1(x) qkv = h.new_empty(batch, tokens, 3, self.num_heads, self.head_dim) qkv.copy_(self.qkv(h).view(batch, tokens, 3, self.num_heads, self.head_dim)) q, k, v = qkv.permute(2, 0, 3, 1, 4) bool_mask = attn_mask[:, None, None, :].bool() attn_out = F.scaled_dot_product_attention( q, k, v, attn_mask=bool_mask, dropout_p=0.0, is_causal=False ) attn_out = attn_out.transpose(1, 2).contiguous().view(batch, tokens, channels) attn_out = self.resid_drop(self.proj(attn_out)) if self.use_ls: attn_out = attn_out * self.gamma1 x = x + attn_out mlp_out = self.mlp(self.ln2(x)) if self.use_ls: mlp_out = mlp_out * self.gamma2 return x + mlp_out class TextEncoder(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() self.cfg = cfg self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.hidden_size, padding_idx=0) self.pos_emb = nn.Embedding(cfg.max_position_embeddings, cfg.hidden_size) self.emb_ln = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps) self.emb_drop = nn.Dropout(cfg.hidden_dropout_prob) self.blocks = nn.ModuleList([TransformerEncoderBlock(cfg) for _ in range(cfg.num_hidden_layers)]) self.ln_f = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps) self.proj_head = nn.Sequential( nn.Linear(cfg.hidden_size, cfg.hidden_size), nn.Tanh(), nn.Linear(cfg.hidden_size, cfg.embedding_dim), ) def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: batch, tokens = input_ids.shape positions = torch.arange(tokens, device=input_ids.device).unsqueeze(0).expand(batch, tokens) x = self.tok_emb(input_ids) + self.pos_emb(positions) x = self.emb_drop(self.emb_ln(x)) for block in self.blocks: x = block(x, attention_mask) x = self.ln_f(x) mask = attention_mask.unsqueeze(-1).float() pooled = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-6) emb = self.proj_head(pooled) return F.normalize(emb, p=2, dim=-1) # ============================================================================= # 2. SERVICE # ============================================================================= class RAGService: def __init__( self, ckpt_path: Optional[str], tokenizer_dir: Optional[str], max_len: int = 128, llm_model_id: str = TINYLLAMA_REPO_ID, llm_dir: str = DEFAULT_LLM_DIR, download_llm: bool = True, llm_offline: bool = False, public_url: str = "", load_encoder: bool = True, ): self.ckpt_path = ckpt_path self.tokenizer_dir = tokenizer_dir self.max_len = max_len self.llm_model_id = llm_model_id self.llm_dir = llm_dir self.download_llm = download_llm self.llm_offline = llm_offline self.public_url = clean_public_url(public_url) if public_url else "" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if self.device.type == "cuda": torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.set_float32_matmul_precision("high") self.tokenizer = None self.model = None self.cfg = ModelConfig() self.train_metrics: Dict[str, Any] = {} self.train_epoch = -1 self.llm_tokenizer = None self.llm_model = None self.llm_path = "" self.index = None self.corpus_texts: List[str] = [] self.corpus_meta: List[Dict[str, Any]] = [] if load_encoder: self._try_load_encoder() self._load_llm() @property def encoder_loaded(self) -> bool: return self.model is not None and self.tokenizer is not None def _try_load_encoder(self) -> None: if not self.ckpt_path or not self.tokenizer_dir: log.warning("Encodeur RAG désactivé : ckpt/tokenizer-dir non fournis.") return ckpt_path = Path(self.ckpt_path) tokenizer_dir = Path(self.tokenizer_dir) if not ckpt_path.exists(): log.warning("Encodeur RAG désactivé : checkpoint introuvable : %s", ckpt_path) return if not tokenizer_dir.exists() or not any(tokenizer_dir.glob("tokenizer*")): log.warning("Encodeur RAG désactivé : tokenizer introuvable : %s", tokenizer_dir) return log.info("Chargement tokenizer RAG depuis %s", tokenizer_dir) self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_dir)) log.info("Chargement checkpoint RAG %s", ckpt_path) ckpt = torch.load(str(ckpt_path), map_location=self.device) saved_cfg = ckpt.get("config", {}) self.cfg = ModelConfig( vocab_size=saved_cfg.get("vocab_size", self.tokenizer.vocab_size), hidden_size=saved_cfg.get("hidden_size", 384), num_hidden_layers=saved_cfg.get("num_hidden_layers", 6), num_attention_heads=saved_cfg.get("num_attention_heads", 6), intermediate_size=saved_cfg.get("intermediate_size", 1536), max_position_embeddings=saved_cfg.get("max_position_embeddings", 256), embedding_dim=saved_cfg.get("embedding_dim", 384), use_layer_scale=saved_cfg.get("use_layer_scale", True), layer_scale_init=saved_cfg.get("layer_scale_init", 1e-4), ) self.model = TextEncoder(self.cfg).to(self.device) self.model.load_state_dict(ckpt["model_state"], strict=False) self.model.eval() self.train_metrics = ckpt.get("metrics", {}) self.train_epoch = ckpt.get("epoch", -1) log.info("Encodeur RAG chargé sur %s, dim=%s", self.device, self.cfg.embedding_dim) def _load_llm(self) -> None: self.llm_path = ensure_tinyllama_local( model_id=self.llm_model_id, local_dir=self.llm_dir, download=self.download_llm and not self.llm_offline, ) kwargs = { "trust_remote_code": True, "local_files_only": self.llm_offline, } log.info("Chargement tokenizer TinyLlama depuis %s", self.llm_path) self.llm_tokenizer = AutoTokenizer.from_pretrained(self.llm_path, **kwargs) if self.llm_tokenizer.pad_token_id is None: self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token dtype = torch.bfloat16 if self.device.type == "cuda" else torch.float32 log.info("Chargement modèle TinyLlama depuis %s", self.llm_path) try: self.llm_model = AutoModelForCausalLM.from_pretrained( self.llm_path, torch_dtype=dtype, device_map="auto" if self.device.type == "cuda" else None, low_cpu_mem_usage=True, **kwargs, ) except Exception as exc: log.warning("device_map='auto' indisponible ou erreur accelerate (%s). Bascule .to(device).", exc) self.llm_model = AutoModelForCausalLM.from_pretrained( self.llm_path, torch_dtype=dtype, low_cpu_mem_usage=True, **kwargs, ).to(self.device) self.llm_model.eval() log.info("TinyLlama chargé : %s sur %s", self.llm_model_id, self.device) def require_encoder(self) -> None: """Lève HTTPException 503 si l'encodeur RAG n'est pas chargé.""" if not self.encoder_loaded: raise HTTPException( status_code=503, detail=( "Encodeur RAG non chargé. Fournis --ckpt et --tokenizer-dir valides, " "ou utilise seulement /generate et /v1/chat/completions." ), ) def require_index(self) -> None: """Lève HTTPException 400 si aucun index FAISS n'est chargé en mémoire.""" if self.index is None: raise HTTPException( status_code=400, detail="Aucun index chargé. Appelle POST /index/build ou POST /index/load d'abord.", ) @torch.no_grad() def generate( self, prompt: str, max_new_tokens: int = 128, temperature: float = 0.7, top_p: float = 0.9, system: Optional[str] = None, ) -> str: if self.llm_model is None or self.llm_tokenizer is None: raise HTTPException(status_code=503, detail="TinyLlama non chargé.") messages = [] if system: messages.append({"role": "system", "content": system}) messages.append({"role": "user", "content": prompt}) if hasattr(self.llm_tokenizer, "apply_chat_template") and self.llm_tokenizer.chat_template: input_ids = self.llm_tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt", ).to(self.llm_model.device) attention_mask = torch.ones_like(input_ids) model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} else: model_inputs = self.llm_tokenizer( prompt, return_tensors="pt", truncation=True, max_length=2048, ).to(self.llm_model.device) do_sample = temperature > 0 generated = self.llm_model.generate( **model_inputs, max_new_tokens=max_new_tokens, temperature=max(temperature, 1e-5), top_p=top_p, do_sample=do_sample, pad_token_id=self.llm_tokenizer.pad_token_id, eos_token_id=self.llm_tokenizer.eos_token_id, no_repeat_ngram_size=2, ) prompt_len = model_inputs["input_ids"].shape[-1] answer_ids = generated[0][prompt_len:] return self.llm_tokenizer.decode(answer_ids, skip_special_tokens=True).strip() @torch.no_grad() def encode(self, texts: List[str], batch_size: int = 64, normalize: bool = True) -> np.ndarray: self.require_encoder() if not texts: return np.zeros((0, self.cfg.embedding_dim), dtype=np.float32) embs = [] for i in range(0, len(texts), batch_size): chunk = texts[i : i + batch_size] enc = self.tokenizer( chunk, padding=True, truncation=True, max_length=self.max_len, return_tensors="pt", ).to(self.device) with torch.autocast( device_type=self.device.type, dtype=torch.bfloat16, enabled=(self.device.type == "cuda"), ): e = self.model(enc["input_ids"], enc["attention_mask"]) embs.append(e.float().cpu().numpy()) out = np.concatenate(embs, axis=0) if not normalize: return out.astype(np.float32) norms = np.linalg.norm(out, axis=1, keepdims=True).clip(min=1e-12) return (out / norms).astype(np.float32) def build_index(self, corpus: List[str], metas: Optional[List[Dict[str, Any]]] = None) -> None: self.require_encoder() try: import faiss except ImportError as exc: raise HTTPException( status_code=501, detail="faiss-cpu non installé. Fais : pip install faiss-cpu", ) from exc log.info("Indexation de %s documents...", len(corpus)) embs = self.encode(corpus, batch_size=128) index = faiss.IndexFlatIP(embs.shape[1]) index.add(embs) self.index = index self.corpus_texts = list(corpus) self.corpus_meta = metas if metas else [{} for _ in corpus] def save_index(self, path: str) -> None: try: import faiss except ImportError as exc: raise HTTPException(status_code=501, detail="faiss-cpu non installé.") from exc if self.index is None: raise HTTPException(status_code=400, detail="Aucun index à sauvegarder.") faiss.write_index(self.index, path) with open(path + ".meta.json", "w", encoding="utf-8") as f: json.dump({"texts": self.corpus_texts, "meta": self.corpus_meta}, f, ensure_ascii=False) def load_index(self, path: str) -> None: try: import faiss except ImportError as exc: raise HTTPException(status_code=501, detail="faiss-cpu non installé.") from exc index_path = Path(path) if not index_path.exists(): raise HTTPException(status_code=404, detail=f"Index introuvable : {path}") self.index = faiss.read_index(str(index_path)) meta_path = str(index_path) + ".meta.json" if os.path.exists(meta_path): with open(meta_path, "r", encoding="utf-8") as f: data = json.load(f) self.corpus_texts = data.get("texts", []) self.corpus_meta = data.get("meta", []) log.info("Index chargé : %s vecteurs", self.index.ntotal) def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: self.require_index() q_emb = self.encode([query]) scores, ids = self.index.search(q_emb, top_k) results = [] for score, idx in zip(scores[0].tolist(), ids[0].tolist()): if idx < 0 or idx >= len(self.corpus_texts): continue results.append( { "score": float(score), "text": self.corpus_texts[idx], "meta": self.corpus_meta[idx] if idx < len(self.corpus_meta) else {}, "doc_id": idx, } ) return results def index_status(self) -> Dict[str, Any]: return { "loaded": self.index is not None, "n_vectors": int(self.index.ntotal) if self.index is not None else 0, "n_documents": len(self.corpus_texts), "has_metadata": bool(self.corpus_meta), "embedding_dim": self.cfg.embedding_dim if self.encoder_loaded else None, "encoder_loaded": self.encoder_loaded, } def clear_index(self) -> None: self.index = None self.corpus_texts = [] self.corpus_meta = [] def add_documents(self, documents: List[str], metas: Optional[List[Dict[str, Any]]] = None) -> None: self.require_encoder() if not documents: raise HTTPException(status_code=400, detail="documents vide") if metas is not None and len(metas) != len(documents): raise HTTPException(status_code=400, detail="metadata doit avoir la même taille que documents") new_texts = self.corpus_texts + list(documents) new_meta = self.corpus_meta + (metas if metas is not None else [{} for _ in documents]) self.build_index(new_texts, new_meta) def delete_document(self, doc_id: int) -> None: self.require_encoder() if doc_id < 0 or doc_id >= len(self.corpus_texts): raise HTTPException(status_code=404, detail=f"Document introuvable : doc_id={doc_id}") del self.corpus_texts[doc_id] if doc_id < len(self.corpus_meta): del self.corpus_meta[doc_id] if self.corpus_texts: self.build_index(self.corpus_texts, self.corpus_meta) else: self.clear_index() def get_document(self, doc_id: int) -> Dict[str, Any]: if doc_id < 0 or doc_id >= len(self.corpus_texts): raise HTTPException(status_code=404, detail=f"Document introuvable : doc_id={doc_id}") return { "doc_id": doc_id, "text": self.corpus_texts[doc_id], "meta": self.corpus_meta[doc_id] if doc_id < len(self.corpus_meta) else {}, } def list_documents(self, offset: int = 0, limit: int = 50) -> Dict[str, Any]: offset = max(0, offset) limit = max(1, min(limit, 500)) items = [] for idx in range(offset, min(offset + limit, len(self.corpus_texts))): text = self.corpus_texts[idx] items.append({ "doc_id": idx, "text": text, "preview": text[:300], "chars": len(text), "meta": self.corpus_meta[idx] if idx < len(self.corpus_meta) else {}, }) return {"offset": offset, "limit": limit, "total": len(self.corpus_texts), "documents": items} def save_corpus(self, path: str) -> None: target = Path(path).expanduser().resolve() target.parent.mkdir(parents=True, exist_ok=True) with open(target, "w", encoding="utf-8") as f: json.dump({"texts": self.corpus_texts, "meta": self.corpus_meta}, f, ensure_ascii=False, indent=2) def load_corpus(self, path: str, rebuild_index: bool = True) -> None: source = Path(path).expanduser().resolve() if not source.exists(): raise HTTPException(status_code=404, detail=f"Corpus introuvable : {source}") with open(source, "r", encoding="utf-8") as f: data = json.load(f) texts = data.get("texts") or data.get("documents") or [] metas = data.get("meta") or data.get("metadata") or [{} for _ in texts] if len(metas) != len(texts): raise HTTPException(status_code=422, detail="Corpus invalide : texts/meta tailles différentes") if rebuild_index: self.build_index(texts, metas) else: self.corpus_texts = list(texts) self.corpus_meta = list(metas) self.index = None def count_llm_tokens(self, text: str) -> int: if self.llm_tokenizer is None: raise HTTPException(status_code=503, detail="Tokenizer TinyLlama non chargé.") return int(len(self.llm_tokenizer(text, add_special_tokens=False)["input_ids"])) def generate_batch( self, prompts: List[str], max_new_tokens: int = 128, temperature: float = 0.7, top_p: float = 0.9, system: Optional[str] = None, ) -> List[str]: if not prompts: raise HTTPException(status_code=400, detail="prompts vide") return [ self.generate( prompt=p, system=system, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, ) for p in prompts ] def build_rag_prompt(self, query: str, top_k: int = 3, max_context_chars: int = 2000) -> Dict[str, Any]: self.require_index() results = self.search(query, top_k=top_k) pieces, total = [], 0 for r in results: snippet = r["text"] remaining = max_context_chars - total if remaining <= 0: break if len(snippet) > remaining: snippet = snippet[:remaining] pieces.append(f"[doc#{r['doc_id']}] {snippet}") total += len(snippet) context = "\n\n".join(pieces) prompt = ( "Tu es un assistant qui répond UNIQUEMENT à partir du contexte fourni.\n" "Si la réponse n'est pas dans le contexte, dis : Information non disponible.\n\n" f"Contexte :\n{context}\n\n" f"Question : {query}\n" "Réponse :" ) return {"context": context, "prompt": prompt, "sources": results} def ask( self, query: str, top_k: int = 3, max_context_chars: int = 2000, max_new_tokens: int = 256, temperature: float = 0.4, top_p: float = 0.9, ) -> Dict[str, Any]: """Répond avec RAG si un index est chargé, sinon génération directe TinyLlama.""" if self.index is not None: rag_data = self.build_rag_prompt(query, top_k=top_k, max_context_chars=max_context_chars) answer = self.generate( prompt=rag_data["prompt"], system="Tu réponds en français, de manière concise et sourcée par le contexte.", max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, ) return {"mode": "rag", "answer": answer, **rag_data} answer = self.generate( prompt=query, system="Tu es un assistant utile qui répond clairement en français.", max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, ) return {"mode": "llm", "answer": answer, "context": "", "prompt": query, "sources": []} # ============================================================================= # 3. APP FASTAPI # ============================================================================= service: Optional[RAGService] = None def mount_model_files(app: FastAPI, model_files_dir: str) -> None: path = Path(model_files_dir).expanduser().resolve() path.mkdir(parents=True, exist_ok=True) app.mount("/files", StaticFiles(directory=str(path)), name="model_files") log.info("Dossier /files monté depuis %s", path) @asynccontextmanager async def lifespan(app: FastAPI): global service ckpt = os.environ.get("RAG_CKPT", "") tok_dir = os.environ.get("RAG_TOKENIZER_DIR", "") llm_model_id = os.environ.get("RAG_LLM_MODEL", TINYLLAMA_REPO_ID) llm_dir = os.environ.get("RAG_LLM_DIR", DEFAULT_LLM_DIR) model_files_dir = os.environ.get("RAG_MODEL_FILES_DIR", "./models") public_url = os.environ.get("OVH_PUBLIC_URL", "") download_llm = env_bool("RAG_LLM_DOWNLOAD", True) offline_llm = env_bool("RAG_LLM_OFFLINE", False) load_encoder = env_bool("RAG_LOAD_ENCODER", True) mount_model_files(app, model_files_dir) service = RAGService( ckpt_path=ckpt or None, tokenizer_dir=tok_dir or None, llm_model_id=llm_model_id, llm_dir=llm_dir, download_llm=download_llm, llm_offline=offline_llm, public_url=public_url, load_encoder=load_encoder, ) log.info("Service prêt. URL publique configurée : %s", public_url or "non fournie") yield log.info("Arrêt du service.") app = FastAPI( title="SecureRAG FR + TinyLlama Local API", description="API RAG optionnelle + génération TinyLlama local pour OVH AI Notebooks", version="3.2.0", lifespan=lifespan, ) allowed_origins_raw = os.environ.get("CORS_ALLOW_ORIGINS", "*") allowed_origins = [item.strip() for item in allowed_origins_raw.split(",") if item.strip()] app.add_middleware( CORSMiddleware, allow_origins=allowed_origins or ["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ----------------------------------------------------------------------------- # Schémas Pydantic # ----------------------------------------------------------------------------- class EncodeRequest(BaseModel): texts: List[str] = Field(..., description="Liste de textes à encoder") normalize: bool = True class EncodeResponse(BaseModel): embeddings: List[List[float]] dim: int elapsed_ms: float class BuildIndexRequest(BaseModel): documents: List[str] metadata: Optional[List[Dict[str, Any]]] = None save_path: Optional[str] = Field(None, description="Si fourni, sauvegarde l'index FAISS") class LoadIndexRequest(BaseModel): path: str class SearchRequest(BaseModel): query: str top_k: int = Field(5, ge=1, le=100) class SearchResult(BaseModel): score: float text: str meta: Dict[str, Any] doc_id: int class SearchResponse(BaseModel): query: str results: List[SearchResult] elapsed_ms: float class RAGRequest(BaseModel): query: str top_k: int = Field(3, ge=1, le=20) max_context_chars: int = Field(2000, ge=100, le=20000) generate_answer: bool = True max_new_tokens: int = Field(256, ge=1, le=2048) temperature: float = Field(0.4, ge=0.0, le=2.0) top_p: float = Field(0.9, ge=0.01, le=1.0) class RAGResponse(BaseModel): query: str context: str sources: List[SearchResult] prompt_template: str answer: Optional[str] = None class GenerateRequest(BaseModel): prompt: str system: Optional[str] = "Tu es un assistant utile qui répond clairement en français." max_new_tokens: int = Field(128, ge=1, le=2048) temperature: float = Field(0.7, ge=0.0, le=2.0) top_p: float = Field(0.9, ge=0.01, le=1.0) class GenerateResponse(BaseModel): prompt: str answer: str elapsed_ms: float model: str model_path: str class SimilarityRequest(BaseModel): text_a: str text_b: str class SimilarityResponse(BaseModel): cosine_similarity: float class ChatMessage(BaseModel): role: str content: str class OpenAIChatCompletionRequest(BaseModel): model: Optional[str] = TINYLLAMA_REPO_ID messages: List[ChatMessage] max_tokens: int = Field(256, ge=1, le=2048) temperature: float = Field(0.7, ge=0.0, le=2.0) top_p: float = Field(0.9, ge=0.01, le=1.0) stream: bool = False stop: Optional[Any] = None class CompletionRequest(BaseModel): model: Optional[str] = TINYLLAMA_REPO_ID prompt: str max_tokens: int = Field(128, ge=1, le=2048) temperature: float = Field(0.7, ge=0.0, le=2.0) top_p: float = Field(0.9, ge=0.01, le=1.0) stream: bool = False stop: Optional[Any] = None class CompletionChoice(BaseModel): text: str index: int logprobs: Optional[Any] = None finish_reason: Optional[str] = None class CompletionResponse(BaseModel): id: str object: str created: int model: str choices: List[CompletionChoice] usage: Dict[str, Optional[Any]] class BatchGenerateRequest(BaseModel): prompts: List[str] system: Optional[str] = "Tu es un assistant utile qui répond clairement en français." max_new_tokens: int = Field(128, ge=1, le=2048) temperature: float = Field(0.7, ge=0.0, le=2.0) top_p: float = Field(0.9, ge=0.01, le=1.0) class BatchGenerateResponse(BaseModel): answers: List[str] elapsed_ms: float model: str class TokenCountRequest(BaseModel): text: str class TokenCountResponse(BaseModel): tokens: int model: str class SaveIndexRequest(BaseModel): path: str class AddDocumentsRequest(BaseModel): documents: List[str] metadata: Optional[List[Dict[str, Any]]] = None save_path: Optional[str] = None class CorpusSaveRequest(BaseModel): path: str class CorpusLoadRequest(BaseModel): path: str rebuild_index: bool = True class AskRequest(BaseModel): query: str top_k: int = Field(3, ge=1, le=20) max_context_chars: int = Field(2000, ge=100, le=20000) max_new_tokens: int = Field(256, ge=1, le=2048) temperature: float = Field(0.4, ge=0.0, le=2.0) top_p: float = Field(0.9, ge=0.01, le=1.0) class OpenAIEmbeddingsRequest(BaseModel): model: Optional[str] = "rag-encoder-local" input: Any encoding_format: Optional[str] = "float" # ----------------------------------------------------------------------------- # Helpers # ----------------------------------------------------------------------------- def _check_service() -> RAGService: """Retourne le service ou lève 503.""" if service is None: raise HTTPException(status_code=503, detail="Service non prêt, réessaie dans quelques secondes.") return service # ----------------------------------------------------------------------------- # Endpoints # ----------------------------------------------------------------------------- @app.get("/") def root(): svc = _check_service() base = svc.public_url or "http://127.0.0.1:8080" return { "status": "ok", "name": "SecureRAG FR + TinyLlama Local API", "version": "3.2.0", "device": str(svc.device), "public_url": svc.public_url, "local_test_url": "http://127.0.0.1:8080", "llm_model": svc.llm_model_id, "llm_path": svc.llm_path, "llm_loaded": svc.llm_model is not None, "encoder_loaded": svc.encoder_loaded, "embedding_dim": svc.cfg.embedding_dim if svc.encoder_loaded else None, "index_loaded": svc.index is not None, "index_size": svc.index.ntotal if svc.index is not None else 0, "endpoints": { "health": f"{base}/health", "generate": f"{base}/generate", "generate_batch": f"{base}/generate/batch", "ask": f"{base}/ask", "tokens_count": f"{base}/tokens/count", "encode": f"{base}/encode", "similarity": f"{base}/similarity", "index_build": f"{base}/index/build", "index_load": f"{base}/index/load", "index_save": f"{base}/index/save", "index_add": f"{base}/index/add", "index_status": f"{base}/index/status", "index_clear": f"{base}/index/clear", "search": f"{base}/search", "rag": f"{base}/rag", "documents_list": f"{base}/documents", "document_get": f"{base}/documents/{{doc_id}}", "document_delete": f"{base}/documents/{{doc_id}}", "corpus_save": f"{base}/corpus/save", "corpus_load": f"{base}/corpus/load", "openai_chat": f"{base}/v1/chat/completions", "openai_completions": f"{base}/v1/completions", "openai_embeddings": f"{base}/v1/embeddings", "models": f"{base}/v1/models", "client_snippets": f"{base}/client/snippets", "model_info": f"{base}/model/info", "files_list": f"{base}/files/list", }, } @app.get("/health") def health(): svc = _check_service() return { "status": "ok", "gpu": torch.cuda.is_available(), "device": str(svc.device), "llm_loaded": svc.llm_model is not None, "encoder_loaded": svc.encoder_loaded, } @app.get("/client/snippets") def client_snippets(): svc = _check_service() base = svc.public_url or "https://TON-ID-NOTEBOOK.notebook.gra.ai.cloud.ovh.net" return { "curl_generate": ( f"curl -X POST '{base}/generate' -H 'Content-Type: application/json' " "-d '{\"prompt\":\"Explique TinyLlama en français.\",\"max_new_tokens\":128}'" ), "python_requests": ( "import requests\n" f"BASE_URL = '{base}'\n" "r = requests.post(BASE_URL + '/generate', json={\n" " 'prompt': 'Explique RAG simplement en français.',\n" " 'max_new_tokens': 128\n" "})\n" "print(r.json()['answer'])" ), "openai_compatible_python": ( "from openai import OpenAI\n" f"client = OpenAI(base_url='{base}/v1', api_key='not-needed')\n" "resp = client.chat.completions.create(\n" " model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',\n" " messages=[{'role': 'user', 'content': 'Bonjour, réponds en français.'}],\n" " max_tokens=128,\n" ")\n" "print(resp.choices[0].message.content)" ), "openai_embeddings_python": ( "from openai import OpenAI\n" f"client = OpenAI(base_url='{base}/v1', api_key='not-needed')\n" "resp = client.embeddings.create(\n" " model='rag-encoder-local',\n" " input=['Bonjour le monde']\n" ")\n" "print(resp.data[0].embedding[:5])" ), } @app.get("/files/list") def files_list(): files_dir = Path(os.environ.get("RAG_MODEL_FILES_DIR", "./models")).expanduser().resolve() if not files_dir.exists(): raise HTTPException(status_code=404, detail=f"Dossier introuvable : {files_dir}") files = [str(p.relative_to(files_dir)) for p in sorted(files_dir.rglob("*")) if p.is_file()] return {"model_files_dir": str(files_dir), "files": files} @app.get("/model/info") def model_info(): svc = _check_service() return { "llm_model": svc.llm_model_id, "llm_path": svc.llm_path, "device": str(svc.device), "encoder_loaded": svc.encoder_loaded, "index": svc.index_status(), "public_url": svc.public_url, } @app.post("/tokens/count", response_model=TokenCountResponse) def count_tokens(req: TokenCountRequest): svc = _check_service() return TokenCountResponse(tokens=svc.count_llm_tokens(req.text), model=svc.llm_model_id) @app.post("/generate/batch", response_model=BatchGenerateResponse) def generate_batch(req: BatchGenerateRequest): svc = _check_service() if not req.prompts: raise HTTPException(status_code=400, detail="prompts ne peut pas être vide") t0 = time.time() answers = svc.generate_batch( prompts=req.prompts, system=req.system, max_new_tokens=req.max_new_tokens, temperature=req.temperature, top_p=req.top_p, ) return BatchGenerateResponse(answers=answers, elapsed_ms=(time.time() - t0) * 1000.0, model=svc.llm_model_id) @app.get("/index/status") def index_status(): svc = _check_service() return svc.index_status() @app.delete("/index/clear") def index_clear(): svc = _check_service() svc.clear_index() return {"status": "ok", "message": "index vidé"} @app.post("/index/save") def index_save(req: SaveIndexRequest): svc = _check_service() if svc.index is None: raise HTTPException(status_code=400, detail="Aucun index à sauvegarder.") svc.save_index(req.path) return {"status": "ok", "saved_to": req.path} @app.post("/index/add") def index_add(req: AddDocumentsRequest): svc = _check_service() if not req.documents: raise HTTPException(status_code=400, detail="documents ne peut pas être vide") if req.metadata is not None and len(req.metadata) != len(req.documents): raise HTTPException(status_code=422, detail="metadata doit avoir la même taille que documents") t0 = time.time() svc.add_documents(req.documents, req.metadata) saved = None if req.save_path: svc.save_index(req.save_path) saved = req.save_path return { "status": "ok", "added": len(req.documents), "index": svc.index_status(), "saved_to": saved, "elapsed_ms": (time.time() - t0) * 1000.0, } @app.get("/documents") def documents_list(offset: int = 0, limit: int = 50): svc = _check_service() return svc.list_documents(offset=offset, limit=limit) @app.get("/documents/{doc_id}") def document_get(doc_id: int): svc = _check_service() return svc.get_document(doc_id) @app.delete("/documents/{doc_id}") def document_delete(doc_id: int): svc = _check_service() svc.delete_document(doc_id) return {"status": "ok", "deleted_doc_id": doc_id, "index": svc.index_status()} @app.post("/corpus/save") def corpus_save(req: CorpusSaveRequest): svc = _check_service() svc.save_corpus(req.path) return {"status": "ok", "saved_to": req.path, "n_documents": len(svc.corpus_texts)} @app.post("/corpus/load") def corpus_load(req: CorpusLoadRequest): svc = _check_service() svc.load_corpus(req.path, rebuild_index=req.rebuild_index) return {"status": "ok", "loaded_from": req.path, "index": svc.index_status()} @app.post("/ask") def ask(req: AskRequest): svc = _check_service() t0 = time.time() data = svc.ask( query=req.query, top_k=req.top_k, max_context_chars=req.max_context_chars, max_new_tokens=req.max_new_tokens, temperature=req.temperature, top_p=req.top_p, ) data["elapsed_ms"] = (time.time() - t0) * 1000.0 return data @app.get("/v1/models") def openai_models(): svc = _check_service() return { "object": "list", "data": [ { "id": svc.llm_model_id, "object": "model", "created": 0, "owned_by": "local", "path": svc.llm_path, } ], } @app.get("/v1/models/{model_id:path}") def openai_model_get(model_id: str): svc = _check_service() return { "id": model_id or svc.llm_model_id, "object": "model", "created": 0, "owned_by": "local", "path": svc.llm_path, } @app.post("/v1/embeddings") def openai_embeddings(req: OpenAIEmbeddingsRequest): svc = _check_service() if req.encoding_format not in (None, "float"): raise HTTPException(status_code=400, detail="encoding_format supporté uniquement : float") if isinstance(req.input, str): inputs = [req.input] elif isinstance(req.input, list): if not req.input: raise HTTPException(status_code=400, detail="input vide") if all(isinstance(x, str) for x in req.input): inputs = req.input else: raise HTTPException(status_code=400, detail="input doit être une chaîne ou une liste de chaînes") else: raise HTTPException(status_code=400, detail="input doit être une chaîne ou une liste de chaînes") embs = svc.encode(inputs, normalize=True) return { "object": "list", "model": req.model or "rag-encoder-local", "data": [ {"object": "embedding", "index": idx, "embedding": emb.tolist()} for idx, emb in enumerate(embs) ], "usage": {"prompt_tokens": None, "total_tokens": None}, } @app.post("/v1/chat/completions") def openai_chat_completions(req: OpenAIChatCompletionRequest): svc = _check_service() if req.stream: raise HTTPException(status_code=400, detail="stream=True non supporté dans cette version locale") if not req.messages: raise HTTPException(status_code=400, detail="messages vide") system_parts = [m.content for m in req.messages if m.role == "system"] user_parts = [f"{m.role}: {m.content}" for m in req.messages if m.role != "system"] prompt = "\n".join(user_parts).strip() system = "\n".join(system_parts).strip() or None t0 = time.time() answer = svc.generate( prompt=prompt, system=system, max_new_tokens=req.max_tokens, temperature=req.temperature, top_p=req.top_p, ) created = int(time.time()) return { "id": f"chatcmpl-local-{created}", "object": "chat.completion", "created": created, "model": svc.llm_model_id, "choices": [ { "index": 0, "message": {"role": "assistant", "content": answer}, "finish_reason": "stop", } ], "usage": { "prompt_tokens": None, "completion_tokens": None, "total_tokens": None, "elapsed_ms": (time.time() - t0) * 1000.0, }, } @app.post("/v1/completions", response_model=CompletionResponse) def openai_completions(req: CompletionRequest): svc = _check_service() if req.stream: raise HTTPException(status_code=400, detail="stream=True non supporté dans cette version locale") t0 = time.time() answer = svc.generate( prompt=req.prompt, system=None, max_new_tokens=req.max_tokens, temperature=req.temperature, top_p=req.top_p, ) created = int(time.time()) return CompletionResponse( id=f"cmpl-local-{created}", object="text_completion", created=created, model=svc.llm_model_id, choices=[CompletionChoice(text=answer, index=0, finish_reason="stop")], usage={ "prompt_tokens": None, "completion_tokens": None, "total_tokens": None, }, ) @app.post("/v1/engines/{engine}/completions", response_model=CompletionResponse) def openai_engine_completions(engine: str, req: CompletionRequest): svc = _check_service() if req.stream: raise HTTPException(status_code=400, detail="stream=True non supporté dans cette version locale") t0 = time.time() answer = svc.generate( prompt=req.prompt, system=None, max_new_tokens=req.max_tokens, temperature=req.temperature, top_p=req.top_p, ) created = int(time.time()) return CompletionResponse( id=f"cmpl-local-{created}", object="text_completion", created=created, model=svc.llm_model_id, choices=[CompletionChoice(text=answer, index=0, finish_reason="stop")], usage={ "prompt_tokens": None, "completion_tokens": None, "total_tokens": None, }, ) @app.post("/generate", response_model=GenerateResponse) def generate(req: GenerateRequest): svc = _check_service() t0 = time.time() answer = svc.generate( prompt=req.prompt, system=req.system, max_new_tokens=req.max_new_tokens, temperature=req.temperature, top_p=req.top_p, ) return GenerateResponse( prompt=req.prompt, answer=answer, elapsed_ms=(time.time() - t0) * 1000.0, model=svc.llm_model_id, model_path=svc.llm_path, ) @app.post("/encode", response_model=EncodeResponse) def encode(req: EncodeRequest): svc = _check_service() if not req.texts: raise HTTPException(status_code=400, detail="texts ne peut pas être vide") t0 = time.time() embs = svc.encode(req.texts, normalize=req.normalize) return EncodeResponse( embeddings=embs.tolist(), dim=int(embs.shape[1]), elapsed_ms=(time.time() - t0) * 1000.0, ) @app.post("/index/build") def index_build(req: BuildIndexRequest): svc = _check_service() if not req.documents: raise HTTPException(status_code=400, detail="documents ne peut pas être vide") if req.metadata is not None and len(req.metadata) != len(req.documents): raise HTTPException(status_code=422, detail="metadata doit avoir la même taille que documents") t0 = time.time() svc.build_index(req.documents, req.metadata) saved = None if req.save_path: svc.save_index(req.save_path) saved = req.save_path return { "status": "ok", "n_docs": len(req.documents), "elapsed_ms": (time.time() - t0) * 1000.0, "saved_to": saved, } @app.post("/index/load") def index_load(req: LoadIndexRequest): svc = _check_service() # Délègue la vérification d'existence au service (lève 404 proprement) svc.load_index(req.path) return {"status": "ok", "n_vectors": svc.index.ntotal} @app.post("/search", response_model=SearchResponse) def search(req: SearchRequest): svc = _check_service() t0 = time.time() results = svc.search(req.query, top_k=req.top_k) return SearchResponse( query=req.query, results=[SearchResult(**r) for r in results], elapsed_ms=(time.time() - t0) * 1000.0, ) @app.post("/rag", response_model=RAGResponse) def rag(req: RAGRequest): svc = _check_service() svc.require_index() results = svc.search(req.query, top_k=req.top_k) pieces, total = [], 0 for r in results: snippet = r["text"] remaining = req.max_context_chars - total if remaining <= 0: break if len(snippet) > remaining: snippet = snippet[:remaining] pieces.append(f"[doc#{r['doc_id']}] {snippet}") total += len(snippet) context = "\n\n".join(pieces) prompt = ( "Tu es un assistant qui répond UNIQUEMENT à partir du contexte fourni.\n" "Si la réponse n'est pas dans le contexte, dis : Information non disponible.\n\n" f"Contexte :\n{context}\n\n" f"Question : {req.query}\n" "Réponse :" ) answer = None if req.generate_answer: answer = svc.generate( prompt=prompt, system="Tu réponds en français, de manière concise et sourcée par le contexte.", max_new_tokens=req.max_new_tokens, temperature=req.temperature, top_p=req.top_p, ) return RAGResponse( query=req.query, context=context, sources=[SearchResult(**r) for r in results], prompt_template=prompt, answer=answer, ) @app.post("/similarity", response_model=SimilarityResponse) def similarity(req: SimilarityRequest): svc = _check_service() embs = svc.encode([req.text_a, req.text_b]) sim = float(np.dot(embs[0], embs[1])) return SimilarityResponse(cosine_similarity=sim) # ----------------------------------------------------------------------------- # Gestionnaire d'erreurs global — converti les RuntimeError résiduelles en 500 # avec un message lisible (et non une trace Python brute). # ----------------------------------------------------------------------------- @app.exception_handler(Exception) async def all_exception_handler(request: Request, exc: Exception): log.exception("Erreur non gérée sur %s", request.url.path) if isinstance(exc, HTTPException): return JSONResponse( status_code=exc.status_code, content={"detail": exc.detail, "path": str(request.url.path)}, ) return JSONResponse( status_code=500, content={"detail": str(exc), "path": str(request.url.path)}, ) # ============================================================================= # 4. ENTRY POINT # ============================================================================= def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", type=int, default=8080, help="OVH AI Notebooks expose le port 8080") parser.add_argument("--ckpt", default="") parser.add_argument("--tokenizer-dir", default="") parser.add_argument("--no-encoder", action="store_true", help="Désactive les endpoints encode/search/rag") parser.add_argument("--llm-model", default=TINYLLAMA_REPO_ID) parser.add_argument("--llm-dir", default=DEFAULT_LLM_DIR, help="Dossier local de TinyLlama") parser.add_argument("--download-llm", action="store_true", help="Télécharge TinyLlama si le dossier local est vide") parser.add_argument("--offline-llm", action="store_true", help="Force le chargement 100% local") parser.add_argument("--model-files-dir", default="./models", help="Dossier exposé par /files") parser.add_argument("--public-url", default="", help="URL OVH publique, sans slash final") parser.add_argument("--workers", type=int, default=1, help="Garde 1 avec GPU pour éviter de dupliquer la VRAM") args = parser.parse_args() # Chemins vides → encodeur désactivé proprement (pas de crash au démarrage) os.environ["RAG_CKPT"] = args.ckpt os.environ["RAG_TOKENIZER_DIR"] = args.tokenizer_dir os.environ["RAG_LOAD_ENCODER"] = "false" if args.no_encoder else "true" os.environ["RAG_LLM_MODEL"] = args.llm_model os.environ["RAG_LLM_DIR"] = args.llm_dir os.environ["RAG_LLM_DOWNLOAD"] = "true" if args.download_llm else os.environ.get("RAG_LLM_DOWNLOAD", "true") os.environ["RAG_LLM_OFFLINE"] = "true" if args.offline_llm else os.environ.get("RAG_LLM_OFFLINE", "false") os.environ["RAG_MODEL_FILES_DIR"] = args.model_files_dir if args.public_url: os.environ["OVH_PUBLIC_URL"] = clean_public_url(args.public_url) import uvicorn if args.workers != 1: log.warning("workers>1 peut dupliquer le modèle en VRAM. Recommandé : --workers 1") uvicorn.run(app, host=args.host, port=args.port, workers=args.workers, log_level="info") if __name__ == "__main__": main()