Spaces:
Running
Running
| from dotenv import load_dotenv | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_mistralai import ChatMistralAI | |
| from langchain_chroma import Chroma | |
| from langchain_core.messages import SystemMessage, HumanMessage, convert_to_messages | |
| from langchain_core.documents import Document | |
| from pydantic import BaseModel, Field | |
| from tenacity import retry, wait_exponential | |
| load_dotenv(override=True) | |
| # 2. Configuration for Groq | |
| DB_NAME = "db" | |
| RETRIEVAL_K = 20 | |
| RETRIEVAL_AFTER_RERANK_K = 10 | |
| chat_model = "mistral-large-latest" | |
| llm = ChatMistralAI(temperature=0, model_name=chat_model) | |
| # Embeddings (kept as HuggingFace per your snippet) | |
| embedding_model = "all-MiniLM-L6-v2" | |
| embeddings = HuggingFaceEmbeddings(model_name=embedding_model) | |
| vectorstore = Chroma( | |
| persist_directory=DB_NAME, | |
| embedding_function=embeddings, | |
| collection_name="documents", | |
| ) | |
| retriever = vectorstore.as_retriever() | |
| # Ensure GROQ_API_KEY is in your .env file | |
| reranker_model = "ministral-14b-latest" | |
| # reranker_model = "gpt-5-nano" | |
| class RankOrder(BaseModel): | |
| order: list[int] = Field( | |
| description="The order of relevance of documents, from most relevant to least relevant, by document id number" | |
| ) | |
| reranker_llm = ChatMistralAI( | |
| temperature=0, model_name=reranker_model | |
| ).with_structured_output(RankOrder) | |
| def rerank(question, docs): | |
| reranker_prompt = """ | |
| You are a document re-ranker. | |
| The documents are provided in the order they were retrieved; this should be approximately ordered by relevance, but you may be able to improve on that. | |
| You must rank order the provided documents by relevance to the list question, with the most relevant question first. | |
| Reply only with the list of ranked document ids, nothing else. Include all the document ids you are provided with, reranked. | |
| """ | |
| user_prompt = f"The user has asked the following question:\n\n{question}\n\nOrder all the documents by relevance to the question, from most relevant to least relevant. Include all the document ids you are provided with, reranked.\n\n" | |
| user_prompt += "Here are the documents:\n\n" | |
| for index, doc in enumerate(docs): | |
| user_prompt += f"# DOCUMENT ID: {index + 1}:\n\n{doc.page_content}\n\n" | |
| user_prompt += f"Reply only with the list of ranked document ids, nothing else. Return them as valid JSON matching this schema:{RankOrder.model_json_schema()}" | |
| response = reranker_llm.invoke( | |
| input=[ | |
| SystemMessage(content=reranker_prompt), | |
| HumanMessage(content=user_prompt), | |
| ] | |
| ) | |
| order = response.order | |
| return [docs[i - 1] for i in order] | |
| def fetch_context(question: str) -> list[Document]: | |
| """ | |
| Retrieve relevant context documents for a question. | |
| """ | |
| # Note: Chroma retriever doesn't accept 'k' in invoke directly, | |
| # it is usually set in as_retriever(search_kwargs={"k": k}) | |
| # But for compatibility with your snippet, we assume the retriever is configured correctly. | |
| # To be safe, we can configure k here if needed: | |
| retriever.search_kwargs["k"] = RETRIEVAL_K | |
| return retriever.invoke(question) | |
| def combined_question(question: str, history: list[dict] = []) -> str: | |
| """ | |
| Combine all the user's messages into a single string. | |
| """ | |
| prior = "\n".join(m["content"] for m in history if m["role"] == "user") | |
| return prior + "\n" + question | |
| def stream_answer_question( | |
| question: str, history: list[dict] = [], rerank_docs: bool = True | |
| ): | |
| """ | |
| Generator function that yields chunks of the answer. | |
| Returns: (chunk, docs) | |
| """ | |
| main_prompt = """ | |
| You are a knowledgeable, friendly assistant who answers questions about scientific documents. | |
| You provide clear, concise, and accurate answers based on the provided context. | |
| If relevant, use the given context to answer any question. | |
| If you don't know the answer, say so. | |
| Context: | |
| {context} | |
| """ | |
| combined = combined_question(question, history) | |
| docs = fetch_context(combined) | |
| if rerank_docs: | |
| docs = rerank(combined, docs) | |
| docs = docs[:RETRIEVAL_AFTER_RERANK_K] | |
| context = "\n\n".join(doc.page_content for doc in docs) | |
| system_prompt = main_prompt.format(context=context) | |
| messages = [SystemMessage(content=system_prompt)] | |
| messages.extend(convert_to_messages(history)) | |
| messages.append(HumanMessage(content=question)) | |
| # Stream the response | |
| for chunk in llm.stream(messages): | |
| yield chunk.content, docs | |