Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from langchain_community.embeddings.fastembed import FastEmbedEmbeddings | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_groq import ChatGroq | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.output_parsers import StrOutputParser | |
| from dotenv import load_dotenv | |
| from typing import List, Optional | |
| load_dotenv() | |
| # Define request body models | |
| class Message(BaseModel): | |
| response: str | |
| class QuestionRequest(BaseModel): | |
| question: str | |
| historique: Optional[List[Message]] = [] | |
| app = FastAPI() | |
| # Load Groq API key | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| # Vectorstore persistence directory | |
| persist_directory = "chroma_storage" | |
| # Embedding model | |
| embed_model = FastEmbedEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| # Load vectorstore | |
| vectorstore = Chroma( | |
| embedding_function=embed_model, | |
| persist_directory=persist_directory | |
| ) | |
| # Retriever setup | |
| retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
| # RAG prompt template | |
| rag_template = """ | |
| Tu es un assistant médical spécialisé conçu pour aider les prescripteurs dans les hôpitaux et dispensaires en t’appuyant exclusivement sur le *Guide clinique et thérapeutique* de Médecins Sans Frontières. | |
| 🩺règle tres important a respecter | |
| • si c'est une salutation tu peux aussi répondre de manière simple et polie aux salutations (comme "bonjour" etc.) avec un ton professionnel et bienveillant. sinon ignorer cette regle | |
| • si la question a une pour intension 'de te remercier'(merci merci beaucoup etc ...) repondre au remerciement de la façon la plus simple (comme de rien ravie de pouvoir vous aider etc) sans rien ajouter car c'est pour montrer qu'il est satisfait des reponse que tu lui a fournie sinon ignorer cette regle | |
| • Pour toute question médicale utilise uniquement les informations disponibles dans le contexte. Si tu ne trouves pas la réponse dans ce contexte dis simplement : "Je ne sais pas". | |
| • la reponse devra etre claire précise pertinente sans avoir des information inutiles | |
| Question de l'utilisateur : | |
| {question} | |
| Contexte : | |
| {context} | |
| historique : | |
| cette historique te sert de memoire contextuelle, en effet elle contient le messages précédents de l'utilisateur et la réponses que tu lui a fournie . Utilise ces informations pour comprendre le contexte de la question si la question n'est pas claire ou ambigue ou trop vague. | |
| ❗regle obligatoire: | |
| 🚫 tu doit obligatoirement ignorer l'historique si la question poser n'a aucun rapport avec le contenue de l'historique | |
| 🚫 si la reponse a la question figure déjà dans l'historique alors il faut repondre a nouveau a la question | |
| 🚫 si il y a des elements pertinents dans l'historique permetant de fournir une reponse plus pertinente l'utiliser | |
| 🚫 en aucun cas tu doit preciser que la question n'as pas de rapport avec l'historique | |
| 🚫 meme si tu donne des reponse en te basant sur l'historique en aucun cas tu doit le preciser | |
| {historique} | |
| """ | |
| rag_prompt = ChatPromptTemplate.from_template(rag_template) | |
| # Chat model | |
| chat_model = ChatGroq( | |
| temperature=0.3, | |
| model_name="llama-3.1-8b-instant", | |
| api_key=GROQ_API_KEY | |
| ) | |
| # RAG chain | |
| naive_rag_chain = ( | |
| {"context": retriever, "question": RunnablePassthrough(), "historique": RunnablePassthrough()} | |
| | rag_prompt | |
| | chat_model | |
| | StrOutputParser() | |
| ) | |
| async def ask_question(request: QuestionRequest): | |
| if not request.question: | |
| raise HTTPException(status_code=400, detail="Question must be provided") | |
| try: | |
| inputs = { | |
| "question": request.question, | |
| "historique": "\n".join([msg.response for msg in request.historique]) | |
| } | |
| answer = naive_rag_chain.invoke(inputs) | |
| return {"question": request.question, "response": answer} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |