File size: 4,570 Bytes
48dda83
96a8cf0
b8bcd10
96a8cf0
b8bcd10
96a8cf0
 
 
 
 
 
48dda83
19ceb17
48dda83
 
 
 
19ceb17
b93d070
b8bcd10
 
b93d070
 
 
b8bcd10
 
 
 
b93d070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8bcd10
 
2d4b79a
48dda83
 
 
 
19ceb17
b93d070
96a8cf0
b93d070
96a8cf0
 
19ceb17
2d4b79a
48dda83
 
 
19ceb17
48dda83
 
 
 
 
 
 
 
 
19ceb17
48dda83
19ceb17
b8bcd10
48dda83
 
 
 
19ceb17
48dda83
 
2d4b79a
b8bcd10
48dda83
b8bcd10
48dda83
b8bcd10
48dda83
2d4b79a
 
48dda83
 
2d4b79a
48dda83
 
 
 
 
19ceb17
b8bcd10
48dda83
 
2d4b79a
48dda83
 
 
 
19ceb17
48dda83
 
b8bcd10
48dda83
 
b8bcd10
48dda83
2d4b79a
48dda83
 
2d4b79a
48dda83
 
 
2d4b79a
 
48dda83
2d4b79a
 
 
48dda83
2d4b79a
48dda83
 
 
 
 
2d4b79a
48dda83
19ceb17
 
48dda83
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import sys
import zipfile

# --- 1. SQLITE FIX ---
try:
    __import__('pysqlite3')
    sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
except ImportError:
    pass

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
from langchain_chroma import Chroma
from typing import Dict, Any, List

# --- 2. UNZIP & AUTO-DETECT PATH ---
print("⏳ Checking for Database...")

# Unzip if the zip exists
if os.path.exists("./chroma_db.zip"):
    print("πŸ“¦ Found zip file! Unzipping...")
    with zipfile.ZipFile("./chroma_db.zip", 'r') as zip_ref:
        zip_ref.extractall(".")
    print("βœ… Unzip complete.")

# SMART DETECTION: Find where the database went
db_path = ""

if os.path.exists("./chroma_db/chroma.sqlite3"):
    # Case A: It's inside the folder (Perfect)
    db_path = "./chroma_db"
    print(f"πŸ“‚ Found database in folder: {db_path}")

elif os.path.exists("./chroma.sqlite3"):
    # Case B: It spilled into the root directory
    db_path = "."
    print(f"πŸ“‚ Found database in root directory: {db_path}")

elif os.path.exists("./content/chroma_db/chroma.sqlite3"):
    # Case C: It's inside a 'content' folder (Common Colab issue)
    db_path = "./content/chroma_db"
    print(f"πŸ“‚ Found database in content folder: {db_path}")

else:
    # Case D: Panic
    # Let's list the files to debug
    print("❌ ERROR: Cannot find chroma.sqlite3. Current files in folder:")
    print(os.listdir("."))
    raise ValueError("Could not find the database file after unzipping!")

# --- 3. MODEL SETUP ---
print("⏳ Loading Embeddings...")
embedding_function = HuggingFaceEmbeddings(
    model_name="nomic-ai/nomic-embed-text-v1.5",
    model_kwargs={"trust_remote_code": True, "device": "cpu"}
)

print(f"⏳ Loading Database from {db_path}...")
vector_db = Chroma(
    persist_directory=db_path, 
    embedding_function=embedding_function
)

print("⏳ Loading TinyLlama Model...")
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=256,
    repetition_penalty=1.15,
    temperature=0.1,
    do_sample=True
)

llm = HuggingFacePipeline(pipeline=pipe)

# --- 4. RAG CHAIN ---
class ManualQAChain:
    def __init__(self, vector_store: Chroma, llm_pipeline: HuggingFacePipeline):
        self.retriever = vector_store.as_retriever(search_kwargs={"k": 2})
        self.llm = llm_pipeline

    def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
        query = inputs.get("query", "")
        
        # Retrieval
        docs = self.retriever.invoke(query)
        context = "\n\n".join([d.page_content for d in docs]) if docs else "No context found."

        # Prompt
        prompt = f"""<|system|>
You are a helpful medical assistant. Use ONLY the context below.
If the answer is not in the context, say "I cannot find the answer."

Context:
{context[:2000]}
</s>
<|user|>
{query}
</s>
<|assistant|>
"""
        # Generation
        response = self.llm.invoke(prompt)
        text = response[0]['generated_text'] if isinstance(response, list) else str(response)
        
        if "<|assistant|>" in text:
            final_answer = text.split("<|assistant|>")[-1].strip()
        else:
            final_answer = text.strip()

        return {"result": final_answer, "source_documents": docs}

# Initialize
qa_chain = ManualQAChain(vector_db, llm)

# --- 5. UI ---
def medical_rag_chat(message, history):
    if not message: return "Please ask a question."
    try:
        response = qa_chain.invoke({"query": message})
        sources = "\n\n---\n**Retrieved Context:**\n"
        
        if response.get('source_documents'):
            for i, doc in enumerate(response['source_documents']):
                topic = doc.metadata.get('focus_area', 'Protocol')
                sources += f"**{i+1}. [{topic}]** {doc.page_content[:300]}...\n"
        else:
            sources += "(No context found)"
            
        return response['result'] + sources
    except Exception as e:
        return f"Error: {str(e)}"

demo = gr.ChatInterface(
    fn=medical_rag_chat,
    title="Cardio-Oncology RAG Assistant",
    description="TinyLlama-1.1B + MedQuAD RAG",
    examples=["What are the symptoms of Lung Cancer?", "Who is at risk for Heart Failure?"]
)

if __name__ == "__main__":
    demo.launch()