mon-workspace / rag_api_server_local_tinyllama.py
Medyassino's picture
Add files using upload-large-folder tool
6cc01b6 verified
"""
==============================================================================
API FastAPI RAG/NLP + TinyLlama local pour OVH AI Notebooks
==============================================================================
Objectifs de cette version corrigée :
1) Télécharger TinyLlama/TinyLlama-1.1B-Chat-v1.0 dans un dossier local.
2) Charger TinyLlama depuis ce dossier local pour éviter de re-télécharger.
3) Exposer l'API sur le port 8080, compatible avec l'URL publique OVH AI Notebooks.
4) Permettre la consommation depuis ton PC local via l'URL OVH publique.
5) Garder les endpoints RAG existants, mais rendre l'encodeur RAG optionnel.
Installation minimale :
pip install -U fastapi uvicorn transformers torch huggingface_hub pydantic numpy
pip install -U accelerate safetensors sentencepiece protobuf
pip install -U faiss-cpu # seulement si tu utilises /index/build ou /search
Premier lancement sur OVH AI Notebook :
python rag_api_server_local_tinyllama.py \
--host 0.0.0.0 \
--port 8080 \
--download-llm \
--public-url "https://TON-ID-NOTEBOOK.notebook.gra.ai.cloud.ovh.net"
Ensuite, depuis ton PC local :
export OVH_RAG_URL="https://TON-ID-NOTEBOOK.notebook.gra.ai.cloud.ovh.net"
curl -X POST "$OVH_RAG_URL/generate" \
-H "Content-Type: application/json" \
-d '{"prompt":"Explique RAG simplement en français.","max_new_tokens":128}'
Notes OVH :
- Sur AI Notebooks, l'URL publique mappe le port 8080. Lance donc FastAPI sur 8080.
- Si ton notebook est en accès restreint, ajoute ton token dans l'en-tête Authorization
selon la configuration OVH.
==============================================================================
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
from transformers import AutoModelForCausalLM, AutoTokenizer
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
log = logging.getLogger("rag-tinyllama-api")
TINYLLAMA_REPO_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
DEFAULT_LLM_DIR = "./models/TinyLlama-1.1B-Chat-v1.0"
def env_bool(name: str, default: bool = False) -> bool:
value = os.environ.get(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "y", "on"}
def clean_public_url(url: str) -> str:
return url.strip().rstrip("/")
def ensure_tinyllama_local(model_id: str, local_dir: str, download: bool = True) -> str:
"""Return a local model path. Download from Hugging Face Hub if requested/missing."""
target = Path(local_dir).expanduser().resolve()
has_config = (target / "config.json").exists()
has_tokenizer = any((target / name).exists() for name in ["tokenizer.json", "tokenizer.model"])
if has_config and has_tokenizer:
log.info("TinyLlama déjà disponible localement : %s", target)
return str(target)
if not download:
raise RuntimeError(
f"TinyLlama introuvable dans {target}. Relance avec --download-llm ou mets "
f"RAG_LLM_DOWNLOAD=true."
)
log.info("Téléchargement TinyLlama depuis Hugging Face : %s -> %s", model_id, target)
try:
from huggingface_hub import snapshot_download
except ImportError as exc:
raise RuntimeError("Installe huggingface_hub : pip install -U huggingface_hub") from exc
target.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id=model_id,
local_dir=str(target),
local_dir_use_symlinks=False,
resume_download=True,
ignore_patterns=["*.md", ".gitattributes"],
)
log.info("Téléchargement terminé : %s", target)
return str(target)
# =============================================================================
# 1. ARCHITECTURE ENCODEUR RAG 20M - optionnelle
# =============================================================================
@dataclass
class ModelConfig:
vocab_size: int = 32005
hidden_size: int = 384
num_hidden_layers: int = 6
num_attention_heads: int = 6
intermediate_size: int = 1536
max_position_embeddings: int = 256
hidden_dropout_prob: float = 0.0
attention_probs_dropout_prob: float = 0.0
layer_norm_eps: float = 1e-12
embedding_dim: int = 384
use_layer_scale: bool = True
layer_scale_init: float = 1e-4
class TransformerEncoderBlock(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.num_heads = cfg.num_attention_heads
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.ln1 = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
self.qkv = nn.Linear(cfg.hidden_size, 3 * cfg.hidden_size, bias=True)
self.proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
self.ln2 = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
self.mlp = nn.Sequential(
nn.Linear(cfg.hidden_size, cfg.intermediate_size),
nn.GELU(),
nn.Linear(cfg.intermediate_size, cfg.hidden_size),
nn.Dropout(cfg.hidden_dropout_prob),
)
self.resid_drop = nn.Dropout(cfg.hidden_dropout_prob)
self.use_ls = cfg.use_layer_scale
if cfg.use_layer_scale:
self.gamma1 = nn.Parameter(cfg.layer_scale_init * torch.ones(cfg.hidden_size))
self.gamma2 = nn.Parameter(cfg.layer_scale_init * torch.ones(cfg.hidden_size))
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
batch, tokens, channels = x.shape
h = self.ln1(x)
qkv = h.new_empty(batch, tokens, 3, self.num_heads, self.head_dim)
qkv.copy_(self.qkv(h).view(batch, tokens, 3, self.num_heads, self.head_dim))
q, k, v = qkv.permute(2, 0, 3, 1, 4)
bool_mask = attn_mask[:, None, None, :].bool()
attn_out = F.scaled_dot_product_attention(
q, k, v, attn_mask=bool_mask, dropout_p=0.0, is_causal=False
)
attn_out = attn_out.transpose(1, 2).contiguous().view(batch, tokens, channels)
attn_out = self.resid_drop(self.proj(attn_out))
if self.use_ls:
attn_out = attn_out * self.gamma1
x = x + attn_out
mlp_out = self.mlp(self.ln2(x))
if self.use_ls:
mlp_out = mlp_out * self.gamma2
return x + mlp_out
class TextEncoder(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.hidden_size, padding_idx=0)
self.pos_emb = nn.Embedding(cfg.max_position_embeddings, cfg.hidden_size)
self.emb_ln = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
self.emb_drop = nn.Dropout(cfg.hidden_dropout_prob)
self.blocks = nn.ModuleList([TransformerEncoderBlock(cfg) for _ in range(cfg.num_hidden_layers)])
self.ln_f = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
self.proj_head = nn.Sequential(
nn.Linear(cfg.hidden_size, cfg.hidden_size),
nn.Tanh(),
nn.Linear(cfg.hidden_size, cfg.embedding_dim),
)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
batch, tokens = input_ids.shape
positions = torch.arange(tokens, device=input_ids.device).unsqueeze(0).expand(batch, tokens)
x = self.tok_emb(input_ids) + self.pos_emb(positions)
x = self.emb_drop(self.emb_ln(x))
for block in self.blocks:
x = block(x, attention_mask)
x = self.ln_f(x)
mask = attention_mask.unsqueeze(-1).float()
pooled = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-6)
emb = self.proj_head(pooled)
return F.normalize(emb, p=2, dim=-1)
# =============================================================================
# 2. SERVICE
# =============================================================================
class RAGService:
def __init__(
self,
ckpt_path: Optional[str],
tokenizer_dir: Optional[str],
max_len: int = 128,
llm_model_id: str = TINYLLAMA_REPO_ID,
llm_dir: str = DEFAULT_LLM_DIR,
download_llm: bool = True,
llm_offline: bool = False,
public_url: str = "",
load_encoder: bool = True,
):
self.ckpt_path = ckpt_path
self.tokenizer_dir = tokenizer_dir
self.max_len = max_len
self.llm_model_id = llm_model_id
self.llm_dir = llm_dir
self.download_llm = download_llm
self.llm_offline = llm_offline
self.public_url = clean_public_url(public_url) if public_url else ""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if self.device.type == "cuda":
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")
self.tokenizer = None
self.model = None
self.cfg = ModelConfig()
self.train_metrics: Dict[str, Any] = {}
self.train_epoch = -1
self.llm_tokenizer = None
self.llm_model = None
self.llm_path = ""
self.index = None
self.corpus_texts: List[str] = []
self.corpus_meta: List[Dict[str, Any]] = []
if load_encoder:
self._try_load_encoder()
self._load_llm()
@property
def encoder_loaded(self) -> bool:
return self.model is not None and self.tokenizer is not None
def _try_load_encoder(self) -> None:
if not self.ckpt_path or not self.tokenizer_dir:
log.warning("Encodeur RAG désactivé : ckpt/tokenizer-dir non fournis.")
return
ckpt_path = Path(self.ckpt_path)
tokenizer_dir = Path(self.tokenizer_dir)
if not ckpt_path.exists():
log.warning("Encodeur RAG désactivé : checkpoint introuvable : %s", ckpt_path)
return
if not tokenizer_dir.exists() or not any(tokenizer_dir.glob("tokenizer*")):
log.warning("Encodeur RAG désactivé : tokenizer introuvable : %s", tokenizer_dir)
return
log.info("Chargement tokenizer RAG depuis %s", tokenizer_dir)
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_dir))
log.info("Chargement checkpoint RAG %s", ckpt_path)
ckpt = torch.load(str(ckpt_path), map_location=self.device)
saved_cfg = ckpt.get("config", {})
self.cfg = ModelConfig(
vocab_size=saved_cfg.get("vocab_size", self.tokenizer.vocab_size),
hidden_size=saved_cfg.get("hidden_size", 384),
num_hidden_layers=saved_cfg.get("num_hidden_layers", 6),
num_attention_heads=saved_cfg.get("num_attention_heads", 6),
intermediate_size=saved_cfg.get("intermediate_size", 1536),
max_position_embeddings=saved_cfg.get("max_position_embeddings", 256),
embedding_dim=saved_cfg.get("embedding_dim", 384),
use_layer_scale=saved_cfg.get("use_layer_scale", True),
layer_scale_init=saved_cfg.get("layer_scale_init", 1e-4),
)
self.model = TextEncoder(self.cfg).to(self.device)
self.model.load_state_dict(ckpt["model_state"], strict=False)
self.model.eval()
self.train_metrics = ckpt.get("metrics", {})
self.train_epoch = ckpt.get("epoch", -1)
log.info("Encodeur RAG chargé sur %s, dim=%s", self.device, self.cfg.embedding_dim)
def _load_llm(self) -> None:
self.llm_path = ensure_tinyllama_local(
model_id=self.llm_model_id,
local_dir=self.llm_dir,
download=self.download_llm and not self.llm_offline,
)
kwargs = {
"trust_remote_code": True,
"local_files_only": self.llm_offline,
}
log.info("Chargement tokenizer TinyLlama depuis %s", self.llm_path)
self.llm_tokenizer = AutoTokenizer.from_pretrained(self.llm_path, **kwargs)
if self.llm_tokenizer.pad_token_id is None:
self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
dtype = torch.bfloat16 if self.device.type == "cuda" else torch.float32
log.info("Chargement modèle TinyLlama depuis %s", self.llm_path)
try:
self.llm_model = AutoModelForCausalLM.from_pretrained(
self.llm_path,
torch_dtype=dtype,
device_map="auto" if self.device.type == "cuda" else None,
low_cpu_mem_usage=True,
**kwargs,
)
except Exception as exc:
log.warning("device_map='auto' indisponible ou erreur accelerate (%s). Bascule .to(device).", exc)
self.llm_model = AutoModelForCausalLM.from_pretrained(
self.llm_path,
torch_dtype=dtype,
low_cpu_mem_usage=True,
**kwargs,
).to(self.device)
self.llm_model.eval()
log.info("TinyLlama chargé : %s sur %s", self.llm_model_id, self.device)
def require_encoder(self) -> None:
"""Lève HTTPException 503 si l'encodeur RAG n'est pas chargé."""
if not self.encoder_loaded:
raise HTTPException(
status_code=503,
detail=(
"Encodeur RAG non chargé. Fournis --ckpt et --tokenizer-dir valides, "
"ou utilise seulement /generate et /v1/chat/completions."
),
)
def require_index(self) -> None:
"""Lève HTTPException 400 si aucun index FAISS n'est chargé en mémoire."""
if self.index is None:
raise HTTPException(
status_code=400,
detail="Aucun index chargé. Appelle POST /index/build ou POST /index/load d'abord.",
)
@torch.no_grad()
def generate(
self,
prompt: str,
max_new_tokens: int = 128,
temperature: float = 0.7,
top_p: float = 0.9,
system: Optional[str] = None,
) -> str:
if self.llm_model is None or self.llm_tokenizer is None:
raise HTTPException(status_code=503, detail="TinyLlama non chargé.")
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
if hasattr(self.llm_tokenizer, "apply_chat_template") and self.llm_tokenizer.chat_template:
input_ids = self.llm_tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
).to(self.llm_model.device)
attention_mask = torch.ones_like(input_ids)
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
else:
model_inputs = self.llm_tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=2048,
).to(self.llm_model.device)
do_sample = temperature > 0
generated = self.llm_model.generate(
**model_inputs,
max_new_tokens=max_new_tokens,
temperature=max(temperature, 1e-5),
top_p=top_p,
do_sample=do_sample,
pad_token_id=self.llm_tokenizer.pad_token_id,
eos_token_id=self.llm_tokenizer.eos_token_id,
no_repeat_ngram_size=2,
)
prompt_len = model_inputs["input_ids"].shape[-1]
answer_ids = generated[0][prompt_len:]
return self.llm_tokenizer.decode(answer_ids, skip_special_tokens=True).strip()
@torch.no_grad()
def encode(self, texts: List[str], batch_size: int = 64, normalize: bool = True) -> np.ndarray:
self.require_encoder()
if not texts:
return np.zeros((0, self.cfg.embedding_dim), dtype=np.float32)
embs = []
for i in range(0, len(texts), batch_size):
chunk = texts[i : i + batch_size]
enc = self.tokenizer(
chunk,
padding=True,
truncation=True,
max_length=self.max_len,
return_tensors="pt",
).to(self.device)
with torch.autocast(
device_type=self.device.type,
dtype=torch.bfloat16,
enabled=(self.device.type == "cuda"),
):
e = self.model(enc["input_ids"], enc["attention_mask"])
embs.append(e.float().cpu().numpy())
out = np.concatenate(embs, axis=0)
if not normalize:
return out.astype(np.float32)
norms = np.linalg.norm(out, axis=1, keepdims=True).clip(min=1e-12)
return (out / norms).astype(np.float32)
def build_index(self, corpus: List[str], metas: Optional[List[Dict[str, Any]]] = None) -> None:
self.require_encoder()
try:
import faiss
except ImportError as exc:
raise HTTPException(
status_code=501,
detail="faiss-cpu non installé. Fais : pip install faiss-cpu",
) from exc
log.info("Indexation de %s documents...", len(corpus))
embs = self.encode(corpus, batch_size=128)
index = faiss.IndexFlatIP(embs.shape[1])
index.add(embs)
self.index = index
self.corpus_texts = list(corpus)
self.corpus_meta = metas if metas else [{} for _ in corpus]
def save_index(self, path: str) -> None:
try:
import faiss
except ImportError as exc:
raise HTTPException(status_code=501, detail="faiss-cpu non installé.") from exc
if self.index is None:
raise HTTPException(status_code=400, detail="Aucun index à sauvegarder.")
faiss.write_index(self.index, path)
with open(path + ".meta.json", "w", encoding="utf-8") as f:
json.dump({"texts": self.corpus_texts, "meta": self.corpus_meta}, f, ensure_ascii=False)
def load_index(self, path: str) -> None:
try:
import faiss
except ImportError as exc:
raise HTTPException(status_code=501, detail="faiss-cpu non installé.") from exc
index_path = Path(path)
if not index_path.exists():
raise HTTPException(status_code=404, detail=f"Index introuvable : {path}")
self.index = faiss.read_index(str(index_path))
meta_path = str(index_path) + ".meta.json"
if os.path.exists(meta_path):
with open(meta_path, "r", encoding="utf-8") as f:
data = json.load(f)
self.corpus_texts = data.get("texts", [])
self.corpus_meta = data.get("meta", [])
log.info("Index chargé : %s vecteurs", self.index.ntotal)
def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
self.require_index()
q_emb = self.encode([query])
scores, ids = self.index.search(q_emb, top_k)
results = []
for score, idx in zip(scores[0].tolist(), ids[0].tolist()):
if idx < 0 or idx >= len(self.corpus_texts):
continue
results.append(
{
"score": float(score),
"text": self.corpus_texts[idx],
"meta": self.corpus_meta[idx] if idx < len(self.corpus_meta) else {},
"doc_id": idx,
}
)
return results
def index_status(self) -> Dict[str, Any]:
return {
"loaded": self.index is not None,
"n_vectors": int(self.index.ntotal) if self.index is not None else 0,
"n_documents": len(self.corpus_texts),
"has_metadata": bool(self.corpus_meta),
"embedding_dim": self.cfg.embedding_dim if self.encoder_loaded else None,
"encoder_loaded": self.encoder_loaded,
}
def clear_index(self) -> None:
self.index = None
self.corpus_texts = []
self.corpus_meta = []
def add_documents(self, documents: List[str], metas: Optional[List[Dict[str, Any]]] = None) -> None:
self.require_encoder()
if not documents:
raise HTTPException(status_code=400, detail="documents vide")
if metas is not None and len(metas) != len(documents):
raise HTTPException(status_code=400, detail="metadata doit avoir la même taille que documents")
new_texts = self.corpus_texts + list(documents)
new_meta = self.corpus_meta + (metas if metas is not None else [{} for _ in documents])
self.build_index(new_texts, new_meta)
def delete_document(self, doc_id: int) -> None:
self.require_encoder()
if doc_id < 0 or doc_id >= len(self.corpus_texts):
raise HTTPException(status_code=404, detail=f"Document introuvable : doc_id={doc_id}")
del self.corpus_texts[doc_id]
if doc_id < len(self.corpus_meta):
del self.corpus_meta[doc_id]
if self.corpus_texts:
self.build_index(self.corpus_texts, self.corpus_meta)
else:
self.clear_index()
def get_document(self, doc_id: int) -> Dict[str, Any]:
if doc_id < 0 or doc_id >= len(self.corpus_texts):
raise HTTPException(status_code=404, detail=f"Document introuvable : doc_id={doc_id}")
return {
"doc_id": doc_id,
"text": self.corpus_texts[doc_id],
"meta": self.corpus_meta[doc_id] if doc_id < len(self.corpus_meta) else {},
}
def list_documents(self, offset: int = 0, limit: int = 50) -> Dict[str, Any]:
offset = max(0, offset)
limit = max(1, min(limit, 500))
items = []
for idx in range(offset, min(offset + limit, len(self.corpus_texts))):
text = self.corpus_texts[idx]
items.append({
"doc_id": idx,
"text": text,
"preview": text[:300],
"chars": len(text),
"meta": self.corpus_meta[idx] if idx < len(self.corpus_meta) else {},
})
return {"offset": offset, "limit": limit, "total": len(self.corpus_texts), "documents": items}
def save_corpus(self, path: str) -> None:
target = Path(path).expanduser().resolve()
target.parent.mkdir(parents=True, exist_ok=True)
with open(target, "w", encoding="utf-8") as f:
json.dump({"texts": self.corpus_texts, "meta": self.corpus_meta}, f, ensure_ascii=False, indent=2)
def load_corpus(self, path: str, rebuild_index: bool = True) -> None:
source = Path(path).expanduser().resolve()
if not source.exists():
raise HTTPException(status_code=404, detail=f"Corpus introuvable : {source}")
with open(source, "r", encoding="utf-8") as f:
data = json.load(f)
texts = data.get("texts") or data.get("documents") or []
metas = data.get("meta") or data.get("metadata") or [{} for _ in texts]
if len(metas) != len(texts):
raise HTTPException(status_code=422, detail="Corpus invalide : texts/meta tailles différentes")
if rebuild_index:
self.build_index(texts, metas)
else:
self.corpus_texts = list(texts)
self.corpus_meta = list(metas)
self.index = None
def count_llm_tokens(self, text: str) -> int:
if self.llm_tokenizer is None:
raise HTTPException(status_code=503, detail="Tokenizer TinyLlama non chargé.")
return int(len(self.llm_tokenizer(text, add_special_tokens=False)["input_ids"]))
def generate_batch(
self,
prompts: List[str],
max_new_tokens: int = 128,
temperature: float = 0.7,
top_p: float = 0.9,
system: Optional[str] = None,
) -> List[str]:
if not prompts:
raise HTTPException(status_code=400, detail="prompts vide")
return [
self.generate(
prompt=p,
system=system,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
)
for p in prompts
]
def build_rag_prompt(self, query: str, top_k: int = 3, max_context_chars: int = 2000) -> Dict[str, Any]:
self.require_index()
results = self.search(query, top_k=top_k)
pieces, total = [], 0
for r in results:
snippet = r["text"]
remaining = max_context_chars - total
if remaining <= 0:
break
if len(snippet) > remaining:
snippet = snippet[:remaining]
pieces.append(f"[doc#{r['doc_id']}] {snippet}")
total += len(snippet)
context = "\n\n".join(pieces)
prompt = (
"Tu es un assistant qui répond UNIQUEMENT à partir du contexte fourni.\n"
"Si la réponse n'est pas dans le contexte, dis : Information non disponible.\n\n"
f"Contexte :\n{context}\n\n"
f"Question : {query}\n"
"Réponse :"
)
return {"context": context, "prompt": prompt, "sources": results}
def ask(
self,
query: str,
top_k: int = 3,
max_context_chars: int = 2000,
max_new_tokens: int = 256,
temperature: float = 0.4,
top_p: float = 0.9,
) -> Dict[str, Any]:
"""Répond avec RAG si un index est chargé, sinon génération directe TinyLlama."""
if self.index is not None:
rag_data = self.build_rag_prompt(query, top_k=top_k, max_context_chars=max_context_chars)
answer = self.generate(
prompt=rag_data["prompt"],
system="Tu réponds en français, de manière concise et sourcée par le contexte.",
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
)
return {"mode": "rag", "answer": answer, **rag_data}
answer = self.generate(
prompt=query,
system="Tu es un assistant utile qui répond clairement en français.",
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
)
return {"mode": "llm", "answer": answer, "context": "", "prompt": query, "sources": []}
# =============================================================================
# 3. APP FASTAPI
# =============================================================================
service: Optional[RAGService] = None
def mount_model_files(app: FastAPI, model_files_dir: str) -> None:
path = Path(model_files_dir).expanduser().resolve()
path.mkdir(parents=True, exist_ok=True)
app.mount("/files", StaticFiles(directory=str(path)), name="model_files")
log.info("Dossier /files monté depuis %s", path)
@asynccontextmanager
async def lifespan(app: FastAPI):
global service
ckpt = os.environ.get("RAG_CKPT", "")
tok_dir = os.environ.get("RAG_TOKENIZER_DIR", "")
llm_model_id = os.environ.get("RAG_LLM_MODEL", TINYLLAMA_REPO_ID)
llm_dir = os.environ.get("RAG_LLM_DIR", DEFAULT_LLM_DIR)
model_files_dir = os.environ.get("RAG_MODEL_FILES_DIR", "./models")
public_url = os.environ.get("OVH_PUBLIC_URL", "")
download_llm = env_bool("RAG_LLM_DOWNLOAD", True)
offline_llm = env_bool("RAG_LLM_OFFLINE", False)
load_encoder = env_bool("RAG_LOAD_ENCODER", True)
mount_model_files(app, model_files_dir)
service = RAGService(
ckpt_path=ckpt or None,
tokenizer_dir=tok_dir or None,
llm_model_id=llm_model_id,
llm_dir=llm_dir,
download_llm=download_llm,
llm_offline=offline_llm,
public_url=public_url,
load_encoder=load_encoder,
)
log.info("Service prêt. URL publique configurée : %s", public_url or "non fournie")
yield
log.info("Arrêt du service.")
app = FastAPI(
title="SecureRAG FR + TinyLlama Local API",
description="API RAG optionnelle + génération TinyLlama local pour OVH AI Notebooks",
version="3.2.0",
lifespan=lifespan,
)
allowed_origins_raw = os.environ.get("CORS_ALLOW_ORIGINS", "*")
allowed_origins = [item.strip() for item in allowed_origins_raw.split(",") if item.strip()]
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins or ["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# -----------------------------------------------------------------------------
# Schémas Pydantic
# -----------------------------------------------------------------------------
class EncodeRequest(BaseModel):
texts: List[str] = Field(..., description="Liste de textes à encoder")
normalize: bool = True
class EncodeResponse(BaseModel):
embeddings: List[List[float]]
dim: int
elapsed_ms: float
class BuildIndexRequest(BaseModel):
documents: List[str]
metadata: Optional[List[Dict[str, Any]]] = None
save_path: Optional[str] = Field(None, description="Si fourni, sauvegarde l'index FAISS")
class LoadIndexRequest(BaseModel):
path: str
class SearchRequest(BaseModel):
query: str
top_k: int = Field(5, ge=1, le=100)
class SearchResult(BaseModel):
score: float
text: str
meta: Dict[str, Any]
doc_id: int
class SearchResponse(BaseModel):
query: str
results: List[SearchResult]
elapsed_ms: float
class RAGRequest(BaseModel):
query: str
top_k: int = Field(3, ge=1, le=20)
max_context_chars: int = Field(2000, ge=100, le=20000)
generate_answer: bool = True
max_new_tokens: int = Field(256, ge=1, le=2048)
temperature: float = Field(0.4, ge=0.0, le=2.0)
top_p: float = Field(0.9, ge=0.01, le=1.0)
class RAGResponse(BaseModel):
query: str
context: str
sources: List[SearchResult]
prompt_template: str
answer: Optional[str] = None
class GenerateRequest(BaseModel):
prompt: str
system: Optional[str] = "Tu es un assistant utile qui répond clairement en français."
max_new_tokens: int = Field(128, ge=1, le=2048)
temperature: float = Field(0.7, ge=0.0, le=2.0)
top_p: float = Field(0.9, ge=0.01, le=1.0)
class GenerateResponse(BaseModel):
prompt: str
answer: str
elapsed_ms: float
model: str
model_path: str
class SimilarityRequest(BaseModel):
text_a: str
text_b: str
class SimilarityResponse(BaseModel):
cosine_similarity: float
class ChatMessage(BaseModel):
role: str
content: str
class OpenAIChatCompletionRequest(BaseModel):
model: Optional[str] = TINYLLAMA_REPO_ID
messages: List[ChatMessage]
max_tokens: int = Field(256, ge=1, le=2048)
temperature: float = Field(0.7, ge=0.0, le=2.0)
top_p: float = Field(0.9, ge=0.01, le=1.0)
stream: bool = False
stop: Optional[Any] = None
class CompletionRequest(BaseModel):
model: Optional[str] = TINYLLAMA_REPO_ID
prompt: str
max_tokens: int = Field(128, ge=1, le=2048)
temperature: float = Field(0.7, ge=0.0, le=2.0)
top_p: float = Field(0.9, ge=0.01, le=1.0)
stream: bool = False
stop: Optional[Any] = None
class CompletionChoice(BaseModel):
text: str
index: int
logprobs: Optional[Any] = None
finish_reason: Optional[str] = None
class CompletionResponse(BaseModel):
id: str
object: str
created: int
model: str
choices: List[CompletionChoice]
usage: Dict[str, Optional[Any]]
class BatchGenerateRequest(BaseModel):
prompts: List[str]
system: Optional[str] = "Tu es un assistant utile qui répond clairement en français."
max_new_tokens: int = Field(128, ge=1, le=2048)
temperature: float = Field(0.7, ge=0.0, le=2.0)
top_p: float = Field(0.9, ge=0.01, le=1.0)
class BatchGenerateResponse(BaseModel):
answers: List[str]
elapsed_ms: float
model: str
class TokenCountRequest(BaseModel):
text: str
class TokenCountResponse(BaseModel):
tokens: int
model: str
class SaveIndexRequest(BaseModel):
path: str
class AddDocumentsRequest(BaseModel):
documents: List[str]
metadata: Optional[List[Dict[str, Any]]] = None
save_path: Optional[str] = None
class CorpusSaveRequest(BaseModel):
path: str
class CorpusLoadRequest(BaseModel):
path: str
rebuild_index: bool = True
class AskRequest(BaseModel):
query: str
top_k: int = Field(3, ge=1, le=20)
max_context_chars: int = Field(2000, ge=100, le=20000)
max_new_tokens: int = Field(256, ge=1, le=2048)
temperature: float = Field(0.4, ge=0.0, le=2.0)
top_p: float = Field(0.9, ge=0.01, le=1.0)
class OpenAIEmbeddingsRequest(BaseModel):
model: Optional[str] = "rag-encoder-local"
input: Any
encoding_format: Optional[str] = "float"
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
def _check_service() -> RAGService:
"""Retourne le service ou lève 503."""
if service is None:
raise HTTPException(status_code=503, detail="Service non prêt, réessaie dans quelques secondes.")
return service
# -----------------------------------------------------------------------------
# Endpoints
# -----------------------------------------------------------------------------
@app.get("/")
def root():
svc = _check_service()
base = svc.public_url or "http://127.0.0.1:8080"
return {
"status": "ok",
"name": "SecureRAG FR + TinyLlama Local API",
"version": "3.2.0",
"device": str(svc.device),
"public_url": svc.public_url,
"local_test_url": "http://127.0.0.1:8080",
"llm_model": svc.llm_model_id,
"llm_path": svc.llm_path,
"llm_loaded": svc.llm_model is not None,
"encoder_loaded": svc.encoder_loaded,
"embedding_dim": svc.cfg.embedding_dim if svc.encoder_loaded else None,
"index_loaded": svc.index is not None,
"index_size": svc.index.ntotal if svc.index is not None else 0,
"endpoints": {
"health": f"{base}/health",
"generate": f"{base}/generate",
"generate_batch": f"{base}/generate/batch",
"ask": f"{base}/ask",
"tokens_count": f"{base}/tokens/count",
"encode": f"{base}/encode",
"similarity": f"{base}/similarity",
"index_build": f"{base}/index/build",
"index_load": f"{base}/index/load",
"index_save": f"{base}/index/save",
"index_add": f"{base}/index/add",
"index_status": f"{base}/index/status",
"index_clear": f"{base}/index/clear",
"search": f"{base}/search",
"rag": f"{base}/rag",
"documents_list": f"{base}/documents",
"document_get": f"{base}/documents/{{doc_id}}",
"document_delete": f"{base}/documents/{{doc_id}}",
"corpus_save": f"{base}/corpus/save",
"corpus_load": f"{base}/corpus/load",
"openai_chat": f"{base}/v1/chat/completions",
"openai_completions": f"{base}/v1/completions",
"openai_embeddings": f"{base}/v1/embeddings",
"models": f"{base}/v1/models",
"client_snippets": f"{base}/client/snippets",
"model_info": f"{base}/model/info",
"files_list": f"{base}/files/list",
},
}
@app.get("/health")
def health():
svc = _check_service()
return {
"status": "ok",
"gpu": torch.cuda.is_available(),
"device": str(svc.device),
"llm_loaded": svc.llm_model is not None,
"encoder_loaded": svc.encoder_loaded,
}
@app.get("/client/snippets")
def client_snippets():
svc = _check_service()
base = svc.public_url or "https://TON-ID-NOTEBOOK.notebook.gra.ai.cloud.ovh.net"
return {
"curl_generate": (
f"curl -X POST '{base}/generate' -H 'Content-Type: application/json' "
"-d '{\"prompt\":\"Explique TinyLlama en français.\",\"max_new_tokens\":128}'"
),
"python_requests": (
"import requests\n"
f"BASE_URL = '{base}'\n"
"r = requests.post(BASE_URL + '/generate', json={\n"
" 'prompt': 'Explique RAG simplement en français.',\n"
" 'max_new_tokens': 128\n"
"})\n"
"print(r.json()['answer'])"
),
"openai_compatible_python": (
"from openai import OpenAI\n"
f"client = OpenAI(base_url='{base}/v1', api_key='not-needed')\n"
"resp = client.chat.completions.create(\n"
" model='TinyLlama/TinyLlama-1.1B-Chat-v1.0',\n"
" messages=[{'role': 'user', 'content': 'Bonjour, réponds en français.'}],\n"
" max_tokens=128,\n"
")\n"
"print(resp.choices[0].message.content)"
),
"openai_embeddings_python": (
"from openai import OpenAI\n"
f"client = OpenAI(base_url='{base}/v1', api_key='not-needed')\n"
"resp = client.embeddings.create(\n"
" model='rag-encoder-local',\n"
" input=['Bonjour le monde']\n"
")\n"
"print(resp.data[0].embedding[:5])"
),
}
@app.get("/files/list")
def files_list():
files_dir = Path(os.environ.get("RAG_MODEL_FILES_DIR", "./models")).expanduser().resolve()
if not files_dir.exists():
raise HTTPException(status_code=404, detail=f"Dossier introuvable : {files_dir}")
files = [str(p.relative_to(files_dir)) for p in sorted(files_dir.rglob("*")) if p.is_file()]
return {"model_files_dir": str(files_dir), "files": files}
@app.get("/model/info")
def model_info():
svc = _check_service()
return {
"llm_model": svc.llm_model_id,
"llm_path": svc.llm_path,
"device": str(svc.device),
"encoder_loaded": svc.encoder_loaded,
"index": svc.index_status(),
"public_url": svc.public_url,
}
@app.post("/tokens/count", response_model=TokenCountResponse)
def count_tokens(req: TokenCountRequest):
svc = _check_service()
return TokenCountResponse(tokens=svc.count_llm_tokens(req.text), model=svc.llm_model_id)
@app.post("/generate/batch", response_model=BatchGenerateResponse)
def generate_batch(req: BatchGenerateRequest):
svc = _check_service()
if not req.prompts:
raise HTTPException(status_code=400, detail="prompts ne peut pas être vide")
t0 = time.time()
answers = svc.generate_batch(
prompts=req.prompts,
system=req.system,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
)
return BatchGenerateResponse(answers=answers, elapsed_ms=(time.time() - t0) * 1000.0, model=svc.llm_model_id)
@app.get("/index/status")
def index_status():
svc = _check_service()
return svc.index_status()
@app.delete("/index/clear")
def index_clear():
svc = _check_service()
svc.clear_index()
return {"status": "ok", "message": "index vidé"}
@app.post("/index/save")
def index_save(req: SaveIndexRequest):
svc = _check_service()
if svc.index is None:
raise HTTPException(status_code=400, detail="Aucun index à sauvegarder.")
svc.save_index(req.path)
return {"status": "ok", "saved_to": req.path}
@app.post("/index/add")
def index_add(req: AddDocumentsRequest):
svc = _check_service()
if not req.documents:
raise HTTPException(status_code=400, detail="documents ne peut pas être vide")
if req.metadata is not None and len(req.metadata) != len(req.documents):
raise HTTPException(status_code=422, detail="metadata doit avoir la même taille que documents")
t0 = time.time()
svc.add_documents(req.documents, req.metadata)
saved = None
if req.save_path:
svc.save_index(req.save_path)
saved = req.save_path
return {
"status": "ok",
"added": len(req.documents),
"index": svc.index_status(),
"saved_to": saved,
"elapsed_ms": (time.time() - t0) * 1000.0,
}
@app.get("/documents")
def documents_list(offset: int = 0, limit: int = 50):
svc = _check_service()
return svc.list_documents(offset=offset, limit=limit)
@app.get("/documents/{doc_id}")
def document_get(doc_id: int):
svc = _check_service()
return svc.get_document(doc_id)
@app.delete("/documents/{doc_id}")
def document_delete(doc_id: int):
svc = _check_service()
svc.delete_document(doc_id)
return {"status": "ok", "deleted_doc_id": doc_id, "index": svc.index_status()}
@app.post("/corpus/save")
def corpus_save(req: CorpusSaveRequest):
svc = _check_service()
svc.save_corpus(req.path)
return {"status": "ok", "saved_to": req.path, "n_documents": len(svc.corpus_texts)}
@app.post("/corpus/load")
def corpus_load(req: CorpusLoadRequest):
svc = _check_service()
svc.load_corpus(req.path, rebuild_index=req.rebuild_index)
return {"status": "ok", "loaded_from": req.path, "index": svc.index_status()}
@app.post("/ask")
def ask(req: AskRequest):
svc = _check_service()
t0 = time.time()
data = svc.ask(
query=req.query,
top_k=req.top_k,
max_context_chars=req.max_context_chars,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
)
data["elapsed_ms"] = (time.time() - t0) * 1000.0
return data
@app.get("/v1/models")
def openai_models():
svc = _check_service()
return {
"object": "list",
"data": [
{
"id": svc.llm_model_id,
"object": "model",
"created": 0,
"owned_by": "local",
"path": svc.llm_path,
}
],
}
@app.get("/v1/models/{model_id:path}")
def openai_model_get(model_id: str):
svc = _check_service()
return {
"id": model_id or svc.llm_model_id,
"object": "model",
"created": 0,
"owned_by": "local",
"path": svc.llm_path,
}
@app.post("/v1/embeddings")
def openai_embeddings(req: OpenAIEmbeddingsRequest):
svc = _check_service()
if req.encoding_format not in (None, "float"):
raise HTTPException(status_code=400, detail="encoding_format supporté uniquement : float")
if isinstance(req.input, str):
inputs = [req.input]
elif isinstance(req.input, list):
if not req.input:
raise HTTPException(status_code=400, detail="input vide")
if all(isinstance(x, str) for x in req.input):
inputs = req.input
else:
raise HTTPException(status_code=400, detail="input doit être une chaîne ou une liste de chaînes")
else:
raise HTTPException(status_code=400, detail="input doit être une chaîne ou une liste de chaînes")
embs = svc.encode(inputs, normalize=True)
return {
"object": "list",
"model": req.model or "rag-encoder-local",
"data": [
{"object": "embedding", "index": idx, "embedding": emb.tolist()}
for idx, emb in enumerate(embs)
],
"usage": {"prompt_tokens": None, "total_tokens": None},
}
@app.post("/v1/chat/completions")
def openai_chat_completions(req: OpenAIChatCompletionRequest):
svc = _check_service()
if req.stream:
raise HTTPException(status_code=400, detail="stream=True non supporté dans cette version locale")
if not req.messages:
raise HTTPException(status_code=400, detail="messages vide")
system_parts = [m.content for m in req.messages if m.role == "system"]
user_parts = [f"{m.role}: {m.content}" for m in req.messages if m.role != "system"]
prompt = "\n".join(user_parts).strip()
system = "\n".join(system_parts).strip() or None
t0 = time.time()
answer = svc.generate(
prompt=prompt,
system=system,
max_new_tokens=req.max_tokens,
temperature=req.temperature,
top_p=req.top_p,
)
created = int(time.time())
return {
"id": f"chatcmpl-local-{created}",
"object": "chat.completion",
"created": created,
"model": svc.llm_model_id,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": answer},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
"elapsed_ms": (time.time() - t0) * 1000.0,
},
}
@app.post("/v1/completions", response_model=CompletionResponse)
def openai_completions(req: CompletionRequest):
svc = _check_service()
if req.stream:
raise HTTPException(status_code=400, detail="stream=True non supporté dans cette version locale")
t0 = time.time()
answer = svc.generate(
prompt=req.prompt,
system=None,
max_new_tokens=req.max_tokens,
temperature=req.temperature,
top_p=req.top_p,
)
created = int(time.time())
return CompletionResponse(
id=f"cmpl-local-{created}",
object="text_completion",
created=created,
model=svc.llm_model_id,
choices=[CompletionChoice(text=answer, index=0, finish_reason="stop")],
usage={
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
},
)
@app.post("/v1/engines/{engine}/completions", response_model=CompletionResponse)
def openai_engine_completions(engine: str, req: CompletionRequest):
svc = _check_service()
if req.stream:
raise HTTPException(status_code=400, detail="stream=True non supporté dans cette version locale")
t0 = time.time()
answer = svc.generate(
prompt=req.prompt,
system=None,
max_new_tokens=req.max_tokens,
temperature=req.temperature,
top_p=req.top_p,
)
created = int(time.time())
return CompletionResponse(
id=f"cmpl-local-{created}",
object="text_completion",
created=created,
model=svc.llm_model_id,
choices=[CompletionChoice(text=answer, index=0, finish_reason="stop")],
usage={
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
},
)
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
svc = _check_service()
t0 = time.time()
answer = svc.generate(
prompt=req.prompt,
system=req.system,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
)
return GenerateResponse(
prompt=req.prompt,
answer=answer,
elapsed_ms=(time.time() - t0) * 1000.0,
model=svc.llm_model_id,
model_path=svc.llm_path,
)
@app.post("/encode", response_model=EncodeResponse)
def encode(req: EncodeRequest):
svc = _check_service()
if not req.texts:
raise HTTPException(status_code=400, detail="texts ne peut pas être vide")
t0 = time.time()
embs = svc.encode(req.texts, normalize=req.normalize)
return EncodeResponse(
embeddings=embs.tolist(),
dim=int(embs.shape[1]),
elapsed_ms=(time.time() - t0) * 1000.0,
)
@app.post("/index/build")
def index_build(req: BuildIndexRequest):
svc = _check_service()
if not req.documents:
raise HTTPException(status_code=400, detail="documents ne peut pas être vide")
if req.metadata is not None and len(req.metadata) != len(req.documents):
raise HTTPException(status_code=422, detail="metadata doit avoir la même taille que documents")
t0 = time.time()
svc.build_index(req.documents, req.metadata)
saved = None
if req.save_path:
svc.save_index(req.save_path)
saved = req.save_path
return {
"status": "ok",
"n_docs": len(req.documents),
"elapsed_ms": (time.time() - t0) * 1000.0,
"saved_to": saved,
}
@app.post("/index/load")
def index_load(req: LoadIndexRequest):
svc = _check_service()
# Délègue la vérification d'existence au service (lève 404 proprement)
svc.load_index(req.path)
return {"status": "ok", "n_vectors": svc.index.ntotal}
@app.post("/search", response_model=SearchResponse)
def search(req: SearchRequest):
svc = _check_service()
t0 = time.time()
results = svc.search(req.query, top_k=req.top_k)
return SearchResponse(
query=req.query,
results=[SearchResult(**r) for r in results],
elapsed_ms=(time.time() - t0) * 1000.0,
)
@app.post("/rag", response_model=RAGResponse)
def rag(req: RAGRequest):
svc = _check_service()
svc.require_index()
results = svc.search(req.query, top_k=req.top_k)
pieces, total = [], 0
for r in results:
snippet = r["text"]
remaining = req.max_context_chars - total
if remaining <= 0:
break
if len(snippet) > remaining:
snippet = snippet[:remaining]
pieces.append(f"[doc#{r['doc_id']}] {snippet}")
total += len(snippet)
context = "\n\n".join(pieces)
prompt = (
"Tu es un assistant qui répond UNIQUEMENT à partir du contexte fourni.\n"
"Si la réponse n'est pas dans le contexte, dis : Information non disponible.\n\n"
f"Contexte :\n{context}\n\n"
f"Question : {req.query}\n"
"Réponse :"
)
answer = None
if req.generate_answer:
answer = svc.generate(
prompt=prompt,
system="Tu réponds en français, de manière concise et sourcée par le contexte.",
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
)
return RAGResponse(
query=req.query,
context=context,
sources=[SearchResult(**r) for r in results],
prompt_template=prompt,
answer=answer,
)
@app.post("/similarity", response_model=SimilarityResponse)
def similarity(req: SimilarityRequest):
svc = _check_service()
embs = svc.encode([req.text_a, req.text_b])
sim = float(np.dot(embs[0], embs[1]))
return SimilarityResponse(cosine_similarity=sim)
# -----------------------------------------------------------------------------
# Gestionnaire d'erreurs global — converti les RuntimeError résiduelles en 500
# avec un message lisible (et non une trace Python brute).
# -----------------------------------------------------------------------------
@app.exception_handler(Exception)
async def all_exception_handler(request: Request, exc: Exception):
log.exception("Erreur non gérée sur %s", request.url.path)
if isinstance(exc, HTTPException):
return JSONResponse(
status_code=exc.status_code,
content={"detail": exc.detail, "path": str(request.url.path)},
)
return JSONResponse(
status_code=500,
content={"detail": str(exc), "path": str(request.url.path)},
)
# =============================================================================
# 4. ENTRY POINT
# =============================================================================
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8080, help="OVH AI Notebooks expose le port 8080")
parser.add_argument("--ckpt", default="")
parser.add_argument("--tokenizer-dir", default="")
parser.add_argument("--no-encoder", action="store_true", help="Désactive les endpoints encode/search/rag")
parser.add_argument("--llm-model", default=TINYLLAMA_REPO_ID)
parser.add_argument("--llm-dir", default=DEFAULT_LLM_DIR, help="Dossier local de TinyLlama")
parser.add_argument("--download-llm", action="store_true", help="Télécharge TinyLlama si le dossier local est vide")
parser.add_argument("--offline-llm", action="store_true", help="Force le chargement 100% local")
parser.add_argument("--model-files-dir", default="./models", help="Dossier exposé par /files")
parser.add_argument("--public-url", default="", help="URL OVH publique, sans slash final")
parser.add_argument("--workers", type=int, default=1, help="Garde 1 avec GPU pour éviter de dupliquer la VRAM")
args = parser.parse_args()
# Chemins vides → encodeur désactivé proprement (pas de crash au démarrage)
os.environ["RAG_CKPT"] = args.ckpt
os.environ["RAG_TOKENIZER_DIR"] = args.tokenizer_dir
os.environ["RAG_LOAD_ENCODER"] = "false" if args.no_encoder else "true"
os.environ["RAG_LLM_MODEL"] = args.llm_model
os.environ["RAG_LLM_DIR"] = args.llm_dir
os.environ["RAG_LLM_DOWNLOAD"] = "true" if args.download_llm else os.environ.get("RAG_LLM_DOWNLOAD", "true")
os.environ["RAG_LLM_OFFLINE"] = "true" if args.offline_llm else os.environ.get("RAG_LLM_OFFLINE", "false")
os.environ["RAG_MODEL_FILES_DIR"] = args.model_files_dir
if args.public_url:
os.environ["OVH_PUBLIC_URL"] = clean_public_url(args.public_url)
import uvicorn
if args.workers != 1:
log.warning("workers>1 peut dupliquer le modèle en VRAM. Recommandé : --workers 1")
uvicorn.run(app, host=args.host, port=args.port, workers=args.workers, log_level="info")
if __name__ == "__main__":
main()