Spaces:
Sleeping
Sleeping
Venkatesh Rajagopal
REFRAME: live CBT studio — fine-tuned Gemma 12B on Modal + Cohere voice (ZeroGPU)
4ae4ae8 | """REFRAME — Live cognitive restructuring studio. | |
| Main Gradio 6 application. Dual-panel: chat + live thought card. | |
| """ | |
| # ZeroGPU: `spaces` MUST be imported before torch (it patches torch for GPU | |
| # scheduling). On a ZeroGPU Space the real package is present; locally it isn't, | |
| # so we fall back to a no-op shim and every Space branch becomes dead code. | |
| try: | |
| import spaces # noqa: E402 (intentionally first — before torch) | |
| except ImportError: | |
| class _SpacesShim: | |
| def GPU(*args, **kwargs): | |
| # Support both @spaces.GPU and @spaces.GPU(duration=...) | |
| if len(args) == 1 and callable(args[0]) and not kwargs: | |
| return args[0] | |
| def _decorator(fn): | |
| return fn | |
| return _decorator | |
| spaces = _SpacesShim() | |
| import os | |
| import warnings | |
| from pathlib import Path | |
| import gradio as gr | |
| import config | |
| import inference | |
| from atmosphere import detect_emotion | |
| from card_engine import ActiveCard, CardState, reset_card, should_show_card, update_card | |
| from components.card_deck import render_deck | |
| from components.thought_card import render_card, render_empty_card, render_intro_card | |
| from crisis import detect_crisis, get_crisis_banner_html, get_crisis_response | |
| from patterns import get_patterns_html | |
| from prompts import build_system_prompt, get_greeting | |
| from session import ( | |
| Experiment, | |
| SessionState, | |
| ThoughtCard, | |
| deserialize_session, | |
| get_session_context, | |
| save_card, | |
| save_experiment, | |
| serialize_session, | |
| start_new_session, | |
| ) | |
| # --- Optional STT import (fail-safe) --- | |
| try: | |
| import stt | |
| _stt_available = config.STT_ENABLED and stt.is_available() | |
| except ImportError: | |
| _stt_available = False | |
| # Silence a noisy (harmless) third-party deprecation: Gradio 6.18 uses Starlette's | |
| # old HTTP_422_UNPROCESSABLE_ENTITY constant (renamed to *_CONTENT). Fires on every | |
| # request. Remove once Gradio ships a fix. | |
| warnings.filterwarnings("ignore", message=r".*HTTP_422_UNPROCESSABLE_ENTITY.*") | |
| # Load CSS | |
| CSS_PATH = Path(__file__).parent / "static" / "theme.css" | |
| CUSTOM_CSS = CSS_PATH.read_text() if CSS_PATH.exists() else "" | |
| def _model_badges_html() -> str: | |
| """Header status pills — [label] · [exact id], ordered Modal → Gemma → Cohere.""" | |
| stt_lower = config.STT_MODEL.lower() | |
| if "cohere" in stt_lower: | |
| stt_name, stt_cls = "Cohere", "model-badge badge-cohere" | |
| elif "whisper" in stt_lower: | |
| stt_name, stt_cls = "Whisper", "model-badge" | |
| else: | |
| stt_name, stt_cls = "STT", "model-badge" | |
| return f""" | |
| <div class="model-badges"> | |
| <span class="model-badge badge-modal" title="Training platform · single GPU"> | |
| <span class="badge-dot"></span>Modal · NVIDIA H100 80GB | |
| </span> | |
| <span class="model-badge" title="Runtime: {config.OLLAMA_MODEL} (Ollama / llama.cpp)"> | |
| <span class="badge-dot"></span>Gemma 4 12B · google/gemma-4-12B-it · QLoRA fine-tuned | |
| </span> | |
| <span class="{stt_cls}" title="Speech-to-text (active)"> | |
| <span class="badge-dot"></span>{stt_name} · {config.STT_MODEL} | |
| </span> | |
| </div> | |
| """ | |
| def _models_card_html() -> str: | |
| """Tech-stack / models showcase for the right pane (prize transparency).""" | |
| return """ | |
| <div class="stack-card"> | |
| <div class="stack-summary"> | |
| A fine-tuned Gemma model that helps you reframe thoughts — trained in the cloud, runs locally. | |
| </div> | |
| <div class="stack-line"><b>Gemma 4 12B</b> — Google's open LLM, fine-tuned on mental-health counseling data</div> | |
| <div class="stack-line"><b>QLoRA + unsloth</b> — efficient, low-cost fine-tuning</div> | |
| <div class="stack-line"><b>Modal · H100 80GB</b> — cloud GPU the training ran on</div> | |
| <div class="stack-line"><b>GGUF Q4_K_M</b> — compact, quantized model file</div> | |
| <div class="stack-line"><b>llama.cpp</b> — runs the model locally, low latency</div> | |
| <div class="stack-line"><b>Ollama</b> — serves llama.cpp as a local API</div> | |
| <div class="stack-line"><b>Cohere Transcribe</b> — turns your voice into text</div> | |
| <div class="stack-line"><b>Gradio</b> — the app's web interface</div> | |
| <div class="stack-line"><b>Fine-tuning datasets</b> — mental-health counseling, empathetic dialogue & crisis responses | |
| <ul class="stack-list"> | |
| <li><a href="https://huggingface.co/datasets/ShenLab/MentalChat16K" target="_blank">MentalChat16K</a></li> | |
| <li><a href="https://huggingface.co/datasets/Amod/mental_health_counseling_conversations" target="_blank">Mental Health Counseling Conversations</a></li> | |
| <li><a href="https://huggingface.co/datasets/nbertagnolli/counsel-chat" target="_blank">CounselChat</a></li> | |
| <li><a href="https://huggingface.co/datasets/Estwld/empathetic_dialogues_llm" target="_blank">EmpatheticDialogues</a></li> | |
| <li><a href="https://huggingface.co/datasets/arnaiztech/llms-mental-health-crisis-responses" target="_blank">Mental-Health Crisis Responses</a></li> | |
| </ul> | |
| </div> | |
| </div> | |
| """ | |
| # --- Status animation HTML (wave for STT, dots for thinking) --- | |
| WAVE_HTML = """ | |
| <div style="display:flex;align-items:center;gap:8px;padding:8px 12px; | |
| background:#1a1a2e;border-radius:8px;border:1px solid #2a2a4a;"> | |
| <div style="display:flex;gap:3px;align-items:end;height:20px;"> | |
| <span style="width:3px;background:#a78bfa;border-radius:2px; | |
| animation:wave 0.8s ease-in-out infinite;height:8px;"></span> | |
| <span style="width:3px;background:#a78bfa;border-radius:2px; | |
| animation:wave 0.8s ease-in-out 0.1s infinite;height:14px;"></span> | |
| <span style="width:3px;background:#a78bfa;border-radius:2px; | |
| animation:wave 0.8s ease-in-out 0.2s infinite;height:20px;"></span> | |
| <span style="width:3px;background:#a78bfa;border-radius:2px; | |
| animation:wave 0.8s ease-in-out 0.3s infinite;height:14px;"></span> | |
| <span style="width:3px;background:#a78bfa;border-radius:2px; | |
| animation:wave 0.8s ease-in-out 0.4s infinite;height:8px;"></span> | |
| </div> | |
| <span style="color:#c4b5fd;font-size:0.85rem;">Transcribing your voice...</span> | |
| </div> | |
| <style> | |
| @keyframes wave{0%,100%{transform:scaleY(0.4)}50%{transform:scaleY(1)}} | |
| </style> | |
| """ | |
| DOTS_HTML = """ | |
| <div style="display:flex;align-items:center;gap:8px;padding:8px 12px; | |
| background:#1a1a2e;border-radius:8px;border:1px solid #2a2a4a;"> | |
| <div style="display:flex;gap:4px;"> | |
| <span style="width:6px;height:6px;background:#8899aa;border-radius:50%; | |
| animation:dots 1.2s ease-in-out infinite;"></span> | |
| <span style="width:6px;height:6px;background:#8899aa;border-radius:50%; | |
| animation:dots 1.2s ease-in-out 0.2s infinite;"></span> | |
| <span style="width:6px;height:6px;background:#8899aa;border-radius:50%; | |
| animation:dots 1.2s ease-in-out 0.4s infinite;"></span> | |
| </div> | |
| <span style="color:#8899aa;font-size:0.85rem;">Thinking...</span> | |
| </div> | |
| <style> | |
| @keyframes dots{0%,100%{opacity:0.3;transform:scale(0.8)}50%{opacity:1;transform:scale(1.2)}} | |
| </style> | |
| """ | |
| def _show_wave(): | |
| return gr.update(value=WAVE_HTML, visible=True) | |
| def _show_dots(): | |
| return gr.update(value=DOTS_HTML, visible=True) | |
| def _hide_status(): | |
| return gr.update(value="", visible=False) | |
| def _transcribe_and_clear(audio_filepath): | |
| """Transcribe audio, return (text_for_input, None_to_clear_audio). | |
| On a ZeroGPU Space this runs in a GPU-allocated fork (Cohere on CUDA); | |
| locally @spaces.GPU is a no-op shim and STT runs on XPU/CPU. | |
| """ | |
| if not audio_filepath or not _stt_available: | |
| return "Couldn't hear that — try typing instead.", None | |
| text = stt.transcribe(audio_filepath) | |
| if not text: | |
| return "Couldn't hear that — try typing instead.", None | |
| return text, None | |
| def respond( | |
| history: list[dict], | |
| session_json: str, | |
| card_state_json: str, | |
| ): | |
| """Main chat response handler with streaming. | |
| Yields: (history, card_html, atmosphere_class, crisis_html, session_json, card_state_json) | |
| """ | |
| # Get the user message from the last history entry (already added by user_submit) | |
| if not history: | |
| yield history, render_empty_card(), "atmosphere-neutral", "", session_json, card_state_json | |
| return | |
| last_entry = history[-1] | |
| # Handle both dict format and possible list/tuple format | |
| if isinstance(last_entry, dict): | |
| raw_content = last_entry.get("content", "") | |
| elif isinstance(last_entry, (list, tuple)): | |
| raw_content = last_entry[0] if last_entry else "" | |
| else: | |
| raw_content = str(last_entry) | |
| # Content may be a list of part-dicts in Gradio 6, e.g. {"text": ..., "type": "text"}. | |
| # Extract the text field rather than str()-ing the whole dict. | |
| if isinstance(raw_content, list): | |
| parts = [] | |
| for part in raw_content: | |
| if isinstance(part, dict): | |
| parts.append(str(part.get("text", ""))) | |
| elif part: | |
| parts.append(str(part)) | |
| user_message = " ".join(p for p in parts if p).strip() | |
| else: | |
| user_message = str(raw_content) if raw_content else "" | |
| if not user_message.strip(): | |
| yield history, render_empty_card(), "atmosphere-neutral", "", session_json, card_state_json | |
| return | |
| # Deserialize state | |
| session = deserialize_session(session_json) | |
| active_card = _deserialize_card_state(card_state_json) | |
| # Build system prompt with session context | |
| context = get_session_context(session) | |
| system_prompt = build_system_prompt(context) | |
| # Check for crisis | |
| crisis_html = "" | |
| if detect_crisis(user_message): | |
| crisis_html = get_crisis_banner_html() | |
| # Stream model response (history already has user message from user_submit) | |
| full_response = "" | |
| for partial in inference.stream_response(user_message, history[:-1], system_prompt): | |
| full_response = partial | |
| # Update card engine | |
| active_card = update_card(active_card, full_response, user_message) | |
| card_html_val = render_card(active_card) if should_show_card(active_card) else render_empty_card() | |
| # Detect atmosphere | |
| atmos = detect_emotion(user_message + " " + full_response) | |
| # Build current history with streaming response | |
| current_history = history + [{"role": "assistant", "content": full_response}] | |
| yield ( | |
| current_history, | |
| card_html_val, | |
| atmos, | |
| crisis_html, | |
| serialize_session(session), | |
| _serialize_card_state(active_card), | |
| ) | |
| # If card completed, save it | |
| if active_card.state == CardState.COMPLETE: | |
| session = save_card(session, active_card.card) | |
| active_card = reset_card() | |
| # Final yield | |
| card_html_val = render_card(active_card) if should_show_card(active_card) else render_empty_card() | |
| final_history = history + [{"role": "assistant", "content": full_response}] | |
| yield ( | |
| final_history, | |
| card_html_val, | |
| detect_emotion(user_message + " " + full_response), | |
| crisis_html, | |
| serialize_session(session), | |
| _serialize_card_state(active_card), | |
| ) | |
| def initialize_session(session_json: str): | |
| """Called on app load — initialize or resume session.""" | |
| session = deserialize_session(session_json) | |
| session = start_new_session(session) | |
| greeting = get_greeting(get_session_context(session)) | |
| history = [{"role": "assistant", "content": greeting}] | |
| return history, serialize_session(session) | |
| def get_deck_html(session_json: str) -> str: | |
| """Render the card deck from session state.""" | |
| session = deserialize_session(session_json) | |
| return render_deck(session.cards) | |
| def get_progress_html(session_json: str) -> str: | |
| """Render the patterns/progress panel.""" | |
| session = deserialize_session(session_json) | |
| return get_patterns_html(session) | |
| def save_current_card(session_json: str, card_state_json: str): | |
| """Manually save the current card to deck.""" | |
| session = deserialize_session(session_json) | |
| active_card = _deserialize_card_state(card_state_json) | |
| if active_card.state != CardState.IDLE and active_card.card.automatic_thought: | |
| session = save_card(session, active_card.card) | |
| active_card = reset_card() | |
| return ( | |
| serialize_session(session), | |
| _serialize_card_state(active_card), | |
| render_empty_card(), | |
| render_deck(session.cards), | |
| ) | |
| # --- Card state serialization helpers --- | |
| def _serialize_card_state(active: ActiveCard) -> str: | |
| """Serialize ActiveCard to JSON string.""" | |
| import json | |
| from dataclasses import asdict | |
| return json.dumps({ | |
| "state": active.state.value, | |
| "card": asdict(active.card), | |
| "turn_count": active.turn_count, | |
| }) | |
| def _deserialize_card_state(data: str | None) -> ActiveCard: | |
| """Deserialize ActiveCard from JSON string.""" | |
| import json | |
| if not data: | |
| return ActiveCard() | |
| try: | |
| d = json.loads(data) if isinstance(data, str) else data | |
| card = ThoughtCard(**d.get("card", {})) | |
| state = CardState(d.get("state", "idle")) | |
| return ActiveCard(state=state, card=card, turn_count=d.get("turn_count", 0)) | |
| except (json.JSONDecodeError, TypeError, ValueError): | |
| return ActiveCard() | |
| # --- Build the Gradio App --- | |
| with gr.Blocks( | |
| title=config.APP_TITLE, | |
| fill_height=True, | |
| fill_width=True, | |
| ) as demo: | |
| # Hidden state components | |
| session_state = gr.BrowserState("", storage_key="reframe_session") | |
| card_state = gr.State("") | |
| atmosphere_state = gr.State("atmosphere-neutral") | |
| # Top-center credit | |
| gr.HTML('<div class="app-credit">Created with 🍁 in Canada</div>') | |
| # Header — title, subtitle, then model pills (horizontal, after the string) | |
| with gr.Row(): | |
| gr.HTML(f""" | |
| <div class="app-header"> | |
| <div class="app-title">{config.APP_TITLE}</div> | |
| <div class="app-subtitle-row"> | |
| <div class="app-subtitle">{config.APP_SUBTITLE}</div> | |
| {_model_badges_html()} | |
| </div> | |
| </div> | |
| """) | |
| # Main layout: Chat (60%) + Card Panel (40%) | |
| with gr.Row(): | |
| # Left: Chat | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| elem_id="reframe-chat", | |
| height=430, | |
| show_label=False, | |
| placeholder="What's on your mind today?", | |
| buttons=["copy"], | |
| ) | |
| msg_input = gr.Textbox( | |
| placeholder="Type here...", | |
| show_label=False, | |
| container=False, | |
| scale=7, | |
| submit_btn=True, | |
| ) | |
| # Voice input (only if STT deps available) | |
| if _stt_available: | |
| _voice_label = ( | |
| "🎙️ Speak your thoughts — powered by Cohere" | |
| if "cohere" in config.STT_MODEL.lower() | |
| else "🎤 Or speak your thoughts" | |
| ) | |
| audio_input = gr.Audio( | |
| sources=["microphone"], | |
| type="filepath", | |
| label=_voice_label, | |
| show_label=True, | |
| ) | |
| voice_status = gr.HTML(value="", visible=False) | |
| else: | |
| audio_input = None | |
| voice_status = None | |
| # Right: tabs at top (How it works default); live card lives in the Deck tab | |
| with gr.Column(scale=2, min_width=280): | |
| # Crisis banner (hidden by default) — kept above tabs so it always shows | |
| crisis_html = gr.HTML(value="") | |
| with gr.Tabs(): | |
| with gr.Tab("How it works"): | |
| gr.HTML(value=render_intro_card()) | |
| with gr.Tab("🛠️ Stack"): | |
| gr.HTML(value=_models_card_html()) | |
| with gr.Tab("Deck"): | |
| card_html = gr.HTML(value=render_empty_card()) | |
| save_btn = gr.Button("✨ Save to Deck", variant="secondary", size="sm") | |
| deck_html = gr.HTML(value="") | |
| with gr.Tab("Patterns"): | |
| patterns_html_component = gr.HTML(value="") | |
| # --- Event Wiring --- | |
| # Helper to add user msg to chat immediately and clear input | |
| def user_submit(message, history): | |
| """Immediately show user message and clear input.""" | |
| if not message or not message.strip(): | |
| return "", history | |
| return "", history + [{"role": "user", "content": message}] | |
| # Main chat submit: clear input + show msg instantly, then stream response | |
| submit_event = msg_input.submit( | |
| fn=user_submit, | |
| inputs=[msg_input, chatbot], | |
| outputs=[msg_input, chatbot], | |
| ).then( | |
| fn=lambda: gr.update(interactive=False), | |
| outputs=[msg_input], | |
| ).then( | |
| fn=respond, | |
| inputs=[chatbot, session_state, card_state], | |
| outputs=[chatbot, card_html, atmosphere_state, crisis_html, session_state, card_state], | |
| ).then( | |
| fn=lambda: gr.update(interactive=True), | |
| outputs=[msg_input], | |
| ) | |
| # Voice input: record → transcribe (with wave) → submit → respond (with dots) | |
| if _stt_available and audio_input is not None: | |
| audio_input.stop_recording( | |
| fn=_show_wave, | |
| outputs=[voice_status], | |
| ).then( | |
| fn=_transcribe_and_clear, | |
| inputs=[audio_input], | |
| outputs=[msg_input, audio_input], | |
| ).then( | |
| fn=user_submit, | |
| inputs=[msg_input, chatbot], | |
| outputs=[msg_input, chatbot], | |
| ).then( | |
| fn=_show_dots, | |
| outputs=[voice_status], | |
| ).then( | |
| fn=lambda: gr.update(interactive=False), | |
| outputs=[msg_input], | |
| ).then( | |
| fn=respond, | |
| inputs=[chatbot, session_state, card_state], | |
| outputs=[chatbot, card_html, atmosphere_state, crisis_html, session_state, card_state], | |
| ).then( | |
| fn=lambda: gr.update(interactive=True), | |
| outputs=[msg_input], | |
| ).then( | |
| fn=_hide_status, | |
| outputs=[voice_status], | |
| ) | |
| # Save card button | |
| save_btn.click( | |
| fn=save_current_card, | |
| inputs=[session_state, card_state], | |
| outputs=[session_state, card_state, card_html, deck_html], | |
| ) | |
| # Refresh deck/patterns when tabs are clicked | |
| deck_html.change(fn=None) # placeholder | |
| # Initialize on load | |
| demo.load( | |
| fn=initialize_session, | |
| inputs=[session_state], | |
| outputs=[chatbot, session_state], | |
| ) | |
| # Refresh deck and patterns on session change | |
| session_state.change( | |
| fn=get_deck_html, | |
| inputs=[session_state], | |
| outputs=[deck_html], | |
| ) | |
| session_state.change( | |
| fn=get_progress_html, | |
| inputs=[session_state], | |
| outputs=[patterns_html_component], | |
| ) | |
| if __name__ == "__main__": | |
| # Preload the STT model locally for an instant first response. On a Space we | |
| # SKIP preload so the app starts immediately (no startup timeout downloading | |
| # ~4 GB); the model then loads lazily on the first transcription, inside the | |
| # @spaces.GPU context on ZeroGPU. | |
| if _stt_available and not os.environ.get("SPACE_ID"): | |
| stt.preload_model() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| css=CUSTOM_CSS, | |
| head=""" | |
| <style> | |
| /* Audio device selector — prevent truncation */ | |
| [data-testid="audio"] select, | |
| [data-testid="audio"] .audio-source-select, | |
| [data-testid="audio"] button[aria-label*="microphone"], | |
| .audio-component select { | |
| min-width: 180px !important; | |
| max-width: 250px !important; | |
| } | |
| #reframe-chat .placeholder { | |
| font-size: 2.5rem !important; | |
| font-family: 'Comic Sans MS', 'Chalkboard SE', 'Comic Neue', cursive !important; | |
| color: #a78bfa !important; | |
| opacity: 1 !important; | |
| font-weight: 700 !important; | |
| background: transparent !important; | |
| border: none !important; | |
| box-shadow: none !important; | |
| } | |
| #reframe-chat .placeholder * { | |
| font-size: inherit !important; | |
| font-family: inherit !important; | |
| color: inherit !important; | |
| background: transparent !important; | |
| } | |
| </style> | |
| <script> | |
| // Force-style placeholder after Gradio renders | |
| function stylePlaceholder() { | |
| const el = document.querySelector('#reframe-chat .placeholder'); | |
| if (el) { | |
| el.style.fontSize = '2.5rem'; | |
| el.style.fontFamily = "'Comic Sans MS', 'Chalkboard SE', cursive"; | |
| el.style.color = '#a78bfa'; | |
| el.style.opacity = '1'; | |
| el.style.fontWeight = '700'; | |
| el.style.background = 'transparent'; | |
| el.style.border = 'none'; | |
| el.style.boxShadow = 'none'; | |
| el.querySelectorAll('*').forEach(c => { | |
| c.style.fontSize = 'inherit'; | |
| c.style.fontFamily = 'inherit'; | |
| c.style.color = 'inherit'; | |
| c.style.background = 'transparent'; | |
| }); | |
| } | |
| } | |
| setTimeout(stylePlaceholder, 500); | |
| setTimeout(stylePlaceholder, 1500); | |
| new MutationObserver(stylePlaceholder).observe(document.body, {childList: true, subtree: true}); | |
| // Explicitly request microphone permission on first user interaction | |
| let micRequested = false; | |
| function requestMic() { | |
| if (micRequested) return; | |
| micRequested = true; | |
| if (navigator.mediaDevices && navigator.mediaDevices.getUserMedia) { | |
| navigator.mediaDevices.getUserMedia({ audio: true }) | |
| .then(stream => { | |
| console.log('Microphone access granted'); | |
| stream.getTracks().forEach(t => t.stop()); | |
| }) | |
| .catch(err => { | |
| console.warn('Microphone access denied:', err.message); | |
| const audioEl = document.querySelector('[data-testid="audio"]'); | |
| if (audioEl) { | |
| audioEl.style.opacity = '0.5'; | |
| audioEl.title = 'Microphone blocked — check browser permissions'; | |
| } | |
| }); | |
| } | |
| } | |
| // Trigger on first click anywhere (browsers require user gesture) | |
| document.addEventListener('click', requestMic, { once: true }); | |
| // Also try on page load (works if permission was previously granted) | |
| setTimeout(requestMic, 2000); | |
| </script> | |
| """, | |
| ) | |