KeenWoo commited on
Commit
4e0c1a7
·
verified ·
1 Parent(s): 758f234

Create agent.py

Browse files
Files changed (1) hide show
  1. alz_companion/agent.py +333 -0
alz_companion/agent.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ import json
4
+ import base64
5
+ import time
6
+ import tempfile
7
+ from typing import List, Dict, Any, Optional
8
+ import re
9
+
10
+ # OpenAI for LLM (optional)
11
+ try:
12
+ from openai import OpenAI
13
+ except Exception:
14
+ OpenAI = None
15
+
16
+ # LangChain & RAG
17
+ from langchain.schema import Document
18
+ from langchain_community.vectorstores import FAISS
19
+ from langchain_community.embeddings import HuggingFaceEmbeddings
20
+
21
+ # TTS
22
+ try:
23
+ from gtts import gTTS
24
+ except Exception:
25
+ gTTS = None
26
+
27
+ from .prompts import (
28
+ SYSTEM_TEMPLATE, ANSWER_TEMPLATE_CALM, ANSWER_TEMPLATE_ADQ,
29
+ SAFETY_GUARDRAILS, RISK_FOOTER, render_emotion_guidelines,
30
+ GOAL_ROUTER_PROMPT, SPECIALIST_CLASSIFICATION_PROMPT,
31
+ ANSWER_TEMPLATE_FACTUAL, ANSWER_TEMPLATE_GENERAL_KNOWLEDGE, ANSWER_TEMPLATE_GENERAL,
32
+ ROUTER_PROMPT as RAG_ROUTER_PROMPT
33
+ )
34
+
35
+
36
+ # -----------------------------
37
+ # Multimodal Processing Functions
38
+ # -----------------------------
39
+
40
+ def _openai_client() -> Optional[OpenAI]:
41
+ api_key = os.getenv("OPENAI_API_KEY", "").strip()
42
+ return OpenAI(api_key=api_key) if api_key and OpenAI else None
43
+
44
+ def describe_image(image_path: str) -> str:
45
+ """Uses a vision model to describe an image for context."""
46
+ client = _openai_client()
47
+ if not client:
48
+ return "(Image description failed: OpenAI API key not configured.)"
49
+
50
+ try:
51
+ extension = os.path.splitext(image_path)[1].lower()
52
+ mime_type = "image/jpeg"
53
+ if extension == ".png": mime_type = "image/png"
54
+ elif extension in [".jpg", ".jpeg"]: mime_type = "image/jpeg"
55
+
56
+ with open(image_path, "rb") as image_file:
57
+ base64_image = base64.b64encode(image_file.read()).decode('utf-8')
58
+
59
+ response = client.chat.completions.create(
60
+ model="gpt-4o",
61
+ messages=[{"role": "user", "content": [
62
+ {"type": "text", "text": "Describe this image concisely for a memory journal. Focus on people, places, and key objects."},
63
+ {"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{base64_image}"}}
64
+ ]}],
65
+ max_tokens=100,
66
+ )
67
+ return response.choices[0].message.content or "No description available."
68
+ except Exception as e:
69
+ return f"[Image description error: {e}]"
70
+
71
+ # -----------------------------
72
+ # NLU Classification Function (Router/Specialist Model)
73
+ # -----------------------------
74
+ def detect_tags_from_query(
75
+ query: str,
76
+ behavior_options: list, emotion_options: list, topic_options: list, context_options: list,
77
+ example_retriever: Optional[FAISS] = None, **kwargs
78
+ ) -> Dict[str, Any]:
79
+ """Uses a dynamic few-shot process to classify the user's query."""
80
+
81
+ print("\n--- RUNNING NLU V4 (Dynamic Few-Shot) ---")
82
+
83
+ examples_str = "No examples provided."
84
+ if example_retriever:
85
+ try:
86
+ similar_docs = example_retriever.invoke(query)
87
+ examples = [doc.metadata["full_fixture"] for doc in similar_docs]
88
+ examples_str = "\n\n".join([
89
+ f"User Query: \"{ex['turns'][0]['text']}\"\n<thinking>\n{ex['expected'].get('reasoning', 'No reasoning provided.')}\n</thinking>\nJSON Response:\n{json.dumps(ex['expected'], indent=4)}"
90
+ for ex in examples
91
+ ])
92
+ print(f"Dynamically retrieved {len(examples)} examples for the prompt.")
93
+ except Exception as e:
94
+ print(f"Could not retrieve examples: {e}")
95
+
96
+ specialist_prompt = SPECIALIST_CLASSIFICATION_PROMPT.format(
97
+ behavior_options=", ".join(f'"{opt}"' for opt in behavior_options),
98
+ emotion_options=", ".join(f'"{opt}"' for opt in emotion_options),
99
+ topic_options=", ".join(f'"{opt}"' for opt in topic_options),
100
+ context_options=", ".join(f'"{opt}"' for opt in context_options),
101
+ examples=examples_str,
102
+ query=query
103
+ )
104
+
105
+ messages = [{"role": "user", "content": specialist_prompt}]
106
+ response_str = call_llm(messages, temperature=0.1)
107
+
108
+ print(f"\n--- NLU Full Response ---\n{response_str}\n-----------------------\n")
109
+
110
+ result_dict = {"detected_behaviors": [], "detected_emotion": "None", "detected_topic": "None", "detected_contexts": []}
111
+
112
+ try:
113
+ start_brace = response_str.find('{')
114
+ end_brace = response_str.rfind('}')
115
+ if start_brace != -1 and end_brace != -1 and end_brace > start_brace:
116
+ json_str = response_str[start_brace : end_brace + 1]
117
+ result = json.loads(json_str)
118
+
119
+ result_dict["detected_behaviors"] = [b for b in result.get("detected_behaviors", []) if b in behavior_options]
120
+ result_dict["detected_emotion"] = result.get("detected_emotion") if result.get("detected_emotion") in emotion_options else "None"
121
+ result_dict["detected_topic"] = result.get("detected_topic") if result.get("detected_topic") in topic_options else "None"
122
+ result_dict["detected_contexts"] = [c for c in result.get("detected_contexts", []) if c in context_options]
123
+
124
+ return result_dict
125
+ except (json.JSONDecodeError, AttributeError) as e:
126
+ print(f"ERROR parsing Specialist JSON: {e}")
127
+ return result_dict
128
+
129
+ # -----------------------------
130
+ # Embeddings & VectorStore
131
+ # -----------------------------
132
+
133
+ def _default_embeddings():
134
+ model_name = os.getenv("EMBEDDINGS_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
135
+ return HuggingFaceEmbeddings(model_name=model_name)
136
+
137
+ def build_or_load_vectorstore(docs: List[Document], index_path: str, is_personal: bool = False) -> FAISS:
138
+ os.makedirs(os.path.dirname(index_path), exist_ok=True)
139
+ if os.path.isdir(index_path) and os.path.exists(os.path.join(index_path, "index.faiss")):
140
+ try:
141
+ return FAISS.load_local(index_path, _default_embeddings(), allow_dangerous_deserialization=True)
142
+ except Exception as e:
143
+ print(f"Could not load existing vector store at {index_path}: {e}")
144
+
145
+ if is_personal and not docs:
146
+ docs = [Document(page_content="(This is the start of the personal memory journal.)", metadata={"source": "placeholder"})]
147
+
148
+ vs = FAISS.from_documents(docs, _default_embeddings())
149
+ vs.save_local(index_path)
150
+ return vs
151
+
152
+ def texts_from_jsonl(path: str) -> List[Document]:
153
+ out: List[Document] = []
154
+ try:
155
+ with open(path, "r", encoding="utf-8") as f:
156
+ for i, line in enumerate(f):
157
+ line = line.strip()
158
+ if not line: continue
159
+ obj = json.loads(line)
160
+ txt = obj.get("text") or ""
161
+ if not isinstance(txt, str) or not txt.strip(): continue
162
+ md = {"source": os.path.basename(path), "chunk": i}
163
+ for k in ("behaviors", "emotion", "topic_tags", "context_tags"):
164
+ if k in obj: md[k] = obj[k]
165
+ out.append(Document(page_content=txt, metadata=md))
166
+ except Exception as e:
167
+ print(f"Error reading from JSONL file {path}: {e}")
168
+ return []
169
+ return out
170
+
171
+ def bootstrap_vectorstore(sample_paths: List[str] | None = None, index_path: str = "data/faiss_index") -> FAISS:
172
+ docs: List[Document] = []
173
+ for p in (sample_paths or []):
174
+ try:
175
+ if p.lower().endswith(".jsonl"):
176
+ docs.extend(texts_from_jsonl(p))
177
+ else:
178
+ with open(p, "r", encoding="utf-8", errors="ignore") as fh:
179
+ docs.append(Document(page_content=fh.read(), metadata={"source": os.path.basename(p)}))
180
+ except Exception:
181
+ continue
182
+ if not docs:
183
+ docs = [Document(page_content="(empty index)", metadata={"source": "placeholder"})]
184
+ return build_or_load_vectorstore(docs, index_path=index_path)
185
+
186
+ # -----------------------------
187
+ # LLM Call
188
+ # -----------------------------
189
+ def call_llm(messages: List[Dict[str, str]], temperature: float = 0.6, stop: Optional[List[str]] = None) -> str:
190
+ client = _openai_client()
191
+ model = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
192
+ if not client:
193
+ return "(Offline Mode: OpenAI API key not configured.)"
194
+ try:
195
+ api_args = {
196
+ "model": model, "messages": messages,
197
+ "temperature": float(temperature if temperature is not None else 0.6)
198
+ }
199
+ if stop: api_args["stop"] = stop
200
+ resp = client.chat.completions.create(**api_args)
201
+ return (resp.choices[0].message.content or "").strip()
202
+ except Exception as e:
203
+ return f"[LLM API Error: {e}]"
204
+
205
+ # -----------------------------
206
+ # Prompting & RAG Chain
207
+ # -----------------------------
208
+
209
+ def _format_sources(docs: List[Document]) -> List[str]:
210
+ return list(set(d.metadata.get("source", "unknown") for d in docs))
211
+
212
+ def make_rag_chain(
213
+ vs_general: FAISS,
214
+ vs_personal: FAISS,
215
+ *,
216
+ role: str = "patient",
217
+ temperature: float = 0.6,
218
+ language: str = "English",
219
+ patient_name: str = "the patient",
220
+ caregiver_name: str = "the caregiver",
221
+ tone: str = "warm",
222
+ **kwargs
223
+ ):
224
+ """Returns a callable that performs the complete, RAG process."""
225
+
226
+ def _format_docs(docs: List[Document], default_msg: str) -> str:
227
+ if not docs: return default_msg
228
+ unique_docs = {doc.page_content: doc for doc in docs}.values()
229
+ return "\n".join([f"- {d.page_content.strip()}" for d in unique_docs])
230
+
231
+ def _answer_fn(query: str, chat_history: List[Dict[str, str]], scenario_tag: Optional[str] = None, emotion_tag: Optional[str] = None, topic_tag: Optional[str] = None) -> Dict[str, Any]:
232
+
233
+ router_messages = [{"role": "user", "content": RAG_ROUTER_PROMPT.format(query=query)}]
234
+ query_type = call_llm(router_messages, temperature=0.0).strip().lower()
235
+
236
+ system_message = SYSTEM_TEMPLATE.format(tone=tone, language=language, patient_name=patient_name or "the patient", caregiver_name=caregiver_name or "the caregiver", guardrails=SAFETY_GUARDRAILS)
237
+ messages = [{"role": "system", "content": system_message}]
238
+ messages.extend(chat_history)
239
+
240
+ if "factual_question" in query_type:
241
+ retriever_personal = vs_personal.as_retriever(search_kwargs={"k": 2})
242
+ retriever_general = vs_general.as_retriever(search_kwargs={"k": 2})
243
+ all_docs = retriever_personal.invoke(query) + retriever_general.invoke(query)
244
+ context = _format_docs(all_docs, "(No relevant information found in the memory journal.)")
245
+ user_prompt = ANSWER_TEMPLATE_FACTUAL.format(context=context, query=query, language=language)
246
+ messages.append({"role": "user", "content": user_prompt})
247
+ answer = call_llm(messages, temperature=temperature)
248
+ return {"answer": answer, "sources": _format_sources(all_docs)}
249
+
250
+ elif "general_knowledge_question" in query_type:
251
+ user_prompt = ANSWER_TEMPLATE_GENERAL_KNOWLEDGE.format(query=query, language=language)
252
+ messages.append({"role": "user", "content": user_prompt})
253
+ answer = call_llm(messages, temperature=temperature)
254
+ return {"answer": answer, "sources": ["General Knowledge"]}
255
+
256
+ elif "general_conversation" in query_type:
257
+ user_prompt = ANSWER_TEMPLATE_GENERAL.format(query=query, language=language)
258
+ messages.append({"role": "user", "content": user_prompt})
259
+ answer = call_llm(messages, temperature=temperature)
260
+ return {"answer": answer, "sources": []}
261
+
262
+ else: # Default to caregiving scenario
263
+ search_filter = {}
264
+ if scenario_tag and isinstance(scenario_tag, list):
265
+ search_filter["behaviors"] = [s.lower() for s in scenario_tag]
266
+
267
+ if emotion_tag and emotion_tag != "None": search_filter["emotion"] = emotion_tag.lower()
268
+ if topic_tag and topic_tag != "None": search_filter["topic_tags"] = topic_tag.lower()
269
+
270
+ personal_docs = vs_personal.similarity_search(query, k=3, filter=search_filter if search_filter else None)
271
+ general_docs = vs_general.similarity_search(query, k=3, filter=search_filter if search_filter else None)
272
+
273
+ personal_context = _format_docs(personal_docs, "(No relevant personal memories found.)")
274
+ general_context = _format_docs(general_docs, "(No general guidance found.)")
275
+
276
+ all_docs = personal_docs + general_docs
277
+ first_emotion = next((doc.metadata.get("emotion") for doc in all_docs if doc.metadata.get("emotion")), emotion_tag)
278
+ emotions_context = render_emotion_guidelines(first_emotion)
279
+
280
+ is_tagged_scenario = scenario_tag or emotion_tag or first_emotion
281
+ template = ANSWER_TEMPLATE_ADQ if is_tagged_scenario else ANSWER_TEMPLATE_CALM
282
+
283
+ display_scenario_tag = (scenario_tag[0] if isinstance(scenario_tag, list) and scenario_tag else scenario_tag) or ""
284
+
285
+ if template == ANSWER_TEMPLATE_ADQ:
286
+ user_prompt = template.format(general_context=general_context, personal_context=personal_context, query=query, scenario_tag=display_scenario_tag, emotions_context=emotions_context, role=role, language=language)
287
+ else:
288
+ combined_context = f"General Guidance:\n{general_context}\n\nPersonal Memories:\n{personal_context}"
289
+ user_prompt = template.format(context=combined_context, query=query, language=language)
290
+
291
+ messages.append({"role": "user", "content": user_prompt})
292
+ answer = call_llm(messages, temperature=temperature)
293
+
294
+ high_risk_scenarios = ["exit_seeking", "wandering", "elopement"]
295
+ if display_scenario_tag and display_scenario_tag.lower() in high_risk_scenarios:
296
+ answer += f"\n\n---\n{RISK_FOOTER}"
297
+
298
+ return {"answer": answer, "sources": _format_sources(all_docs)}
299
+
300
+ return _answer_fn
301
+
302
+ def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
303
+ if not callable(chain): return {"answer": "[Error: RAG chain is not callable]", "sources": []}
304
+ try:
305
+ return chain(question, **kwargs)
306
+ except Exception as e:
307
+ print(f"ERROR in answer_query: {e}")
308
+ return {"answer": f"[Error executing chain: {e}]", "sources": []}
309
+
310
+ # -----------------------------
311
+ # TTS & Transcription
312
+ # -----------------------------
313
+ def synthesize_tts(text: str, lang: str = "en"):
314
+ if not text or gTTS is None: return None
315
+ try:
316
+ fd, path = tempfile.mkstemp(suffix=".mp3")
317
+ os.close(fd)
318
+ tts = gTTS(text=text, lang=(lang or "en"))
319
+ tts.save(path)
320
+ return path
321
+ except Exception:
322
+ return None
323
+
324
+ def transcribe_audio(filepath: str, lang: str = "en"):
325
+ client = _openai_client()
326
+ if not client:
327
+ return "[Transcription failed: API key not configured]"
328
+ api_args = {"model": "whisper-1"}
329
+ if lang and lang != "auto":
330
+ api_args["language"] = lang
331
+ with open(filepath, "rb") as audio_file:
332
+ transcription = client.audio.transcriptions.create(file=audio_file, **api_args)
333
+ return transcription.text