duniele commited on
Commit
2d4b79a
·
verified ·
1 Parent(s): d4a35c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -46
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import os
2
  import sys
3
 
4
- # --- 1. SQLITE FIX FOR HUGGING FACE ---
5
- # ChromaDB requires a newer version of SQLite than what comes with Python.
6
- # This forces the system to use pysqlite3-binary.
7
  try:
8
  __import__('pysqlite3')
9
  sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
@@ -18,31 +17,27 @@ from langchain_chroma import Chroma
18
  from typing import Dict, Any, List
19
 
20
  # --- 2. SETUP & MODEL LOADING ---
21
- print("⏳ Loading Models...")
22
-
23
- # Initialize Embeddings (CPU is fine for this)
24
  embedding_function = HuggingFaceEmbeddings(
25
  model_name="nomic-ai/nomic-embed-text-v1.5",
26
  model_kwargs={"trust_remote_code": True, "device": "cpu"}
27
  )
28
 
29
- # Load Vector Database
30
- # CRITICAL FIX: We look for the file in the current directory (".")
31
- # because you uploaded 'chroma.sqlite3' directly, not inside a folder.
32
- if not os.path.exists("./chroma.sqlite3"):
33
- print("⚠️ Warning: chroma.sqlite3 not found. App may crash if DB is missing.")
34
 
35
  vector_db = Chroma(
36
- persist_directory=".",
37
  embedding_function=embedding_function
38
  )
39
 
40
- # Load LLM (TinyLlama)
41
  model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
42
  tokenizer = AutoTokenizer.from_pretrained(model_id)
43
  model = AutoModelForCausalLM.from_pretrained(model_id)
44
 
