Spaces:
Sleeping
Sleeping
File size: 5,957 Bytes
b4470c3 26e1101 b4470c3 26e1101 b4470c3 26e1101 b4470c3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from prompt_engineering import build_prompt
# โโ รtat global โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
rag_chain = None
retriever = None
# NOTE: On Hugging Face Spaces, set MISTRAL_API_KEY in your Space's Settings > Secrets.
# Do NOT use python-dotenv or .env files on HF Spaces.
# โโ Helper format docs โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
def format_docs(docs) -> str:
"""Convertit les documents rรฉcupรฉrรฉs en texte pour le prompt."""
return "\n\n".join(
f"[Source: {os.path.basename(doc.metadata.get('source', 'Inconnue'))}]\n{doc.page_content}"
for doc in docs
)
def extract_sources(docs) -> list[str]:
"""Formate les sources depuis les mรฉtadonnรฉes des documents."""
sources = []
seen = set()
for doc in docs:
source = doc.metadata.get("source", "Inconnue")
page = doc.metadata.get("page")
label = (
f"{os.path.basename(source)}, page {page + 1}"
if page is not None
else os.path.basename(source)
)
if label not in seen:
sources.append(label)
seen.add(label)
return sources
def get_confidence(docs_with_scores: list) -> str:
"""Calcule le niveau de confiance selon les scores FAISS (distance L2)."""
if not docs_with_scores:
return "low"
avg_score = sum(s for _, s in docs_with_scores) / len(docs_with_scores)
if avg_score < 0.4:
return "high"
elif avg_score < 0.8:
return "medium"
return "low"
# โโ Chargement au dรฉmarrage โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
@asynccontextmanager
async def lifespan(app: FastAPI):
global rag_chain, retriever
embedding = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True}
)
vectorstore = FAISS.load_local(
"faiss_index",
embeddings=embedding,
allow_dangerous_deserialization=True
)
llm = ChatOpenAI(
base_url="https://api.mistral.ai/v1",
api_key=os.environ["MISTRAL_API_KEY"], # set in HF Spaces Secrets
model_name="mistral-medium"
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
rag_chain = (
{
"context": retriever | RunnableLambda(format_docs),
"question": RunnablePassthrough()
}
| build_prompt()
| llm
| StrOutputParser()
)
print("โ
RAG chain chargรฉe et prรชte.")
yield
print("๐ Arrรชt de l'API.")
# โโ Application FastAPI โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
app = FastAPI(
title="ShopVite RAG API",
version="1.0.0",
lifespan=lifespan
)
# โโ Schรฉmas โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
class AskRequest(BaseModel):
question: str
class AskResponse(BaseModel):
answer: str
sources: list[str]
confidence: str # "high" | "medium" | "low" | "out_of_context"
# โโ Routes โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
@app.get("/health")
def health():
if rag_chain is None:
raise HTTPException(status_code=503, detail="RAG chain non initialisรฉe.")
return {
"status": "ok",
"model": "mistral-medium",
"vectorstore": "faiss_index"
}
@app.post("/ask", response_model=AskResponse)
def ask(body: AskRequest):
question = body.question.strip()
if not question:
raise HTTPException(status_code=400, detail="La question ne peut pas รชtre vide.")
if len(question) > 500:
raise HTTPException(status_code=400, detail="Question trop longue (max 500 caractรจres).")
# Rรฉcupรฉrer les docs et leurs scores FAISS
docs_with_scores = retriever.vectorstore.similarity_search_with_score(question, k=3)
docs = [doc for doc, _ in docs_with_scores]
sources = extract_sources(docs)
# Gรฉnรฉrer la rรฉponse via la chaรฎne RAG
try:
answer = rag_chain.invoke(question)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erreur LLM : {str(e)}")
# Dรฉtecter question hors contexte
if "HORS_CONTEXTE" in answer:
return AskResponse(
answer=(
"Je suis dรฉsolรฉ, cette information ne figure pas dans mes documents. "
"Pour toute question spรฉcifique, contactez notre support : "
"support@shopvite.fr | 01 23 45 67 89 (lun-ven, 9h-18h)."
),
sources=[],
confidence="out_of_context"
)
confidence = get_confidence(docs_with_scores)
return AskResponse(
answer=answer,
sources=sources,
confidence=confidence
) |