gnumanth's picture
Revert to known working version (da8393f)
436150b
import gradio as gr
import torch
import numpy as np
import spaces
from datetime import datetime
import random
import string
SESSION_ID = f"LIVE_{''.join(random.choices(string.ascii_uppercase + string.digits, k=4))}"
# --- BACKEND LOGIC ---
print(f"[{datetime.now().strftime('%H:%M:%S')}] --- SYSTEM STARTUP ---", flush=True)
try:
print("Loading NeMo ASR Model...", flush=True)
import nemo.collections.asr as nemo_asr
ASR_MODEL = nemo_asr.models.ASRModel.from_pretrained(model_name="nvidia/nemotron-speech-streaming-en-0.6b")
ASR_MODEL.eval()
if torch.cuda.is_available():
print("Moving model to CUDA...", flush=True)
ASR_MODEL = ASR_MODEL.cuda()
else:
print("WARNING: CUDA not available, running on CPU (Slow)", flush=True)
print("Model Loaded Successfully.", flush=True)
except Exception as e:
print(f"CRITICAL MODEL LOAD ERROR: {e}", flush=True)
ASR_MODEL = None
@spaces.GPU(duration=120)
def transcribe(audio, state):
if state is None:
state = {'transcript': [], 'buffer': None, 'counter': 0}
print(f"[SESSION START] {SESSION_ID}", flush=True)
if audio is None:
return state, "Listening..."
try:
sr, data = audio
if len(data) > 0:
peak = np.abs(data).max()
if state['counter'] % 10 == 0:
print(f"[AUDIO RECV] Step {state['counter']} | Shape: {data.shape} | Peak: {peak:.4f}", flush=True)
# Normalize
if data.dtype == np.int16: data = data.astype(np.float32) / 32768.0
elif data.dtype == np.int32: data = data.astype(np.float32) / 2147483648.0
else: data = data.astype(np.float32)
if len(data.shape) > 1: data = data.mean(axis=1)
if sr != 16000:
import librosa
data = librosa.resample(data, orig_sr=sr, target_sr=16000)
# Buffer
if state['buffer'] is None: state['buffer'] = data
else: state['buffer'] = np.concatenate((state['buffer'], data))
state['counter'] += 1
if len(state['buffer']) >= 3200:
if ASR_MODEL:
with torch.no_grad():
context = state['buffer'][-32000:]
results = ASR_MODEL.transcribe([context])
print(f"[INFER] Context: {len(context)} | Raw: {results}", flush=True)
if results and len(results) > 0:
hyp = results[0]
text = ""
if isinstance(hyp, str): text = hyp
elif hasattr(hyp, 'text'): text = hyp.text
elif hasattr(hyp, 'pred_text'): text = hyp.pred_text
if text and text.strip():
print(f" >>> [TXT] {text}", flush=True)
current_lines = state['transcript']
if not current_lines: current_lines.append(text)
else: current_lines[-1] = text
if len(state['buffer']) > 32000:
state['buffer'] = state['buffer'][-32000:]
except Exception as e:
print(f"[CRITICAL PROCESSING ERROR] {e}", flush=True)
import traceback
traceback.print_exc()
valid = [l for l in state['transcript'] if l]
current = valid[-1] if valid else "Listening..."
return state, current
def clear_session():
print("[SESSION RESET]", flush=True)
return {'transcript': [], 'buffer': None, 'counter': 0}, "Listening..."
def log_connection():
print(">>> CLIENT CONNECTED <<<", flush=True)
# --- CUSTOM THEME CSS ---
custom_css = """
.gradio-container {
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 50%, #0f0f23 100%) !important;
min-height: 100vh;
}
#title-text {
text-align: center;
color: #76b900;
font-size: 2em;
font-weight: bold;
margin-bottom: 10px;
}
#subtitle-text {
text-align: center;
color: #888;
font-size: 1em;
margin-bottom: 30px;
}
#session-info {
text-align: center;
color: #76b900;
font-size: 0.9em;
padding: 10px;
background: rgba(118, 185, 0, 0.1);
border-radius: 20px;
display: inline-block;
}
#transcript-box {
min-height: 200px;
font-size: 1.5em;
text-align: center;
padding: 40px 20px;
background: rgba(255, 255, 255, 0.05);
border-radius: 15px;
border: 1px solid rgba(255, 255, 255, 0.1);
}
#transcript-box textarea {
background: transparent !important;
color: #ffffff !important;
font-size: 1.5em !important;
text-align: center !important;
border: none !important;
}
#mic-button {
margin: 20px auto;
display: block;
}
#reset-button {
background: rgba(255, 255, 255, 0.1) !important;
border: 1px solid rgba(255, 255, 255, 0.2) !important;
}
footer {
display: none !important;
}
"""
# --- GRADIO UI ---
with gr.Blocks(css=custom_css, title="Nemotron Speech Streaming", theme=gr.themes.Soft(primary_hue="green")) as demo:
state = gr.State({'transcript': [], 'buffer': None, 'counter': 0})
gr.HTML(f"""
<div id="title-text">Nemotron Speech Streaming</div>
<div id="subtitle-text">Real-time speech recognition powered by NVIDIA NeMo</div>
<div style="text-align: center; margin-bottom: 20px;">
<span id="session-info">Session: {SESSION_ID}</span>
</div>
""")
with gr.Row():
with gr.Column():
transcript_display = gr.Textbox(
value="Listening...",
label="Transcript",
elem_id="transcript-box",
lines=6,
max_lines=10,
interactive=False,
show_copy_button=True
)
with gr.Row():
with gr.Column(scale=2):
audio = gr.Audio(
sources=["microphone"],
streaming=True,
type="numpy",
label="Click to Start Recording",
elem_id="mic-button"
)
with gr.Column(scale=1):
reset_btn = gr.Button("Reset", elem_id="reset-button", variant="secondary")
gr.HTML("""
<div style="text-align: center; margin-top: 30px; color: #666; font-size: 0.85em;">
<p>Click the microphone to start speaking. Your speech will be transcribed in real-time.</p>
<p>Model: <strong>nvidia/nemotron-speech-streaming-en-0.6b</strong></p>
</div>
""")
# Events
demo.load(fn=log_connection)
audio.stream(
fn=transcribe,
inputs=[audio, state],
outputs=[state, transcript_display],
show_progress="hidden",
trigger_mode="always_last"
)
reset_btn.click(fn=clear_session, outputs=[state, transcript_display])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)