Spaces:
Running
Running
| import os | |
| from dotenv import load_dotenv | |
| from langchain_chroma import Chroma | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_groq import ChatGroq | |
| from sentence_transformers import CrossEncoder | |
| # Setup Configuration | |
| CHROMA_DB_DIR = "vectorstore" | |
| EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" | |
| LLM_MODEL = "llama-3.1-8b-instant" # Use a currently active Groq model | |
| def main(): | |
| load_dotenv() | |
| # 1. Initialize embeddings and reload the vector store | |
| print("Loading vector store & embedding model...") | |
| embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) | |
| vectorstore = Chroma(persist_directory=CHROMA_DB_DIR, embedding_function=embeddings) | |
| # 2. Setup the base retriever to get top k=5 chunks | |
| base_retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) | |
| # 3. Setup ReRanker for relevance ordering | |
| print("Initializing CrossEncoder ReRanker...") | |
| cross_encoder = CrossEncoder(RERANKER_MODEL) | |
| # 4. Craft strict RAG prompt | |
| template = """You are a factual assistant. Answer ONLY using the context below. | |
| If the answer isn't in the context, say "I don't know." | |
| Context: {context} | |
| Question: {question}""" | |
| prompt = PromptTemplate.from_template(template) | |
| # 5. Initialize the Groq LLM | |
| print("Initializing LLM via Groq...") | |
| if not os.environ.get("GROQ_API_KEY"): | |
| print("ERROR: GROQ_API_KEY not found in environment!") | |
| return | |
| llm = ChatGroq(model_name=LLM_MODEL, temperature=0) | |
| # The query workflow | |
| query = "What is the company policy for remote work?" | |
| print(f"\nQUERY: {query}\n") | |
| print("Retrieving and re-ranking documents...") | |
| initial_docs = base_retriever.invoke(query) | |
| # Apply CrossEncoder manually | |
| pairs = [[query, doc.page_content] for doc in initial_docs] | |
| scores = cross_encoder.predict(pairs) | |
| # Attach scores and sort | |
| for doc, score in zip(initial_docs, scores): | |
| doc.metadata['relevance_score'] = score | |
| # Sort docs by score descending and take top 3 | |
| initial_docs.sort(key=lambda d: d.metadata['relevance_score'], reverse=True) | |
| top_docs = initial_docs[:3] | |
| # Format the context text from the retrieved docs | |
| context_text = "\n\n".join([doc.page_content for doc in top_docs]) | |
| print("Generating response...") | |
| # Format prompt and call LLM | |
| chain = prompt | llm | |
| response = chain.invoke({"context": context_text, "question": query}) | |
| print("\n--- FINAL ANSWER ---") | |
| print(response.content) | |
| print("\n--- SOURCES ---") | |
| for idx, doc in enumerate(top_docs): | |
| print(f"\n[Source {idx+1}] Score: {doc.metadata.get('relevance_score'):.4f}") | |
| print(doc.page_content[:150] + "...") | |
| if __name__ == "__main__": | |
| main() | |