45
- # Create HF Pipeline
46
  pipe = pipeline(
47
  "text-generation",
48
  model=model,
@@ -55,7 +50,7 @@ pipe = pipeline(
55
 
56
  llm = HuggingFacePipeline(pipeline=pipe)
57
 
58
- # --- 3. DEFINE MANUAL QA CHAIN ---
59
  class ManualQAChain:
60
  def __init__(self, vector_store: Chroma, llm_pipeline: HuggingFacePipeline):
61
  self.retriever = vector_store.as_retriever(search_kwargs={"k": 2})
@@ -63,33 +58,33 @@ class ManualQAChain:
63
 
64
  def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
65
  query = inputs.get("query", "")
66
-
67
- # 1. RETRIEVAL
68
  docs = self.retriever.invoke(query)
69
- context = "\n\n".join([d.page_content for d in docs])
 
 
 
 
70
 
71
- # 2. PROMPT CREATION
72
- max_context_length = 2000
73
  prompt = f"""<|system|>
74
- You are a helpful and accurate medical assistant.
75
- Use ONLY the following context to answer the user's question.
76
- If the context does not contain the answer, say: "I cannot find the answer in the provided context."
77
 
78
  Context:
79
- {context[:max_context_length]}
80
  </s>
81
  <|user|>
82
  {query}
83
  </s>
84
  <|assistant|>
85
  """
86
- # 3. GENERATION
87
  response = self.llm.invoke(prompt)
88
-
89
- # Handle Output format
90
  text = response[0]['generated_text'] if isinstance(response, list) else str(response)
91
-
92
- # Clean output
93
  if "<|assistant|>" in text:
94
  final_answer = text.split("<|assistant|>")[-1].strip()
95
  else:
@@ -99,37 +94,30 @@ Context:
99
 
100
  # Initialize Chain
101
  qa_chain = ManualQAChain(vector_db, llm)
102
- print("✅ RAG Pipeline is ready.")
103
 
104
- # --- 4. GRADIO UI FUNCTION ---
105
  def medical_rag_chat(message, history):
106
- if not message:
107
- return "Please ask a medical question."
108
  try:
109
  response = qa_chain.invoke({"query": message})
110
- answer_text = response['result']
111
 
112
- # Format Sources
113
- sources_text = "\n\n---\n**Retrieved Context:**\n"
114
  if response.get('source_documents'):
115
  for i, doc in enumerate(response['source_documents']):
116
- topic = doc.metadata.get('focus_area', 'Medical Protocol')
117
- snippet = doc.page_content.replace('\n', ' ').strip()
118
- sources_text += f"**{i+1}. [{topic}]** *\"{snippet[:500]}...\"*\n"
119
  else:
120
- sources_text += "(No context found.)"
121
-
122
- return answer_text + sources_text
123
  except Exception as e:
124
- return f"⚠️ Error: {str(e)}"
125
 
126
- # --- 5. LAUNCH UI ---
127
  demo = gr.ChatInterface(
128
  fn=medical_rag_chat,
129
  title="Cardio-Oncology RAG Assistant",
130
  description="TinyLlama-1.1B + MedQuAD RAG",
131
- examples=["What are the symptoms of Lung Cancer?", "Who is at risk for Heart Failure?"],
132
- concurrency_limit=2
133
  )
134
 
135
  if __name__ == "__main__":
 
1
  import os
2
  import sys
3
 
4
+ # --- 1. SQLITE FIX FOR HUGGING FACE SPACES ---
5
+ # This ensures ChromaDB works on the cloud server
 
6
  try:
7
  __import__('pysqlite3')
8
  sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
 
17
  from typing import Dict, Any, List
18
 
19
  # --- 2. SETUP & MODEL LOADING ---
20
+ print("⏳ Loading Embeddings...")
 
 
21
  embedding_function = HuggingFaceEmbeddings(
22
  model_name="nomic-ai/nomic-embed-text-v1.5",
23
  model_kwargs={"trust_remote_code": True, "device": "cpu"}
24
  )
25
 
26
+ print("⏳ Loading Database...")
27
+ # FIX: Now we look for the FOLDER './chroma_db'
28
+ if not os.path.exists("./chroma_db"):
29
+ raise ValueError("❌ Error: 'chroma_db' folder not found! Did you run 'git push' correctly?")
 
30
 
31
  vector_db = Chroma(
32
+ persist_directory="./chroma_db",
33
  embedding_function=embedding_function
34
  )
35
 
36
+ print("⏳ Loading TinyLlama Model...")
37
  model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
38
  tokenizer = AutoTokenizer.from_pretrained(model_id)
39
  model = AutoModelForCausalLM.from_pretrained(model_id)
40
 
 
41
  pipe = pipeline(
42
  "text-generation",
43
  model=model,
 
50
 
51
  llm = HuggingFacePipeline(pipeline=pipe)
52
 
53
+ # --- 3. DEFINE RAG CHAIN ---
54
  class ManualQAChain:
55
  def __init__(self, vector_store: Chroma, llm_pipeline: HuggingFacePipeline):
56
  self.retriever = vector_store.as_retriever(search_kwargs={"k": 2})
 
58
 
59
  def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
60
  query = inputs.get("query", "")
61
+
62
+ # 1. Retrieval
63
  docs = self.retriever.invoke(query)
64
+
65
+ if docs:
66
+ context = "\n\n".join([d.page_content for d in docs])
67
+ else:
68
+ context = "No relevant medical context found."
69
 
70
+ # 2. Prompt
 
71
  prompt = f"""<|system|>
72
+ You are a helpful medical assistant. Use ONLY the context below.
73
+ If the answer is not in the context, say "I cannot find the answer."
 
74
 
75
  Context:
76
+ {context[:2000]}
77
  </s>
78
  <|user|>
79
  {query}
80
  </s>
81
  <|assistant|>
82
  """
83
+ # 3. Generation
84
  response = self.llm.invoke(prompt)
 
 
85
  text = response[0]['generated_text'] if isinstance(response, list) else str(response)
86
+
87
+ # Cleanup
88
  if "<|assistant|>" in text:
89
  final_answer = text.split("<|assistant|>")[-1].strip()
90
  else:
 
94
 
95
  # Initialize Chain
96
  qa_chain = ManualQAChain(vector_db, llm)
 
97
 
98
+ # --- 4. GRADIO UI ---
99
  def medical_rag_chat(message, history):
100
+ if not message: return "Please ask a question."
 
101
  try:
102
  response = qa_chain.invoke({"query": message})
103
+ sources = "\n\n---\n**Retrieved Context:**\n"
104
 
 
 
105
  if response.get('source_documents'):
106
  for i, doc in enumerate(response['source_documents']):
107
+ topic = doc.metadata.get('focus_area', 'Protocol')
108
+ sources += f"**{i+1}. [{topic}]** {doc.page_content[:300]}...\n"
 
109
  else:
110
+ sources += "(No context found)"
111
+
112
+ return response['result'] + sources
113
  except Exception as e:
114
+ return f"Error: {str(e)}"
115
 
 
116
  demo = gr.ChatInterface(
117
  fn=medical_rag_chat,
118
  title="Cardio-Oncology RAG Assistant",
119
  description="TinyLlama-1.1B + MedQuAD RAG",
120
+ examples=["What are the symptoms of Lung Cancer?", "Who is at risk for Heart Failure?"]
 
121
  )
122
 
123
  if __name__ == "__main__":