KeenWoo commited on
Commit
ecf3b35
·
verified ·
1 Parent(s): ca28c39

Update alz_companion/agent.py

Browse files
Files changed (1) hide show
  1. alz_companion/agent.py +22 -48
alz_companion/agent.py CHANGED
@@ -30,7 +30,7 @@ from .prompts import (
30
 
31
 
32
  # -----------------------------
33
- # Multimodal Processing Functions (NEW)
34
  # -----------------------------
35
 
36
  def _openai_client() -> Optional[OpenAI]:
@@ -68,7 +68,6 @@ def describe_image(image_path: str) -> str:
68
  # -----------------------------
69
  # NLU Classification Function
70
  # -----------------------------
71
- # (Unchanged from before)
72
  def detect_tags_from_query(query: str, behavior_options: list, emotion_options: list) -> Dict[str, Optional[str]]:
73
  """Uses an LLM call to classify the user's query into a behavior and emotion tag."""
74
  behavior_str = ", ".join(f'"{opt}"' for opt in behavior_options if opt != "None")
@@ -89,7 +88,6 @@ def detect_tags_from_query(query: str, behavior_options: list, emotion_options:
89
  # -----------------------------
90
  # Embeddings & VectorStore
91
  # -----------------------------
92
- # (Unchanged from before)
93
 
94
  def _default_embeddings():
95
  """Lightweight, widely available model."""
@@ -104,7 +102,6 @@ def build_or_load_vectorstore(docs: List[Document], index_path: str, is_personal
104
  except Exception:
105
  pass
106
 
107
- # If it's a new personal vector store with no docs, create a placeholder
108
  if is_personal and not docs:
109
  docs = [Document(page_content="(This is the start of the personal memory journal.)", metadata={"source": "placeholder"})]
110
 
@@ -148,7 +145,6 @@ def bootstrap_vectorstore(sample_paths: List[str] | None = None, index_path: str
148
  # -----------------------------
149
  # LLM Call
150
  # -----------------------------
151
- # (Unchanged from before)
152
  def call_llm(messages: List[Dict[str, str]], temperature: float = 0.6) -> str:
153
  """Call OpenAI Chat Completions if available; else return a fallback."""
154
  client = _openai_client()
@@ -162,7 +158,7 @@ def call_llm(messages: List[Dict[str, str]], temperature: float = 0.6) -> str:
162
  return f"[LLM API Error: {e}]"
163
 
164
  # -----------------------------
165
- # Prompting & RAG Chain (HEAVILY MODIFIED)
166
  # -----------------------------
167
 
168
  def _format_sources(docs: List[Document]) -> List[str]:
@@ -181,70 +177,56 @@ def make_rag_chain(
181
  ):
182
  """Returns a callable that performs the complete, two-tiered RAG process."""
183
 
184
- retriever_general = vs_general.as_retriever(search_kwargs={"k": 3})
185
- retriever_personal = vs_personal.as_retriever(search_kwargs={"k": 3})
186
-
187
  def _format_docs(docs: List[Document], default_msg: str) -> str:
188
  if not docs: return default_msg
189
  return "\n".join([f"- {d.page_content.strip()}" for d in docs])
190
 
191
  def _answer_fn(query: str, chat_history: List[Dict[str, str]], scenario_tag: Optional[str] = None, emotion_tag: Optional[str] = None) -> Dict[str, Any]:
192
 
193
- # --- Step 1: Search the Personal Knowledge Base ---
194
- personal_docs = retriever_personal.invoke(query)
195
- personal_context = _format_docs(personal_docs, "(No relevant personal memories found.)")
196
-
197
- # --- Step 2: Search the General Knowledge Base with filters ---
198
  search_filter = {}
199
  if scenario_tag and scenario_tag != "None":
200
  search_filter["behaviors"] = scenario_tag.lower()
201
  if emotion_tag and emotion_tag != "None":
202
  search_filter["emotion"] = emotion_tag.lower()
203
 
 
204
  if search_filter:
 
205
  general_docs = vs_general.similarity_search(query, k=3, filter=search_filter)
206
  else:
 
 
 
 
207
  general_docs = retriever_general.invoke(query)
 
 
208
  general_context = _format_docs(general_docs, "(No general guidance found.)")
209
 
210
- # --- Step 3: Determine Emotion for Response Guidelines ---
211
  first_emotion = None
212
- # Prioritize emotion from personal memories, then general context
213
  all_docs = personal_docs + general_docs
214
  for doc in all_docs:
215
  if "emotion" in doc.metadata and doc.metadata["emotion"]:
216
  emotion_data = doc.metadata["emotion"]
217
- if isinstance(emotion_data, list):
218
- first_emotion = emotion_data[0]
219
- else:
220
- first_emotion = emotion_data
221
  if first_emotion: break
222
 
223
  emotions_context = render_emotion_guidelines(first_emotion or emotion_tag)
224
 
225
- # --- Step 4: Assemble and Call the LLM ---
226
- is_tagged_scenario = (scenario_tag and scenario_tag != "None") or (emotion_tag and emotion_tag != "None")
227
  template = ANSWER_TEMPLATE_ADQ if is_tagged_scenario else ANSWER_TEMPLATE_CALM
228
 
229
- # Note the new placeholders: general_context and personal_context
230
  if template == ANSWER_TEMPLATE_ADQ:
231
- user_prompt = template.format(
232
- general_context=general_context,
233
- personal_context=personal_context,
234
- question=query,
235
- scenario_tag=scenario_tag,
236
- emotions_context=emotions_context,
237
- role=role,
238
- language=language
239
- )
240
- else: # Calm template only uses a single combined context
241
  combined_context = f"General Guidance:\n{general_context}\n\nPersonal Memories:\n{personal_context}"
242
  user_prompt = template.format(context=combined_context, question=query, language=language)
243
 
244
- system_message = SYSTEM_TEMPLATE.format(
245
- tone=tone, language=language, patient_name=patient_name or "the patient",
246
- caregiver_name=caregiver_name or "the caregiver", guardrails=SAFETY_GUARDRAILS,
247
- )
248
 
249
  messages = [{"role": "system", "content": system_message}]
250
  messages.extend(chat_history)
@@ -261,13 +243,8 @@ def make_rag_chain(
261
  return _answer_fn
262
 
263
  def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
264
- if not callable(chain):
265
- return {"answer": "[Error: RAG chain is not callable]", "sources": []}
266
-
267
- chat_history = kwargs.get("chat_history", [])
268
- scenario_tag = kwargs.get("scenario_tag")
269
- emotion_tag = kwargs.get("emotion_tag")
270
-
271
  try:
272
  return chain(question, chat_history=chat_history, scenario_tag=scenario_tag, emotion_tag=emotion_tag)
273
  except Exception as e:
@@ -277,8 +254,6 @@ def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
277
  # -----------------------------
278
  # TTS & Transcription
279
  # -----------------------------
280
- # (Unchanged)
281
-
282
  def synthesize_tts(text: str, lang: str = "en"):
283
  if not text or gTTS is None: return None
284
  try:
@@ -294,11 +269,10 @@ def transcribe_audio(filepath: str, lang: str = "en"):
294
  client = _openai_client()
295
  if not client:
296
  return "[Transcription failed: API key not configured]"
297
-
298
  api_args = {"model": "whisper-1"}
299
  if lang and lang != "auto":
300
  api_args["language"] = lang
301
-
302
  with open(filepath, "rb") as audio_file:
303
  transcription = client.audio.transcriptions.create(file=audio_file, **api_args)
304
  return transcription.text
 
 
30
 
31
 
32
  # -----------------------------
33
+ # Multimodal Processing Functions
34
  # -----------------------------
35
 
36
  def _openai_client() -> Optional[OpenAI]:
 
68
  # -----------------------------
69
  # NLU Classification Function
70
  # -----------------------------
 
71
  def detect_tags_from_query(query: str, behavior_options: list, emotion_options: list) -> Dict[str, Optional[str]]:
72
  """Uses an LLM call to classify the user's query into a behavior and emotion tag."""
73
  behavior_str = ", ".join(f'"{opt}"' for opt in behavior_options if opt != "None")
 
88
  # -----------------------------
89
  # Embeddings & VectorStore
90
  # -----------------------------
 
91
 
92
  def _default_embeddings():
93
  """Lightweight, widely available model."""
 
102
  except Exception:
103
  pass
104
 
 
105
  if is_personal and not docs:
106
  docs = [Document(page_content="(This is the start of the personal memory journal.)", metadata={"source": "placeholder"})]
107
 
 
145
  # -----------------------------
146
  # LLM Call
147
  # -----------------------------
 
148
  def call_llm(messages: List[Dict[str, str]], temperature: float = 0.6) -> str:
149
  """Call OpenAI Chat Completions if available; else return a fallback."""
150
  client = _openai_client()
 
158
  return f"[LLM API Error: {e}]"
159
 
160
  # -----------------------------
161
+ # Prompting & RAG Chain
162
  # -----------------------------
163
 
164
  def _format_sources(docs: List[Document]) -> List[str]:
 
177
  ):
178
  """Returns a callable that performs the complete, two-tiered RAG process."""
179
 
 
 
 
180
  def _format_docs(docs: List[Document], default_msg: str) -> str:
181
  if not docs: return default_msg
182
  return "\n".join([f"- {d.page_content.strip()}" for d in docs])
183
 
184
  def _answer_fn(query: str, chat_history: List[Dict[str, str]], scenario_tag: Optional[str] = None, emotion_tag: Optional[str] = None) -> Dict[str, Any]:
185
 
186
+ # Build a dynamic filter that will be used for BOTH knowledge bases
 
 
 
 
187
  search_filter = {}
188
  if scenario_tag and scenario_tag != "None":
189
  search_filter["behaviors"] = scenario_tag.lower()
190
  if emotion_tag and emotion_tag != "None":
191
  search_filter["emotion"] = emotion_tag.lower()
192
 
193
+ # Use the filter on both searches if available
194
  if search_filter:
195
+ personal_docs = vs_personal.similarity_search(query, k=3, filter=search_filter)
196
  general_docs = vs_general.similarity_search(query, k=3, filter=search_filter)
197
  else:
198
+ # If no filters, perform standard semantic search on both
199
+ retriever_personal = vs_personal.as_retriever(search_kwargs={"k": 3})
200
+ retriever_general = vs_general.as_retriever(search_kwargs={"k": 3})
201
+ personal_docs = retriever_personal.invoke(query)
202
  general_docs = retriever_general.invoke(query)
203
+
204
+ personal_context = _format_docs(personal_docs, "(No relevant personal memories found.)")
205
  general_context = _format_docs(general_docs, "(No general guidance found.)")
206
 
207
+ # Determine emotion for the response guidelines
208
  first_emotion = None
 
209
  all_docs = personal_docs + general_docs
210
  for doc in all_docs:
211
  if "emotion" in doc.metadata and doc.metadata["emotion"]:
212
  emotion_data = doc.metadata["emotion"]
213
+ if isinstance(emotion_data, list): first_emotion = emotion_data[0]
214
+ else: first_emotion = emotion_data
 
 
215
  if first_emotion: break
216
 
217
  emotions_context = render_emotion_guidelines(first_emotion or emotion_tag)
218
 
219
+ # Assemble and Call the LLM
220
+ is_tagged_scenario = (scenario_tag and scenario_tag != "None") or (emotion_tag and emotion_tag != "None") or (first_emotion is not None)
221
  template = ANSWER_TEMPLATE_ADQ if is_tagged_scenario else ANSWER_TEMPLATE_CALM
222
 
 
223
  if template == ANSWER_TEMPLATE_ADQ:
224
+ user_prompt = template.format(general_context=general_context, personal_context=personal_context, question=query, scenario_tag=scenario_tag, emotions_context=emotions_context, role=role, language=language)
225
+ else:
 
 
 
 
 
 
 
 
226
  combined_context = f"General Guidance:\n{general_context}\n\nPersonal Memories:\n{personal_context}"
227
  user_prompt = template.format(context=combined_context, question=query, language=language)
228
 
229
+ system_message = SYSTEM_TEMPLATE.format(tone=tone, language=language, patient_name=patient_name or "the patient", caregiver_name=caregiver_name or "the caregiver", guardrails=SAFETY_GUARDRAILS)
 
 
 
230
 
231
  messages = [{"role": "system", "content": system_message}]
232
  messages.extend(chat_history)
 
243
  return _answer_fn
244
 
245
  def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
246
+ if not callable(chain): return {"answer": "[Error: RAG chain is not callable]", "sources": []}
247
+ chat_history, scenario_tag, emotion_tag = kwargs.get("chat_history", []), kwargs.get("scenario_tag"), kwargs.get("emotion_tag")
 
 
 
 
 
248
  try:
249
  return chain(question, chat_history=chat_history, scenario_tag=scenario_tag, emotion_tag=emotion_tag)
250
  except Exception as e:
 
254
  # -----------------------------
255
  # TTS & Transcription
256
  # -----------------------------
 
 
257
  def synthesize_tts(text: str, lang: str = "en"):
258
  if not text or gTTS is None: return None
259
  try:
 
269
  client = _openai_client()
270
  if not client:
271
  return "[Transcription failed: API key not configured]"
 
272
  api_args = {"model": "whisper-1"}
273
  if lang and lang != "auto":
274
  api_args["language"] = lang
 
275
  with open(filepath, "rb") as audio_file:
276
  transcription = client.audio.transcriptions.create(file=audio_file, **api_args)
277
  return transcription.text
278
+