KeenWoo commited on
Commit
8ae6cf5
·
verified ·
1 Parent(s): 962e049

Delete alz_companion/agent.py

Browse files
Files changed (1) hide show
  1. alz_companion/agent.py +0 -724
alz_companion/agent.py DELETED
@@ -1,724 +0,0 @@
1
- from __future__ import annotations
2
- import os
3
- import json
4
- import base64
5
- import time
6
- import tempfile
7
- import re
8
-
9
- from typing import List, Dict, Any, Optional
10
-
11
- try:
12
- from openai import OpenAI
13
- except Exception:
14
- OpenAI = None
15
-
16
- from langchain.schema import Document
17
- from langchain_community.vectorstores import FAISS
18
- from langchain_community.embeddings import HuggingFaceEmbeddings
19
-
20
- try:
21
- from gtts import gTTS
22
- except Exception:
23
- gTTS = None
24
-
25
- from .prompts import (
26
- SYSTEM_TEMPLATE, ANSWER_TEMPLATE_CALM,
27
- ANSWER_TEMPLATE_ADQ,
28
- # --- ADD YOUR NEW PROMPTS HERE ---
29
- ANSWER_TEMPLATE_ADQ_MODERATE,
30
- ANSWER_TEMPLATE_ADQ_ADVANCED,
31
- # --- END OF ADDITION ---
32
- SAFETY_GUARDRAILS, RISK_FOOTER, render_emotion_guidelines,
33
- NLU_ROUTER_PROMPT, SPECIALIST_CLASSIFIER_PROMPT,
34
- ROUTER_PROMPT,
35
- ANSWER_TEMPLATE_FACTUAL,
36
- ANSWER_TEMPLATE_GENERAL_KNOWLEDGE,
37
- ANSWER_TEMPLATE_GENERAL,
38
- ANSWER_TEMPLATE_FACTUAL_MULTI,
39
- ANSWER_TEMPLATE_SUMMARIZE,
40
- QUERY_EXPANSION_PROMPT
41
- )
42
-
43
- _BEHAVIOR_ALIASES = {
44
- "repeating questions": "repetitive_questioning", "repetitive questions": "repetitive_questioning",
45
- "confusion": "confusion", "wandering": "wandering", "agitation": "agitation",
46
- "accusing people": "false_accusations", "false accusations": "false_accusations",
47
- "memory loss": "address_memory_loss", "seeing things": "hallucinations_delusions",
48
- "hallucinations": "hallucinations_delusions", "delusions": "hallucinations_delusions",
49
- "trying to leave": "exit_seeking", "wanting to go home": "exit_seeking",
50
- "aphasia": "aphasia", "word finding": "aphasia", "withdrawn": "withdrawal",
51
- "apathy": "apathy", "affection": "affection", "sleep problems": "sleep_disturbance",
52
- "anxiety": "anxiety", "sadness": "depression_sadness", "depression": "depression_sadness",
53
- "checking orientation": "orientation_check", "misidentification": "misidentification",
54
- "sundowning": "sundowning_restlessness", "restlessness": "sundowning_restlessness",
55
- "losing things": "object_misplacement", "misplacing things": "object_misplacement",
56
- "planning": "goal_breakdown", "reminiscing": "reminiscence_prompting",
57
- "communication strategy": "caregiver_communication_template",
58
- }
59
-
60
- def _canon_behavior_list(xs: list[str] | None, opts: list[str]) -> list[str]:
61
- out = []
62
- for x in (xs or []):
63
- y = _BEHAVIOR_ALIASES.get(x.strip().lower(), x.strip())
64
- if y in opts and y not in out:
65
- out.append(y)
66
- return out
67
-
68
- _TOPIC_ALIASES = {
69
- "home safety": "treatment_option:home_safety", "long-term care": "treatment_option:long_term_care",
70
- "music": "treatment_option:music_therapy", "reassure": "treatment_option:reassurance",
71
- "routine": "treatment_option:routine_structuring", "validation": "treatment_option:validation_therapy",
72
- "caregiving advice": "caregiving_advice", "medical": "medical_fact",
73
- "research": "research_update", "story": "personal_story",
74
- }
75
- _CONTEXT_ALIASES = {
76
- "mild": "disease_stage_mild", "moderate": "disease_stage_moderate", "advanced": "disease_stage_advanced",
77
- "care home": "setting_care_home", "hospital": "setting_clinic_or_hospital", "home": "setting_home_or_community",
78
- "group": "interaction_mode_group_activity", "1:1": "interaction_mode_one_to_one", "one to one": "interaction_mode_one_to_one",
79
- "family": "relationship_family", "spouse": "relationship_spouse", "staff": "relationship_staff_or_caregiver",
80
- }
81
-
82
- def _canon_topic(x: str, opts: list[str]) -> str:
83
- if not x: return "None"
84
- y = _TOPIC_ALIASES.get(x.strip().lower(), x.strip())
85
- return y if y in opts else "None"
86
-
87
- def _canon_context_list(xs: list[str] | None, opts: list[str]) -> list[str]:
88
- out = []
89
- for x in (xs or []):
90
- y = _CONTEXT_ALIASES.get(x.strip().lower(), x.strip())
91
- if y in opts and y not in out: out.append(y)
92
- return out
93
-
94
-
95
- MULTI_HOP_KEYPHRASES = [
96
- r"\bcompare\b", r"\bvs\.?\b", r"\bversus\b", r"\bdifference between\b",
97
- r"\b(more|less|fewer) (than|visitors|agitated)\b", r"\bchange after\b",
98
- r"\bafter.*(vs|before)\b", r"\bbefore.*(vs|after)\b", r"\b(who|which) .*(more|less)\b",
99
- # --- START: REVISED & MORE ROBUST PATTERNS ---
100
- r"\b(did|was|is)\b .*\b(where|when|who)\b", # Catches MH1_new ("Did X happen where Y happened?")
101
- r"\bconsidering\b .*\bhow long\b", # Catches MH2_new
102
- r"\b(but|and)\b who was the other person\b", # Catches MH3_new
103
- r"what does the journal say about" # Catches MH4_new
104
- # --- END: REVISED & MORE ROBUST PATTERNS ---
105
- ]
106
- _MH_PATTERNS = [re.compile(p, re.IGNORECASE) for p in MULTI_HOP_KEYPHRASES]
107
-
108
-
109
- # Add this near the top of agent.py with the other keyphrase lists
110
- SUMMARIZATION_KEYPHRASES = [
111
- r"^\b(summarize|summarise|recap)\b", r"^\b(give me a summary|create a short summary)\b"
112
- ]
113
- _SUM_PATTERNS = [re.compile(p, re.IGNORECASE) for p in SUMMARIZATION_KEYPHRASES]
114
-
115
- def _pre_router_summarization(query: str) -> str | None:
116
- q = (query or "")
117
- for pat in _SUM_PATTERNS:
118
- if re.search(pat, q): return "summarization"
119
- return None
120
-
121
-
122
- CARE_KEYPHRASES = [
123
- r"\bwhere am i\b", r"\byou('?| ha)ve stolen my\b|\byou'?ve stolen my\b",
124
- r"\bi lost (the )?word\b|\bword-finding\b|\bcan.?t find the word\b",
125
- r"\bshe didn('?| no)t know me\b|\bhe didn('?| no)t know me\b",
126
- r"\bdisorient(?:ed|ation)\b|\bagitation\b|\bconfus(?:ed|ion)\b",
127
- r"\bcare home\b|\bnursing home\b|\bthe.*home\b",
128
- r"\bplaylist\b|\bsongs?\b.*\b(memories?|calm|soothe|familiar)\b",
129
- r"\bi want to keep teaching\b|\bi want to keep driving\b|\bi want to go home\b",
130
- r"music therapy",
131
- # --- ADD THESE LINES for handle test cases ---
132
- r"music therapy"
133
- r"\bremembering the\b", # Catches P7
134
- r"\bmissed you so much\b" # Catches P4
135
- r"\b(i forgot my job|what did i work as|do you remember my job)\b" # Catches queries about forgetting profession
136
- ]
137
- _CARE_PATTERNS = [re.compile(p) for p in CARE_KEYPHRASES]
138
-
139
-
140
-
141
- _STRIP_PATTERNS = [(r'^\s*(your\s+(final\s+)?answer|your\s+response)\s+in\s+[A-Za-z\-]+\s*:?\s*', ''), (r'\bbased on (?:the |any )?(?:provided )?(?:context|information|details)(?: provided)?(?:,|\.)?\s*', ''), (r'^\s*as an ai\b.*?(?:,|\.)\s*', ''), (r'\b(according to|from)\s+(the\s+)?(sources?|context)\b[:,]?\s*', ''), (r'\bI hope this helps[.!]?\s*$', '')]
142
-
143
- def _clean_surface_text(text: str) -> str:
144
- # This function remains unchanged from agent_work.py
145
- out = text or ""
146
- for pat, repl in _STRIP_PATTERNS:
147
- out = re.sub(pat, repl, out, flags=re.IGNORECASE)
148
- return re.sub(r'\n{3,}', '\n\n', out).strip()
149
-
150
-
151
-
152
-
153
- # Utilities
154
- def _openai_client() -> Optional[OpenAI]:
155
- api_key = os.getenv("OPENAI_API_KEY", "").strip()
156
- return OpenAI(api_key=api_key) if api_key and OpenAI else None
157
-
158
- def describe_image(image_path: str) -> str:
159
- # This function remains unchanged from agent_work.py
160
- client = _openai_client()
161
- if not client: return "(Image description failed: OpenAI API key not configured.)"
162
- try:
163
- extension = os.path.splitext(image_path)[1].lower()
164
- mime_type = f"image/{'jpeg' if extension in ['.jpg', '.jpeg'] else extension.strip('.')}"
165
- with open(image_path, "rb") as image_file:
166
- base64_image = base64.b64encode(image_file.read()).decode('utf-8')
167
- response = client.chat.completions.create(
168
- model="gpt-4o",
169
- messages=[{"role": "user", "content": [{"type": "text", "text": "Describe this image concisely for a memory journal. Focus on people, places, and key objects. Example: 'A photo of John and Mary smiling on a bench at the park.'"},{"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{base64_image}"}}]}], max_tokens=100)
170
- return response.choices[0].message.content or "No description available."
171
- except Exception as e:
172
- return f"[Image description error: {e}]"
173
-
174
- # --- MODIFICATION 1: Use the new, corrected NLU function ---
175
- def detect_tags_from_query(
176
- query: str,
177
- nlu_vectorstore: FAISS,
178
- behavior_options: list,
179
- emotion_options: list,
180
- topic_options: list,
181
- context_options: list,
182
- settings: dict = None
183
- ) -> Dict[str, Any]:
184
- """Uses a dynamic two-step NLU process: Route -> Retrieve Examples -> Classify."""
185
- result_dict = {"detected_behaviors": [], "detected_emotion": "None", "detected_topics": [], "detected_contexts": []}
186
- router_prompt = NLU_ROUTER_PROMPT.format(query=query)
187
- primary_goal_raw = call_llm([{"role": "user", "content": router_prompt}], temperature=0.0).strip().lower()
188
- goal_for_filter = "practical_planning" if "practical" in primary_goal_raw else "emotional_support"
189
- goal_for_prompt = "Practical Planning" if "practical" in primary_goal_raw else "Emotional Support"
190
-
191
- if settings and settings.get("debug_mode"):
192
- print(f"\n--- NLU Router ---\nGoal: {goal_for_prompt} (Filter: '{goal_for_filter}')\n------------------\n")
193
-
194
- retriever = nlu_vectorstore.as_retriever(search_kwargs={"k": 2, "filter": {"primary_goal": goal_for_filter}})
195
- retrieved_docs = retriever.invoke(query)
196
- if not retrieved_docs:
197
- retrieved_docs = nlu_vectorstore.as_retriever(search_kwargs={"k": 2}).invoke(query)
198
-
199
- selected_examples = "\n".join(
200
- f"User Query: \"{doc.page_content}\"\n{json.dumps(doc.metadata['classification'], indent=4)}"
201
- for doc in retrieved_docs
202
- )
203
- if not selected_examples:
204
- selected_examples = "(No relevant examples found)"
205
- if settings and settings.get("debug_mode"):
206
- print("WARNING: NLU retriever found no examples for this query.")
207
-
208
- behavior_str = ", ".join(f'"{opt}"' for opt in behavior_options if opt != "None")
209
- emotion_str = ", ".join(f'"{opt}"' for opt in emotion_options if opt != "None")
210
- topic_str = ", ".join(f'"{opt}"' for opt in topic_options if opt != "None")
211
- context_str = ", ".join(f'"{opt}"' for opt in context_options if opt != "None")
212
-
213
- prompt = SPECIALIST_CLASSIFIER_PROMPT.format(
214
- primary_goal=goal_for_prompt, examples=selected_examples,
215
- behavior_options=behavior_str, emotion_options=emotion_str,
216
- topic_options=topic_str, context_options=context_str, query=query
217
- )
218
-
219
- messages = [{"role": "system", "content": "You are a helpful NLU classification assistant."}, {"role": "user", "content": prompt}]
220
- response_str = call_llm(messages, temperature=0.0, response_format={"type": "json_object"})
221
-
222
- if settings and settings.get("debug_mode"):
223
- print(f"\n--- NLU Specialist Full Response ---\n{response_str}\n----------------------------------\n")
224
-
225
- try:
226
- start_brace = response_str.find('{')
227
- end_brace = response_str.rfind('}')
228
- if start_brace == -1 or end_brace <= start_brace:
229
- raise json.JSONDecodeError("No valid JSON object found in response.", response_str, 0)
230
-
231
- json_str = response_str[start_brace : end_brace + 1]
232
- result = json.loads(json_str)
233
-
234
- result_dict["detected_emotion"] = result.get("detected_emotion") or "None"
235
-
236
- behaviors_raw = result.get("detected_behaviors")
237
- behaviors_canon = _canon_behavior_list(behaviors_raw, behavior_options)
238
- if behaviors_canon:
239
- result_dict["detected_behaviors"] = behaviors_canon
240
-
241
- topics_raw = result.get("detected_topics") or result.get("detected_topic")
242
- detected_topics = []
243
- if isinstance(topics_raw, list):
244
- for t in topics_raw:
245
- ct = _canon_topic(t, topic_options)
246
- if ct != "None": detected_topics.append(ct)
247
- elif isinstance(topics_raw, str):
248
- ct = _canon_topic(topics_raw, topic_options)
249
- if ct != "None": detected_topics.append(ct)
250
- result_dict["detected_topics"] = detected_topics
251
-
252
- contexts_raw = result.get("detected_contexts")
253
- contexts_canon = _canon_context_list(contexts_raw, context_options)
254
- if contexts_canon:
255
- result_dict["detected_contexts"] = contexts_canon
256
-
257
- return result_dict
258
-
259
- except (json.JSONDecodeError, AttributeError) as e:
260
- print(f"ERROR parsing NLU Specialist JSON: {e}")
261
- return result_dict
262
-
263
- def _default_embeddings():
264
- # This function remains unchanged from agent_work.py
265
- model_name = os.getenv("EMBEDDINGS_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
266
- return HuggingFaceEmbeddings(model_name=model_name)
267
-
268
- def build_or_load_vectorstore(docs: List[Document], index_path: str, is_personal: bool = False) -> FAISS:
269
- # This function remains unchanged from agent_work.py
270
- os.makedirs(os.path.dirname(index_path), exist_ok=True)
271
- if os.path.isdir(index_path) and os.path.exists(os.path.join(index_path, "index.faiss")):
272
- try:
273
- return FAISS.load_local(index_path, _default_embeddings(), allow_dangerous_deserialization=True)
274
- except Exception: pass
275
- if is_personal and not docs:
276
- docs = [Document(page_content="(This is the start of the personal memory journal.)", metadata={"source": "placeholder"})]
277
- vs = FAISS.from_documents(docs, _default_embeddings())
278
- vs.save_local(index_path)
279
- return vs
280
-
281
- def texts_from_jsonl(path: str) -> List[Document]:
282
- # This function remains unchanged from agent_work.py
283
- out: List[Document] = []
284
- try:
285
- with open(path, "r", encoding="utf-8") as f:
286
- for i, line in enumerate(f):
287
- obj = json.loads(line.strip())
288
- txt = obj.get("text") or ""
289
- if not txt.strip(): continue
290
- md = {"source": os.path.basename(path), "chunk": i}
291
- for k in ("behaviors", "emotion", "topic_tags", "context_tags"):
292
- if k in obj and obj[k]: md[k] = obj[k]
293
- out.append(Document(page_content=txt, metadata=md))
294
- except Exception: return []
295
- return out
296
-
297
- # Some vectorstores might return duplicates.
298
- # This is useful when top-k cutoff might otherwise include near-duplicates from query expansion
299
- def dedup_docs(scored_docs):
300
- seen = set()
301
- unique = []
302
- for doc, score in scored_docs:
303
- uid = doc.metadata.get("source", "") + "::" + doc.page_content.strip()
304
- if uid not in seen:
305
- unique.append((doc, score))
306
- seen.add(uid)
307
- return unique
308
-
309
-
310
- def bootstrap_vectorstore(sample_paths: List[str] | None = None, index_path: str = "data/faiss_index") -> FAISS:
311
- # This function remains unchanged from agent_work.py
312
- docs: List[Document] = []
313
- for p in (sample_paths or []):
314
- try:
315
- if p.lower().endswith(".jsonl"):
316
- docs.extend(texts_from_jsonl(p))
317
- else:
318
- with open(p, "r", encoding="utf-8", errors="ignore") as fh:
319
- docs.append(Document(page_content=fh.read(), metadata={"source": os.path.basename(p)}))
320
- except Exception: continue
321
- if not docs:
322
- docs = [Document(page_content="(empty index)", metadata={"source": "placeholder"})]
323
- return build_or_load_vectorstore(docs, index_path=index_path)
324
-
325
- def call_llm(messages: List[Dict[str, str]], temperature: float = 0.6, stop: Optional[List[str]] = None, response_format: Optional[dict] = None) -> str:
326
- # This function remains unchanged from agent_work.py
327
- client = _openai_client()
328
- if client is None: raise RuntimeError("OpenAI client not configured (missing API key?).")
329
- model = os.getenv("OPENAI_CHAT_MODEL", "gpt-4o-mini")
330
- api_args = {"model": model, "messages": messages, "temperature": float(temperature if temperature is not None else 0.6)}
331
- if stop: api_args["stop"] = stop
332
- if response_format: api_args["response_format"] = response_format
333
- resp = client.chat.completions.create(**api_args)
334
- content = ""
335
- try:
336
- content = resp.choices[0].message.content or ""
337
- except Exception:
338
- msg = getattr(resp.choices[0], "message", None)
339
- if isinstance(msg, dict): content = msg.get("content") or ""
340
- return content.strip()
341
-
342
- MULTI_HOP_KEYPHRASES = [r"\bcompare\b", r"\bvs\.?\b", r"\bversus\b", r"\bdifference between\b", r"\b(more|less|fewer) (than|visitors|agitated)\b", r"\bchange after\b", r"\bafter.*(vs|before)\b", r"\bbefore.*(vs|after)\b", r"\b(who|which) .*(more|less)\b"]
343
- _MH_PATTERNS = [re.compile(p, re.IGNORECASE) for p in MULTI_HOP_KEYPHRASES]
344
-
345
- def _pre_router_multi_hop(query: str) -> str | None:
346
- # This function remains unchanged from agent_work.py
347
- q = (query or "")
348
- for pat in _MH_PATTERNS:
349
- if re.search(pat, q): return "multi_hop"
350
- return None
351
-
352
- def _pre_router(query: str) -> str | None:
353
- # This function remains unchanged from agent_work.py
354
- q = (query or "").lower()
355
- for pat in _CARE_PATTERNS:
356
- if re.search(pat, q): return "caregiving_scenario"
357
- return None
358
-
359
- def _llm_route_with_prompt(query: str, temperature: float = 0.0) -> str:
360
- # This function remains unchanged from agent_work.py
361
- router_messages = [{"role": "user", "content": ROUTER_PROMPT.format(query=query)}]
362
- query_type = call_llm(router_messages, temperature=temperature).strip().lower()
363
- return query_type
364
-
365
- # OLD use this new pre-router and place it in the correct order of priority.
366
- # OLD def route_query_type(query: str) -> str:
367
- # NEW the severity override only apply to moderate or advanced stages
368
- def route_query_type(query: str, severity: str = "Normal / Unspecified"):
369
- # This new, adaptive logic ONLY applies if severity is set to moderate or advanced.
370
- if severity in ["Moderate Stage", "Advanced Stage"]:
371
- # Check if it's an obvious other type first (e.g., summarization)
372
- if not _pre_router_summarization(query) and not _pre_router_multi_hop(query):
373
- print(f"Query classified as: caregiving_scenario (severity override)")
374
- return "caregiving_scenario"
375
- # END
376
-
377
- # FOR "Normal / Unspecified", THE CODE CONTINUES HERE, USING THE EXISTING LOGIC
378
- # This is your original code path.
379
- # Priority 1: Check for specific, structural queries first.
380
- mh_hit = _pre_router_multi_hop(query)
381
- if mh_hit:
382
- print(f"Query classified as: {mh_hit} (multi-hop pre-router)")
383
- return mh_hit
384
-
385
- # Priority 2: Check for explicit commands like "summarize".
386
- sum_hit = _pre_router_summarization(query)
387
- if sum_hit:
388
- print(f"Query classified as: {sum_hit} (summarization pre-router)")
389
- return sum_hit
390
-
391
- # Priority 3: Check for general caregiving keywords.
392
- care_hit = _pre_router(query)
393
- if care_hit:
394
- print(f"Query classified as: {care_hit} (caregiving pre-router)")
395
- return care_hit
396
-
397
- # Fallback: If no pre-routers match, use the LLM for nuanced classification.
398
- query_type = _llm_route_with_prompt(query, temperature=0.0)
399
- print(f"Query classified as: {query_type} (LLM router)")
400
- return query_type
401
-
402
-
403
- # helper: put near other small utils in agent.py
404
- # In agent.py, replace the _source_ids_for_eval function
405
-
406
- def _source_ids_for_eval(docs, cap=5):
407
- """
408
- Return the source identifiers for evaluation.
409
- - For jsonl files, it returns the numeric chunk ID or the scene_id if present.
410
- - For ANY other source, it returns the generic name "Text Input".
411
- - It excludes the 'placeholder' source.
412
- """
413
- out, seen = [], set()
414
- for d in docs or []:
415
- md = getattr(d, "metadata", {}) or {}
416
- src = str(md.get("source", "")).lower()
417
-
418
- if src == 'placeholder':
419
- continue
420
-
421
- key = None
422
-
423
- if src.endswith(".jsonl"):
424
- # Prioritize 'scene_id' if it exists (for alive_inside.jsonl)
425
- if 'scene_id' in md:
426
- key = str(md['scene_id'])
427
- # Fallback to numeric chunk ID for other jsonl files
428
- elif 'chunk' in md and isinstance(md['chunk'], int):
429
- key = str(md['chunk'])
430
- else:
431
- key = "Text Input"
432
-
433
- if key and key not in seen:
434
- seen.add(key)
435
- out.append(str(key))
436
- if len(out) >= cap:
437
- break
438
- return out
439
-
440
-
441
- # In agent.py, replace the ENTIRE make_rag_chain function with this one.
442
- # def make_rag_chain(vs_general: FAISS, vs_personal: FAISS, *, for_evaluation: bool = False, role: str = "patient", temperature: float = 0.6, language: str = "English", patient_name: str = "the patient", caregiver_name: str = "the caregiver", tone: str = "warm"):
443
- # NEW: accept the new disease_stage parameter.
444
- def make_rag_chain(vs_general: FAISS, vs_personal: FAISS, *, for_evaluation: bool = False, role: str = "patient", temperature: float = 0.6, language: str = "English", patient_name: str = "the patient", caregiver_name: str = "the caregiver", tone: str = "warm", disease_stage: str = "Normal / Unspecified"):
445
- """Returns a callable that performs the complete RAG process."""
446
-
447
- RELEVANCE_THRESHOLD = 0.85
448
- SCORE_MARGIN = 0.10 # Margin to decide if scores are "close enough" to blend.
449
-
450
- def _format_docs(docs: List[Document], default_msg: str) -> str:
451
- if not docs: return default_msg
452
- unique_docs = {doc.page_content: doc for doc in docs}.values()
453
- return "\n".join([f"- {d.page_content.strip()}" for d in unique_docs])
454
-
455
- # def _answer_fn(query: str, query_type: str, chat_history: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
456
- # NEW
457
- def _answer_fn(query: str, query_type: str, chat_history: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
458
-
459
- # --- ADD THIS LINE FOR VERIFICATION ---
460
- print(f"DEBUG: RAG chain received disease_stage = '{disease_stage}'")
461
- # --- END OF ADDITION ---
462
-
463
- # Create a local variable for test_temperature to avoid the UnboundLocalError.
464
- test_temperature = temperature
465
-
466
- p_name = patient_name or "the patient"
467
- c_name = caregiver_name or "the caregiver"
468
-
469
- perspective_line = (f"You are speaking directly to {p_name}, who is the patient...") if role == "patient" else (f"You are communicating with {c_name}, the caregiver, about {p_name}.")
470
- system_message = SYSTEM_TEMPLATE.format(tone=tone, language=language, perspective_line=perspective_line, guardrails=SAFETY_GUARDRAILS)
471
- messages = [{"role": "system", "content": system_message}]
472
- messages.extend(chat_history)
473
-
474
- if "general_knowledge_question" in query_type or "general_conversation" in query_type:
475
- template = ANSWER_TEMPLATE_GENERAL_KNOWLEDGE if "general_knowledge" in query_type else ANSWER_TEMPLATE_GENERAL
476
- user_prompt = template.format(question=query, language=language)
477
- messages.append({"role": "user", "content": user_prompt})
478
- raw_answer = call_llm(messages, temperature=test_temperature)
479
- answer = _clean_surface_text(raw_answer)
480
- sources = ["General Knowledge"] if "general_knowledge" in query_type else []
481
- return {"answer": answer, "sources": sources, "source_documents": []}
482
-
483
- expansion_prompt = QUERY_EXPANSION_PROMPT.format(question=query)
484
- expansion_response = call_llm([{"role": "user", "content": expansion_prompt}], temperature=0.1)
485
- try:
486
- search_queries = [query] + json.loads(expansion_response.strip().replace("```json", "").replace("```", ""))
487
- except json.JSONDecodeError:
488
- search_queries = [query]
489
-
490
- # NEW: Determine sourcing weight
491
- if disease_stage in ["Moderate Stage", "Advanced Stage"]:
492
- top_k_general = 5
493
- top_k_personal = 1
494
- else: # current default
495
- top_k_general = 2
496
- top_k_personal = 3
497
-
498
- # NEW: pass top_k_personal and top_k_general parameters
499
- personal_results_with_scores = [
500
- result for q in search_queries for result in vs_personal.similarity_search_with_score(q, k=top_k_personal)
501
- ]
502
- general_results_with_scores = [
503
- result for q in search_queries for result in vs_general.similarity_search_with_score(q, k=top_k_general)
504
- ]
505
-
506
- # NEW: Remove duplicates
507
- personal_results_with_scores = dedup_docs(personal_results_with_scores)
508
- general_results_with_scores = dedup_docs(general_results_with_scores)
509
-
510
- ## BEGIN DEBUGGING
511
- print(f"[DEBUG] Retrieved {len(personal_results_with_scores)} personal, {len(general_results_with_scores)} general results")
512
- if personal_results_with_scores:
513
- print(f"Top personal score: {max([s for _, s in personal_results_with_scores]):.3f}")
514
- if general_results_with_scores:
515
- print(f"Top general score: {max([s for _, s in general_results_with_scores]):.3f}")
516
-
517
- print("\n--- DEBUG: Personal Search Results with Scores (Before Filtering) ---")
518
- if personal_results_with_scores:
519
- for doc, score in personal_results_with_scores:
520
- print(f" - Score: {score:.4f} | Source: {doc.metadata.get('source', 'N/A')}")
521
- else:
522
- print(" - No results found.")
523
- print("-----------------------------------------------------------------")
524
-
525
- print("\n--- DEBUG: General Search Results with Scores (Before Filtering) ----")
526
- if general_results_with_scores:
527
- for doc, score in general_results_with_scores:
528
- print(f" - Score: {score:.4f} | Source: {doc.metadata.get('source', 'N/A')}")
529
- else:
530
- print(" - No results found.")
531
- print("-----------------------------------------------------------------")
532
- ## END DEBUGGING
533
-
534
- # Return the most relevant doc if not return the best score; and all strip OUT placehoder doc
535
- def get_best_docs_with_fallback(results_with_scores: list[tuple[Document, float]]) -> (list[Document], float):
536
- valid_results = [res for res in results_with_scores if res[0].metadata.get("source") != "placeholder"]
537
- if not valid_results:
538
- return [], float('inf')
539
-
540
- best_score = sorted(valid_results, key=lambda x: x[1])[0][1]
541
- filtered_docs = [doc for doc, score in valid_results if score < RELEVANCE_THRESHOLD]
542
-
543
- if not filtered_docs:
544
- return [sorted(valid_results, key=lambda x: x[1])[0][0]], best_score
545
-
546
- return filtered_docs, best_score
547
- # END def get_best_docs_with_fallback
548
-
549
- if disease_stage in ["Moderate Stage", "Advanced Stage"]:
550
- # Use top-k selection (e.g. top 5 for general, top 1 for personal)
551
- filtered_general_docs = [doc for doc, score in general_results_with_scores[:top_k_general]]
552
- best_general_score = general_results_with_scores[0][1] if general_results_with_scores else 0.0
553
-
554
- filtered_personal_docs = [doc for doc, score in personal_results_with_scores[:top_k_personal]]
555
- best_personal_score = personal_results_with_scores[0][1] if personal_results_with_scores else 0.0
556
- else:
557
- # Use standard fallback-based scoring
558
- filtered_personal_docs, best_personal_score = get_best_docs_with_fallback(personal_results_with_scores)
559
- filtered_general_docs, best_general_score = get_best_docs_with_fallback(general_results_with_scores)
560
-
561
- print("\n--- DEBUG: Filtered Personal Docs (After Threshold/Fallback) ---")
562
- if filtered_personal_docs:
563
- for doc in filtered_personal_docs:
564
- print(f" - Source: {doc.metadata.get('source', 'N/A')}")
565
- else:
566
- print(" - No documents met the criteria.")
567
- print("----------------------------------------------------------------")
568
-
569
- print("\n--- DEBUG: Filtered General Docs (After Threshold/Fallback) ----")
570
- if filtered_general_docs:
571
- for doc in filtered_general_docs:
572
- print(f" - Source: {doc.metadata.get('source', 'N/A')}")
573
- else:
574
- print(" - No documents met the criteria.")
575
- print("----------------------------------------------------------------")
576
-
577
- personal_memory_routes = ["factual", "multi_hop", "summarization"]
578
- is_personal_route = any(route_keyword in query_type for route_keyword in personal_memory_routes)
579
-
580
- all_retrieved_docs = []
581
- if is_personal_route:
582
- # --- MODIFIED AS PER YOUR SPECIFICATION ---
583
- # Implements the simple fallback logic for personal routes.
584
- # the logic of it always returns a personal doc unless it's not loaded with personal memory
585
- if filtered_personal_docs:
586
- all_retrieved_docs = filtered_personal_docs
587
- else:
588
- all_retrieved_docs = filtered_general_docs
589
- # --- END OF MODIFICATION ---
590
- else: # caregiving_scenario
591
- if disease_stage in ["Moderate Stage", "Advanced Stage"]:
592
- # --- STAGE-AWARE LOGIC FOR CAREGIVING SCENARIOS ---
593
- if filtered_general_docs:
594
- all_retrieved_docs = filtered_general_docs
595
- elif filtered_personal_docs:
596
- all_retrieved_docs = filtered_personal_docs
597
- else:
598
- all_retrieved_docs = []
599
- # --- END STAGE-AWARE BLOCK ---
600
- else:
601
- # --- NORMAL ROUTING LOGIC ---
602
- # Conditional Blending logic for caregiving remains.
603
- if abs(best_personal_score - best_general_score) <= SCORE_MARGIN:
604
- all_retrieved_docs = list({doc.page_content: doc for doc in filtered_personal_docs + filtered_general_docs}.values())[:4]
605
- elif best_personal_score < best_general_score:
606
- all_retrieved_docs = filtered_personal_docs
607
- else:
608
- all_retrieved_docs = filtered_general_docs
609
-
610
- # --- Prompt Generation and LLM Call ---
611
- answer = ""
612
- if is_personal_route:
613
- personal_context = _format_docs(all_retrieved_docs, "(No relevant personal memories found.)")
614
- # New modify for test evaluation, general_context is empty but use general context in live chat
615
- general_context = _format_docs([], "") if for_evaluation else _format_docs(filtered_general_docs, "(No general information found.)")
616
- # End
617
-
618
- template = ANSWER_TEMPLATE_SUMMARIZE if "summarization" in query_type else ANSWER_TEMPLATE_FACTUAL
619
- user_prompt = ""
620
- if "summarization" in query_type:
621
- if for_evaluation: # for evaluation, use only personal
622
- user_prompt = template.format(context=personal_context, question=query, language=language, patient_name=p_name, caregiver_name=c_name, role=role)
623
- else: # for live chat, use more context
624
- combined_context = f"{personal_context}\n{general_context}".strip()
625
- user_prompt = template.format(context=combined_context, question=query, language=language, patient_name=p_name, caregiver_name=c_name, role=role)
626
-
627
- else: # ANSWER_TEMPLATE_FACTUAL
628
- user_prompt = template.format(personal_context=personal_context, general_context=general_context, question=query, language=language, patient_name=p_name, caregiver_name=c_name)
629
-
630
- messages.append({"role": "user", "content": user_prompt})
631
- if for_evaluation: # if evaluation test, set temperature (creativity) low from 0.6 input
632
- test_temperature = 0.0 # Modify the local variable
633
- raw_answer = call_llm(messages, temperature=test_temperature)
634
- answer = _clean_surface_text(raw_answer)
635
-
636
- else: # caregiving_scenario
637
- # --- MODIFICATION START: Integrate the severity-based logic ---
638
- # The disease_stage variable is available here from the outer function's scope
639
-
640
- # 1. Select the appropriate template based on the disease stage setting.
641
- if disease_stage == "Advanced Stage":
642
- template = ANSWER_TEMPLATE_ADQ_ADVANCED
643
- elif disease_stage == "Moderate Stage":
644
- template = ANSWER_TEMPLATE_ADQ_MODERATE
645
- else: # Normal / Unspecified or Mild Stage
646
- template = ANSWER_TEMPLATE_ADQ
647
-
648
- # 2. The rest of the logic remains the same. It will use the 'template' variable
649
- # that was just selected above.
650
- personal_sources = {'1 Complaints of a Dutiful Daughter.txt', 'Saved Chat', 'Text Input'}
651
- personal_context = _format_docs([d for d in all_retrieved_docs if d.metadata.get('source') in personal_sources], "(No relevant personal memories found.)")
652
- general_context = _format_docs([d for d in all_retrieved_docs if d.metadata.get('source') not in personal_sources], "(No general guidance found.)")
653
-
654
- first_emotion = next((d.metadata.get("emotion") for d in all_retrieved_docs if d.metadata.get("emotion")), None)
655
- emotions_context = render_emotion_guidelines(first_emotion or kwargs.get("emotion_tag"))
656
-
657
- # NEW: Add Emotion Tag
658
- user_prompt = template.format(general_context=general_context, personal_context=personal_context,
659
- question=query, scenario_tag=kwargs.get("scenario_tag"),
660
- emotions_context=emotions_context, role=role, language=language,
661
- patient_name=p_name, caregiver_name=c_name,
662
- emotion_tag=kwargs.get("emotion_tag"))
663
- messages.append({"role": "user", "content": user_prompt})
664
- # --- MODIFICATION END ---
665
-
666
- # OLD
667
- # template = ANSWER_TEMPLATE_ADQ
668
- # user_prompt = template.format(general_context=general_context, personal_context=personal_context,
669
- # question=query, scenario_tag=kwargs.get("scenario_tag"),
670
- # emotions_context=emotions_context, role=role, language=language,
671
- # patient_name=p_name, caregiver_name=c_name)
672
- # messages.append({"role": "user", "content": user_prompt})
673
-
674
- if for_evaluation: # if evaluation test, set temperature (creativity) low from 0.6 input
675
- test_temperature = 0.0 # Modify the local variable
676
- raw_answer = call_llm(messages, temperature=test_temperature)
677
- answer = _clean_surface_text(raw_answer)
678
-
679
- high_risk_scenarios = ["exit_seeking", "wandering", "elopement"]
680
- if kwargs.get("scenario_tag") and kwargs["scenario_tag"].lower() in high_risk_scenarios:
681
- answer += f"\n\n---\n{RISK_FOOTER}"
682
-
683
- if for_evaluation:
684
- sources = _source_ids_for_eval(all_retrieved_docs)
685
- else:
686
- sources = sorted(list(set(d.metadata.get("source", "unknown") for d in all_retrieved_docs if d.metadata.get("source") != "placeholder")))
687
-
688
- print("DEBUG Sources (After Filtering):", sources)
689
- return {"answer": answer, "sources": sources, "source_documents": all_retrieved_docs}
690
-
691
- return _answer_fn
692
-
693
- # END of make_rag_chain
694
-
695
- def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
696
- # This function remains unchanged from agent_work.py
697
- if not callable(chain): return {"answer": "[Error: RAG chain is not callable]", "sources": []}
698
- try:
699
- return chain(question, **kwargs)
700
- except Exception as e:
701
- print(f"ERROR in answer_query: {e}")
702
- return {"answer": f"[Error executing chain: {e}]", "sources": []}
703
-
704
- def synthesize_tts(text: str, lang: str = "en"):
705
- # This function remains unchanged from agent_work.py
706
- if not text or gTTS is None: return None
707
- try:
708
- with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as fp:
709
- tts = gTTS(text=text, lang=(lang or "en"))
710
- tts.save(fp.name)
711
- return fp.name
712
- except Exception:
713
- return None
714
-
715
- def transcribe_audio(filepath: str, lang: str = "en"):
716
- # This function remains unchanged from agent_work.py
717
- client = _openai_client()
718
- if not client: return "[Transcription failed: API key not configured]"
719
- model = os.getenv("TRANSCRIBE_MODEL", "whisper-1")
720
- api_args = {"model": model}
721
- if lang and lang != "auto": api_args["language"] = lang
722
- with open(filepath, "rb") as audio_file:
723
- transcription = client.audio.transcriptions.create(file=audio_file, **api_args)
724
- return transcription.text