stt-gpu-service-python-v4 / app_gradio_stt.py
Peter Michael Gits
Fix Dockerfile directory permissions - create /app as root before switching users
26096f4
import gradio as gr
import numpy as np
import time
import torch
import logging
from typing import Optional
# Version tracking
VERSION = "1.2.0"
COMMIT_SHA = "TBD"
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model variables
model = None
processor = None
device = None
def load_stt_model():
"""Load STT model on startup"""
global model, processor, device
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Loading STT model on {device}...")
# Try to load the actual Kyutai STT model
try:
from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
model_id = "kyutai/stt-1b-en_fr"
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id).to(device)
logger.info(f"βœ… {model_id} loaded successfully on {device}")
return f"βœ… Model loaded: {model_id} on {device}"
except Exception as model_error:
logger.warning(f"Could not load Kyutai model: {model_error}")
# Fallback to Whisper if Kyutai fails
try:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
model_id = "openai/whisper-base"
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device)
logger.info(f"βœ… Fallback model loaded: {model_id} on {device}")
return f"βœ… Fallback model loaded: {model_id} on {device}"
except Exception as whisper_error:
logger.error(f"Both Kyutai and Whisper failed: {whisper_error}")
model = "mock"
processor = "mock"
return f"⚠️ Using mock STT (models failed to load)"
except Exception as e:
logger.error(f"Error in load_stt_model: {e}")
model = "mock"
processor = "mock"
return f"❌ Error: {str(e)}"
def transcribe_audio(audio_input, progress=gr.Progress()):
"""Transcribe audio using STT model"""
if audio_input is None:
return "❌ No audio provided"
progress(0.1, desc="Processing audio...")
try:
# Extract audio data
if isinstance(audio_input, tuple):
sample_rate, audio_data = audio_input
else:
sample_rate = 16000 # Default
audio_data = audio_input
if audio_data is None or len(audio_data) == 0:
return "❌ Empty audio data"
progress(0.3, desc="Running STT model...")
# Convert to float32 if needed
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
# Normalize audio
if np.max(np.abs(audio_data)) > 0:
audio_data = audio_data / np.max(np.abs(audio_data))
if model == "mock":
# Mock transcription
duration = len(audio_data) / sample_rate
progress(1.0, desc="Complete!")
return f"πŸŽ™οΈ Mock transcription: {duration:.2f}s audio at {sample_rate}Hz ({len(audio_data)} samples)"
# Real transcription
progress(0.5, desc="Model inference...")
# Resample if needed (Kyutai expects 24kHz, Whisper expects 16kHz)
target_sr = 24000 if "Kyutai" in str(type(model)) else 16000
if sample_rate != target_sr:
import librosa
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=target_sr)
sample_rate = target_sr
# Prepare inputs
inputs = processor(audio_data, sampling_rate=sample_rate, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
progress(0.8, desc="Generating transcription...")
# Generate transcription
with torch.no_grad():
generated_ids = model.generate(**inputs)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
progress(1.0, desc="Complete!")
return f"πŸŽ™οΈ {transcription}"
except Exception as e:
logger.error(f"Transcription error: {e}")
return f"❌ Error: {str(e)}"
def get_health_status():
"""Get system health status"""
return {
"status": "healthy",
"timestamp": time.time(),
"version": VERSION,
"commit_sha": COMMIT_SHA,
"model_loaded": model is not None and model != "mock",
"device": str(device) if device else "unknown",
"model_type": str(type(model)) if model else "none"
}
def format_health_status():
"""Format health status for display"""
health = get_health_status()
status_text = f"""
πŸ“Š **System Status**: {health['status']}
πŸ•’ **Timestamp**: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(health['timestamp']))}
πŸ”’ **Version**: {health['version']}
πŸ”— **Commit SHA**: {health['commit_sha']}
πŸ€– **Model Loaded**: {health['model_loaded']}
πŸ’» **Device**: {health['device']}
🧠 **Model Type**: {health['model_type']}
"""
return status_text
# Load model on startup
startup_message = load_stt_model()
# Create Gradio interface
with gr.Blocks(
title="STT GPU Service Python v4",
theme=gr.themes.Soft(),
css="""
.version-info {
font-size: 0.8em;
color: #666;
text-align: center;
margin-top: 20px;
}
"""
) as demo:
gr.Markdown("# πŸŽ™οΈ STT GPU Service Python v4")
gr.Markdown("**Real-time Speech-to-Text with kyutai/stt-1b-en_fr**")
# Startup status
gr.Markdown(f"**Startup Status**: {startup_message}")
with gr.Tabs():
with gr.Tab("🎀 Speech Transcription"):
gr.Markdown("### Real-time Speech-to-Text")
gr.Markdown("Record audio or upload a file to transcribe with STT model")
with gr.Row():
with gr.Column():
# Microphone input
mic_input = gr.Audio(
sources=["microphone"],
type="numpy",
label="🎀 Record Audio",
format="wav"
)
# File upload
file_input = gr.Audio(
sources=["upload"],
type="numpy",
label="πŸ“ Upload Audio File",
format="wav"
)
transcribe_mic_btn = gr.Button("πŸŽ™οΈ Transcribe Microphone", variant="primary")
transcribe_file_btn = gr.Button("πŸ“ Transcribe File", variant="secondary")
with gr.Column():
output_text = gr.Textbox(
label="πŸ“ Transcription Output",
placeholder="Transcription will appear here...",
lines=10,
max_lines=20
)
with gr.Tab("⚑ Health Check"):
gr.Markdown("### System Health Status")
health_btn = gr.Button("πŸ” Check System Health")
health_output = gr.Markdown()
with gr.Tab("πŸ“‹ API Info"):
gr.Markdown("""
### API Endpoints
**WebSocket Streaming** (Planned):
- `ws://space-url/ws/stream` - Real-time audio streaming
- Expected: 80ms chunks at 24kHz (1920 samples per chunk)
**REST API** (Planned):
- `POST /api/transcribe` - Single audio file transcription
**Current Implementation**:
- Gradio interface with real-time transcription
- Supports microphone input and file upload
- Uses kyutai/stt-1b-en_fr model with Whisper fallback
""")
# Event handlers
transcribe_mic_btn.click(
fn=transcribe_audio,
inputs=[mic_input],
outputs=[output_text],
show_progress=True
)
transcribe_file_btn.click(
fn=transcribe_audio,
inputs=[file_input],
outputs=[output_text],
show_progress=True
)
health_btn.click(
fn=format_health_status,
outputs=[health_output]
)
# Version info
gr.Markdown(
f'<div class="version-info">v{VERSION} (SHA: {COMMIT_SHA}) - STT GPU Service Python v4</div>',
elem_classes=["version-info"]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_api=True
)