KeenWoo commited on
Commit
f367bf7
·
verified ·
1 Parent(s): ba9ac68

Delete alz_companion/agent.py

Browse files
Files changed (1) hide show
  1. alz_companion/agent.py +0 -340
alz_companion/agent.py DELETED
@@ -1,340 +0,0 @@
1
- from __future__ import annotations
2
- import os
3
- import json
4
- import base64
5
- import time
6
- import tempfile
7
- from typing import List, Dict, Any, Optional
8
-
9
- # OpenAI for LLM (optional)
10
- try:
11
- from openai import OpenAI
12
- except Exception: # pragma: no cover
13
- OpenAI = None # type: ignore
14
-
15
- # LangChain & RAG
16
- from langchain.schema import Document
17
- from langchain_community.vectorstores import FAISS
18
- from langchain_community.embeddings import HuggingFaceEmbeddings
19
-
20
- # TTS
21
- try:
22
- from gtts import gTTS
23
- except Exception: # pragma: no cover
24
- gTTS = None # type: ignore
25
-
26
-
27
- from .prompts import (
28
- SYSTEM_TEMPLATE, ANSWER_TEMPLATE_CALM, ANSWER_TEMPLATE_ADQ,
29
- SAFETY_GUARDRAILS, RISK_FOOTER, render_emotion_guidelines, CLASSIFICATION_PROMPT,
30
- # Add the new templates to the import list
31
- ROUTER_PROMPT,
32
- ANSWER_TEMPLATE_FACTUAL,
33
- ANSWER_TEMPLATE_GENERAL
34
- )
35
-
36
- # -----------------------------
37
- # Multimodal Processing Functions
38
- # -----------------------------
39
-
40
- def _openai_client() -> Optional[OpenAI]:
41
- api_key = os.getenv("OPENAI_API_KEY", "").strip()
42
- return OpenAI(api_key=api_key) if api_key and OpenAI else None
43
-
44
- # In agent.py
45
-
46
- def describe_image(image_path: str) -> str:
47
- """Uses a vision model to describe an image for context."""
48
- client = _openai_client()
49
- if not client:
50
- return "(Image description failed: OpenAI API key not configured.)"
51
-
52
- try:
53
- # --- FIX START ---
54
- # Determine the MIME type based on the file extension
55
- extension = os.path.splitext(image_path)[1].lower()
56
- if extension == ".png":
57
- mime_type = "image/png"
58
- elif extension in [".jpg", ".jpeg"]:
59
- mime_type = "image/jpeg"
60
- elif extension == ".gif":
61
- mime_type = "image/gif"
62
- elif extension == ".webp":
63
- mime_type = "image/webp"
64
- else:
65
- # Default to JPEG, but this handles the most common cases
66
- mime_type = "image/jpeg"
67
- # --- FIX END ---
68
-
69
- with open(image_path, "rb") as image_file:
70
- base64_image = base64.b64encode(image_file.read()).decode('utf-8')
71
-
72
- response = client.chat.completions.create(
73
- model="gpt-4o",
74
- messages=[
75
- {
76
- "role": "user",
77
- "content": [
78
- {"type": "text", "text": "Describe this image in a concise, factual way for a memory journal. Focus on people, places, and key objects. For example: 'A photo of John and Mary smiling on a bench at the park.'"},
79
- {
80
- "type": "image_url",
81
- # Use the dynamically determined MIME type
82
- "image_url": {"url": f"data:{mime_type};base64,{base64_image}"}
83
- }
84
- ],
85
- }
86
- ],
87
- max_tokens=100,
88
- )
89
- return response.choices[0].message.content or "No description available."
90
- except Exception as e:
91
- return f"[Image description error: {e}]"
92
-
93
- # -----------------------------
94
- # NLU Classification Function
95
- # -----------------------------
96
- def detect_tags_from_query(query: str, behavior_options: list, emotion_options: list) -> Dict[str, Optional[str]]:
97
- """Uses an LLM call to classify the user's query into a behavior and emotion tag."""
98
- behavior_str = ", ".join(f'"{opt}"' for opt in behavior_options if opt != "None")
99
- emotion_str = ", ".join(f'"{opt}"' for opt in emotion_options if opt != "None")
100
- prompt = CLASSIFICATION_PROMPT.format(behavior_options=behavior_str, emotion_options=emotion_str, query=query)
101
- messages = [{"role": "system", "content": "You are a helpful NLU classification assistant. Respond only with the JSON object requested."}, {"role": "user", "content": prompt}]
102
- response_str = call_llm(messages, temperature=0.1)
103
- try:
104
- clean_response = response_str.strip().replace("```json", "").replace("```", "")
105
- result = json.loads(clean_response)
106
- behavior = result.get("detected_behavior")
107
- emotion = result.get("detected_emotion")
108
- return {"detected_behavior": behavior if behavior in behavior_options else "None", "detected_emotion": emotion if emotion in emotion_options else "None"}
109
- except (json.JSONDecodeError, AttributeError):
110
- return {"detected_behavior": "None", "detected_emotion": "None"}
111
-
112
-
113
- # -----------------------------
114
- # Embeddings & VectorStore
115
- # -----------------------------
116
-
117
- def _default_embeddings():
118
- """Lightweight, widely available model."""
119
- model_name = os.getenv("EMBEDDINGS_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
120
- return HuggingFaceEmbeddings(model_name=model_name)
121
-
122
- def build_or_load_vectorstore(docs: List[Document], index_path: str, is_personal: bool = False) -> FAISS:
123
- os.makedirs(os.path.dirname(index_path), exist_ok=True)
124
- if os.path.isdir(index_path) and os.path.exists(os.path.join(index_path, "index.faiss")):
125
- try:
126
- return FAISS.load_local(index_path, _default_embeddings(), allow_dangerous_deserialization=True)
127
- except Exception:
128
- pass
129
-
130
- if is_personal and not docs:
131
- docs = [Document(page_content="(This is the start of the personal memory journal.)", metadata={"source": "placeholder"})]
132
-
133
- vs = FAISS.from_documents(docs, _default_embeddings())
134
- vs.save_local(index_path)
135
- return vs
136
-
137
- def texts_from_jsonl(path: str) -> List[Document]:
138
- out: List[Document] = []
139
- try:
140
- with open(path, "r", encoding="utf-8") as f:
141
- for i, line in enumerate(f):
142
- line = line.strip()
143
- if not line: continue
144
- obj = json.loads(line)
145
- txt = obj.get("text") or ""
146
- if not isinstance(txt, str) or not txt.strip(): continue
147
- md = {"source": os.path.basename(path), "chunk": i}
148
- for k in ("behaviors", "emotion"):
149
- if k in obj: md[k] = obj[k]
150
- out.append(Document(page_content=txt, metadata=md))
151
- except Exception:
152
- return []
153
- return out
154
-
155
- def bootstrap_vectorstore(sample_paths: List[str] | None = None, index_path: str = "data/faiss_index") -> FAISS:
156
- docs: List[Document] = []
157
- for p in (sample_paths or []):
158
- try:
159
- if p.lower().endswith(".jsonl"):
160
- docs.extend(texts_from_jsonl(p))
161
- else:
162
- with open(p, "r", encoding="utf-8", errors="ignore") as fh:
163
- docs.append(Document(page_content=fh.read(), metadata={"source": os.path.basename(p)}))
164
- except Exception:
165
- continue
166
- if not docs:
167
- docs = [Document(page_content="(empty index)", metadata={"source": "placeholder"})]
168
- return build_or_load_vectorstore(docs, index_path=index_path)
169
-
170
- # -----------------------------
171
- # LLM Call
172
- # -----------------------------
173
- def call_llm(messages: List[Dict[str, str]], temperature: float = 0.6) -> str:
174
- """Call OpenAI Chat Completions if available; else return a fallback."""
175
- client = _openai_client()
176
- model = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
177
- if not client:
178
- return "(Offline Mode: OpenAI API key not configured.)"
179
- try:
180
- # --- FIX START ---
181
- # Use a default temperature if the provided value is None
182
- temp_value = temperature if temperature is not None else 0.6
183
- # --- FIX END ---
184
-
185
- resp = client.chat.completions.create(model=model, messages=messages, temperature=float(temp_value))
186
- return (resp.choices[0].message.content or "").strip()
187
- except Exception as e:
188
- return f"[LLM API Error: {e}]"
189
-
190
- # -----------------------------
191
- # Prompting & RAG Chain
192
- # -----------------------------
193
-
194
- def _format_sources(docs: List[Document]) -> List[str]:
195
- return list(set(d.metadata.get("source", "unknown") for d in docs))
196
-
197
- # In agent.py, replace the existing make_rag_chain function with this new one to handle general & specific conversations .
198
- def make_rag_chain(
199
- vs_general: FAISS,
200
- vs_personal: FAISS,
201
- *,
202
- role: str = "patient",
203
- temperature: float = 0.6,
204
- language: str = "English",
205
- patient_name: str = "the patient",
206
- caregiver_name: str = "the caregiver",
207
- tone: str = "warm",
208
- ):
209
- """Returns a callable that performs the complete, intelligent RAG process."""
210
-
211
- def _format_docs(docs: List[Document], default_msg: str) -> str:
212
- if not docs: return default_msg
213
- return "\n".join([f"- {d.page_content.strip()}" for d in docs])
214
-
215
- # This is the core function that will be returned
216
- def _answer_fn(query: str, chat_history: List[Dict[str, str]], scenario_tag: Optional[str] = None, emotion_tag: Optional[str] = None) -> Dict[str, Any]:
217
-
218
- # --- NEW ROUTING LOGIC ---
219
- # 1. First, classify the user's intent to decide which path to take.
220
- router_messages = [{"role": "user", "content": ROUTER_PROMPT.format(query=query)}]
221
- query_type = call_llm(router_messages, temperature=0.0).strip().lower()
222
- print(f"Query classified as: {query_type}") # For debugging
223
-
224
- system_message = SYSTEM_TEMPLATE.format(tone=tone, language=language, patient_name=patient_name or "the patient", caregiver_name=caregiver_name or "the caregiver", guardrails=SAFETY_GUARDRAILS)
225
- messages = [{"role": "system", "content": system_message}]
226
- messages.extend(chat_history)
227
-
228
- # --- PATH 1: Factual Question ---
229
- if "factual_question" in query_type:
230
- # For factual questions, we search both knowledge bases and combine them.
231
- retriever_personal = vs_personal.as_retriever(search_kwargs={"k": 2})
232
- retriever_general = vs_general.as_retriever(search_kwargs={"k": 2})
233
- personal_docs = retriever_personal.invoke(query)
234
- general_docs = retriever_general.invoke(query)
235
-
236
- all_docs = personal_docs + general_docs
237
- context = _format_docs(all_docs, "(No relevant information found in the memory journal.)")
238
-
239
- user_prompt = ANSWER_TEMPLATE_FACTUAL.format(context=context, question=query, language=language)
240
- messages.append({"role": "user", "content": user_prompt})
241
-
242
- answer = call_llm(messages, temperature=temperature)
243
- return {"answer": answer, "sources": _format_sources(all_docs)}
244
-
245
- # --- PATH 2: General Conversation ---
246
- elif "general_conversation" in query_type:
247
- # For chit-chat, we don't need RAG. Just call the LLM with a simple prompt.
248
- user_prompt = ANSWER_TEMPLATE_GENERAL.format(question=query, language=language)
249
- messages.append({"role": "user", "content": user_prompt})
250
-
251
- answer = call_llm(messages, temperature=temperature)
252
- return {"answer": answer, "sources": []}
253
-
254
- # --- PATH 3: Caregiving Scenario (The Original Logic) ---
255
- else: # Default to the original caregiving logic
256
- search_filter = {}
257
- if scenario_tag and scenario_tag != "None":
258
- search_filter["behaviors"] = scenario_tag.lower()
259
- if emotion_tag and emotion_tag != "None":
260
- search_filter["emotion"] = emotion_tag.lower()
261
-
262
- if search_filter:
263
- personal_docs = vs_personal.similarity_search(query, k=3, filter=search_filter)
264
- general_docs = vs_general.similarity_search(query, k=3, filter=search_filter)
265
- else:
266
- retriever_personal = vs_personal.as_retriever(search_kwargs={"k": 3})
267
- retriever_general = vs_general.as_retriever(search_kwargs={"k": 3})
268
- personal_docs = retriever_personal.invoke(query)
269
- general_docs = retriever_general.invoke(query)
270
-
271
- personal_context = _format_docs(personal_docs, "(No relevant personal memories found.)")
272
- general_context = _format_docs(general_docs, "(No general guidance found.)")
273
-
274
- first_emotion = None
275
- all_docs_care = personal_docs + general_docs
276
- for doc in all_docs_care:
277
- if "emotion" in doc.metadata and doc.metadata["emotion"]:
278
- emotion_data = doc.metadata["emotion"]
279
- if isinstance(emotion_data, list): first_emotion = emotion_data[0]
280
- else: first_emotion = emotion_data
281
- if first_emotion: break
282
-
283
- emotions_context = render_emotion_guidelines(first_emotion or emotion_tag)
284
-
285
- is_tagged_scenario = (scenario_tag and scenario_tag != "None") or (emotion_tag and emotion_tag != "None") or (first_emotion is not None)
286
- template = ANSWER_TEMPLATE_ADQ if is_tagged_scenario else ANSWER_TEMPLATE_CALM
287
-
288
- if template == ANSWER_TEMPLATE_ADQ:
289
- user_prompt = template.format(general_context=general_context, personal_context=personal_context, question=query, scenario_tag=scenario_tag, emotions_context=emotions_context, role=role, language=language)
290
- else:
291
- combined_context = f"General Guidance:\n{general_context}\n\nPersonal Memories:\n{personal_context}"
292
- user_prompt = template.format(context=combined_context, question=query, language=language)
293
-
294
- messages.append({"role": "user", "content": user_prompt})
295
-
296
- answer = call_llm(messages, temperature=temperature)
297
-
298
- high_risk_scenarios = ["exit_seeking", "wandering", "elopement"]
299
- if scenario_tag and scenario_tag.lower() in high_risk_scenarios:
300
- answer += f"\n\n---\n{RISK_FOOTER}"
301
-
302
- return {"answer": answer, "sources": _format_sources(all_docs_care)}
303
-
304
- return _answer_fn
305
-
306
-
307
- def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
308
- if not callable(chain): return {"answer": "[Error: RAG chain is not callable]", "sources": []}
309
- chat_history, scenario_tag, emotion_tag = kwargs.get("chat_history", []), kwargs.get("scenario_tag"), kwargs.get("emotion_tag")
310
- try:
311
- return chain(question, chat_history=chat_history, scenario_tag=scenario_tag, emotion_tag=emotion_tag)
312
- except Exception as e:
313
- print(f"ERROR in answer_query: {e}")
314
- return {"answer": f"[Error executing chain: {e}]", "sources": []}
315
-
316
- # -----------------------------
317
- # TTS & Transcription
318
- # -----------------------------
319
- def synthesize_tts(text: str, lang: str = "en"):
320
- if not text or gTTS is None: return None
321
- try:
322
- fd, path = tempfile.mkstemp(suffix=".mp3")
323
- os.close(fd)
324
- tts = gTTS(text=text, lang=(lang or "en"))
325
- tts.save(path)
326
- return path
327
- except Exception:
328
- return None
329
-
330
- def transcribe_audio(filepath: str, lang: str = "en"):
331
- client = _openai_client()
332
- if not client:
333
- return "[Transcription failed: API key not configured]"
334
- api_args = {"model": "whisper-1"}
335
- if lang and lang != "auto":
336
- api_args["language"] = lang
337
- with open(filepath, "rb") as audio_file:
338
- transcription = client.audio.transcriptions.create(file=audio_file, **api_args)
339
- return transcription.text
340
-