SSR / app.py
hmdliu's picture
Update app.py
4062687 verified
# app.py
import os
import json
import http.client
from io import BytesIO
import gradio as gr
from dotenv import load_dotenv
from elevenlabs.client import ElevenLabs
# ----------------------------
# Config & clients
# ----------------------------
load_dotenv() # supports local .env; on HF Spaces, set secrets in the UI
ELEVENLABS_API_KEY = os.getenv("ELEVENLABS_API_KEY", "")
API_KEY_302 = os.getenv("API_KEY_302", "")
# ElevenLabs client (only if key is present)
elevenlabs_client = None
if ELEVENLABS_API_KEY:
elevenlabs_client = ElevenLabs(api_key=ELEVENLABS_API_KEY)
# ----------------------------
# Prompt templates
# ----------------------------
PROMPT_TEMPLATE_1 = """\
You are a speech-language assistant. Given the ORIGINAL script and the TRANSCRIPT (imperfect ASR),
list words/phrases likely to trigger stuttering (e.g., consonant clusters, long multisyllabic words).
Output a short, structured summary and diagnosis for easy-to-stutter scenarios.
ORIGINAL:
{original_text}
TRANSCRIPT:
{transcribed_text}
Never give any suggestion. Only return a concise, principled diagnosis notes with easy-to-stutter scenarios.
"""
PROMPT_TEMPLATE_2 = """\
You are a speech-language assistant. Rewrite the ORIGINAL script to reduce stuttering risk, while
preserving meaning and tone. Prefer simpler synonyms, shorter clauses, easier onsets. Keep it concise.
Diagnosis notes on easy-to-stutter scenarios:
{notes}
ORIGINAL:
{original_text}
Only return the revised full script, nothing else.
"""
# New: IPA-only prompt (Baseline+IPA, step 1)
PROMPT_TEMPLATE_IPA = """\
Convert BOTH the ORIGINAL script and the ASR TRANSCRIPT into IPA with syllable boundaries.
Return ONLY the IPA text in a clearly labeled, compact format, such as:
ORIGINAL_IPA:
<ipa for original with syllable markers>
TRANSCRIPT_IPA:
<ipa for transcript with syllable markers>
Do not include any additional commentary.
ORIGINAL:
{original_text}
TRANSCRIPT:
{transcribed_text}
"""
# New: Diagnosis that uses IPA as extra signal (Baseline+IPA, step 2)
PROMPT_TEMPLATE_1_WITH_IPA = """\
You are a speech-language assistant. Given the ORIGINAL script, the TRANSCRIPT (imperfect ASR),
and their IPA annotations, list words/phrases likely to trigger stuttering (e.g., consonant clusters,
long multisyllabic words, difficult onsets). Output a short, structured summary and diagnosis for
easy-to-stutter scenarios.
ORIGINAL:
{original_text}
TRANSCRIPT:
{transcribed_text}
IPA_ANNOTATIONS:
{ipa_text}
Never give any suggestion. Only return a concise, principled diagnosis notes with easy-to-stutter scenarios.
"""
# ----------------------------
# Helpers: STT & LLM calls
# ----------------------------
def transcribe_audio(record_path: str | None) -> str:
"""
Returns the transcribed text (or an error message).
"""
audio_path = record_path
if not audio_path:
return "No audio provided. Please upload or record audio."
if not ELEVENLABS_API_KEY:
return "ELEVENLABS_API_KEY not set. Please configure your environment."
try:
with open(audio_path, "rb") as f:
audio_data = BytesIO(f.read())
except Exception as e:
return f"Failed to read audio: {e}"
try:
transcription = elevenlabs_client.speech_to_text.convert(
file=audio_data,
model_id="scribe_v1",
tag_audio_events=True,
language_code="eng",
diarize=True,
)
return transcription.text or ""
except Exception as e:
return f"Transcription error: {e}"
def call_llm_302(model: str, prompt: str) -> str:
"""
Minimal wrapper around 302.ai /v1/chat/completions.
Returns assistant text or an error string.
"""
if not API_KEY_302:
return "API_KEY_302 not set. Please configure your environment."
try:
conn = http.client.HTTPSConnection("api.302.ai")
payload = json.dumps({
"model": model,
"messages": [
{"role": "user", "content": prompt}
]
})
headers = {
"Accept": "application/json",
"Authorization": f"Bearer {API_KEY_302}",
"Content-Type": "application/json"
}
conn.request("POST", "/v1/chat/completions", payload, headers)
res = conn.getresponse()
raw = res.read().decode("utf-8")
conn.close()
output = json.loads(raw)
msg = output.get("choices", [{}])[0].get("message", {})
text = msg.get("content") or msg.get("text") or str(msg)
return text.strip()
except Exception as e:
return f"LLM API error: {e}"
# ----------------------------
# Button handlers (shared)
# ----------------------------
def on_click_transcribe(record_path):
"""Row 1: Transcribe audio."""
text = transcribe_audio(record_path)
return gr.update(value=text)
def on_click_analyze_baseline(selected_model, original_text, transcribed_text):
"""
Baseline Tab: Single-call analysis using PROMPT_TEMPLATE_1.
"""
prompt = PROMPT_TEMPLATE_1.format(
original_text=original_text or "",
transcribed_text=transcribed_text or "",
)
analysis = call_llm_302(selected_model, prompt)
return gr.update(value=analysis)
def on_click_analyze_ipa(selected_model, original_text, transcribed_text):
"""
Baseline+IPA Tab: Two-step analysis.
1) Generate IPA annotations.
2) Use IPA + original + transcript for diagnosis.
Returns (ipa_box_update, summary_update)
"""
# Step 1: IPA
ipa_prompt = PROMPT_TEMPLATE_IPA.format(
original_text=original_text or "",
transcribed_text=transcribed_text or "",
)
ipa_text = call_llm_302(selected_model, ipa_prompt)
# Step 2: Diagnosis with IPA
diag_prompt = PROMPT_TEMPLATE_1_WITH_IPA.format(
original_text=original_text or "",
transcribed_text=transcribed_text or "",
ipa_text=ipa_text or "",
)
summary = call_llm_302(selected_model, diag_prompt)
return gr.update(value=ipa_text), gr.update(value=summary)
def on_click_rewrite(selected_model, original_text, _transcribed_text_unused, summary):
"""
Row 3: Rewrite script (always annotated version) -> PROMPT_TEMPLATE_2.
"""
prompt = PROMPT_TEMPLATE_2.format(
notes=summary or "",
original_text=original_text or "",
)
revised = call_llm_302(selected_model, prompt)
return gr.update(value=revised)
# Simple pass-through to mirror recorded file into a Gradio File component
def passthrough_file(path):
return path
# ----------------------------
# Gradio UI (Tabs)
# ----------------------------
with gr.Blocks(title="DeStammerer: AI-assisted Speech Script Revision") as demo:
# gr.Markdown("### DeStammerer\nChoose a mode below. Both tabs share the same LLM selector semantics.")
with gr.Tabs():
# ------------------------ Tab 1: Baseline ------------------------
with gr.Tab("Baseline"):
# Row 1: Record + Download + Transcribe
with gr.Row():
audio_record_b = gr.Audio(label="Record Audio", sources=["microphone"], type="filepath")
audio_download_b = gr.File(label="Audio Download", interactive=False)
btn_transcribe_b = gr.Button("1) Transcribe")
# Row 2: ASR, Original, Model selector, Analyze
with gr.Row():
txt_transcribed_b = gr.Textbox(label="Transcribed Text (ASR)", interactive=False, lines=6, placeholder="ASR output appears here.")
txt_original_b = gr.Textbox(label="Original Script (input)", lines=6, placeholder="Paste your original script here.")
model_selector_b = gr.Dropdown(
choices=["gpt-4o-mini", "gpt-5"],
value="gpt-4o-mini",
label="LLM Model"
)
btn_analyze_b = gr.Button("2) Analyze")
# Row 3: Summary, Revised, Revise button
with gr.Row():
txt_summary_b = gr.Textbox(label="LLM Summary: Easy-to-Stutter Words", lines=8, placeholder="Analysis will appear here.")
txt_revised_b = gr.Textbox(label="Revised Script", lines=8, placeholder="Rewritten script will appear here.")
btn_rewrite_b = gr.Button("3) Revise Script")
# Row 4: Post-hoc audio record and download
with gr.Row():
posthoc_record_b = gr.Audio(label="Post-hoc Record Audio", sources=["microphone"], type="filepath")
posthoc_download_b = gr.File(label="Post-hoc Audio Download", interactive=False)
# Wiring (Baseline)
audio_record_b.change(fn=passthrough_file, inputs=audio_record_b, outputs=audio_download_b)
btn_transcribe_b.click(fn=on_click_transcribe, inputs=[audio_record_b], outputs=[txt_transcribed_b])
btn_analyze_b.click(
fn=on_click_analyze_baseline,
inputs=[model_selector_b, txt_original_b, txt_transcribed_b],
outputs=[txt_summary_b],
)
btn_rewrite_b.click(
fn=on_click_rewrite,
inputs=[model_selector_b, txt_original_b, txt_transcribed_b, txt_summary_b],
outputs=[txt_revised_b],
)
posthoc_record_b.change(fn=passthrough_file, inputs=posthoc_record_b, outputs=posthoc_download_b)
# -------------------- Tab 2: Baseline+IPA --------------------
with gr.Tab("Baseline+IPA"):
# Row 1: Record + Download + Transcribe
with gr.Row():
audio_record_i = gr.Audio(label="Record Audio", sources=["microphone"], type="filepath")
audio_download_i = gr.File(label="Audio Download", interactive=False)
btn_transcribe_i = gr.Button("1) Transcribe")
# Row 2: ASR, Original, IPA box, Model selector, Analyze
with gr.Row():
txt_transcribed_i = gr.Textbox(label="Transcribed Text (ASR)", interactive=False, lines=6, placeholder="ASR output appears here.")
txt_original_i = gr.Textbox(label="Original Script (input)", lines=6, placeholder="Paste your original script here.")
txt_ipa_i = gr.Textbox(label="IPA Annotations (LLM Output)", interactive=False, lines=6, placeholder="IPA for Original & Transcript will appear here.")
model_selector_i = gr.Dropdown(
choices=["gpt-4o-mini", "gpt-5"],
value="gpt-4o-mini",
label="LLM Model"
)
btn_analyze_i = gr.Button("2) Analyze (IPA → Diagnosis)")
# Row 3: Summary, Revised, Revise button
with gr.Row():
txt_summary_i = gr.Textbox(label="LLM Summary: Easy-to-Stutter Words (IPA-aware)", lines=8, placeholder="Analysis will appear here.")
txt_revised_i = gr.Textbox(label="Revised Script", lines=8, placeholder="Rewritten script will appear here.")
btn_rewrite_i = gr.Button("3) Revise Script")
# Row 4: Post-hoc audio record and download
with gr.Row():
posthoc_record_i = gr.Audio(label="Post-hoc Record Audio", sources=["microphone"], type="filepath")
posthoc_download_i = gr.File(label="Post-hoc Audio Download", interactive=False)
# Wiring (Baseline+IPA)
audio_record_i.change(fn=passthrough_file, inputs=audio_record_i, outputs=audio_download_i)
btn_transcribe_i.click(fn=on_click_transcribe, inputs=[audio_record_i], outputs=[txt_transcribed_i])
# Analyze in two steps: IPA then Diagnosis
def analyze_ipa_pipeline(model, original_text, transcribed_text):
ipa_update, summary_update = on_click_analyze_ipa(model, original_text, transcribed_text)
return ipa_update, summary_update
btn_analyze_i.click(
fn=analyze_ipa_pipeline,
inputs=[model_selector_i, txt_original_i, txt_transcribed_i],
outputs=[txt_ipa_i, txt_summary_i],
)
btn_rewrite_i.click(
fn=on_click_rewrite,
inputs=[model_selector_i, txt_original_i, txt_transcribed_i, txt_summary_i],
outputs=[txt_revised_i],
)
posthoc_record_i.change(fn=passthrough_file, inputs=posthoc_record_i, outputs=posthoc_download_i)
if __name__ == "__main__":
demo.launch()