Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import shutil | |
| import gradio as gr | |
| import tempfile | |
| from datetime import datetime | |
| from typing import List, Dict, Any, Optional | |
| from pytube import YouTube | |
| from pathlib import Path # <-- Add this import at the top of your file with the other imports | |
| import re | |
| # --- Agent Imports & Safe Fallbacks --- | |
| try: | |
| from alz_companion.agent import ( | |
| bootstrap_vectorstore, make_rag_chain, answer_query, synthesize_tts, | |
| transcribe_audio, detect_tags_from_query, describe_image, build_or_load_vectorstore, | |
| _default_embeddings | |
| ) | |
| from alz_companion.prompts import BEHAVIOUR_TAGS, EMOTION_STYLES | |
| from langchain.schema import Document | |
| from langchain_community.vectorstores import FAISS | |
| AGENT_OK = True | |
| except Exception as e: | |
| AGENT_OK = False | |
| # Define all fallback functions and classes | |
| def bootstrap_vectorstore(sample_paths=None, index_path="data/"): return object() | |
| def build_or_load_vectorstore(docs, index_path, is_personal=False): return object() | |
| def make_rag_chain(vs_general, vs_personal, **kwargs): return lambda q, **k: {"answer": f"(Demo) You asked: {q}", "sources": []} | |
| def answer_query(chain, q, **kwargs): return chain(q, **kwargs) | |
| def synthesize_tts(text: str, lang: str = "en"): return None | |
| def transcribe_audio(filepath: str, lang: str = "en"): return "This is a transcribed message." | |
| def detect_tags_from_query(query: str, behavior_options: list, emotion_options: list): return {"detected_behavior": "None", "detected_emotion": "None"} | |
| def describe_image(image_path: str): return "This is a description of an image." | |
| def _default_embeddings(): return None | |
| class Document: | |
| def __init__(self, page_content, metadata): | |
| self.page_content = page_content | |
| self.metadata = metadata | |
| class FAISS: | |
| def __init__(self): | |
| self.docstore = type('obj', (object,), {'_dict': {}})() | |
| BEHAVIOUR_TAGS = {"None": []} | |
| EMOTION_STYLES = {"None": {}} | |
| print(f"WARNING: Could not import from alz_companion ({e}). Running in UI-only demo mode.") | |
| # --- Centralized Configuration --- | |
| CONFIG = { | |
| "themes": ["All", "The Father", "Still Alice", "Away from Her", "Alive Inside", "General Caregiving"], | |
| "roles": ["patient", "caregiver"], | |
| "behavior_tags": ["None"] + list(BEHAVIOUR_TAGS.keys()), | |
| "emotion_tags": ["None"] + list(EMOTION_STYLES.keys()), | |
| # --- THIS LIST HAS BEEN UPDATED AND EXPANDED --- | |
| "topic_tags": [ | |
| "None", | |
| "caregiving_advice", | |
| "medical_fact", | |
| "personal_story", | |
| "research_update", | |
| "treatment_option:home_safety", | |
| "treatment_option:long_term_care", | |
| "treatment_option:music_therapy", | |
| "treatment_option:reassurance", | |
| "treatment_option:routine_structuring", | |
| "treatment_option:validation_therapy" | |
| ], | |
| # --- END OF Topic_tag UPDATE --- | |
| # --- ADD THIS NEW LIST to handle context_tag --- | |
| "context_tags": [ | |
| "None", "disease_stage_mild", | |
| "disease_stage_moderate", "disease_stage_advanced", | |
| "disease_stage_unspecified", "interaction_mode_one_to_one", | |
| "interaction_mode_small_group", "interaction_mode_group_activity", | |
| "relationship_family", "relationship_spouse", | |
| "relationship_staff_or_caregiver", "relationship_unspecified", | |
| "setting_home_or_community", "setting_care_home", | |
| "setting_clinic_or_hospital" | |
| ], | |
| # --- END OF Context_tag UPDATE --- | |
| "languages": {"English": "en", "Chinese": "zh", "Malay": "ms", "French": "fr", "Spanish": "es"}, | |
| "tones": ["warm", "neutral", "formal", "playful"] | |
| } | |
| # --- File Management & Vector Store Logic --- | |
| # --- Persistent storage root --- CG5 | |
| def _storage_root() -> Path: | |
| """ | |
| Choose a durable home for runtime artefacts: | |
| 1) $SPACE_STORAGE -> custom mount if you set it | |
| 2) /data -> Hugging Face Spaces persistent volume | |
| 3) ~/.cache/alz_companion -> portable fallback | |
| """ | |
| candidates = [ | |
| Path(os.getenv("SPACE_STORAGE", "")), | |
| Path("/data"), | |
| Path.home() / ".cache" / "alz_companion", | |
| ] | |
| for p in candidates: | |
| if not p: | |
| continue | |
| try: | |
| p.mkdir(parents=True, exist_ok=True) | |
| probe = p / ".write_test" | |
| with open(probe, "w") as f: | |
| f.write("ok") | |
| probe.unlink(missing_ok=True) | |
| return p | |
| except Exception: | |
| continue | |
| # Last resort: temp (not persistent, but avoids crashing) | |
| tmp = Path(tempfile.gettempdir()) / "alz_companion" | |
| tmp.mkdir(parents=True, exist_ok=True) | |
| return tmp | |
| STORAGE_ROOT = _storage_root() | |
| # --- File Management & Vector Store Logic (persistent) --- CG5 | |
| INDEX_BASE = str(STORAGE_ROOT / "index") | |
| PERSONAL_DATA_BASE = str(STORAGE_ROOT / "personal") | |
| UPLOADS_BASE = os.path.join(INDEX_BASE, "uploads") | |
| PERSONAL_INDEX_PATH = os.path.join(PERSONAL_DATA_BASE, "personal_faiss_index") | |
| THEME_PATHS = { | |
| t: os.path.join(INDEX_BASE, f"faiss_index_{t.replace(' ', '').lower()}") | |
| for t in CONFIG["themes"] | |
| } | |
| os.makedirs(UPLOADS_BASE, exist_ok=True) | |
| os.makedirs(os.path.dirname(PERSONAL_INDEX_PATH), exist_ok=True) | |
| for p in THEME_PATHS.values(): | |
| os.makedirs(p, exist_ok=True) | |
| vectorstores = {} | |
| personal_vectorstore = None | |
| test_fixtures = [] # <-- ADD THIS LINE | |
| # --- Load existing personal index if present --- CG5 | |
| try: | |
| personal_vectorstore = build_or_load_vectorstore([], PERSONAL_INDEX_PATH, is_personal=True) | |
| except Exception: | |
| # stay graceful if the index is missing/corrupt; user can rebuild by adding memories | |
| personal_vectorstore = None | |
| def canonical_theme(tk: str) -> str: return tk if tk in CONFIG["themes"] else "All" | |
| def theme_upload_dir(theme: str) -> str: | |
| p = os.path.join(UPLOADS_BASE, f"theme_{canonical_theme(theme).replace(' ', '').lower()}") | |
| os.makedirs(p, exist_ok=True) | |
| return p | |
| def load_manifest(theme: str) -> Dict[str, Any]: | |
| p = os.path.join(theme_upload_dir(theme), "manifest.json") | |
| if os.path.exists(p): | |
| try: | |
| with open(p, "r", encoding="utf-8") as f: return json.load(f) | |
| except Exception: pass | |
| return {"files": {}} | |
| def save_manifest(theme: str, man: Dict[str, Any]): | |
| with open(os.path.join(theme_upload_dir(theme), "manifest.json"), "w", encoding="utf-8") as f: json.dump(man, f, indent=2) | |
| def list_theme_files(theme: str) -> List[tuple[str, bool]]: | |
| man = load_manifest(theme) | |
| base = theme_upload_dir(theme) | |
| found = [(n, bool(e)) for n, e in man.get("files", {}).items() if os.path.exists(os.path.join(base, n))] | |
| existing = {n for n, e in found} | |
| for name in sorted(os.listdir(base)): | |
| if name not in existing and os.path.isfile(os.path.join(base, name)): found.append((name, False)) | |
| man["files"] = dict(found) | |
| save_manifest(theme, man) | |
| return found | |
| def copy_into_theme(theme: str, src_path: str) -> str: | |
| fname = os.path.basename(src_path) | |
| dest = os.path.join(theme_upload_dir(theme), fname) | |
| shutil.copy2(src_path, dest) | |
| return dest | |
| def seed_files_into_theme(theme: str): | |
| SEED_FILES = [ | |
| ("sample_data/caregiving_tips.txt", True), | |
| ("sample_data/the_father_segments_enriched_harmonized_plus.jsonl", True), | |
| ("sample_data/still_alice_enriched_harmonized_plus.jsonl", True), | |
| ("sample_data/away_from_her_enriched_harmonized_plus.jsonl", True), | |
| ("sample_data/alive_inside_enriched_harmonized.jsonl", True) | |
| ] | |
| man, changed = load_manifest(theme), False | |
| for path, enable in SEED_FILES: | |
| if not os.path.exists(path): continue | |
| fname = os.path.basename(path) | |
| if not os.path.exists(os.path.join(theme_upload_dir(theme), fname)): | |
| copy_into_theme(theme, path) | |
| man["files"][fname] = bool(enable) | |
| changed = True | |
| if changed: save_manifest(theme, man) | |
| def ensure_index(theme='All'): | |
| theme = canonical_theme(theme) | |
| if theme in vectorstores: return vectorstores[theme] | |
| upload_dir = theme_upload_dir(theme) | |
| enabled_files = [os.path.join(upload_dir, n) for n, enabled in list_theme_files(theme) if enabled] | |
| index_path = THEME_PATHS.get(theme) | |
| vectorstores[theme] = bootstrap_vectorstore(sample_paths=enabled_files, index_path=index_path) | |
| return vectorstores[theme] | |
| # --- Gradio Callbacks --- | |
| def collect_settings(*args): | |
| keys = ["role", "patient_name", "caregiver_name", "tone", "language", "tts_lang", "temperature", "behaviour_tag", "emotion_tag", "topic_tag", "active_theme", "tts_on", "debug_mode"] | |
| return dict(zip(keys, args)) | |
| # In app.py, replace the existing parse_and_tag_entries function with this one. | |
| # orignal without debug mode -> def parse_and_tag_entries(text_content: str, source: str) -> List[Document]: | |
| def parse_and_tag_entries(text_content: str, source: str, settings: dict = None) -> List[Document]: | |
| separator_pattern = r'\n(?:---|--|-|-\*-|-\.-)\n' | |
| entries = re.split(separator_pattern, text_content) | |
| docs_to_add = [] | |
| for entry in entries: | |
| if not entry.strip(): | |
| continue | |
| title = "Untitled Text Entry" | |
| content = entry.strip() | |
| lines = entry.strip().split('\n') | |
| if lines and "title:" in lines[0].lower(): | |
| title_line = lines[0].split(':', 1) | |
| title = title_line[1].strip() if len(title_line) > 1 else "Untitled" | |
| content_part = "\n".join(lines[1:]) | |
| if "content:" in content_part.lower(): | |
| content = content_part.split(':', 1)[1].strip() | |
| else: | |
| content = content_part.strip() | |
| full_content = f"Title: {title}\n\nContent: {content}" | |
| # add setting for debug mode | |
| if settings and settings.get("debug_mode"): | |
| print(f" - Parsing entry: '{title}'") | |
| behavior_options = CONFIG.get("behavior_tags", []) | |
| emotion_options = CONFIG.get("emotion_tags", []) | |
| topic_options = CONFIG.get("topic_tags", []) | |
| context_options = CONFIG.get("context_tags", []) # <-- ADD THIS LINE | |
| # Update the function call to include the new argument | |
| detected_tags = detect_tags_from_query( | |
| content, | |
| behavior_options=behavior_options, | |
| emotion_options=emotion_options, | |
| topic_options=topic_options, | |
| context_options=context_options # <-- AND ADD THIS ARGUMENT | |
| ) | |
| metadata = {"source": source, "title": title} | |
| # Note: The raw response from the NLU now returns lists for behaviors/contexts | |
| detected_behaviors = detected_tags.get("detected_behaviors", []) | |
| if detected_behaviors: | |
| metadata["behaviors"] = [b.lower() for b in detected_behaviors] | |
| if detected_tags.get("detected_emotion") != "None": | |
| metadata["emotion"] = detected_tags.get("detected_emotion").lower() | |
| detected_topics = detected_tags.get("detected_topic") # Topic is a single string | |
| if detected_topics and detected_topics != "None": | |
| metadata["topic_tags"] = [detected_topics.lower()] | |
| detected_contexts = detected_tags.get("detected_contexts", []) | |
| if detected_contexts: | |
| metadata["context_tags"] = [c.lower() for c in detected_contexts] | |
| docs_to_add.append(Document(page_content=full_content, metadata=metadata)) | |
| return docs_to_add | |
| # def handle_add_knowledge(title, text_input, file_input, image_input, yt_url): | |
| def handle_add_knowledge(title, text_input, file_input, image_input, yt_url, settings): | |
| global personal_vectorstore | |
| docs_to_add = [] | |
| # Corrected prioritization of inputs | |
| if text_input and text_input.strip(): | |
| # Handle manual text input first | |
| docs_to_add = parse_and_tag_entries(f"Title: {title}\n\nContent: {text_input}", "Text Input", settings=settings) | |
| elif file_input: | |
| content_source = os.path.basename(file_input) | |
| if file_input.lower().endswith('.txt'): | |
| with open(file_input, 'r', encoding='utf-8') as f: | |
| file_content = f.read() | |
| docs_to_add = parse_and_tag_entries(file_content, content_source, settings=settings) | |
| else: # Handle audio/video | |
| final_title = title.strip() if title and title.strip() else "Audio/Video Note" | |
| content_text = transcribe_audio(file_input) | |
| full_content = f"Title: {final_title}\n\nContent: {content_text}" | |
| docs_to_add = parse_and_tag_entries(full_content, content_source, settings=settings) | |
| elif image_input: | |
| final_title = title.strip() if title and title.strip() else "Image Note" | |
| content_text = describe_image(image_input) | |
| full_content = f"Title: {final_title}\n\nContent: {content_text}" | |
| docs_to_add = parse_and_tag_entries(full_content, "Image Input", settings=settings) | |
| elif yt_url and ("youtube.com" in yt_url or "youtu.be" in yt_url): | |
| try: | |
| yt = YouTube(yt_url) | |
| video_title = yt.title | |
| final_title = title.strip() if title and title.strip() else video_title | |
| # --- suggested as optional by CG5 | |
| # media_dir = STORAGE_ROOT / "media" | |
| # media_dir.mkdir(parents=True, exist_ok=True) | |
| # temp_audio_path = str(media_dir / f"yt_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4") | |
| # audio_stream = yt.streams.get_audio_only() | |
| # audio_stream.download(filename=temp_audio_path) | |
| # content_text = transcribe_audio(temp_audio_path) | |
| # content_source = f"YouTube: {video_title} ({temp_audio_path})" | |
| # If you truly don't want to keep files, you may still remove it later via a UI control. | |
| audio_stream = yt.streams.get_audio_only() | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_audio_file: | |
| audio_stream.download(filename=temp_audio_file.name) | |
| temp_audio_path = temp_audio_file.name | |
| content_text = transcribe_audio(temp_audio_path) | |
| content_source = f"YouTube: {video_title}" | |
| os.remove(temp_audio_path) | |
| full_content = f"Title: {final_title}\n\nContent: {content_text}" | |
| docs_to_add = parse_and_tag_entries(full_content, content_source, settings=settings) | |
| except Exception as e: | |
| return f"Error processing YouTube link: {e}" | |
| else: | |
| return "Please provide a title and content, or another input source." | |
| if not docs_to_add: | |
| return "No processable content found to add." | |
| if personal_vectorstore is None: | |
| personal_vectorstore = build_or_load_vectorstore(docs_to_add, PERSONAL_INDEX_PATH, is_personal=True) | |
| else: | |
| personal_vectorstore.add_documents(docs_to_add) | |
| personal_vectorstore.save_local(PERSONAL_INDEX_PATH) | |
| return f"Successfully added {len(docs_to_add)} new memory/memories." | |
| def save_chat_to_memory(chat_history): | |
| global personal_vectorstore | |
| if not chat_history: return "Nothing to save." | |
| formatted_chat = [] | |
| for message in chat_history: | |
| role = "User" if message["role"] == "user" else "Assistant" | |
| content = message["content"].strip() | |
| if content.startswith("*(Auto-detected context:"): continue | |
| formatted_chat.append(f"{role}: {content}") | |
| conversation_text = "\n".join(formatted_chat) | |
| if not conversation_text: return "No conversation content to save." | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| title = f"Conversation from {timestamp}" | |
| full_content = f"Title: {title}\n\nContent:\n{conversation_text}" | |
| doc_to_add = Document(page_content=full_content, metadata={"source": "Saved Chat", "title": title}) | |
| if personal_vectorstore is None: | |
| personal_vectorstore = build_or_load_vectorstore([doc_to_add], PERSONAL_INDEX_PATH, is_personal=True) | |
| else: | |
| personal_vectorstore.add_documents([doc_to_add]) | |
| personal_vectorstore.save_local(PERSONAL_INDEX_PATH) | |
| print(f"Saved conversation to long-term memory.") | |
| return f"Conversation from {timestamp} saved successfully!" | |
| def list_personal_memories(): | |
| global personal_vectorstore | |
| if personal_vectorstore is None or not hasattr(personal_vectorstore.docstore, '_dict') or not personal_vectorstore.docstore._dict: | |
| return gr.update(value=[["No memories to display", "", ""]]), gr.update(choices=["No memories to select"], value=None) | |
| docs = list(personal_vectorstore.docstore._dict.values()) | |
| dataframe_data = [[doc.metadata.get('title', 'Untitled'), doc.metadata.get('source', 'Unknown'), doc.page_content] for doc in docs] | |
| dropdown_choices = [doc.page_content for doc in docs] | |
| return gr.update(value=dataframe_data), gr.update(choices=dropdown_choices) | |
| def delete_personal_memory(memory_to_delete): | |
| global personal_vectorstore | |
| if personal_vectorstore is None or not memory_to_delete: | |
| return "Knowledge base is empty or no memory selected." | |
| all_docs = list(personal_vectorstore.docstore._dict.values()) | |
| docs_to_keep = [doc for doc in all_docs if doc.page_content != memory_to_delete] | |
| if len(all_docs) == len(docs_to_keep): | |
| return "Error: Could not find the selected memory to delete." | |
| print(f"Deleting memory. {len(docs_to_keep)} memories remaining.") | |
| if not docs_to_keep: | |
| if os.path.isdir(PERSONAL_INDEX_PATH): | |
| shutil.rmtree(PERSONAL_INDEX_PATH) | |
| personal_vectorstore = build_or_load_vectorstore([], PERSONAL_INDEX_PATH, is_personal=True) | |
| else: | |
| new_vs = FAISS.from_documents(docs_to_keep, _default_embeddings()) | |
| new_vs.save_local(PERSONAL_INDEX_PATH) | |
| personal_vectorstore = new_vs | |
| return "Successfully deleted memory. The list will now refresh." | |
| # adjust the main application logic in chat_fn to use the new auto-detection after adding topic_tag | |
| def chat_fn(user_text, audio_file, settings, chat_history): | |
| global personal_vectorstore | |
| question = (user_text or "").strip() | |
| if audio_file and not question: | |
| try: | |
| voice_lang_name = settings.get("tts_lang", "English") | |
| voice_lang_code = CONFIG["languages"].get(voice_lang_name, "en") | |
| question = transcribe_audio(audio_file, lang=voice_lang_code) | |
| except Exception as e: | |
| err_msg = f"Audio Error: {e}" if settings.get("debug_mode") else "Sorry, I couldn't understand the audio." | |
| chat_history.append({"role": "assistant", "content": err_msg}) | |
| return "", None, chat_history | |
| if not question: | |
| return "", None, chat_history | |
| chat_history.append({"role": "user", "content": question}) | |
| # --- UPDATED DETECTION AND OVERRIDE LOGIC --- | |
| # Get manual settings from the UI dropdowns | |
| manual_behavior_tag = settings.get("behaviour_tag", "None") | |
| manual_emotion_tag = settings.get("emotion_tag", "None") | |
| manual_topic_tag = settings.get("topic_tag", "None") | |
| # By default, the final tags are the manual ones. | |
| scenario_tag = manual_behavior_tag | |
| emotion_tag = manual_emotion_tag | |
| topic_tag = manual_topic_tag | |
| # If all manual filters are set to "None", then run auto-detection. | |
| if manual_behavior_tag == "None" and manual_emotion_tag == "None" and manual_topic_tag == "None": | |
| print("No manual tags set, running auto-detection...") | |
| behavior_options = CONFIG.get("behavior_tags", []) | |
| emotion_options = CONFIG.get("emotion_tags", []) | |
| topic_options = CONFIG.get("topic_tags", []) | |
| context_options = CONFIG.get("context_tags", []) # <-- ADD THIS LINE | |
| detected_tags = detect_tags_from_query( | |
| question, | |
| behavior_options=behavior_options, | |
| emotion_options=emotion_options, | |
| topic_options=topic_options, | |
| context_options=context_options # <-- ADD THIS ARGUMENT | |
| ) | |
| scenario_tag = detected_tags.get("detected_behavior", "None") | |
| emotion_tag = detected_tags.get("detected_emotion", "None") | |
| topic_tag = detected_tags.get("detected_topic", "None") | |
| # Display the auto-detected tags in the chat | |
| detected_parts = [] | |
| if scenario_tag and scenario_tag != "None": | |
| detected_parts.append(f"Behavior=`{scenario_tag}`") | |
| if emotion_tag and emotion_tag != "None": | |
| detected_parts.append(f"Emotion=`{emotion_tag}`") | |
| if topic_tag and topic_tag != "None": | |
| detected_parts.append(f"Topic=`{topic_tag}`") | |
| # Turn on debug mode | |
| # if detected_parts and settings.get("debug_mode"): | |
| # right now it's default without turning on debug mode | |
| if detected_parts: | |
| detected_msg = f"*(Auto-detected context: {', '.join(detected_parts)})*" | |
| chat_history.append({"role": "assistant", "content": detected_msg}) | |
| else: | |
| print("Manual tags detected, skipping auto-detection.") | |
| # --- END OF UPDATED LOGIC --- | |
| active_theme = settings.get("active_theme", "All") | |
| vs_general = ensure_index(active_theme) | |
| if personal_vectorstore is None: | |
| personal_vectorstore = build_or_load_vectorstore([], PERSONAL_INDEX_PATH, is_personal=True) | |
| rag_chain_settings = {"role": settings.get("role"), "temperature": settings.get("temperature"), "language": settings.get("language"), "patient_name": settings.get("patient_name"), "caregiver_name": settings.get("caregiver_name"), "tone": settings.get("tone"),} | |
| chain = make_rag_chain(vs_general, personal_vectorstore, **rag_chain_settings) | |
| # Ensure "None" values are treated as None | |
| final_scenario_tag = scenario_tag if scenario_tag != "None" else None | |
| final_emotion_tag = emotion_tag if emotion_tag != "None" else None | |
| final_topic_tag = topic_tag if topic_tag != "None" else None | |
| # --- ADD the line below it --- | |
| # The NLU returns a list, so we'll pass the whole list to the agent for context tag | |
| final_context_tags = detected_tags.get("detected_contexts", []) if (manual_behavior_tag == "None" and manual_emotion_tag == "None" and manual_topic_tag == "None") else [] | |
| simple_history = chat_history[:-1] | |
| # response = answer_query(chain, question, chat_history=simple_history, scenario_tag=final_scenario_tag, emotion_tag=final_emotion_tag, topic_tag=final_topic_tag) | |
| # Corrected call: for adding context tag | |
| response = answer_query(chain, question, chat_history=simple_history, scenario_tag=final_scenario_tag, emotion_tag=final_emotion_tag, topic_tag=final_topic_tag, context_tags=final_context_tags) | |
| answer = response.get("answer", "[No answer found]") | |
| chat_history.append({"role": "assistant", "content": answer}) | |
| # --- NEW SOURCE DISPLAY LOGIC --- | |
| # If debug mode is on and the response dictionary contains sources, display them. | |
| # if settings.get("debug_mode") and response.get("sources"): | |
| # For now, turn on the sources without using debug mode line of code | |
| if response.get("sources"): | |
| sources = response.get("sources", []) | |
| # Filter out placeholders or empty sources if they exist | |
| valid_sources = [s for s in sources if s and s not in ["unknown", "placeholder"]] | |
| if valid_sources: | |
| source_msg = f"*(Sources used: {', '.join(valid_sources)})*" | |
| chat_history.append({"role": "assistant", "content": source_msg}) | |
| # --- END OF NEW LOGIC --- | |
| audio_out = None | |
| if settings.get("tts_on") and answer: | |
| tts_lang_code = CONFIG["languages"].get(settings.get("tts_lang"), "en") | |
| audio_out = synthesize_tts(answer, lang=tts_lang_code) | |
| from gradio import update | |
| return "", (update(value=audio_out, visible=bool(audio_out))), chat_history | |
| def upload_knowledge(files, current_theme): | |
| if not files: return "No files were selected to upload." | |
| added = 0 | |
| for f in files: | |
| try: | |
| copy_into_theme(current_theme, f.name); added += 1 | |
| except Exception as e: print(f"Error uploading file {f.name}: {e}") | |
| if added > 0 and current_theme in vectorstores: del vectorstores[current_theme] | |
| return f"Uploaded {added} file(s). Refreshing file list..." | |
| def save_file_selection(current_theme, enabled_files): | |
| man = load_manifest(current_theme) | |
| for fname in man['files']: man['files'][fname] = fname in enabled_files | |
| save_manifest(current_theme, man) | |
| if current_theme in vectorstores: del vectorstores[current_theme] | |
| return f"Settings saved. Index for theme '{current_theme}' will rebuild on the next query." | |
| def refresh_file_list_ui(current_theme): | |
| files = list_theme_files(current_theme) | |
| enabled = [f for f, en in files if en] | |
| msg = f"Found {len(files)} file(s). {len(enabled)} enabled." | |
| return gr.update(choices=[f for f, _ in files], value=enabled), msg | |
| def auto_setup_on_load(current_theme): | |
| theme_dir = theme_upload_dir(current_theme) | |
| if not os.listdir(theme_dir): | |
| print("First-time setup: Auto-seeding sample data...") | |
| seed_files_into_theme(current_theme) | |
| all_settings = collect_settings("patient", "", "", "warm", "English", "English", 0.7, "None", "None", "All", True, False) | |
| files_ui, status_msg = refresh_file_list_ui(current_theme) | |
| return all_settings, files_ui, status_msg | |
| # In app.py, add the test cases inside the Gradio Callbacks section, | |
| def load_test_fixtures(): | |
| """Loads the test cases and returns a Gradio update object to populate the dropdown.""" | |
| global test_fixtures | |
| test_fixtures = [] # Reset fixtures on each load attempt | |
| try: | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| fixtures_path = os.path.join(script_dir, "conversation_test_fixtures.jsonl") | |
| if not os.path.exists(fixtures_path): | |
| print("WARNING: Test fixtures file not found.") | |
| # Return an update with an empty list of choices | |
| return gr.update(choices=[]) | |
| with open(fixtures_path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| test_fixtures.append(json.loads(line)) | |
| # --- THIS IS THE KEY CHANGE --- | |
| # Create a list of the test titles | |
| test_titles = [fixture["title"] for fixture in test_fixtures] | |
| # Return a Gradio update object that specifically targets the 'choices' property | |
| return gr.update(choices=test_titles) | |
| # --- END OF CHANGE --- | |
| except Exception as e: | |
| print(f"UNEXPECTED ERROR during file loading: {e}") | |
| return gr.update(choices=[]) | |
| # In app.py, fixed run_nlu_test function to handle the new data structure from the detection logic | |
| def run_nlu_test(test_title: str): | |
| """Runs a selected NLU test case with correct pass/fail logic and detailed debugging.""" | |
| print("\n--- RUNNING NLU TEST (Definitive Version) ---") | |
| if not test_title or not test_fixtures: | |
| return "Please select a test case.", None | |
| selected_fixture = next((f for f in test_fixtures if f["title"] == test_title), None) | |
| if not selected_fixture: | |
| return f"Error: Could not find test case titled '{test_title}'.", None | |
| user_query = selected_fixture["turns"][0]["text"] | |
| expected_results = selected_fixture["expected"] | |
| print(f"Test Case: '{test_title}'") | |
| print(f"User Query: '{user_query}'") | |
| behavior_options = CONFIG.get("behavior_tags", []) | |
| emotion_options = CONFIG.get("emotion_tags", []) | |
| topic_options = CONFIG.get("topic_tags", []) | |
| context_options = CONFIG.get("context_tags", []) | |
| actual_results_raw = detect_tags_from_query( | |
| user_query, | |
| behavior_options=behavior_options, | |
| emotion_options=emotion_options, | |
| topic_options=topic_options, | |
| context_options=context_options | |
| ) | |
| print(f"\nRAW NLU RESULTS from detect_tags_from_query:\n{actual_results_raw}\n") | |
| actual_results = { | |
| "emotion": [actual_results_raw.get("detected_emotion")], | |
| "behaviors": actual_results_raw.get("detected_behaviors", []), | |
| "topic_tags": [actual_results_raw.get("detected_topic")], | |
| "context_tags": actual_results_raw.get("detected_contexts", []) | |
| } | |
| pass_count = 0 | |
| total_count = 0 | |
| comparison_data = [] | |
| # Use a comprehensive set of keys from both expected and actual for thoroughness | |
| all_keys = set(expected_results.keys()) | set(actual_results.keys()) | |
| print("--- COMPARING RESULTS ---") | |
| for key in sorted(list(all_keys)): | |
| expected_set = set(expected_results.get(key, [])) | |
| actual_set = set(a for a in actual_results.get(key, []) if a and a != "None") | |
| # We only count categories that have an expectation | |
| if not expected_set: continue | |
| total_count += 1 | |
| # --- DEFINITIVE PASS/FAIL LOGIC --- | |
| # The test passes ONLY if the set of expected tags is a subset of the actual tags. | |
| # This means all expected tags must be present. | |
| # is_pass = expected_set.issubset(actual_set) | |
| # --- NEW FLEXIBLE PASS/FAIL LOGIC --- | |
| # The test now passes if there is any overlap between the expected and actual tags. | |
| is_pass = len(expected_set.intersection(actual_set)) > 0 | |
| print(f"Category: '{key}'") | |
| print(f" - Expected Set: {expected_set}") | |
| print(f" - Actual Set : {actual_set}") | |
| print(f" - Logic : expected_set.issubset(actual_set)") | |
| print(f" - Result : {is_pass}") | |
| if is_pass: | |
| pass_count += 1 | |
| comparison_data.append([ | |
| key, | |
| ", ".join(sorted(list(expected_set))), | |
| ", ".join(sorted(list(actual_set))) if actual_set else "None", | |
| "✅ Pass" if is_pass else "❌ Fail" | |
| ]) | |
| status = f"## Test Result: {pass_count} / {total_count} Categories Passed" | |
| print(f"Final Status: {pass_count}/{total_count} passed.") | |
| print("--- TEST COMPLETE ---\n") | |
| return status, comparison_data | |
| # add the new function that will run when the "Run All Tests" button is clicked. | |
| def run_all_nlu_tests(): | |
| """Runs all test fixtures in a batch and provides a summary.""" | |
| if not test_fixtures: | |
| load_test_fixtures() | |
| if not test_fixtures: | |
| return "## Batch Test Summary: No test fixtures found. Please ensure `conversation_test_fixtures.jsonl` is present.", [] | |
| print("\n--- RUNNING ALL NLU TESTS ---") | |
| behavior_options = CONFIG.get("behavior_tags", []) | |
| emotion_options = CONFIG.get("emotion_tags", []) | |
| topic_options = CONFIG.get("topic_tags", []) | |
| context_options = CONFIG.get("context_tags", []) | |
| total_tests = len(test_fixtures) | |
| passed_tests = 0 | |
| all_results_data = [] | |
| for fixture in test_fixtures: | |
| user_query = fixture["turns"][0]["text"] | |
| expected_results = fixture["expected"] | |
| actual_results_raw = detect_tags_from_query( | |
| user_query, | |
| behavior_options=behavior_options, | |
| emotion_options=emotion_options, | |
| topic_options=topic_options, | |
| context_options=context_options | |
| ) | |
| actual_results = { | |
| "emotion": [actual_results_raw.get("detected_emotion")], | |
| "behaviors": actual_results_raw.get("detected_behaviors", []), | |
| "topic_tags": [actual_results_raw.get("detected_topic")], | |
| "context_tags": actual_results_raw.get("detected_contexts", []) | |
| } | |
| pass_count = 0 | |
| total_count = 0 | |
| all_keys = set(expected_results.keys()) | |
| for key in sorted(list(all_keys)): | |
| expected_set = set(expected_results.get(key, [])) | |
| if not expected_set: continue | |
| total_count += 1 | |
| actual_set = set(a for a in actual_results.get(key, []) if a and a != "None") | |
| # Flexible pass logic: passes if there is any overlap | |
| is_pass = len(expected_set.intersection(actual_set)) > 0 | |
| if is_pass: | |
| pass_count += 1 | |
| # Determine Overall Result for this specific test case | |
| overall_result = "❌ Fail" # Default to Fail | |
| if total_count > 0: | |
| pass_ratio = pass_count / total_count | |
| if pass_ratio == 1.0: # Perfect pass (100%) | |
| passed_tests += 1 | |
| overall_result = "✅ Pass" | |
| elif pass_ratio > 0.65: # Partial Pass (> 65%) | |
| overall_result = "⚠️ Partial Pass" | |
| all_results_data.append([ | |
| fixture["title"], | |
| overall_result, | |
| f"{pass_count} / {total_count}" | |
| ]) | |
| pass_rate = (passed_tests / total_tests) * 100 if total_tests > 0 else 0 | |
| summary_md = f"## Batch Test Summary: {passed_tests} / {total_tests} Tests Passed ({pass_rate:.1f}%)" | |
| print(f"--- BATCH TEST COMPLETE: {summary_md} ---") | |
| return summary_md, all_results_data | |
| # In app.py, inside the Gradio Callbacks section for debugging | |
| def test_save_file(): | |
| """A simple function to test if we can write a file to the persistent storage.""" | |
| try: | |
| # Get the directory where the personal index is supposed to be stored | |
| storage_dir = os.path.dirname(PERSONAL_INDEX_PATH) | |
| test_file_path = os.path.join(storage_dir, "persistence_test.txt") | |
| # Write the current time to the file | |
| current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| content = f"File saved successfully at: {current_time}" | |
| with open(test_file_path, "w", encoding="utf-8") as f: | |
| f.write(content) | |
| return f"✅ Success! Wrote test file to: {test_file_path}" | |
| except Exception as e: | |
| return f"❌ Error! Failed to write file. Reason: {e}" | |
| def check_test_file(): | |
| """A simple function to check if the test file from a previous session exists.""" | |
| try: | |
| storage_dir = os.path.dirname(PERSONAL_INDEX_PATH) | |
| test_file_path = os.path.join(storage_dir, "persistence_test.txt") | |
| if os.path.exists(test_file_path): | |
| with open(test_file_path, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| return f"✅ Success! Found test file. Contents: '{content}'" | |
| else: | |
| return f"❌ Failure. Test file not found at: {test_file_path}" | |
| except Exception as e: | |
| return f"❌ Error! Failed to check for file. Reason: {e}" | |
| # --- UI Definition --- | |
| CSS = ".gradio-container { font-size: 14px; } #chatbot { min-height: 250px; } #audio_out audio { max-height: 40px; } #audio_in audio { max-height: 40px; padding: 0; }" | |
| with gr.Blocks(theme=gr.themes.Soft(), css=CSS) as demo: | |
| settings_state = gr.State({}) | |
| with gr.Tab("Chat"): | |
| user_text = gr.Textbox(show_label=False, placeholder="Type your message here...") | |
| audio_in = gr.Audio(sources=["microphone"], type="filepath", label="Voice Input", elem_id="audio_in") | |
| with gr.Row(): | |
| submit_btn = gr.Button("Send", variant="primary") | |
| save_btn = gr.Button("Save to Memory") | |
| clear_btn = gr.Button("Clear") | |
| chat_status = gr.Markdown() | |
| audio_out = gr.Audio(label="Response Audio", autoplay=True, visible=True, elem_id="audio_out") | |
| chatbot = gr.Chatbot(elem_id="chatbot", label="Conversation", type="messages") | |
| with gr.Tab("Personalize"): | |
| with gr.Accordion("Add to Personal Knowledge Base", open=True): | |
| gr.Markdown("Add personal notes, memories, or descriptions. A descriptive title helps the AI find memories more accurately.") | |
| personal_title = gr.Textbox(label="Title / Entry Name", placeholder="e.g., 'Dad's favorite songs'") | |
| personal_text = gr.Textbox(lines=5, label="Text Content (or use file upload)", placeholder="Type or paste text here. Use '—' on a new line to separate multiple entries.") | |
| personal_file = gr.File(label="Upload Audio/Video/Text File") | |
| personal_image = gr.Image(type="filepath", label="Upload Image") | |
| personal_yt_url = gr.Textbox(label="Or, provide a YouTube URL", placeholder="Paste a YouTube link here...") | |
| with gr.Row(): | |
| personal_add_btn = gr.Button("Add Knowledge to Memory", variant="primary") | |
| personal_status = gr.Markdown() | |
| with gr.Accordion("Manage Personal Knowledge", open=False): | |
| personal_memory_display = gr.DataFrame(headers=["Title", "Source", "Content"], label="Saved Personal Memories", interactive=False, row_count=(5, "dynamic")) | |
| with gr.Row(): | |
| personal_refresh_btn = gr.Button("Refresh Memories") | |
| with gr.Row(): | |
| personal_delete_selector = gr.Dropdown(label="Select a memory to delete (by its full content)", scale=3, interactive=True) | |
| personal_delete_btn = gr.Button("Delete Selected Memory", variant="stop", scale=1) | |
| personal_delete_status = gr.Markdown() | |
| with gr.Tab("Testing"): | |
| gr.Markdown("## NLU Context Detection Tests") | |
| gr.Markdown("Select a single test case to run, or run the entire batch of fixtures to get a summary of the NLU's performance.") | |
| batch_summary_md = gr.Markdown("### Batch Test Summary: Not yet run.") # <-- ADD THIS | |
| with gr.Row(): | |
| test_case_dropdown = gr.Dropdown(label="Select Single Test Case", scale=2) | |
| run_test_btn = gr.Button("Run Single Test", variant="secondary", scale=1) | |
| run_all_btn = gr.Button("Run All Tests", variant="primary", scale=1) # <-- ADD THIS | |
| test_status_md = gr.Markdown("### Test Results") | |
| test_results_df = gr.DataFrame( | |
| label="Test Results Comparison", | |
| # UPDATE these headers for the batch summary | |
| headers=["Test Case Title", "Overall Result", "Categories Passed"], | |
| interactive=False | |
| ) | |
| with gr.Tab("Settings"): | |
| with gr.Group(): | |
| gr.Markdown("## Conversation & Persona Settings") | |
| with gr.Row(): | |
| role = gr.Radio(CONFIG["roles"], value="caregiver", label="Your Role") | |
| temperature = gr.Slider(0.0, 1.2, value=0.7, step=0.1, label="Creativity") | |
| tone = gr.Dropdown(CONFIG["tones"], value="warm", label="Response Tone") | |
| with gr.Row(): | |
| patient_name = gr.Textbox(label="Patient's Name", placeholder="e.g., 'Dad' or 'John'") | |
| caregiver_name = gr.Textbox(label="Caregiver's Name", placeholder="e.g., 'me' or 'Jane'") | |
| behaviour_tag = gr.Dropdown(CONFIG["behavior_tags"], value="None", label="Behaviour Filter (Manual Override)") | |
| emotion_tag = gr.Dropdown(CONFIG["emotion_tags"], value="None", label="Emotion Filter (Manual Override)") | |
| topic_tag = gr.Dropdown(CONFIG["topic_tags"], value="None", label="Topic Tag Filter (Manual Override)") | |
| with gr.Accordion("Language, Voice & Debugging", open=False): | |
| language = gr.Dropdown(list(CONFIG["languages"].keys()), value="English", label="Response Language") | |
| tts_lang = gr.Dropdown(list(CONFIG["languages"].keys()), value="English", label="Voice Language") | |
| tts_on = gr.Checkbox(True, label="Enable Voice Response (TTS)") | |
| debug_mode = gr.Checkbox(False, label="Show Debug Info") | |
| gr.Markdown("--- \n ## General Knowledge Base Management") | |
| active_theme = gr.Radio(CONFIG["themes"], value="All", label="Active Knowledge Theme") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| files_in = gr.File(file_count="multiple", file_types=[".jsonl", ".txt"], label="Upload Knowledge Files") | |
| upload_btn = gr.Button("Upload to Theme", variant="secondary") | |
| seed_btn = gr.Button("Import Sample Data", variant="secondary") | |
| with gr.Column(scale=2): | |
| mgmt_status = gr.Markdown() | |
| files_box = gr.CheckboxGroup(choices=[], label="Enable Files for the Selected Theme") | |
| with gr.Row(): | |
| save_files_btn = gr.Button("Save Selection", variant="primary") | |
| refresh_btn = gr.Button("Refresh List") | |
| with gr.Accordion("Persistence Test", open=False): | |
| gr.Markdown("Use this tool to verify that the Hugging Face persistent storage is working correctly. \n1. Click 'Run Test'. \n2. Manually restart the Space. \n3. Click 'Check for File'.") | |
| with gr.Row(): | |
| test_save_btn = gr.Button("1. Run Persistence Test (Save File)") | |
| check_save_btn = gr.Button("3. Check for Test File") | |
| test_status = gr.Markdown() | |
| # --- Event Wiring --- | |
| all_settings_components = [role, patient_name, caregiver_name, tone, language, tts_lang, temperature, behaviour_tag, emotion_tag, topic_tag, active_theme, tts_on, debug_mode] | |
| for component in all_settings_components: | |
| component.change(fn=collect_settings, inputs=all_settings_components, outputs=settings_state) | |
| submit_btn.click(fn=chat_fn, inputs=[user_text, audio_in, settings_state, chatbot], outputs=[user_text, audio_out, chatbot]) | |
| save_btn.click(fn=save_chat_to_memory, inputs=[chatbot], outputs=[chat_status]) | |
| clear_btn.click(lambda: (None, None, [], None, "", ""), outputs=[user_text, audio_out, chatbot, audio_in, user_text, chat_status]) | |
| # add settings for debug mode | |
| personal_add_btn.click( | |
| fn=handle_add_knowledge, | |
| inputs=[personal_title, personal_text, personal_file, personal_image, personal_yt_url, settings_state], | |
| outputs=[personal_status] | |
| ).then( | |
| lambda: (None, None, None, None, None), | |
| outputs=[personal_title, personal_text, personal_file, personal_image, personal_yt_url] | |
| ) | |
| personal_refresh_btn.click(fn=list_personal_memories, inputs=None, outputs=[personal_memory_display, personal_delete_selector]) | |
| personal_delete_btn.click(fn=delete_personal_memory, inputs=[personal_delete_selector], outputs=[personal_delete_status]).then(fn=list_personal_memories, inputs=None, outputs=[personal_memory_display, personal_delete_selector]) | |
| upload_btn.click(upload_knowledge, inputs=[files_in, active_theme], outputs=[mgmt_status]).then(refresh_file_list_ui, inputs=[active_theme], outputs=[files_box, mgmt_status]) | |
| save_files_btn.click(save_file_selection, inputs=[active_theme, files_box], outputs=[mgmt_status]) | |
| seed_btn.click(seed_files_into_theme, inputs=[active_theme]).then(refresh_file_list_ui, inputs=[active_theme], outputs=[files_box, mgmt_status]) | |
| refresh_btn.click(refresh_file_list_ui, inputs=[active_theme], outputs=[files_box, mgmt_status]) | |
| active_theme.change(refresh_file_list_ui, inputs=[active_theme], outputs=[files_box, mgmt_status]) | |
| demo.load(auto_setup_on_load, inputs=[active_theme], outputs=[settings_state, files_box, mgmt_status]) | |
| test_save_btn.click(fn=test_save_file, inputs=None, outputs=[test_status]) | |
| check_save_btn.click(fn=check_test_file, inputs=None, outputs=[test_status]) | |
| # --- ADD WIRING FOR THE TESTING TAB --- | |
| demo.load(load_test_fixtures, outputs=[test_case_dropdown]) | |
| run_test_btn.click( | |
| fn=run_nlu_test, | |
| inputs=[test_case_dropdown], | |
| outputs=[test_status_md, test_results_df] | |
| ) | |
| # --- ADD THE LINE BELOW --- | |
| run_all_btn.click(fn=run_all_nlu_tests, outputs=[batch_summary_md, test_results_df]) | |
| # --- Startup Logic --- | |
| def pre_load_indexes(): | |
| global personal_vectorstore | |
| print("Pre-loading all knowledge base indexes at startup...") | |
| for theme in CONFIG["themes"]: | |
| print(f" - Loading general index for theme: '{theme}'") | |
| try: | |
| ensure_index(theme) | |
| print(f" ...'{theme}' theme loaded successfully.") | |
| except Exception as e: | |
| print(f" ...Error loading theme '{theme}': {e}") | |
| print(" - Loading personal knowledge index...") | |
| try: | |
| personal_vectorstore = build_or_load_vectorstore([], PERSONAL_INDEX_PATH, is_personal=True) | |
| print(" ...Personal knowledge loaded successfully.") | |
| except Exception as e: | |
| print(f" ...Error loading personal knowledge: {e}") | |
| print("All indexes loaded. Application is ready.") | |
| # --- STARTUP LOGIC WITH DIAGNOSTICS --- | |
| print("\n--- SCRIPT LOADED. CHECKING IF RUNNING AS MAIN ---\n") | |
| if __name__ == "__main__": | |
| print("--- STARTUP BLOCK (__name__ == '__main__') IS EXECUTING ---") | |
| print("\nStep 1: Seeding sample files...") | |
| seed_files_into_theme('All') | |
| print("Step 1: Seeding complete.") | |
| print("\nStep 2: Pre-loading indexes...") | |
| pre_load_indexes() | |
| print("Step 2: Pre-loading complete.") | |
| print("\nStep 3: Launching Gradio interface...") | |
| demo.queue().launch(debug=True) | |
| print("Step 3: Gradio launch command issued.") | |
| else: | |
| print("--- WARNING: SCRIPT IS BEING IMPORTED, NOT RUN DIRECTLY ---") | |
| print("--- The if __name__ == '__main__' block was SKIPPED. ---") | |
| # if __name__ == "__main__": | |
| # Ensure the default theme has its sample files before we try to build an index | |
| # seed_files_into_theme('All') | |
| # pre_load_indexes() | |
| # demo.queue().launch(debug=True) | |