amiraghhh commited on
Commit
831deda
·
verified ·
1 Parent(s): 72533b2

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +59 -20
model.py CHANGED
@@ -1,6 +1,5 @@
1
  import chromadb
2
  import traceback
3
- import os
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
  from retriever import retrieve
6
  from utils import build_prompt, refine_response
@@ -12,24 +11,18 @@ from utils import build_prompt, refine_response
12
 
13
  _vector_store = None
14
  _finetuned_llm = None
 
15
 
16
  def get_vector_store():
17
  """Load vector store (lazy-loaded on first use)"""
18
  global _vector_store
19
-
20
  if _vector_store is None:
21
- # Create directory if it doesn't exist
22
- db_path = "./MedQuAD_db"
23
- os.makedirs(db_path, exist_ok=True)
24
-
25
- db_client = chromadb.PersistentClient(path=db_path)
26
-
27
  try:
28
  _vector_store = db_client.get_collection("medical_rag")
29
  except:
30
- # Collection doesn't exist yet - create it
31
  _vector_store = db_client.create_collection(name="medical_rag")
32
-
33
  return _vector_store
34
 
35
  def get_finetuned_llm():
@@ -49,8 +42,13 @@ def get_finetuned_llm():
49
  return _finetuned_llm
50
 
51
 
 
 
 
 
52
  def rag(user_query):
53
  """Main RAG function: retrieve context and generate answer.
 
54
  Returns: str(generated_answer)"""
55
 
56
  try:
@@ -58,22 +56,58 @@ def rag(user_query):
58
  vector_store = get_vector_store()
59
  finetuned_llm = get_finetuned_llm()
60
 
61
- # Check for emergency keywords
62
- emergency_keywords = ["emergency", "severe pain", "bleeding", "blind", "lose consciousness", "pass out"]
 
63
 
64
  if any(keyword in user_query.lower() for keyword in emergency_keywords):
65
  emergency_msg = """I am an AI and cannot provide medical advice for emergencies.
66
  PLEASE contact emergency services or a medical professional immediately."""
67
- return emergency_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- # Retrieve relevant contexts
70
  contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True)
71
 
72
  if not contexts:
73
  return "I'm not confident about my answer (0%).\n\nCouldn't find relevant information to answer your question."
74
 
75
- # Build prompt and generate answer
76
  prompt = build_prompt(user_query, contexts)
 
 
77
  result = finetuned_llm(
78
  prompt,
79
  max_new_tokens=70,
@@ -87,18 +121,23 @@ PLEASE contact emergency services or a medical professional immediately."""
87
  answer = result[0]['generated_text'].strip()
88
  answer = refine_response(answer)
89
 
90
- # Calculate confidence score
91
  if contexts and len(contexts) > 0:
92
  avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts)
93
  confidence_score = (1 - avg_distance) * 100
94
  confidence_score = max(0, min(100, confidence_score))
95
 
 
96
  if confidence_score < 40:
97
- return f"I'm not confident ({confidence_score:.1f}%).\n\n{answer}"
98
  else:
99
- return f"{answer}\n\n[Confidence: {confidence_score:.1f}%]"
 
 
100
 
101
- return answer
102
 
103
  except Exception as e:
104
- return f"ERROR: {str(e)}"
 
 
 
1
  import chromadb
2
  import traceback
 
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
4
  from retriever import retrieve
5
  from utils import build_prompt, refine_response
 
11
 
12
  _vector_store = None
13
  _finetuned_llm = None
14
+ _base_model = None
15
 
16
  def get_vector_store():
17
  """Load vector store (lazy-loaded on first use)"""
18
  global _vector_store
 
19
  if _vector_store is None:
20
+ db_client = chromadb.PersistentClient(path="./MedQuAD_db")
 
 
 
 
 
21
  try:
22
  _vector_store = db_client.get_collection("medical_rag")
23
  except:
24
+ # If collection doesn't exist, create it
25
  _vector_store = db_client.create_collection(name="medical_rag")
 
26
  return _vector_store
27
 
28
  def get_finetuned_llm():
 
42
  return _finetuned_llm
43
 
44
 
45
+ # ============================================================================
46
+ # MAIN RAG FUNCTION
47
+ # ============================================================================
48
+
49
  def rag(user_query):
50
  """Main RAG function: retrieve context and generate answer.
51
+ Takes a question string and returns an answer string with confidence.
52
  Returns: str(generated_answer)"""
53
 
54
  try:
 
56
  vector_store = get_vector_store()
57
  finetuned_llm = get_finetuned_llm()
58
 
59
+ # 1. Check for emergency keywords
60
+ emergency_keywords = ["emergency", "severe pain", "bleeding",
61
+ "blind", "lose consciousness", "pass out"]
62
 
63
  if any(keyword in user_query.lower() for keyword in emergency_keywords):
64
  emergency_msg = """I am an AI and cannot provide medical advice for emergencies.
65
  PLEASE contact emergency services or a medical professional immediately."""
66
+
67
+ try:
68
+ # Still generate answer for context
69
+ contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True)
70
+
71
+ if not contexts:
72
+ return f"{emergency_msg}\n\nNo relevant information found for your query."
73
+
74
+ prompt = build_prompt(user_query, contexts)
75
+ result = finetuned_llm(
76
+ prompt,
77
+ max_new_tokens=70,
78
+ num_beams=3,
79
+ early_stopping=True,
80
+ do_sample=False,
81
+ repetition_penalty=1.4,
82
+ eos_token_id=finetuned_llm.tokenizer.eos_token_id
83
+ )
84
+
85
+ answer = result[0]['generated_text'].strip()
86
+ answer = refine_response(answer)
87
+
88
+ # Calculate confidence
89
+ if contexts:
90
+ avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts)
91
+ confidence_score = (1 - avg_distance) * 100
92
+ confidence_score = max(0, min(100, confidence_score))
93
+ else:
94
+ confidence_score = 0
95
+
96
+ return f"{emergency_msg}\n\n[Confidence: {confidence_score:.1f}%]\n\n{answer}"
97
+
98
+ except Exception as e:
99
+ return f"{emergency_msg}\n\nError generating answer: {str(e)}"
100
 
101
+ # 2. Retrieve relevant contexts
102
  contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True)
103
 
104
  if not contexts:
105
  return "I'm not confident about my answer (0%).\n\nCouldn't find relevant information to answer your question."
106
 
107
+ # 3. Build prompt with context
108
  prompt = build_prompt(user_query, contexts)
109
+
110
+ # 4. Generate answer
111
  result = finetuned_llm(
112
  prompt,
113
  max_new_tokens=70,
 
121
  answer = result[0]['generated_text'].strip()
122
  answer = refine_response(answer)
123
 
124
+ # 5. Calculate confidence score based on retrieval quality
125
  if contexts and len(contexts) > 0:
126
  avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts)
127
  confidence_score = (1 - avg_distance) * 100
128
  confidence_score = max(0, min(100, confidence_score))
129
 
130
+ # Build final response with confidence
131
  if confidence_score < 40:
132
+ final_response = f"I'm not confident about my answer ({confidence_score:.1f}%).\n\n{answer}"
133
  else:
134
+ final_response = f"{answer}\n\n[Confidence: {confidence_score:.1f}%]"
135
+ else:
136
+ final_response = "I'm not confident about my answer (0%).\n\n" + answer
137
 
138
+ return final_response
139
 
140
  except Exception as e:
141
+ error_msg = f"ERROR in RAG pipeline: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
142
+ print(error_msg)
143
+ return error_msg