import os import sys import zipfile # --- 1. SQLITE FIX --- try: __import__('pysqlite3') sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') except ImportError: pass import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings from langchain_chroma import Chroma from typing import Dict, Any, List # --- 2. UNZIP & AUTO-DETECT PATH --- print("⏳ Checking for Database...") # Unzip if the zip exists if os.path.exists("./chroma_db.zip"): print("📦 Found zip file! Unzipping...") with zipfile.ZipFile("./chroma_db.zip", 'r') as zip_ref: zip_ref.extractall(".") print("✅ Unzip complete.") # SMART DETECTION: Find where the database went db_path = "" if os.path.exists("./chroma_db/chroma.sqlite3"): # Case A: It's inside the folder (Perfect) db_path = "./chroma_db" print(f"📂 Found database in folder: {db_path}") elif os.path.exists("./chroma.sqlite3"): # Case B: It spilled into the root directory db_path = "." print(f"📂 Found database in root directory: {db_path}") elif os.path.exists("./content/chroma_db/chroma.sqlite3"): # Case C: It's inside a 'content' folder (Common Colab issue) db_path = "./content/chroma_db" print(f"📂 Found database in content folder: {db_path}") else: # Case D: Panic # Let's list the files to debug print("❌ ERROR: Cannot find chroma.sqlite3. Current files in folder:") print(os.listdir(".")) raise ValueError("Could not find the database file after unzipping!") # --- 3. MODEL SETUP --- print("⏳ Loading Embeddings...") embedding_function = HuggingFaceEmbeddings( model_name="nomic-ai/nomic-embed-text-v1.5", model_kwargs={"trust_remote_code": True, "device": "cpu"} ) print(f"⏳ Loading Database from {db_path}...") vector_db = Chroma( persist_directory=db_path, embedding_function=embedding_function ) print("⏳ Loading TinyLlama Model...") model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256, repetition_penalty=1.15, temperature=0.1, do_sample=True ) llm = HuggingFacePipeline(pipeline=pipe) # --- 4. RAG CHAIN --- class ManualQAChain: def __init__(self, vector_store: Chroma, llm_pipeline: HuggingFacePipeline): self.retriever = vector_store.as_retriever(search_kwargs={"k": 2}) self.llm = llm_pipeline def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]: query = inputs.get("query", "") # Retrieval docs = self.retriever.invoke(query) context = "\n\n".join([d.page_content for d in docs]) if docs else "No context found." # Prompt prompt = f"""<|system|> You are a helpful medical assistant. Use ONLY the context below. If the answer is not in the context, say "I cannot find the answer." Context: {context[:2000]} <|user|> {query} <|assistant|> """ # Generation response = self.llm.invoke(prompt) text = response[0]['generated_text'] if isinstance(response, list) else str(response) if "<|assistant|>" in text: final_answer = text.split("<|assistant|>")[-1].strip() else: final_answer = text.strip() return {"result": final_answer, "source_documents": docs} # Initialize qa_chain = ManualQAChain(vector_db, llm) # --- 5. UI --- def medical_rag_chat(message, history): if not message: return "Please ask a question." try: response = qa_chain.invoke({"query": message}) sources = "\n\n---\n**Retrieved Context:**\n" if response.get('source_documents'): for i, doc in enumerate(response['source_documents']): topic = doc.metadata.get('focus_area', 'Protocol') sources += f"**{i+1}. [{topic}]** {doc.page_content[:300]}...\n" else: sources += "(No context found)" return response['result'] + sources except Exception as e: return f"Error: {str(e)}" demo = gr.ChatInterface( fn=medical_rag_chat, title="Cardio-Oncology RAG Assistant", description="TinyLlama-1.1B + MedQuAD RAG", examples=["What are the symptoms of Lung Cancer?", "Who is at risk for Heart Failure?"] ) if __name__ == "__main__": demo.launch()