"""REFRAME — Live cognitive restructuring studio. Main Gradio 6 application. Dual-panel: chat + live thought card. """ # ZeroGPU: `spaces` MUST be imported before torch (it patches torch for GPU # scheduling). On a ZeroGPU Space the real package is present; locally it isn't, # so we fall back to a no-op shim and every Space branch becomes dead code. try: import spaces # noqa: E402 (intentionally first — before torch) except ImportError: class _SpacesShim: @staticmethod def GPU(*args, **kwargs): # Support both @spaces.GPU and @spaces.GPU(duration=...) if len(args) == 1 and callable(args[0]) and not kwargs: return args[0] def _decorator(fn): return fn return _decorator spaces = _SpacesShim() import os import warnings from pathlib import Path import gradio as gr import config import inference from atmosphere import detect_emotion from card_engine import ActiveCard, CardState, reset_card, should_show_card, update_card from components.card_deck import render_deck from components.thought_card import render_card, render_empty_card, render_intro_card from crisis import detect_crisis, get_crisis_banner_html, get_crisis_response from patterns import get_patterns_html from prompts import build_system_prompt, get_greeting from session import ( Experiment, SessionState, ThoughtCard, deserialize_session, get_session_context, save_card, save_experiment, serialize_session, start_new_session, ) # --- Optional STT import (fail-safe) --- try: import stt _stt_available = config.STT_ENABLED and stt.is_available() except ImportError: _stt_available = False # Silence a noisy (harmless) third-party deprecation: Gradio 6.18 uses Starlette's # old HTTP_422_UNPROCESSABLE_ENTITY constant (renamed to *_CONTENT). Fires on every # request. Remove once Gradio ships a fix. warnings.filterwarnings("ignore", message=r".*HTTP_422_UNPROCESSABLE_ENTITY.*") # Load CSS CSS_PATH = Path(__file__).parent / "static" / "theme.css" CUSTOM_CSS = CSS_PATH.read_text() if CSS_PATH.exists() else "" def _model_badges_html() -> str: """Header status pills — [label] · [exact id], ordered Modal → Gemma → Cohere.""" stt_lower = config.STT_MODEL.lower() if "cohere" in stt_lower: stt_name, stt_cls = "Cohere", "model-badge badge-cohere" elif "whisper" in stt_lower: stt_name, stt_cls = "Whisper", "model-badge" else: stt_name, stt_cls = "STT", "model-badge" return f"""
Modal · NVIDIA H100 80GB Gemma 4 12B · google/gemma-4-12B-it · QLoRA fine-tuned {stt_name} · {config.STT_MODEL}
""" def _models_card_html() -> str: """Tech-stack / models showcase for the right pane (prize transparency).""" return """
A fine-tuned Gemma model that helps you reframe thoughts — trained in the cloud, runs locally.
Gemma 4 12B — Google's open LLM, fine-tuned on mental-health counseling data
QLoRA + unsloth — efficient, low-cost fine-tuning
Modal · H100 80GB — cloud GPU the training ran on
GGUF Q4_K_M — compact, quantized model file
llama.cpp — runs the model locally, low latency
Ollama — serves llama.cpp as a local API
Cohere Transcribe — turns your voice into text
Gradio — the app's web interface
Fine-tuning datasets — mental-health counseling, empathetic dialogue & crisis responses
""" # --- Status animation HTML (wave for STT, dots for thinking) --- WAVE_HTML = """
Transcribing your voice...
""" DOTS_HTML = """
Thinking...
""" def _show_wave(): return gr.update(value=WAVE_HTML, visible=True) def _show_dots(): return gr.update(value=DOTS_HTML, visible=True) def _hide_status(): return gr.update(value="", visible=False) @spaces.GPU(duration=120) def _transcribe_and_clear(audio_filepath): """Transcribe audio, return (text_for_input, None_to_clear_audio). On a ZeroGPU Space this runs in a GPU-allocated fork (Cohere on CUDA); locally @spaces.GPU is a no-op shim and STT runs on XPU/CPU. """ if not audio_filepath or not _stt_available: return "Couldn't hear that — try typing instead.", None text = stt.transcribe(audio_filepath) if not text: return "Couldn't hear that — try typing instead.", None return text, None def respond( history: list[dict], session_json: str, card_state_json: str, ): """Main chat response handler with streaming. Yields: (history, card_html, atmosphere_class, crisis_html, session_json, card_state_json) """ # Get the user message from the last history entry (already added by user_submit) if not history: yield history, render_empty_card(), "atmosphere-neutral", "", session_json, card_state_json return last_entry = history[-1] # Handle both dict format and possible list/tuple format if isinstance(last_entry, dict): raw_content = last_entry.get("content", "") elif isinstance(last_entry, (list, tuple)): raw_content = last_entry[0] if last_entry else "" else: raw_content = str(last_entry) # Content may be a list of part-dicts in Gradio 6, e.g. {"text": ..., "type": "text"}. # Extract the text field rather than str()-ing the whole dict. if isinstance(raw_content, list): parts = [] for part in raw_content: if isinstance(part, dict): parts.append(str(part.get("text", ""))) elif part: parts.append(str(part)) user_message = " ".join(p for p in parts if p).strip() else: user_message = str(raw_content) if raw_content else "" if not user_message.strip(): yield history, render_empty_card(), "atmosphere-neutral", "", session_json, card_state_json return # Deserialize state session = deserialize_session(session_json) active_card = _deserialize_card_state(card_state_json) # Build system prompt with session context context = get_session_context(session) system_prompt = build_system_prompt(context) # Check for crisis crisis_html = "" if detect_crisis(user_message): crisis_html = get_crisis_banner_html() # Stream model response (history already has user message from user_submit) full_response = "" for partial in inference.stream_response(user_message, history[:-1], system_prompt): full_response = partial # Update card engine active_card = update_card(active_card, full_response, user_message) card_html_val = render_card(active_card) if should_show_card(active_card) else render_empty_card() # Detect atmosphere atmos = detect_emotion(user_message + " " + full_response) # Build current history with streaming response current_history = history + [{"role": "assistant", "content": full_response}] yield ( current_history, card_html_val, atmos, crisis_html, serialize_session(session), _serialize_card_state(active_card), ) # If card completed, save it if active_card.state == CardState.COMPLETE: session = save_card(session, active_card.card) active_card = reset_card() # Final yield card_html_val = render_card(active_card) if should_show_card(active_card) else render_empty_card() final_history = history + [{"role": "assistant", "content": full_response}] yield ( final_history, card_html_val, detect_emotion(user_message + " " + full_response), crisis_html, serialize_session(session), _serialize_card_state(active_card), ) def initialize_session(session_json: str): """Called on app load — initialize or resume session.""" session = deserialize_session(session_json) session = start_new_session(session) greeting = get_greeting(get_session_context(session)) history = [{"role": "assistant", "content": greeting}] return history, serialize_session(session) def get_deck_html(session_json: str) -> str: """Render the card deck from session state.""" session = deserialize_session(session_json) return render_deck(session.cards) def get_progress_html(session_json: str) -> str: """Render the patterns/progress panel.""" session = deserialize_session(session_json) return get_patterns_html(session) def save_current_card(session_json: str, card_state_json: str): """Manually save the current card to deck.""" session = deserialize_session(session_json) active_card = _deserialize_card_state(card_state_json) if active_card.state != CardState.IDLE and active_card.card.automatic_thought: session = save_card(session, active_card.card) active_card = reset_card() return ( serialize_session(session), _serialize_card_state(active_card), render_empty_card(), render_deck(session.cards), ) # --- Card state serialization helpers --- def _serialize_card_state(active: ActiveCard) -> str: """Serialize ActiveCard to JSON string.""" import json from dataclasses import asdict return json.dumps({ "state": active.state.value, "card": asdict(active.card), "turn_count": active.turn_count, }) def _deserialize_card_state(data: str | None) -> ActiveCard: """Deserialize ActiveCard from JSON string.""" import json if not data: return ActiveCard() try: d = json.loads(data) if isinstance(data, str) else data card = ThoughtCard(**d.get("card", {})) state = CardState(d.get("state", "idle")) return ActiveCard(state=state, card=card, turn_count=d.get("turn_count", 0)) except (json.JSONDecodeError, TypeError, ValueError): return ActiveCard() # --- Build the Gradio App --- with gr.Blocks( title=config.APP_TITLE, fill_height=True, fill_width=True, ) as demo: # Hidden state components session_state = gr.BrowserState("", storage_key="reframe_session") card_state = gr.State("") atmosphere_state = gr.State("atmosphere-neutral") # Top-center credit gr.HTML('
Created with 🍁 in Canada
') # Header — title, subtitle, then model pills (horizontal, after the string) with gr.Row(): gr.HTML(f"""
{config.APP_TITLE}
{config.APP_SUBTITLE}
{_model_badges_html()}
""") # Main layout: Chat (60%) + Card Panel (40%) with gr.Row(): # Left: Chat with gr.Column(scale=3): chatbot = gr.Chatbot( elem_id="reframe-chat", height=430, show_label=False, placeholder="What's on your mind today?", buttons=["copy"], ) msg_input = gr.Textbox( placeholder="Type here...", show_label=False, container=False, scale=7, submit_btn=True, ) # Voice input (only if STT deps available) if _stt_available: _voice_label = ( "🎙️ Speak your thoughts — powered by Cohere" if "cohere" in config.STT_MODEL.lower() else "🎤 Or speak your thoughts" ) audio_input = gr.Audio( sources=["microphone"], type="filepath", label=_voice_label, show_label=True, ) voice_status = gr.HTML(value="", visible=False) else: audio_input = None voice_status = None # Right: tabs at top (How it works default); live card lives in the Deck tab with gr.Column(scale=2, min_width=280): # Crisis banner (hidden by default) — kept above tabs so it always shows crisis_html = gr.HTML(value="") with gr.Tabs(): with gr.Tab("How it works"): gr.HTML(value=render_intro_card()) with gr.Tab("🛠️ Stack"): gr.HTML(value=_models_card_html()) with gr.Tab("Deck"): card_html = gr.HTML(value=render_empty_card()) save_btn = gr.Button("✨ Save to Deck", variant="secondary", size="sm") deck_html = gr.HTML(value="") with gr.Tab("Patterns"): patterns_html_component = gr.HTML(value="") # --- Event Wiring --- # Helper to add user msg to chat immediately and clear input def user_submit(message, history): """Immediately show user message and clear input.""" if not message or not message.strip(): return "", history return "", history + [{"role": "user", "content": message}] # Main chat submit: clear input + show msg instantly, then stream response submit_event = msg_input.submit( fn=user_submit, inputs=[msg_input, chatbot], outputs=[msg_input, chatbot], ).then( fn=lambda: gr.update(interactive=False), outputs=[msg_input], ).then( fn=respond, inputs=[chatbot, session_state, card_state], outputs=[chatbot, card_html, atmosphere_state, crisis_html, session_state, card_state], ).then( fn=lambda: gr.update(interactive=True), outputs=[msg_input], ) # Voice input: record → transcribe (with wave) → submit → respond (with dots) if _stt_available and audio_input is not None: audio_input.stop_recording( fn=_show_wave, outputs=[voice_status], ).then( fn=_transcribe_and_clear, inputs=[audio_input], outputs=[msg_input, audio_input], ).then( fn=user_submit, inputs=[msg_input, chatbot], outputs=[msg_input, chatbot], ).then( fn=_show_dots, outputs=[voice_status], ).then( fn=lambda: gr.update(interactive=False), outputs=[msg_input], ).then( fn=respond, inputs=[chatbot, session_state, card_state], outputs=[chatbot, card_html, atmosphere_state, crisis_html, session_state, card_state], ).then( fn=lambda: gr.update(interactive=True), outputs=[msg_input], ).then( fn=_hide_status, outputs=[voice_status], ) # Save card button save_btn.click( fn=save_current_card, inputs=[session_state, card_state], outputs=[session_state, card_state, card_html, deck_html], ) # Refresh deck/patterns when tabs are clicked deck_html.change(fn=None) # placeholder # Initialize on load demo.load( fn=initialize_session, inputs=[session_state], outputs=[chatbot, session_state], ) # Refresh deck and patterns on session change session_state.change( fn=get_deck_html, inputs=[session_state], outputs=[deck_html], ) session_state.change( fn=get_progress_html, inputs=[session_state], outputs=[patterns_html_component], ) if __name__ == "__main__": # Preload the STT model locally for an instant first response. On a Space we # SKIP preload so the app starts immediately (no startup timeout downloading # ~4 GB); the model then loads lazily on the first transcription, inside the # @spaces.GPU context on ZeroGPU. if _stt_available and not os.environ.get("SPACE_ID"): stt.preload_model() demo.launch( server_name="0.0.0.0", server_port=7860, share=False, css=CUSTOM_CSS, head=""" """, )