Spaces:
Runtime error
Runtime error
Update alz_companion/agent.py
Browse files- alz_companion/agent.py +44 -15
alz_companion/agent.py
CHANGED
|
@@ -201,10 +201,14 @@ def texts_from_jsonl(path: str) -> List[Document]:
|
|
| 201 |
obj = json.loads(line)
|
| 202 |
txt = obj.get("text") or ""
|
| 203 |
if not isinstance(txt, str) or not txt.strip(): continue
|
|
|
|
|
|
|
| 204 |
md = {"source": os.path.basename(path), "chunk": i}
|
| 205 |
-
for k in ("behaviors", "emotion"):
|
| 206 |
-
if k in obj
|
|
|
|
| 207 |
out.append(Document(page_content=txt, metadata=md))
|
|
|
|
| 208 |
except Exception:
|
| 209 |
return []
|
| 210 |
return out
|
|
@@ -340,21 +344,40 @@ def make_rag_chain(
|
|
| 340 |
search_filter["behaviors"] = scenario_tag.lower()
|
| 341 |
if emotion_tag and emotion_tag != "None":
|
| 342 |
search_filter["emotion"] = emotion_tag.lower()
|
| 343 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
if search_filter:
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
retriever_personal = vs_personal.as_retriever(search_kwargs={"k": 3})
|
| 349 |
-
retriever_general = vs_general.as_retriever(search_kwargs={"k": 3})
|
| 350 |
-
personal_docs = retriever_personal.invoke(query)
|
| 351 |
-
general_docs = retriever_general.invoke(query)
|
| 352 |
|
| 353 |
-
|
| 354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
first_emotion = None
|
| 357 |
-
all_docs_care =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
for doc in all_docs_care:
|
| 359 |
if "emotion" in doc.metadata and doc.metadata["emotion"]:
|
| 360 |
emotion_data = doc.metadata["emotion"]
|
|
@@ -384,15 +407,21 @@ def make_rag_chain(
|
|
| 384 |
return _answer_fn
|
| 385 |
|
| 386 |
|
|
|
|
| 387 |
def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
|
| 388 |
if not callable(chain): return {"answer": "[Error: RAG chain is not callable]", "sources": []}
|
| 389 |
-
chat_history
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
try:
|
| 391 |
-
return chain(question, chat_history=chat_history, scenario_tag=scenario_tag, emotion_tag=emotion_tag)
|
| 392 |
except Exception as e:
|
| 393 |
print(f"ERROR in answer_query: {e}")
|
| 394 |
return {"answer": f"[Error executing chain: {e}]", "sources": []}
|
| 395 |
|
|
|
|
| 396 |
# -----------------------------
|
| 397 |
# TTS & Transcription
|
| 398 |
# -----------------------------
|
|
|
|
| 201 |
obj = json.loads(line)
|
| 202 |
txt = obj.get("text") or ""
|
| 203 |
if not isinstance(txt, str) or not txt.strip(): continue
|
| 204 |
+
|
| 205 |
+
# fix bugs by adding tags for topic and context
|
| 206 |
md = {"source": os.path.basename(path), "chunk": i}
|
| 207 |
+
for k in ("behaviors", "emotion", "topic_tags", "context_tags"):
|
| 208 |
+
if k in obj and obj[k]: # Ensure the key exists and is not empty
|
| 209 |
+
md[k] = obj[k]
|
| 210 |
out.append(Document(page_content=txt, metadata=md))
|
| 211 |
+
|
| 212 |
except Exception:
|
| 213 |
return []
|
| 214 |
return out
|
|
|
|
| 344 |
search_filter["behaviors"] = scenario_tag.lower()
|
| 345 |
if emotion_tag and emotion_tag != "None":
|
| 346 |
search_filter["emotion"] = emotion_tag.lower()
|
| 347 |
+
# fix bug by adding topic tag and context tag
|
| 348 |
+
if topic_tag and topic_tag != "None": # <-- ADD THESE TWO LINES
|
| 349 |
+
search_filter["topic_tags"] = topic_tag.lower()
|
| 350 |
+
if context_tags: # <-- ADD THESE TWO LINES
|
| 351 |
+
search_filter["context_tags"] = {"in": [tag.lower() for tag in context_tags]}
|
| 352 |
+
|
| 353 |
+
# --- Robust Search Strategy ---
|
| 354 |
+
# 1. Start with a general, unfiltered search to always get text-based matches.
|
| 355 |
+
retriever_personal = vs_personal.as_retriever(search_kwargs={"k": 3})
|
| 356 |
+
retriever_general = vs_general.as_retriever(search_kwargs={"k": 3})
|
| 357 |
+
|
| 358 |
+
personal_docs = retriever_personal.invoke(query)
|
| 359 |
+
general_docs = retriever_general.invoke(query)
|
| 360 |
+
|
| 361 |
+
# 2. If filters exist, perform a second, more specific search and add the results.
|
| 362 |
if search_filter:
|
| 363 |
+
print(f"Performing additional search with filter: {search_filter}")
|
| 364 |
+
personal_docs.extend(vs_personal.similarity_search(query, k=3, filter=search_filter))
|
| 365 |
+
general_docs.extend(vs_general.similarity_search(query, k=3, filter=search_filter))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
|
| 367 |
+
# 3. Combine and de-duplicate the results to get the best of both searches.
|
| 368 |
+
all_personal_docs = list({doc.page_content: doc for doc in personal_docs}.values())
|
| 369 |
+
all_general_docs = list({doc.page_content: doc for doc in general_docs}.values())
|
| 370 |
+
|
| 371 |
+
# 4. Define the context variables based on the new, combined results.
|
| 372 |
+
personal_context = _format_docs(all_personal_docs, "(No relevant personal memories found.)")
|
| 373 |
+
general_context = _format_docs(all_general_docs, "(No general guidance found.)")
|
| 374 |
|
| 375 |
first_emotion = None
|
| 376 |
+
all_docs_care = all_personal_docs + all_general_docs
|
| 377 |
+
|
| 378 |
+
# -- end of Robust Search Strategy
|
| 379 |
+
|
| 380 |
+
|
| 381 |
for doc in all_docs_care:
|
| 382 |
if "emotion" in doc.metadata and doc.metadata["emotion"]:
|
| 383 |
emotion_data = doc.metadata["emotion"]
|
|
|
|
| 407 |
return _answer_fn
|
| 408 |
|
| 409 |
|
| 410 |
+
# Fix bug by adding topic tag ... how about context tag??
|
| 411 |
def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
|
| 412 |
if not callable(chain): return {"answer": "[Error: RAG chain is not callable]", "sources": []}
|
| 413 |
+
chat_history = kwargs.get("chat_history", [])
|
| 414 |
+
scenario_tag = kwargs.get("scenario_tag")
|
| 415 |
+
emotion_tag = kwargs.get("emotion_tag")
|
| 416 |
+
topic_tag = kwargs.get("topic_tag") # <-- ADD THIS LINE
|
| 417 |
+
context_tags = kwargs.get("context_tags") # <-- ADD THIS LINE
|
| 418 |
try:
|
| 419 |
+
return chain(question, chat_history=chat_history, scenario_tag=scenario_tag, emotion_tag=emotion_tag, topic_tag=topic_tag) # <-- ADD topic_tag
|
| 420 |
except Exception as e:
|
| 421 |
print(f"ERROR in answer_query: {e}")
|
| 422 |
return {"answer": f"[Error executing chain: {e}]", "sources": []}
|
| 423 |
|
| 424 |
+
|
| 425 |
# -----------------------------
|
| 426 |
# TTS & Transcription
|
| 427 |
# -----------------------------
|