"""REFRAME — Live cognitive restructuring studio.
Main Gradio 6 application. Dual-panel: chat + live thought card.
"""
# ZeroGPU: `spaces` MUST be imported before torch (it patches torch for GPU
# scheduling). On a ZeroGPU Space the real package is present; locally it isn't,
# so we fall back to a no-op shim and every Space branch becomes dead code.
try:
import spaces # noqa: E402 (intentionally first — before torch)
except ImportError:
class _SpacesShim:
@staticmethod
def GPU(*args, **kwargs):
# Support both @spaces.GPU and @spaces.GPU(duration=...)
if len(args) == 1 and callable(args[0]) and not kwargs:
return args[0]
def _decorator(fn):
return fn
return _decorator
spaces = _SpacesShim()
import os
import warnings
from pathlib import Path
import gradio as gr
import config
import inference
from atmosphere import detect_emotion
from card_engine import ActiveCard, CardState, reset_card, should_show_card, update_card
from components.card_deck import render_deck
from components.thought_card import render_card, render_empty_card, render_intro_card
from crisis import detect_crisis, get_crisis_banner_html, get_crisis_response
from patterns import get_patterns_html
from prompts import build_system_prompt, get_greeting
from session import (
Experiment,
SessionState,
ThoughtCard,
deserialize_session,
get_session_context,
save_card,
save_experiment,
serialize_session,
start_new_session,
)
# --- Optional STT import (fail-safe) ---
try:
import stt
_stt_available = config.STT_ENABLED and stt.is_available()
except ImportError:
_stt_available = False
# Silence a noisy (harmless) third-party deprecation: Gradio 6.18 uses Starlette's
# old HTTP_422_UNPROCESSABLE_ENTITY constant (renamed to *_CONTENT). Fires on every
# request. Remove once Gradio ships a fix.
warnings.filterwarnings("ignore", message=r".*HTTP_422_UNPROCESSABLE_ENTITY.*")
# Load CSS
CSS_PATH = Path(__file__).parent / "static" / "theme.css"
CUSTOM_CSS = CSS_PATH.read_text() if CSS_PATH.exists() else ""
def _model_badges_html() -> str:
"""Header status pills — [label] · [exact id], ordered Modal → Gemma → Cohere."""
stt_lower = config.STT_MODEL.lower()
if "cohere" in stt_lower:
stt_name, stt_cls = "Cohere", "model-badge badge-cohere"
elif "whisper" in stt_lower:
stt_name, stt_cls = "Whisper", "model-badge"
else:
stt_name, stt_cls = "STT", "model-badge"
return f"""
Modal · NVIDIA H100 80GB
Gemma 4 12B · google/gemma-4-12B-it · QLoRA fine-tuned
{stt_name} · {config.STT_MODEL}
"""
def _models_card_html() -> str:
"""Tech-stack / models showcase for the right pane (prize transparency)."""
return """
A fine-tuned Gemma model that helps you reframe thoughts — trained in the cloud, runs locally.
Gemma 4 12B — Google's open LLM, fine-tuned on mental-health counseling data
QLoRA + unsloth — efficient, low-cost fine-tuning
Modal · H100 80GB — cloud GPU the training ran on
GGUF Q4_K_M — compact, quantized model file
llama.cpp — runs the model locally, low latency
Ollama — serves llama.cpp as a local API
Cohere Transcribe — turns your voice into text
Gradio — the app's web interface
Fine-tuning datasets — mental-health counseling, empathetic dialogue & crisis responses
"""
# --- Status animation HTML (wave for STT, dots for thinking) ---
WAVE_HTML = """
Transcribing your voice...
"""
DOTS_HTML = """
"""
def _show_wave():
return gr.update(value=WAVE_HTML, visible=True)
def _show_dots():
return gr.update(value=DOTS_HTML, visible=True)
def _hide_status():
return gr.update(value="", visible=False)
@spaces.GPU(duration=120)
def _transcribe_and_clear(audio_filepath):
"""Transcribe audio, return (text_for_input, None_to_clear_audio).
On a ZeroGPU Space this runs in a GPU-allocated fork (Cohere on CUDA);
locally @spaces.GPU is a no-op shim and STT runs on XPU/CPU.
"""
if not audio_filepath or not _stt_available:
return "Couldn't hear that — try typing instead.", None
text = stt.transcribe(audio_filepath)
if not text:
return "Couldn't hear that — try typing instead.", None
return text, None
def respond(
history: list[dict],
session_json: str,
card_state_json: str,
):
"""Main chat response handler with streaming.
Yields: (history, card_html, atmosphere_class, crisis_html, session_json, card_state_json)
"""
# Get the user message from the last history entry (already added by user_submit)
if not history:
yield history, render_empty_card(), "atmosphere-neutral", "", session_json, card_state_json
return
last_entry = history[-1]
# Handle both dict format and possible list/tuple format
if isinstance(last_entry, dict):
raw_content = last_entry.get("content", "")
elif isinstance(last_entry, (list, tuple)):
raw_content = last_entry[0] if last_entry else ""
else:
raw_content = str(last_entry)
# Content may be a list of part-dicts in Gradio 6, e.g. {"text": ..., "type": "text"}.
# Extract the text field rather than str()-ing the whole dict.
if isinstance(raw_content, list):
parts = []
for part in raw_content:
if isinstance(part, dict):
parts.append(str(part.get("text", "")))
elif part:
parts.append(str(part))
user_message = " ".join(p for p in parts if p).strip()
else:
user_message = str(raw_content) if raw_content else ""
if not user_message.strip():
yield history, render_empty_card(), "atmosphere-neutral", "", session_json, card_state_json
return
# Deserialize state
session = deserialize_session(session_json)
active_card = _deserialize_card_state(card_state_json)
# Build system prompt with session context
context = get_session_context(session)
system_prompt = build_system_prompt(context)
# Check for crisis
crisis_html = ""
if detect_crisis(user_message):
crisis_html = get_crisis_banner_html()
# Stream model response (history already has user message from user_submit)
full_response = ""
for partial in inference.stream_response(user_message, history[:-1], system_prompt):
full_response = partial
# Update card engine
active_card = update_card(active_card, full_response, user_message)
card_html_val = render_card(active_card) if should_show_card(active_card) else render_empty_card()
# Detect atmosphere
atmos = detect_emotion(user_message + " " + full_response)
# Build current history with streaming response
current_history = history + [{"role": "assistant", "content": full_response}]
yield (
current_history,
card_html_val,
atmos,
crisis_html,
serialize_session(session),
_serialize_card_state(active_card),
)
# If card completed, save it
if active_card.state == CardState.COMPLETE:
session = save_card(session, active_card.card)
active_card = reset_card()
# Final yield
card_html_val = render_card(active_card) if should_show_card(active_card) else render_empty_card()
final_history = history + [{"role": "assistant", "content": full_response}]
yield (
final_history,
card_html_val,
detect_emotion(user_message + " " + full_response),
crisis_html,
serialize_session(session),
_serialize_card_state(active_card),
)
def initialize_session(session_json: str):
"""Called on app load — initialize or resume session."""
session = deserialize_session(session_json)
session = start_new_session(session)
greeting = get_greeting(get_session_context(session))
history = [{"role": "assistant", "content": greeting}]
return history, serialize_session(session)
def get_deck_html(session_json: str) -> str:
"""Render the card deck from session state."""
session = deserialize_session(session_json)
return render_deck(session.cards)
def get_progress_html(session_json: str) -> str:
"""Render the patterns/progress panel."""
session = deserialize_session(session_json)
return get_patterns_html(session)
def save_current_card(session_json: str, card_state_json: str):
"""Manually save the current card to deck."""
session = deserialize_session(session_json)
active_card = _deserialize_card_state(card_state_json)
if active_card.state != CardState.IDLE and active_card.card.automatic_thought:
session = save_card(session, active_card.card)
active_card = reset_card()
return (
serialize_session(session),
_serialize_card_state(active_card),
render_empty_card(),
render_deck(session.cards),
)
# --- Card state serialization helpers ---
def _serialize_card_state(active: ActiveCard) -> str:
"""Serialize ActiveCard to JSON string."""
import json
from dataclasses import asdict
return json.dumps({
"state": active.state.value,
"card": asdict(active.card),
"turn_count": active.turn_count,
})
def _deserialize_card_state(data: str | None) -> ActiveCard:
"""Deserialize ActiveCard from JSON string."""
import json
if not data:
return ActiveCard()
try:
d = json.loads(data) if isinstance(data, str) else data
card = ThoughtCard(**d.get("card", {}))
state = CardState(d.get("state", "idle"))
return ActiveCard(state=state, card=card, turn_count=d.get("turn_count", 0))
except (json.JSONDecodeError, TypeError, ValueError):
return ActiveCard()
# --- Build the Gradio App ---
with gr.Blocks(
title=config.APP_TITLE,
fill_height=True,
fill_width=True,
) as demo:
# Hidden state components
session_state = gr.BrowserState("", storage_key="reframe_session")
card_state = gr.State("")
atmosphere_state = gr.State("atmosphere-neutral")
# Top-center credit
gr.HTML('Created with 🍁 in Canada
')
# Header — title, subtitle, then model pills (horizontal, after the string)
with gr.Row():
gr.HTML(f"""
""")
# Main layout: Chat (60%) + Card Panel (40%)
with gr.Row():
# Left: Chat
with gr.Column(scale=3):
chatbot = gr.Chatbot(
elem_id="reframe-chat",
height=430,
show_label=False,
placeholder="What's on your mind today?",
buttons=["copy"],
)
msg_input = gr.Textbox(
placeholder="Type here...",
show_label=False,
container=False,
scale=7,
submit_btn=True,
)
# Voice input (only if STT deps available)
if _stt_available:
_voice_label = (
"🎙️ Speak your thoughts — powered by Cohere"
if "cohere" in config.STT_MODEL.lower()
else "🎤 Or speak your thoughts"
)
audio_input = gr.Audio(
sources=["microphone"],
type="filepath",
label=_voice_label,
show_label=True,
)
voice_status = gr.HTML(value="", visible=False)
else:
audio_input = None
voice_status = None
# Right: tabs at top (How it works default); live card lives in the Deck tab
with gr.Column(scale=2, min_width=280):
# Crisis banner (hidden by default) — kept above tabs so it always shows
crisis_html = gr.HTML(value="")
with gr.Tabs():
with gr.Tab("How it works"):
gr.HTML(value=render_intro_card())
with gr.Tab("🛠️ Stack"):
gr.HTML(value=_models_card_html())
with gr.Tab("Deck"):
card_html = gr.HTML(value=render_empty_card())
save_btn = gr.Button("✨ Save to Deck", variant="secondary", size="sm")
deck_html = gr.HTML(value="")
with gr.Tab("Patterns"):
patterns_html_component = gr.HTML(value="")
# --- Event Wiring ---
# Helper to add user msg to chat immediately and clear input
def user_submit(message, history):
"""Immediately show user message and clear input."""
if not message or not message.strip():
return "", history
return "", history + [{"role": "user", "content": message}]
# Main chat submit: clear input + show msg instantly, then stream response
submit_event = msg_input.submit(
fn=user_submit,
inputs=[msg_input, chatbot],
outputs=[msg_input, chatbot],
).then(
fn=lambda: gr.update(interactive=False),
outputs=[msg_input],
).then(
fn=respond,
inputs=[chatbot, session_state, card_state],
outputs=[chatbot, card_html, atmosphere_state, crisis_html, session_state, card_state],
).then(
fn=lambda: gr.update(interactive=True),
outputs=[msg_input],
)
# Voice input: record → transcribe (with wave) → submit → respond (with dots)
if _stt_available and audio_input is not None:
audio_input.stop_recording(
fn=_show_wave,
outputs=[voice_status],
).then(
fn=_transcribe_and_clear,
inputs=[audio_input],
outputs=[msg_input, audio_input],
).then(
fn=user_submit,
inputs=[msg_input, chatbot],
outputs=[msg_input, chatbot],
).then(
fn=_show_dots,
outputs=[voice_status],
).then(
fn=lambda: gr.update(interactive=False),
outputs=[msg_input],
).then(
fn=respond,
inputs=[chatbot, session_state, card_state],
outputs=[chatbot, card_html, atmosphere_state, crisis_html, session_state, card_state],
).then(
fn=lambda: gr.update(interactive=True),
outputs=[msg_input],
).then(
fn=_hide_status,
outputs=[voice_status],
)
# Save card button
save_btn.click(
fn=save_current_card,
inputs=[session_state, card_state],
outputs=[session_state, card_state, card_html, deck_html],
)
# Refresh deck/patterns when tabs are clicked
deck_html.change(fn=None) # placeholder
# Initialize on load
demo.load(
fn=initialize_session,
inputs=[session_state],
outputs=[chatbot, session_state],
)
# Refresh deck and patterns on session change
session_state.change(
fn=get_deck_html,
inputs=[session_state],
outputs=[deck_html],
)
session_state.change(
fn=get_progress_html,
inputs=[session_state],
outputs=[patterns_html_component],
)
if __name__ == "__main__":
# Preload the STT model locally for an instant first response. On a Space we
# SKIP preload so the app starts immediately (no startup timeout downloading
# ~4 GB); the model then loads lazily on the first transcription, inside the
# @spaces.GPU context on ZeroGPU.
if _stt_available and not os.environ.get("SPACE_ID"):
stt.preload_model()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
css=CUSTOM_CSS,
head="""
""",
)