import gradio as gr import numpy as np import time import torch import logging from typing import Optional # Version tracking VERSION = "1.2.0" COMMIT_SHA = "TBD" # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global model variables model = None processor = None device = None def load_stt_model(): """Load STT model on startup""" global model, processor, device try: device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Loading STT model on {device}...") # Try to load the actual Kyutai STT model try: from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration model_id = "kyutai/stt-1b-en_fr" processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id) model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id).to(device) logger.info(f"✅ {model_id} loaded successfully on {device}") return f"✅ Model loaded: {model_id} on {device}" except Exception as model_error: logger.warning(f"Could not load Kyutai model: {model_error}") # Fallback to Whisper if Kyutai fails try: from transformers import WhisperProcessor, WhisperForConditionalGeneration model_id = "openai/whisper-base" processor = WhisperProcessor.from_pretrained(model_id) model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device) logger.info(f"✅ Fallback model loaded: {model_id} on {device}") return f"✅ Fallback model loaded: {model_id} on {device}" except Exception as whisper_error: logger.error(f"Both Kyutai and Whisper failed: {whisper_error}") model = "mock" processor = "mock" return f"⚠️ Using mock STT (models failed to load)" except Exception as e: logger.error(f"Error in load_stt_model: {e}") model = "mock" processor = "mock" return f"❌ Error: {str(e)}" def transcribe_audio(audio_input, progress=gr.Progress()): """Transcribe audio using STT model""" if audio_input is None: return "❌ No audio provided" progress(0.1, desc="Processing audio...") try: # Extract audio data if isinstance(audio_input, tuple): sample_rate, audio_data = audio_input else: sample_rate = 16000 # Default audio_data = audio_input if audio_data is None or len(audio_data) == 0: return "❌ Empty audio data" progress(0.3, desc="Running STT model...") # Convert to float32 if needed if audio_data.dtype != np.float32: audio_data = audio_data.astype(np.float32) # Normalize audio if np.max(np.abs(audio_data)) > 0: audio_data = audio_data / np.max(np.abs(audio_data)) if model == "mock": # Mock transcription duration = len(audio_data) / sample_rate progress(1.0, desc="Complete!") return f"🎙️ Mock transcription: {duration:.2f}s audio at {sample_rate}Hz ({len(audio_data)} samples)" # Real transcription progress(0.5, desc="Model inference...") # Resample if needed (Kyutai expects 24kHz, Whisper expects 16kHz) target_sr = 24000 if "Kyutai" in str(type(model)) else 16000 if sample_rate != target_sr: import librosa audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=target_sr) sample_rate = target_sr # Prepare inputs inputs = processor(audio_data, sampling_rate=sample_rate, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} progress(0.8, desc="Generating transcription...") # Generate transcription with torch.no_grad(): generated_ids = model.generate(**inputs) transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] progress(1.0, desc="Complete!") return f"🎙️ {transcription}" except Exception as e: logger.error(f"Transcription error: {e}") return f"❌ Error: {str(e)}" def get_health_status(): """Get system health status""" return { "status": "healthy", "timestamp": time.time(), "version": VERSION, "commit_sha": COMMIT_SHA, "model_loaded": model is not None and model != "mock", "device": str(device) if device else "unknown", "model_type": str(type(model)) if model else "none" } def format_health_status(): """Format health status for display""" health = get_health_status() status_text = f""" 📊 **System Status**: {health['status']} 🕒 **Timestamp**: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(health['timestamp']))} 🔢 **Version**: {health['version']} 🔗 **Commit SHA**: {health['commit_sha']} 🤖 **Model Loaded**: {health['model_loaded']} 💻 **Device**: {health['device']} 🧠 **Model Type**: {health['model_type']} """ return status_text # Load model on startup startup_message = load_stt_model() # Create Gradio interface with gr.Blocks( title="STT GPU Service Python v4", theme=gr.themes.Soft(), css=""" .version-info { font-size: 0.8em; color: #666; text-align: center; margin-top: 20px; } """ ) as demo: gr.Markdown("# 🎙️ STT GPU Service Python v4") gr.Markdown("**Real-time Speech-to-Text with kyutai/stt-1b-en_fr**") # Startup status gr.Markdown(f"**Startup Status**: {startup_message}") with gr.Tabs(): with gr.Tab("🎤 Speech Transcription"): gr.Markdown("### Real-time Speech-to-Text") gr.Markdown("Record audio or upload a file to transcribe with STT model") with gr.Row(): with gr.Column(): # Microphone input mic_input = gr.Audio( sources=["microphone"], type="numpy", label="🎤 Record Audio", format="wav" ) # File upload file_input = gr.Audio( sources=["upload"], type="numpy", label="📁 Upload Audio File", format="wav" ) transcribe_mic_btn = gr.Button("🎙️ Transcribe Microphone", variant="primary") transcribe_file_btn = gr.Button("📁 Transcribe File", variant="secondary") with gr.Column(): output_text = gr.Textbox( label="📝 Transcription Output", placeholder="Transcription will appear here...", lines=10, max_lines=20 ) with gr.Tab("⚡ Health Check"): gr.Markdown("### System Health Status") health_btn = gr.Button("🔍 Check System Health") health_output = gr.Markdown() with gr.Tab("📋 API Info"): gr.Markdown(""" ### API Endpoints **WebSocket Streaming** (Planned): - `ws://space-url/ws/stream` - Real-time audio streaming - Expected: 80ms chunks at 24kHz (1920 samples per chunk) **REST API** (Planned): - `POST /api/transcribe` - Single audio file transcription **Current Implementation**: - Gradio interface with real-time transcription - Supports microphone input and file upload - Uses kyutai/stt-1b-en_fr model with Whisper fallback """) # Event handlers transcribe_mic_btn.click( fn=transcribe_audio, inputs=[mic_input], outputs=[output_text], show_progress=True ) transcribe_file_btn.click( fn=transcribe_audio, inputs=[file_input], outputs=[output_text], show_progress=True ) health_btn.click( fn=format_health_status, outputs=[health_output] ) # Version info gr.Markdown( f'