KeenWoo commited on
Commit
758f234
·
verified ·
1 Parent(s): 049befa

Delete alz_companion/agent.py

Browse files
Files changed (1) hide show
  1. alz_companion/agent.py +0 -342
alz_companion/agent.py DELETED
@@ -1,342 +0,0 @@
1
- from __future__ import annotations
2
- import os
3
- import json
4
- import base64
5
- import time
6
- import tempfile
7
- from typing import List, Dict, Any, Optional
8
- import re
9
-
10
- # OpenAI for LLM (optional)
11
- try:
12
- from openai import OpenAI
13
- except Exception:
14
- OpenAI = None
15
-
16
- # LangChain & RAG
17
- from langchain.schema import Document
18
- from langchain_community.vectorstores import FAISS
19
- from langchain_community.embeddings import HuggingFaceEmbeddings
20
-
21
- # TTS
22
- try:
23
- from gtts import gTTS
24
- except Exception:
25
- gTTS = None
26
-
27
- # Import all necessary prompts from the final prompts.py
28
- from .prompts import (
29
- SYSTEM_TEMPLATE, ANSWER_TEMPLATE_CALM, ANSWER_TEMPLATE_ADQ,
30
- SAFETY_GUARDRAILS, RISK_FOOTER, render_emotion_guidelines,
31
- GOAL_ROUTER_PROMPT,
32
- NLU_EXAMPLES,
33
- SPECIALIST_CLASSIFICATION_PROMPT,
34
- ANSWER_TEMPLATE_FACTUAL,
35
- ANSWER_TEMPLATE_GENERAL_KNOWLEDGE,
36
- ANSWER_TEMPLATE_GENERAL,
37
- ROUTER_PROMPT as RAG_ROUTER_PROMPT # Alias the RAG router to avoid name conflicts
38
- )
39
-
40
-
41
- # -----------------------------
42
- # Multimodal Processing Functions
43
- # -----------------------------
44
-
45
- def _openai_client() -> Optional[OpenAI]:
46
- api_key = os.getenv("OPENAI_API_KEY", "").strip()
47
- return OpenAI(api_key=api_key) if api_key and OpenAI else None
48
-
49
- def describe_image(image_path: str) -> str:
50
- """Uses a vision model to describe an image for context."""
51
- client = _openai_client()
52
- if not client:
53
- return "(Image description failed: OpenAI API key not configured.)"
54
-
55
- try:
56
- extension = os.path.splitext(image_path)[1].lower()
57
- mime_type = "image/jpeg"
58
- if extension == ".png": mime_type = "image/png"
59
- elif extension in [".jpg", ".jpeg"]: mime_type = "image/jpeg"
60
-
61
- with open(image_path, "rb") as image_file:
62
- base64_image = base64.b64encode(image_file.read()).decode('utf-8')
63
-
64
- response = client.chat.completions.create(
65
- model="gpt-4o",
66
- messages=[{"role": "user", "content": [
67
- {"type": "text", "text": "Describe this image concisely for a memory journal. Focus on people, places, and key objects."},
68
- {"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{base64_image}"}}
69
- ]}],
70
- max_tokens=100,
71
- )
72
- return response.choices[0].message.content or "No description available."
73
- except Exception as e:
74
- return f"[Image description error: {e}]"
75
-
76
- # -----------------------------
77
- # NLU Classification Function (Router/Specialist Model)
78
- # -----------------------------
79
- def detect_tags_from_query(
80
- query: str,
81
- behavior_options: list, emotion_options: list, topic_options: list, context_options: list,
82
- example_retriever: Optional[FAISS] = None, **kwargs
83
- ) -> Dict[str, Any]:
84
- """Uses a two-step Router/Specialist model to classify the user's query."""
85
-
86
- print("\n--- RUNNING NLU V3 (Router/Specialist Model) ---")
87
-
88
- router_prompt = GOAL_ROUTER_PROMPT.format(query=query)
89
- primary_goal = call_llm([{"role": "user", "content": router_prompt}], temperature=0).strip()
90
- print(f"Detected Primary Goal: {primary_goal}")
91
-
92
- result_dict = {"detected_behaviors": [], "detected_emotion": "None", "detected_topic": "None", "detected_contexts": []}
93
-
94
- if primary_goal not in NLU_EXAMPLES:
95
- print("Goal is not Practical Planning or Emotional Support. Skipping specialist.")
96
- return result_dict
97
-
98
- relevant_examples = NLU_EXAMPLES[primary_goal]
99
- examples_str = "\n\n".join([
100
- f"User Query: \"{ex['query']}\"\nJSON Response:\n{json.dumps(ex['json'], indent=4)}"
101
- for ex in relevant_examples
102
- ])
103
- print(f"Selected {len(relevant_examples)} examples for goal: '{primary_goal}'")
104
-
105
- specialist_prompt = SPECIALIST_CLASSIFICATION_PROMPT.format(examples=examples_str, query=query)
106
-
107
- messages = [{"role": "user", "content": specialist_prompt}]
108
- response_str = call_llm(messages, temperature=0.1)
109
-
110
- print(f"\n--- NLU Full Response ---\n{response_str}\n-----------------------\n")
111
-
112
- try:
113
- start_brace = response_str.find('{')
114
- end_brace = response_str.rfind('}')
115
- if start_brace != -1 and end_brace != -1 and end_brace > start_brace:
116
- json_str = response_str[start_brace : end_brace + 1]
117
- result = json.loads(json_str)
118
-
119
- result_dict["detected_behaviors"] = [b for b in result.get("detected_behaviors", []) if b in behavior_options]
120
- result_dict["detected_emotion"] = result.get("detected_emotion") if result.get("detected_emotion") in emotion_options else "None"
121
- result_dict["detected_topic"] = result.get("detected_topic") if result.get("detected_topic") in topic_options else "None"
122
- result_dict["detected_contexts"] = [c for c in result.get("detected_contexts", []) if c in context_options]
123
-
124
- return result_dict
125
- except (json.JSONDecodeError, AttributeError) as e:
126
- print(f"ERROR parsing Specialist JSON: {e}")
127
- return result_dict
128
-
129
- # -----------------------------
130
- # Embeddings & VectorStore
131
- # -----------------------------
132
-
133
- def _default_embeddings():
134
- model_name = os.getenv("EMBEDDINGS_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
135
- return HuggingFaceEmbeddings(model_name=model_name)
136
-
137
- def build_or_load_vectorstore(docs: List[Document], index_path: str, is_personal: bool = False) -> FAISS:
138
- os.makedirs(os.path.dirname(index_path), exist_ok=True)
139
- if os.path.isdir(index_path) and os.path.exists(os.path.join(index_path, "index.faiss")):
140
- try:
141
- return FAISS.load_local(index_path, _default_embeddings(), allow_dangerous_deserialization=True)
142
- except Exception as e:
143
- print(f"Could not load existing vector store at {index_path}: {e}")
144
-
145
- if is_personal and not docs:
146
- docs = [Document(page_content="(This is the start of the personal memory journal.)", metadata={"source": "placeholder"})]
147
-
148
- vs = FAISS.from_documents(docs, _default_embeddings())
149
- vs.save_local(index_path)
150
- return vs
151
-
152
- def texts_from_jsonl(path: str) -> List[Document]:
153
- out: List[Document] = []
154
- try:
155
- with open(path, "r", encoding="utf-8") as f:
156
- for i, line in enumerate(f):
157
- line = line.strip()
158
- if not line: continue
159
- obj = json.loads(line)
160
- txt = obj.get("text") or ""
161
- if not isinstance(txt, str) or not txt.strip(): continue
162
- md = {"source": os.path.basename(path), "chunk": i}
163
- for k in ("behaviors", "emotion", "topic_tags", "context_tags"):
164
- if k in obj: md[k] = obj[k]
165
- out.append(Document(page_content=txt, metadata=md))
166
- except Exception as e:
167
- print(f"Error reading from JSONL file {path}: {e}")
168
- return []
169
- return out
170
-
171
- def bootstrap_vectorstore(sample_paths: List[str] | None = None, index_path: str = "data/faiss_index") -> FAISS:
172
- docs: List[Document] = []
173
- for p in (sample_paths or []):
174
- try:
175
- if p.lower().endswith(".jsonl"):
176
- docs.extend(texts_from_jsonl(p))
177
- else:
178
- with open(p, "r", encoding="utf-8", errors="ignore") as fh:
179
- docs.append(Document(page_content=fh.read(), metadata={"source": os.path.basename(p)}))
180
- except Exception:
181
- continue
182
- if not docs:
183
- docs = [Document(page_content="(empty index)", metadata={"source": "placeholder"})]
184
- return build_or_load_vectorstore(docs, index_path=index_path)
185
-
186
- # -----------------------------
187
- # LLM Call
188
- # -----------------------------
189
- def call_llm(messages: List[Dict[str, str]], temperature: float = 0.6, stop: Optional[List[str]] = None) -> str:
190
- """Call OpenAI Chat Completions if available; else return a fallback."""
191
- client = _openai_client()
192
- model = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
193
- if not client:
194
- return "(Offline Mode: OpenAI API key not configured.)"
195
- try:
196
- api_args = {
197
- "model": model, "messages": messages,
198
- "temperature": float(temperature if temperature is not None else 0.6)
199
- }
200
- if stop: api_args["stop"] = stop
201
- resp = client.chat.completions.create(**api_args)
202
- return (resp.choices[0].message.content or "").strip()
203
- except Exception as e:
204
- return f"[LLM API Error: {e}]"
205
-
206
- # -----------------------------
207
- # Prompting & RAG Chain
208
- # -----------------------------
209
-
210
- def _format_sources(docs: List[Document]) -> List[str]:
211
- return list(set(d.metadata.get("source", "unknown") for d in docs))
212
-
213
- def make_rag_chain(
214
- vs_general: FAISS,
215
- vs_personal: FAISS,
216
- *,
217
- role: str = "patient",
218
- temperature: float = 0.6,
219
- language: str = "English",
220
- patient_name: str = "the patient",
221
- caregiver_name: str = "the caregiver",
222
- tone: str = "warm",
223
- # Custom prompt arguments
224
- system_template: str,
225
- factual_template: str,
226
- general_knowledge_template: str,
227
- general_conversation_template: str,
228
- **kwargs
229
- ):
230
- """Returns a callable that performs the complete, RAG process."""
231
-
232
- def _format_docs(docs: List[Document], default_msg: str) -> str:
233
- if not docs: return default_msg
234
- unique_docs = {doc.page_content: doc for doc in docs}.values()
235
- return "\n".join([f"- {d.page_content.strip()}" for d in unique_docs])
236
-
237
- def _answer_fn(query: str, chat_history: List[Dict[str, str]], scenario_tag: Optional[str] = None, emotion_tag: Optional[str] = None, topic_tag: Optional[str] = None) -> Dict[str, Any]:
238
-
239
- router_messages = [{"role": "user", "content": RAG_ROUTER_PROMPT.format(query=query)}]
240
- query_type = call_llm(router_messages, temperature=0.0).strip().lower()
241
-
242
- system_message = system_template.format(tone=tone, language=language, patient_name=patient_name or "the patient", caregiver_name=caregiver_name or "the caregiver", guardrails=SAFETY_GUARDRAILS)
243
- messages = [{"role": "system", "content": system_message}]
244
- messages.extend(chat_history)
245
-
246
- if "factual_question" in query_type:
247
- retriever_personal = vs_personal.as_retriever(search_kwargs={"k": 2})
248
- retriever_general = vs_general.as_retriever(search_kwargs={"k": 2})
249
- all_docs = retriever_personal.invoke(query) + retriever_general.invoke(query)
250
- context = _format_docs(all_docs, "(No relevant information found in the memory journal.)")
251
- user_prompt = factual_template.format(context=context, query=query, language=language)
252
- messages.append({"role": "user", "content": user_prompt})
253
- answer = call_llm(messages, temperature=temperature)
254
- return {"answer": answer, "sources": _format_sources(all_docs)}
255
-
256
- elif "general_knowledge_question" in query_type:
257
- user_prompt = general_knowledge_template.format(query=query, language=language)
258
- messages.append({"role": "user", "content": user_prompt})
259
- answer = call_llm(messages, temperature=temperature)
260
- return {"answer": answer, "sources": ["General Knowledge"]}
261
-
262
- elif "general_conversation" in query_type:
263
- user_prompt = general_conversation_template.format(query=query, language=language)
264
- messages.append({"role": "user", "content": user_prompt})
265
- answer = call_llm(messages, temperature=temperature)
266
- return {"answer": answer, "sources": []}
267
-
268
- else: # Default to caregiving scenario
269
- search_filter = {}
270
- # Note: scenario_tag from NLU is a list, but RAG filter expects a single string for now.
271
- if isinstance(scenario_tag, list) and scenario_tag:
272
- search_filter["behaviors"] = scenario_tag[0].lower()
273
- elif isinstance(scenario_tag, str):
274
- search_filter["behaviors"] = scenario_tag.lower()
275
-
276
- if emotion_tag: search_filter["emotion"] = emotion_tag.lower()
277
- if topic_tag: search_filter["topic_tags"] = topic_tag.lower()
278
-
279
- personal_docs = vs_personal.similarity_search(query, k=3, filter=search_filter if search_filter else None)
280
- general_docs = vs_general.similarity_search(query, k=3, filter=search_filter if search_filter else None)
281
-
282
- personal_context = _format_docs(personal_docs, "(No relevant personal memories found.)")
283
- general_context = _format_docs(general_docs, "(No general guidance found.)")
284
-
285
- all_docs = personal_docs + general_docs
286
- first_emotion = next((doc.metadata.get("emotion") for doc in all_docs if doc.metadata.get("emotion")), emotion_tag)
287
- emotions_context = render_emotion_guidelines(first_emotion)
288
-
289
- is_tagged_scenario = scenario_tag or emotion_tag or first_emotion
290
- template = ANSWER_TEMPLATE_ADQ if is_tagged_scenario else ANSWER_TEMPLATE_CALM
291
-
292
- display_scenario_tag = scenario_tag[0] if isinstance(scenario_tag, list) and scenario_tag else scenario_tag
293
-
294
- if template == ANSWER_TEMPLATE_ADQ:
295
- user_prompt = template.format(general_context=general_context, personal_context=personal_context, query=query, scenario_tag=display_scenario_tag, emotions_context=emotions_context, role=role, language=language)
296
- else:
297
- combined_context = f"General Guidance:\n{general_context}\n\nPersonal Memories:\n{personal_context}"
298
- user_prompt = template.format(context=combined_context, query=query, language=language)
299
-
300
- messages.append({"role": "user", "content": user_prompt})
301
- answer = call_llm(messages, temperature=temperature)
302
-
303
- high_risk_scenarios = ["exit_seeking", "wandering", "elopement"]
304
- if display_scenario_tag and display_scenario_tag.lower() in high_risk_scenarios:
305
- answer += f"\n\n---\n{RISK_FOOTER}"
306
-
307
- return {"answer": answer, "sources": _format_sources(all_docs)}
308
-
309
- return _answer_fn
310
-
311
- def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
312
- if not callable(chain): return {"answer": "[Error: RAG chain is not callable]", "sources": []}
313
- try:
314
- return chain(question, **kwargs)
315
- except Exception as e:
316
- print(f"ERROR in answer_query: {e}")
317
- return {"answer": f"[Error executing chain: {e}]", "sources": []}
318
-
319
- # -----------------------------
320
- # TTS & Transcription
321
- # -----------------------------
322
- def synthesize_tts(text: str, lang: str = "en"):
323
- if not text or gTTS is None: return None
324
- try:
325
- fd, path = tempfile.mkstemp(suffix=".mp3")
326
- os.close(fd)
327
- tts = gTTS(text=text, lang=(lang or "en"))
328
- tts.save(path)
329
- return path
330
- except Exception:
331
- return None
332
-
333
- def transcribe_audio(filepath: str, lang: str = "en"):
334
- client = _openai_client()
335
- if not client:
336
- return "[Transcription failed: API key not configured]"
337
- api_args = {"model": "whisper-1"}
338
- if lang and lang != "auto":
339
- api_args["language"] = lang
340
- with open(filepath, "rb") as audio_file:
341
- transcription = client.audio.transcriptions.create(file=audio_file, **api_args)
342
- return transcription.text