reframe / app.py
Venkatesh Rajagopal
REFRAME: live CBT studio — fine-tuned Gemma 12B on Modal + Cohere voice (ZeroGPU)
4ae4ae8
Raw
History Blame Contribute Delete
22.6 kB
"""REFRAME — Live cognitive restructuring studio.
Main Gradio 6 application. Dual-panel: chat + live thought card.
"""
# ZeroGPU: `spaces` MUST be imported before torch (it patches torch for GPU
# scheduling). On a ZeroGPU Space the real package is present; locally it isn't,
# so we fall back to a no-op shim and every Space branch becomes dead code.
try:
import spaces # noqa: E402 (intentionally first — before torch)
except ImportError:
class _SpacesShim:
@staticmethod
def GPU(*args, **kwargs):
# Support both @spaces.GPU and @spaces.GPU(duration=...)
if len(args) == 1 and callable(args[0]) and not kwargs:
return args[0]
def _decorator(fn):
return fn
return _decorator
spaces = _SpacesShim()
import os
import warnings
from pathlib import Path
import gradio as gr
import config
import inference
from atmosphere import detect_emotion
from card_engine import ActiveCard, CardState, reset_card, should_show_card, update_card
from components.card_deck import render_deck
from components.thought_card import render_card, render_empty_card, render_intro_card
from crisis import detect_crisis, get_crisis_banner_html, get_crisis_response
from patterns import get_patterns_html
from prompts import build_system_prompt, get_greeting
from session import (
Experiment,
SessionState,
ThoughtCard,
deserialize_session,
get_session_context,
save_card,
save_experiment,
serialize_session,
start_new_session,
)
# --- Optional STT import (fail-safe) ---
try:
import stt
_stt_available = config.STT_ENABLED and stt.is_available()
except ImportError:
_stt_available = False
# Silence a noisy (harmless) third-party deprecation: Gradio 6.18 uses Starlette's
# old HTTP_422_UNPROCESSABLE_ENTITY constant (renamed to *_CONTENT). Fires on every
# request. Remove once Gradio ships a fix.
warnings.filterwarnings("ignore", message=r".*HTTP_422_UNPROCESSABLE_ENTITY.*")
# Load CSS
CSS_PATH = Path(__file__).parent / "static" / "theme.css"
CUSTOM_CSS = CSS_PATH.read_text() if CSS_PATH.exists() else ""
def _model_badges_html() -> str:
"""Header status pills — [label] · [exact id], ordered Modal → Gemma → Cohere."""
stt_lower = config.STT_MODEL.lower()
if "cohere" in stt_lower:
stt_name, stt_cls = "Cohere", "model-badge badge-cohere"
elif "whisper" in stt_lower:
stt_name, stt_cls = "Whisper", "model-badge"
else:
stt_name, stt_cls = "STT", "model-badge"
return f"""
<div class="model-badges">
<span class="model-badge badge-modal" title="Training platform · single GPU">
<span class="badge-dot"></span>Modal · NVIDIA H100 80GB
</span>
<span class="model-badge" title="Runtime: {config.OLLAMA_MODEL} (Ollama / llama.cpp)">
<span class="badge-dot"></span>Gemma 4 12B · google/gemma-4-12B-it · QLoRA fine-tuned
</span>
<span class="{stt_cls}" title="Speech-to-text (active)">
<span class="badge-dot"></span>{stt_name} · {config.STT_MODEL}
</span>
</div>
"""
def _models_card_html() -> str:
"""Tech-stack / models showcase for the right pane (prize transparency)."""
return """
<div class="stack-card">
<div class="stack-summary">
A fine-tuned Gemma model that helps you reframe thoughts — trained in the cloud, runs locally.
</div>
<div class="stack-line"><b>Gemma 4 12B</b> — Google's open LLM, fine-tuned on mental-health counseling data</div>
<div class="stack-line"><b>QLoRA + unsloth</b> — efficient, low-cost fine-tuning</div>
<div class="stack-line"><b>Modal · H100 80GB</b> — cloud GPU the training ran on</div>
<div class="stack-line"><b>GGUF Q4_K_M</b> — compact, quantized model file</div>
<div class="stack-line"><b>llama.cpp</b> — runs the model locally, low latency</div>
<div class="stack-line"><b>Ollama</b> — serves llama.cpp as a local API</div>
<div class="stack-line"><b>Cohere Transcribe</b> — turns your voice into text</div>
<div class="stack-line"><b>Gradio</b> — the app's web interface</div>
<div class="stack-line"><b>Fine-tuning datasets</b> — mental-health counseling, empathetic dialogue &amp; crisis responses
<ul class="stack-list">
<li><a href="https://huggingface.co/datasets/ShenLab/MentalChat16K" target="_blank">MentalChat16K</a></li>
<li><a href="https://huggingface.co/datasets/Amod/mental_health_counseling_conversations" target="_blank">Mental Health Counseling Conversations</a></li>
<li><a href="https://huggingface.co/datasets/nbertagnolli/counsel-chat" target="_blank">CounselChat</a></li>
<li><a href="https://huggingface.co/datasets/Estwld/empathetic_dialogues_llm" target="_blank">EmpatheticDialogues</a></li>
<li><a href="https://huggingface.co/datasets/arnaiztech/llms-mental-health-crisis-responses" target="_blank">Mental-Health Crisis Responses</a></li>
</ul>
</div>
</div>
"""
# --- Status animation HTML (wave for STT, dots for thinking) ---
WAVE_HTML = """
<div style="display:flex;align-items:center;gap:8px;padding:8px 12px;
background:#1a1a2e;border-radius:8px;border:1px solid #2a2a4a;">
<div style="display:flex;gap:3px;align-items:end;height:20px;">
<span style="width:3px;background:#a78bfa;border-radius:2px;
animation:wave 0.8s ease-in-out infinite;height:8px;"></span>
<span style="width:3px;background:#a78bfa;border-radius:2px;
animation:wave 0.8s ease-in-out 0.1s infinite;height:14px;"></span>
<span style="width:3px;background:#a78bfa;border-radius:2px;
animation:wave 0.8s ease-in-out 0.2s infinite;height:20px;"></span>
<span style="width:3px;background:#a78bfa;border-radius:2px;
animation:wave 0.8s ease-in-out 0.3s infinite;height:14px;"></span>
<span style="width:3px;background:#a78bfa;border-radius:2px;
animation:wave 0.8s ease-in-out 0.4s infinite;height:8px;"></span>
</div>
<span style="color:#c4b5fd;font-size:0.85rem;">Transcribing your voice...</span>
</div>
<style>
@keyframes wave{0%,100%{transform:scaleY(0.4)}50%{transform:scaleY(1)}}
</style>
"""
DOTS_HTML = """
<div style="display:flex;align-items:center;gap:8px;padding:8px 12px;
background:#1a1a2e;border-radius:8px;border:1px solid #2a2a4a;">
<div style="display:flex;gap:4px;">
<span style="width:6px;height:6px;background:#8899aa;border-radius:50%;
animation:dots 1.2s ease-in-out infinite;"></span>
<span style="width:6px;height:6px;background:#8899aa;border-radius:50%;
animation:dots 1.2s ease-in-out 0.2s infinite;"></span>
<span style="width:6px;height:6px;background:#8899aa;border-radius:50%;
animation:dots 1.2s ease-in-out 0.4s infinite;"></span>
</div>
<span style="color:#8899aa;font-size:0.85rem;">Thinking...</span>
</div>
<style>
@keyframes dots{0%,100%{opacity:0.3;transform:scale(0.8)}50%{opacity:1;transform:scale(1.2)}}
</style>
"""
def _show_wave():
return gr.update(value=WAVE_HTML, visible=True)
def _show_dots():
return gr.update(value=DOTS_HTML, visible=True)
def _hide_status():
return gr.update(value="", visible=False)
@spaces.GPU(duration=120)
def _transcribe_and_clear(audio_filepath):
"""Transcribe audio, return (text_for_input, None_to_clear_audio).
On a ZeroGPU Space this runs in a GPU-allocated fork (Cohere on CUDA);
locally @spaces.GPU is a no-op shim and STT runs on XPU/CPU.
"""
if not audio_filepath or not _stt_available:
return "Couldn't hear that — try typing instead.", None
text = stt.transcribe(audio_filepath)
if not text:
return "Couldn't hear that — try typing instead.", None
return text, None
def respond(
history: list[dict],
session_json: str,
card_state_json: str,
):
"""Main chat response handler with streaming.
Yields: (history, card_html, atmosphere_class, crisis_html, session_json, card_state_json)
"""
# Get the user message from the last history entry (already added by user_submit)
if not history:
yield history, render_empty_card(), "atmosphere-neutral", "", session_json, card_state_json
return
last_entry = history[-1]
# Handle both dict format and possible list/tuple format
if isinstance(last_entry, dict):
raw_content = last_entry.get("content", "")
elif isinstance(last_entry, (list, tuple)):
raw_content = last_entry[0] if last_entry else ""
else:
raw_content = str(last_entry)
# Content may be a list of part-dicts in Gradio 6, e.g. {"text": ..., "type": "text"}.
# Extract the text field rather than str()-ing the whole dict.
if isinstance(raw_content, list):
parts = []
for part in raw_content:
if isinstance(part, dict):
parts.append(str(part.get("text", "")))
elif part:
parts.append(str(part))
user_message = " ".join(p for p in parts if p).strip()
else:
user_message = str(raw_content) if raw_content else ""
if not user_message.strip():
yield history, render_empty_card(), "atmosphere-neutral", "", session_json, card_state_json
return
# Deserialize state
session = deserialize_session(session_json)
active_card = _deserialize_card_state(card_state_json)
# Build system prompt with session context
context = get_session_context(session)
system_prompt = build_system_prompt(context)
# Check for crisis
crisis_html = ""
if detect_crisis(user_message):
crisis_html = get_crisis_banner_html()
# Stream model response (history already has user message from user_submit)
full_response = ""
for partial in inference.stream_response(user_message, history[:-1], system_prompt):
full_response = partial
# Update card engine
active_card = update_card(active_card, full_response, user_message)
card_html_val = render_card(active_card) if should_show_card(active_card) else render_empty_card()
# Detect atmosphere
atmos = detect_emotion(user_message + " " + full_response)
# Build current history with streaming response
current_history = history + [{"role": "assistant", "content": full_response}]
yield (
current_history,
card_html_val,
atmos,
crisis_html,
serialize_session(session),
_serialize_card_state(active_card),
)
# If card completed, save it
if active_card.state == CardState.COMPLETE:
session = save_card(session, active_card.card)
active_card = reset_card()
# Final yield
card_html_val = render_card(active_card) if should_show_card(active_card) else render_empty_card()
final_history = history + [{"role": "assistant", "content": full_response}]
yield (
final_history,
card_html_val,
detect_emotion(user_message + " " + full_response),
crisis_html,
serialize_session(session),
_serialize_card_state(active_card),
)
def initialize_session(session_json: str):
"""Called on app load — initialize or resume session."""
session = deserialize_session(session_json)
session = start_new_session(session)
greeting = get_greeting(get_session_context(session))
history = [{"role": "assistant", "content": greeting}]
return history, serialize_session(session)
def get_deck_html(session_json: str) -> str:
"""Render the card deck from session state."""
session = deserialize_session(session_json)
return render_deck(session.cards)
def get_progress_html(session_json: str) -> str:
"""Render the patterns/progress panel."""
session = deserialize_session(session_json)
return get_patterns_html(session)
def save_current_card(session_json: str, card_state_json: str):
"""Manually save the current card to deck."""
session = deserialize_session(session_json)
active_card = _deserialize_card_state(card_state_json)
if active_card.state != CardState.IDLE and active_card.card.automatic_thought:
session = save_card(session, active_card.card)
active_card = reset_card()
return (
serialize_session(session),
_serialize_card_state(active_card),
render_empty_card(),
render_deck(session.cards),
)
# --- Card state serialization helpers ---
def _serialize_card_state(active: ActiveCard) -> str:
"""Serialize ActiveCard to JSON string."""
import json
from dataclasses import asdict
return json.dumps({
"state": active.state.value,
"card": asdict(active.card),
"turn_count": active.turn_count,
})
def _deserialize_card_state(data: str | None) -> ActiveCard:
"""Deserialize ActiveCard from JSON string."""
import json
if not data:
return ActiveCard()
try:
d = json.loads(data) if isinstance(data, str) else data
card = ThoughtCard(**d.get("card", {}))
state = CardState(d.get("state", "idle"))
return ActiveCard(state=state, card=card, turn_count=d.get("turn_count", 0))
except (json.JSONDecodeError, TypeError, ValueError):
return ActiveCard()
# --- Build the Gradio App ---
with gr.Blocks(
title=config.APP_TITLE,
fill_height=True,
fill_width=True,
) as demo:
# Hidden state components
session_state = gr.BrowserState("", storage_key="reframe_session")
card_state = gr.State("")
atmosphere_state = gr.State("atmosphere-neutral")
# Top-center credit
gr.HTML('<div class="app-credit">Created with 🍁 in Canada</div>')
# Header — title, subtitle, then model pills (horizontal, after the string)
with gr.Row():
gr.HTML(f"""
<div class="app-header">
<div class="app-title">{config.APP_TITLE}</div>
<div class="app-subtitle-row">
<div class="app-subtitle">{config.APP_SUBTITLE}</div>
{_model_badges_html()}
</div>
</div>
""")
# Main layout: Chat (60%) + Card Panel (40%)
with gr.Row():
# Left: Chat
with gr.Column(scale=3):
chatbot = gr.Chatbot(
elem_id="reframe-chat",
height=430,
show_label=False,
placeholder="What's on your mind today?",
buttons=["copy"],
)
msg_input = gr.Textbox(
placeholder="Type here...",
show_label=False,
container=False,
scale=7,
submit_btn=True,
)
# Voice input (only if STT deps available)
if _stt_available:
_voice_label = (
"🎙️ Speak your thoughts — powered by Cohere"
if "cohere" in config.STT_MODEL.lower()
else "🎤 Or speak your thoughts"
)
audio_input = gr.Audio(
sources=["microphone"],
type="filepath",
label=_voice_label,
show_label=True,
)
voice_status = gr.HTML(value="", visible=False)
else:
audio_input = None
voice_status = None
# Right: tabs at top (How it works default); live card lives in the Deck tab
with gr.Column(scale=2, min_width=280):
# Crisis banner (hidden by default) — kept above tabs so it always shows
crisis_html = gr.HTML(value="")
with gr.Tabs():
with gr.Tab("How it works"):
gr.HTML(value=render_intro_card())
with gr.Tab("🛠️ Stack"):
gr.HTML(value=_models_card_html())
with gr.Tab("Deck"):
card_html = gr.HTML(value=render_empty_card())
save_btn = gr.Button("✨ Save to Deck", variant="secondary", size="sm")
deck_html = gr.HTML(value="")
with gr.Tab("Patterns"):
patterns_html_component = gr.HTML(value="")
# --- Event Wiring ---
# Helper to add user msg to chat immediately and clear input
def user_submit(message, history):
"""Immediately show user message and clear input."""
if not message or not message.strip():
return "", history
return "", history + [{"role": "user", "content": message}]
# Main chat submit: clear input + show msg instantly, then stream response
submit_event = msg_input.submit(
fn=user_submit,
inputs=[msg_input, chatbot],
outputs=[msg_input, chatbot],
).then(
fn=lambda: gr.update(interactive=False),
outputs=[msg_input],
).then(
fn=respond,
inputs=[chatbot, session_state, card_state],
outputs=[chatbot, card_html, atmosphere_state, crisis_html, session_state, card_state],
).then(
fn=lambda: gr.update(interactive=True),
outputs=[msg_input],
)
# Voice input: record → transcribe (with wave) → submit → respond (with dots)
if _stt_available and audio_input is not None:
audio_input.stop_recording(
fn=_show_wave,
outputs=[voice_status],
).then(
fn=_transcribe_and_clear,
inputs=[audio_input],
outputs=[msg_input, audio_input],
).then(
fn=user_submit,
inputs=[msg_input, chatbot],
outputs=[msg_input, chatbot],
).then(
fn=_show_dots,
outputs=[voice_status],
).then(
fn=lambda: gr.update(interactive=False),
outputs=[msg_input],
).then(
fn=respond,
inputs=[chatbot, session_state, card_state],
outputs=[chatbot, card_html, atmosphere_state, crisis_html, session_state, card_state],
).then(
fn=lambda: gr.update(interactive=True),
outputs=[msg_input],
).then(
fn=_hide_status,
outputs=[voice_status],
)
# Save card button
save_btn.click(
fn=save_current_card,
inputs=[session_state, card_state],
outputs=[session_state, card_state, card_html, deck_html],
)
# Refresh deck/patterns when tabs are clicked
deck_html.change(fn=None) # placeholder
# Initialize on load
demo.load(
fn=initialize_session,
inputs=[session_state],
outputs=[chatbot, session_state],
)
# Refresh deck and patterns on session change
session_state.change(
fn=get_deck_html,
inputs=[session_state],
outputs=[deck_html],
)
session_state.change(
fn=get_progress_html,
inputs=[session_state],
outputs=[patterns_html_component],
)
if __name__ == "__main__":
# Preload the STT model locally for an instant first response. On a Space we
# SKIP preload so the app starts immediately (no startup timeout downloading
# ~4 GB); the model then loads lazily on the first transcription, inside the
# @spaces.GPU context on ZeroGPU.
if _stt_available and not os.environ.get("SPACE_ID"):
stt.preload_model()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
css=CUSTOM_CSS,
head="""
<style>
/* Audio device selector — prevent truncation */
[data-testid="audio"] select,
[data-testid="audio"] .audio-source-select,
[data-testid="audio"] button[aria-label*="microphone"],
.audio-component select {
min-width: 180px !important;
max-width: 250px !important;
}
#reframe-chat .placeholder {
font-size: 2.5rem !important;
font-family: 'Comic Sans MS', 'Chalkboard SE', 'Comic Neue', cursive !important;
color: #a78bfa !important;
opacity: 1 !important;
font-weight: 700 !important;
background: transparent !important;
border: none !important;
box-shadow: none !important;
}
#reframe-chat .placeholder * {
font-size: inherit !important;
font-family: inherit !important;
color: inherit !important;
background: transparent !important;
}
</style>
<script>
// Force-style placeholder after Gradio renders
function stylePlaceholder() {
const el = document.querySelector('#reframe-chat .placeholder');
if (el) {
el.style.fontSize = '2.5rem';
el.style.fontFamily = "'Comic Sans MS', 'Chalkboard SE', cursive";
el.style.color = '#a78bfa';
el.style.opacity = '1';
el.style.fontWeight = '700';
el.style.background = 'transparent';
el.style.border = 'none';
el.style.boxShadow = 'none';
el.querySelectorAll('*').forEach(c => {
c.style.fontSize = 'inherit';
c.style.fontFamily = 'inherit';
c.style.color = 'inherit';
c.style.background = 'transparent';
});
}
}
setTimeout(stylePlaceholder, 500);
setTimeout(stylePlaceholder, 1500);
new MutationObserver(stylePlaceholder).observe(document.body, {childList: true, subtree: true});
// Explicitly request microphone permission on first user interaction
let micRequested = false;
function requestMic() {
if (micRequested) return;
micRequested = true;
if (navigator.mediaDevices && navigator.mediaDevices.getUserMedia) {
navigator.mediaDevices.getUserMedia({ audio: true })
.then(stream => {
console.log('Microphone access granted');
stream.getTracks().forEach(t => t.stop());
})
.catch(err => {
console.warn('Microphone access denied:', err.message);
const audioEl = document.querySelector('[data-testid="audio"]');
if (audioEl) {
audioEl.style.opacity = '0.5';
audioEl.title = 'Microphone blocked — check browser permissions';
}
});
}
}
// Trigger on first click anywhere (browsers require user gesture)
document.addEventListener('click', requestMic, { once: true });
// Also try on page load (works if permission was previously granted)
setTimeout(requestMic, 2000);
</script>
""",
)