from langchain_openai import ChatOpenAI from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from src.retrieve import get_hybrid_retriever from src import config SYSTEM_PROMPT = """You are a Clinical Genomic Assistant. You MUST answer strictly using the provided context. RULES: 1. Do NOT use external knowledge. 2. If answer is not clearly present, say: "Insufficient evidence." 3. Always prioritize exact values (numbers, percentages, trial IDs, dates). 4. If multiple pieces of context exist, combine them carefully. 5. Be precise and clinical — no vague statements. When relevant, include: - Drug name - Mutation - Clinical outcome Context: {context} """ def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) def get_rag_chain(): retriever = get_hybrid_retriever() llm = ChatOpenAI(model=config.LLM_MODEL, temperature=0) prompt = ChatPromptTemplate.from_messages([ ("system", SYSTEM_PROMPT), ("user", "{question}") ]) def chain_with_source(query: str): docs = retriever.invoke(query) # LLM based relevance scoring scored_docs = [] for doc in docs: score_prompt = f""" Query: {query} Document: {doc.page_content} Score relevance from 1 to 10 (only number). """ score = llm.invoke(score_prompt).content.strip() try: score = int(score) except: score = 5 scored_docs.append((score, doc)) # Sort and keep top 5 docs = [doc for _, doc in sorted(scored_docs, key=lambda x: x[0], reverse=True)[:5]] context = format_docs(docs) answer = (prompt | llm | StrOutputParser()).invoke({ "context": context, "question": query }) context_data = [] for doc in docs: page = doc.metadata.get("page", "N/A") # Metadata key varies by loader- TextLoader uses 'source', PyMuPDFLoader uses 'file_path' source = doc.metadata.get("source", doc.metadata.get("file_path", "Unknown")) context_data.append({ "content": doc.page_content, "metadata": f"Source: {source} | Page: {page}" }) return {"answer": answer, "contexts": context_data} return chain_with_source