Spaces:
Sleeping
Sleeping
Update alz_companion/agent.py
Browse files- alz_companion/agent.py +22 -48
alz_companion/agent.py
CHANGED
|
@@ -30,7 +30,7 @@ from .prompts import (
|
|
| 30 |
|
| 31 |
|
| 32 |
# -----------------------------
|
| 33 |
-
# Multimodal Processing Functions
|
| 34 |
# -----------------------------
|
| 35 |
|
| 36 |
def _openai_client() -> Optional[OpenAI]:
|
|
@@ -68,7 +68,6 @@ def describe_image(image_path: str) -> str:
|
|
| 68 |
# -----------------------------
|
| 69 |
# NLU Classification Function
|
| 70 |
# -----------------------------
|
| 71 |
-
# (Unchanged from before)
|
| 72 |
def detect_tags_from_query(query: str, behavior_options: list, emotion_options: list) -> Dict[str, Optional[str]]:
|
| 73 |
"""Uses an LLM call to classify the user's query into a behavior and emotion tag."""
|
| 74 |
behavior_str = ", ".join(f'"{opt}"' for opt in behavior_options if opt != "None")
|
|
@@ -89,7 +88,6 @@ def detect_tags_from_query(query: str, behavior_options: list, emotion_options:
|
|
| 89 |
# -----------------------------
|
| 90 |
# Embeddings & VectorStore
|
| 91 |
# -----------------------------
|
| 92 |
-
# (Unchanged from before)
|
| 93 |
|
| 94 |
def _default_embeddings():
|
| 95 |
"""Lightweight, widely available model."""
|
|
@@ -104,7 +102,6 @@ def build_or_load_vectorstore(docs: List[Document], index_path: str, is_personal
|
|
| 104 |
except Exception:
|
| 105 |
pass
|
| 106 |
|
| 107 |
-
# If it's a new personal vector store with no docs, create a placeholder
|
| 108 |
if is_personal and not docs:
|
| 109 |
docs = [Document(page_content="(This is the start of the personal memory journal.)", metadata={"source": "placeholder"})]
|
| 110 |
|
|
@@ -148,7 +145,6 @@ def bootstrap_vectorstore(sample_paths: List[str] | None = None, index_path: str
|
|
| 148 |
# -----------------------------
|
| 149 |
# LLM Call
|
| 150 |
# -----------------------------
|
| 151 |
-
# (Unchanged from before)
|
| 152 |
def call_llm(messages: List[Dict[str, str]], temperature: float = 0.6) -> str:
|
| 153 |
"""Call OpenAI Chat Completions if available; else return a fallback."""
|
| 154 |
client = _openai_client()
|
|
@@ -162,7 +158,7 @@ def call_llm(messages: List[Dict[str, str]], temperature: float = 0.6) -> str:
|
|
| 162 |
return f"[LLM API Error: {e}]"
|
| 163 |
|
| 164 |
# -----------------------------
|
| 165 |
-
# Prompting & RAG Chain
|
| 166 |
# -----------------------------
|
| 167 |
|
| 168 |
def _format_sources(docs: List[Document]) -> List[str]:
|
|
@@ -181,70 +177,56 @@ def make_rag_chain(
|
|
| 181 |
):
|
| 182 |
"""Returns a callable that performs the complete, two-tiered RAG process."""
|
| 183 |
|
| 184 |
-
retriever_general = vs_general.as_retriever(search_kwargs={"k": 3})
|
| 185 |
-
retriever_personal = vs_personal.as_retriever(search_kwargs={"k": 3})
|
| 186 |
-
|
| 187 |
def _format_docs(docs: List[Document], default_msg: str) -> str:
|
| 188 |
if not docs: return default_msg
|
| 189 |
return "\n".join([f"- {d.page_content.strip()}" for d in docs])
|
| 190 |
|
| 191 |
def _answer_fn(query: str, chat_history: List[Dict[str, str]], scenario_tag: Optional[str] = None, emotion_tag: Optional[str] = None) -> Dict[str, Any]:
|
| 192 |
|
| 193 |
-
#
|
| 194 |
-
personal_docs = retriever_personal.invoke(query)
|
| 195 |
-
personal_context = _format_docs(personal_docs, "(No relevant personal memories found.)")
|
| 196 |
-
|
| 197 |
-
# --- Step 2: Search the General Knowledge Base with filters ---
|
| 198 |
search_filter = {}
|
| 199 |
if scenario_tag and scenario_tag != "None":
|
| 200 |
search_filter["behaviors"] = scenario_tag.lower()
|
| 201 |
if emotion_tag and emotion_tag != "None":
|
| 202 |
search_filter["emotion"] = emotion_tag.lower()
|
| 203 |
|
|
|
|
| 204 |
if search_filter:
|
|
|
|
| 205 |
general_docs = vs_general.similarity_search(query, k=3, filter=search_filter)
|
| 206 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
general_docs = retriever_general.invoke(query)
|
|
|
|
|
|
|
| 208 |
general_context = _format_docs(general_docs, "(No general guidance found.)")
|
| 209 |
|
| 210 |
-
#
|
| 211 |
first_emotion = None
|
| 212 |
-
# Prioritize emotion from personal memories, then general context
|
| 213 |
all_docs = personal_docs + general_docs
|
| 214 |
for doc in all_docs:
|
| 215 |
if "emotion" in doc.metadata and doc.metadata["emotion"]:
|
| 216 |
emotion_data = doc.metadata["emotion"]
|
| 217 |
-
if isinstance(emotion_data, list):
|
| 218 |
-
|
| 219 |
-
else:
|
| 220 |
-
first_emotion = emotion_data
|
| 221 |
if first_emotion: break
|
| 222 |
|
| 223 |
emotions_context = render_emotion_guidelines(first_emotion or emotion_tag)
|
| 224 |
|
| 225 |
-
#
|
| 226 |
-
is_tagged_scenario = (scenario_tag and scenario_tag != "None") or (emotion_tag and emotion_tag != "None")
|
| 227 |
template = ANSWER_TEMPLATE_ADQ if is_tagged_scenario else ANSWER_TEMPLATE_CALM
|
| 228 |
|
| 229 |
-
# Note the new placeholders: general_context and personal_context
|
| 230 |
if template == ANSWER_TEMPLATE_ADQ:
|
| 231 |
-
user_prompt = template.format(
|
| 232 |
-
|
| 233 |
-
personal_context=personal_context,
|
| 234 |
-
question=query,
|
| 235 |
-
scenario_tag=scenario_tag,
|
| 236 |
-
emotions_context=emotions_context,
|
| 237 |
-
role=role,
|
| 238 |
-
language=language
|
| 239 |
-
)
|
| 240 |
-
else: # Calm template only uses a single combined context
|
| 241 |
combined_context = f"General Guidance:\n{general_context}\n\nPersonal Memories:\n{personal_context}"
|
| 242 |
user_prompt = template.format(context=combined_context, question=query, language=language)
|
| 243 |
|
| 244 |
-
system_message = SYSTEM_TEMPLATE.format(
|
| 245 |
-
tone=tone, language=language, patient_name=patient_name or "the patient",
|
| 246 |
-
caregiver_name=caregiver_name or "the caregiver", guardrails=SAFETY_GUARDRAILS,
|
| 247 |
-
)
|
| 248 |
|
| 249 |
messages = [{"role": "system", "content": system_message}]
|
| 250 |
messages.extend(chat_history)
|
|
@@ -261,13 +243,8 @@ def make_rag_chain(
|
|
| 261 |
return _answer_fn
|
| 262 |
|
| 263 |
def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
|
| 264 |
-
if not callable(chain):
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
chat_history = kwargs.get("chat_history", [])
|
| 268 |
-
scenario_tag = kwargs.get("scenario_tag")
|
| 269 |
-
emotion_tag = kwargs.get("emotion_tag")
|
| 270 |
-
|
| 271 |
try:
|
| 272 |
return chain(question, chat_history=chat_history, scenario_tag=scenario_tag, emotion_tag=emotion_tag)
|
| 273 |
except Exception as e:
|
|
@@ -277,8 +254,6 @@ def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
|
|
| 277 |
# -----------------------------
|
| 278 |
# TTS & Transcription
|
| 279 |
# -----------------------------
|
| 280 |
-
# (Unchanged)
|
| 281 |
-
|
| 282 |
def synthesize_tts(text: str, lang: str = "en"):
|
| 283 |
if not text or gTTS is None: return None
|
| 284 |
try:
|
|
@@ -294,11 +269,10 @@ def transcribe_audio(filepath: str, lang: str = "en"):
|
|
| 294 |
client = _openai_client()
|
| 295 |
if not client:
|
| 296 |
return "[Transcription failed: API key not configured]"
|
| 297 |
-
|
| 298 |
api_args = {"model": "whisper-1"}
|
| 299 |
if lang and lang != "auto":
|
| 300 |
api_args["language"] = lang
|
| 301 |
-
|
| 302 |
with open(filepath, "rb") as audio_file:
|
| 303 |
transcription = client.audio.transcriptions.create(file=audio_file, **api_args)
|
| 304 |
return transcription.text
|
|
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
# -----------------------------
|
| 33 |
+
# Multimodal Processing Functions
|
| 34 |
# -----------------------------
|
| 35 |
|
| 36 |
def _openai_client() -> Optional[OpenAI]:
|
|
|
|
| 68 |
# -----------------------------
|
| 69 |
# NLU Classification Function
|
| 70 |
# -----------------------------
|
|
|
|
| 71 |
def detect_tags_from_query(query: str, behavior_options: list, emotion_options: list) -> Dict[str, Optional[str]]:
|
| 72 |
"""Uses an LLM call to classify the user's query into a behavior and emotion tag."""
|
| 73 |
behavior_str = ", ".join(f'"{opt}"' for opt in behavior_options if opt != "None")
|
|
|
|
| 88 |
# -----------------------------
|
| 89 |
# Embeddings & VectorStore
|
| 90 |
# -----------------------------
|
|
|
|
| 91 |
|
| 92 |
def _default_embeddings():
|
| 93 |
"""Lightweight, widely available model."""
|
|
|
|
| 102 |
except Exception:
|
| 103 |
pass
|
| 104 |
|
|
|
|
| 105 |
if is_personal and not docs:
|
| 106 |
docs = [Document(page_content="(This is the start of the personal memory journal.)", metadata={"source": "placeholder"})]
|
| 107 |
|
|
|
|
| 145 |
# -----------------------------
|
| 146 |
# LLM Call
|
| 147 |
# -----------------------------
|
|
|
|
| 148 |
def call_llm(messages: List[Dict[str, str]], temperature: float = 0.6) -> str:
|
| 149 |
"""Call OpenAI Chat Completions if available; else return a fallback."""
|
| 150 |
client = _openai_client()
|
|
|
|
| 158 |
return f"[LLM API Error: {e}]"
|
| 159 |
|
| 160 |
# -----------------------------
|
| 161 |
+
# Prompting & RAG Chain
|
| 162 |
# -----------------------------
|
| 163 |
|
| 164 |
def _format_sources(docs: List[Document]) -> List[str]:
|
|
|
|
| 177 |
):
|
| 178 |
"""Returns a callable that performs the complete, two-tiered RAG process."""
|
| 179 |
|
|
|
|
|
|
|
|
|
|
| 180 |
def _format_docs(docs: List[Document], default_msg: str) -> str:
|
| 181 |
if not docs: return default_msg
|
| 182 |
return "\n".join([f"- {d.page_content.strip()}" for d in docs])
|
| 183 |
|
| 184 |
def _answer_fn(query: str, chat_history: List[Dict[str, str]], scenario_tag: Optional[str] = None, emotion_tag: Optional[str] = None) -> Dict[str, Any]:
|
| 185 |
|
| 186 |
+
# Build a dynamic filter that will be used for BOTH knowledge bases
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
search_filter = {}
|
| 188 |
if scenario_tag and scenario_tag != "None":
|
| 189 |
search_filter["behaviors"] = scenario_tag.lower()
|
| 190 |
if emotion_tag and emotion_tag != "None":
|
| 191 |
search_filter["emotion"] = emotion_tag.lower()
|
| 192 |
|
| 193 |
+
# Use the filter on both searches if available
|
| 194 |
if search_filter:
|
| 195 |
+
personal_docs = vs_personal.similarity_search(query, k=3, filter=search_filter)
|
| 196 |
general_docs = vs_general.similarity_search(query, k=3, filter=search_filter)
|
| 197 |
else:
|
| 198 |
+
# If no filters, perform standard semantic search on both
|
| 199 |
+
retriever_personal = vs_personal.as_retriever(search_kwargs={"k": 3})
|
| 200 |
+
retriever_general = vs_general.as_retriever(search_kwargs={"k": 3})
|
| 201 |
+
personal_docs = retriever_personal.invoke(query)
|
| 202 |
general_docs = retriever_general.invoke(query)
|
| 203 |
+
|
| 204 |
+
personal_context = _format_docs(personal_docs, "(No relevant personal memories found.)")
|
| 205 |
general_context = _format_docs(general_docs, "(No general guidance found.)")
|
| 206 |
|
| 207 |
+
# Determine emotion for the response guidelines
|
| 208 |
first_emotion = None
|
|
|
|
| 209 |
all_docs = personal_docs + general_docs
|
| 210 |
for doc in all_docs:
|
| 211 |
if "emotion" in doc.metadata and doc.metadata["emotion"]:
|
| 212 |
emotion_data = doc.metadata["emotion"]
|
| 213 |
+
if isinstance(emotion_data, list): first_emotion = emotion_data[0]
|
| 214 |
+
else: first_emotion = emotion_data
|
|
|
|
|
|
|
| 215 |
if first_emotion: break
|
| 216 |
|
| 217 |
emotions_context = render_emotion_guidelines(first_emotion or emotion_tag)
|
| 218 |
|
| 219 |
+
# Assemble and Call the LLM
|
| 220 |
+
is_tagged_scenario = (scenario_tag and scenario_tag != "None") or (emotion_tag and emotion_tag != "None") or (first_emotion is not None)
|
| 221 |
template = ANSWER_TEMPLATE_ADQ if is_tagged_scenario else ANSWER_TEMPLATE_CALM
|
| 222 |
|
|
|
|
| 223 |
if template == ANSWER_TEMPLATE_ADQ:
|
| 224 |
+
user_prompt = template.format(general_context=general_context, personal_context=personal_context, question=query, scenario_tag=scenario_tag, emotions_context=emotions_context, role=role, language=language)
|
| 225 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
combined_context = f"General Guidance:\n{general_context}\n\nPersonal Memories:\n{personal_context}"
|
| 227 |
user_prompt = template.format(context=combined_context, question=query, language=language)
|
| 228 |
|
| 229 |
+
system_message = SYSTEM_TEMPLATE.format(tone=tone, language=language, patient_name=patient_name or "the patient", caregiver_name=caregiver_name or "the caregiver", guardrails=SAFETY_GUARDRAILS)
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
messages = [{"role": "system", "content": system_message}]
|
| 232 |
messages.extend(chat_history)
|
|
|
|
| 243 |
return _answer_fn
|
| 244 |
|
| 245 |
def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
|
| 246 |
+
if not callable(chain): return {"answer": "[Error: RAG chain is not callable]", "sources": []}
|
| 247 |
+
chat_history, scenario_tag, emotion_tag = kwargs.get("chat_history", []), kwargs.get("scenario_tag"), kwargs.get("emotion_tag")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
try:
|
| 249 |
return chain(question, chat_history=chat_history, scenario_tag=scenario_tag, emotion_tag=emotion_tag)
|
| 250 |
except Exception as e:
|
|
|
|
| 254 |
# -----------------------------
|
| 255 |
# TTS & Transcription
|
| 256 |
# -----------------------------
|
|
|
|
|
|
|
| 257 |
def synthesize_tts(text: str, lang: str = "en"):
|
| 258 |
if not text or gTTS is None: return None
|
| 259 |
try:
|
|
|
|
| 269 |
client = _openai_client()
|
| 270 |
if not client:
|
| 271 |
return "[Transcription failed: API key not configured]"
|
|
|
|
| 272 |
api_args = {"model": "whisper-1"}
|
| 273 |
if lang and lang != "auto":
|
| 274 |
api_args["language"] = lang
|
|
|
|
| 275 |
with open(filepath, "rb") as audio_file:
|
| 276 |
transcription = client.audio.transcriptions.create(file=audio_file, **api_args)
|
| 277 |
return transcription.text
|
| 278 |
+
|