Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| from dotenv import load_dotenv | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_classic.chains import RetrievalQA | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.documents import Document | |
| load_dotenv(override=True) | |
| # Configuration | |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
| if GOOGLE_API_KEY: | |
| print(f"RAG Model: API Key loaded (starting with {GOOGLE_API_KEY[:5]}...)") | |
| else: | |
| print("RAG Model: No API Key found in environment!") | |
| DATA_FILE = "preprocessed_docs.json" | |
| PERSIST_DIRECTORY = "chroma_db" | |
| def build_rag_system(): | |
| if not os.path.exists(DATA_FILE): | |
| print("Data file not found. Please run preprocess.py first.") | |
| return None | |
| with open(DATA_FILE, "r") as f: | |
| docs_data = json.load(f) | |
| print(f"Loading {len(docs_data)} documents...") | |
| # Prepare documents for LangChain | |
| documents = [Document(page_content=d["content"], metadata=d["metadata"]) for d in docs_data] | |
| # Split documents into chunks (optional, but good practice) | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| chunks = text_splitter.split_documents(documents) | |
| print(f"Created {len(chunks)} chunks.") | |
| # Initialize Embeddings | |
| print("Initializing embeddings (HuggingFace)...") | |
| embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
| # Initialize Vector Store | |
| print("Building vector store...") | |
| vectorstore = Chroma.from_documents( | |
| documents=chunks, | |
| embedding=embeddings, | |
| persist_directory=PERSIST_DIRECTORY | |
| ) | |
| vectorstore.persist() | |
| print("Vector store built and persisted.") | |
| return vectorstore | |
| def get_qa_chain(vectorstore): | |
| if not GOOGLE_API_KEY: | |
| print("Warning: GOOGLE_API_KEY not found. LLM functionality will not work.") | |
| return None | |
| llm = ChatGoogleGenerativeAI(model="gemini-flash-latest", google_api_key=GOOGLE_API_KEY) | |
| # Custom Prompt | |
| template = """You are a helpful and expert Retail Product Assistant. | |
| Context (Product Details & Reviews): | |
| {context} | |
| Rules: | |
| 1. If the user says "hi", "hello" or greets you, greet them back warmly and mention 2-3 popular products from the context to get started. | |
| 2. Use the provided context to answer specific questions. | |
| 3. If the answer is not in the context, politely say you don't have that specific information. | |
| 4. Maintain a professional yet friendly tone. | |
| 5. Always use Markdown for formatting (bolding, lists, etc.) to make it easy to read. | |
| 6. Use bullet points if listing features or pros/cons. | |
| 7. IMPORTANT: Convert all prices from USD to Indian Rupees (INR) using an approximate exchange rate of 1 USD = ₹83 INR. Always display the price in INR (₹). | |
| Question: {question} | |
| Answer:""" | |
| QA_CHAIN_PROMPT = PromptTemplate.from_template(template) | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=vectorstore.as_retriever(), | |
| return_source_documents=True, | |
| chain_type_kwargs={"prompt": QA_CHAIN_PROMPT} | |
| ) | |
| return qa_chain | |
| if __name__ == "__main__": | |
| # Check if vector store exists | |
| if os.path.exists(PERSIST_DIRECTORY): | |
| print("Vector store already exists. Loading...") | |
| embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
| vectorstore = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embeddings) | |
| else: | |
| vectorstore = build_rag_system() | |
| if vectorstore: | |
| qa_chain = get_qa_chain(vectorstore) | |
| if qa_chain: | |
| query = "Which Kindle model has the best resolution according to reviews?" | |
| print(f"\nQuestion: {query}") | |
| result = qa_chain({"query": query}) | |
| print(f"\nAnswer: {result['result']}") | |
| print("\nSources:") | |
| for doc in result["source_documents"][:2]: | |
| print(f"- {doc.metadata['name']} (Rating: {doc.metadata['rating']})") | |