Spaces:
Runtime error
Runtime error
| # app.py | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| from flask import Flask, request, Response, redirect | |
| from flask_cors import CORS | |
| import torch | |
| import torchaudio | |
| # Transformers imports (lazy loaded in ModelManager.load to reduce startup overhead) | |
| from transformers import ( | |
| AutoProcessor, | |
| AutoModelForSpeechSeq2Seq, | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM | |
| ) | |
| # ---------- Configuration ---------- | |
| WHISPER_MODEL = os.environ.get("WHISPER_MODEL", "openai/whisper-small") | |
| NLLB_MODEL = os.environ.get("NLLB_MODEL", "facebook/nllb-200-distilled-600M") | |
| LANG_MAP = { | |
| "akan": (None, "aka_Latn"), | |
| "hausa": ("ha", "hau_Latn"), | |
| "swahili": ("sw", "swh_Latn"), | |
| "french": ("fr", "fra_Latn"), | |
| "arabic": ("ar", "arb_Arab"), | |
| "english": ("en", None), | |
| } | |
| DEVICE = torch.device("cpu") # Free HF Spaces = CPU | |
| app = Flask(__name__) | |
| CORS(app) | |
| # ---------- Model manager ---------- | |
| class ModelManager: | |
| def __init__(self): | |
| self.whisper_processor = None | |
| self.whisper_model = None | |
| self.nllb_tokenizer = None | |
| self.nllb_model = None | |
| self._loaded = False | |
| def load(self): | |
| if self._loaded: | |
| return | |
| print(f"Loading Whisper model: {WHISPER_MODEL}") | |
| try: | |
| self.whisper_processor = AutoProcessor.from_pretrained(WHISPER_MODEL) | |
| self.whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(WHISPER_MODEL).to(DEVICE) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load Whisper model ({WHISPER_MODEL}): {e}") | |
| print(f"Loading NLLB tokenizer/model: {NLLB_MODEL}") | |
| try: | |
| self.nllb_tokenizer = AutoTokenizer.from_pretrained(NLLB_MODEL) | |
| self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(NLLB_MODEL).to(DEVICE) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load NLLB model ({NLLB_MODEL}): {e}") | |
| self._loaded = True | |
| print("Models loaded successfully.") | |
| def transcribe(self, audio_path, whisper_language_arg=None): | |
| if self.whisper_processor is None or self.whisper_model is None: | |
| raise RuntimeError("Whisper model not loaded") | |
| waveform, sr = torchaudio.load(audio_path) | |
| if sr != 16000: | |
| waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform) | |
| sr = 16000 | |
| inputs = self.whisper_processor( | |
| waveform.squeeze().numpy(), | |
| sampling_rate=sr, | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| language=whisper_language_arg, | |
| task="transcribe" | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| generated_ids = self.whisper_model.generate( | |
| input_features=inputs["input_features"], | |
| attention_mask=inputs.get("attention_mask"), | |
| ) | |
| decoded = self.whisper_processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| return decoded[0].strip() | |
| def translate_to_english(self, src_text, nllb_src_lang_tag): | |
| if not src_text: | |
| return "" | |
| if not nllb_src_lang_tag: | |
| return src_text | |
| if self.nllb_tokenizer is None or self.nllb_model is None: | |
| raise RuntimeError("NLLB model not loaded") | |
| try: | |
| self.nllb_tokenizer.src_lang = nllb_src_lang_tag | |
| except Exception: | |
| pass | |
| inputs = self.nllb_tokenizer(src_text, return_tensors="pt").to(DEVICE) | |
| forced_bos = None | |
| try: | |
| forced_bos = self.nllb_tokenizer.convert_tokens_to_ids("eng_Latn") | |
| except Exception: | |
| forced_bos = None | |
| gen_kwargs = { | |
| "max_length": 512, | |
| "num_beams": 4, | |
| "no_repeat_ngram_size": 2, | |
| "early_stopping": True | |
| } | |
| if forced_bos is not None: | |
| gen_kwargs["forced_bos_token_id"] = forced_bos | |
| with torch.no_grad(): | |
| translated_tokens = self.nllb_model.generate(**inputs, **gen_kwargs) | |
| translated = self.nllb_tokenizer.decode(translated_tokens[0], skip_special_tokens=True) | |
| return translated.strip() | |
| model_manager = ModelManager() | |
| # ---------- REST endpoint ---------- | |
| def transcribe_endpoint(): | |
| if "audio" not in request.files: | |
| return Response("No audio file provided", status=400, mimetype="text/plain") | |
| audio_file = request.files["audio"] | |
| language = (request.form.get("language") or request.args.get("language") or "english").lower() | |
| if language not in LANG_MAP: | |
| return Response(f"Unsupported language: {language}", status=400, mimetype="text/plain") | |
| whisper_lang_arg, nllb_src_tag = LANG_MAP[language] | |
| try: | |
| model_manager.load() | |
| except Exception as e: | |
| return Response(f"Model loading failed: {e}", status=500, mimetype="text/plain") | |
| tmp_fd, tmp_path = tempfile.mkstemp(suffix=Path(audio_file.filename).suffix or ".wav") | |
| os.close(tmp_fd) | |
| audio_file.save(tmp_path) | |
| try: | |
| transcription = model_manager.transcribe(tmp_path, whisper_language_arg=whisper_lang_arg) | |
| if not transcription: | |
| return Response("", status=204, mimetype="text/plain") | |
| translation = model_manager.translate_to_english(transcription, nllb_src_tag) | |
| return Response(translation, status=200, mimetype="text/plain") | |
| except Exception as e: | |
| return Response(f"Processing failed: {e}", status=500, mimetype="text/plain") | |
| finally: | |
| try: | |
| os.remove(tmp_path) | |
| except Exception: | |
| pass | |
| # ---------- Robust Gradio UI mount ---------- | |
| gradio_mounted = False | |
| if os.environ.get("DISABLE_GRADIO", "0") != "1": | |
| try: | |
| import gradio as gr | |
| import soundfile as sf | |
| import numpy as np | |
| def _ui_transcribe(audio, language): | |
| if audio is None: | |
| return "No audio", "" | |
| audio_path = None | |
| if isinstance(audio, str) and Path(audio).exists(): | |
| audio_path = audio | |
| elif isinstance(audio, (tuple, list)) and len(audio) >= 2: | |
| sr, data = audio[0], audio[1] | |
| tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| sf.write(tmp.name, data, sr) | |
| audio_path = tmp.name | |
| elif isinstance(audio, (np.ndarray,)) or hasattr(audio, "shape"): | |
| sr = 16000 | |
| tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| sf.write(tmp.name, audio, sr) | |
| audio_path = tmp.name | |
| else: | |
| try: | |
| audio_path = getattr(audio, "name", None) | |
| except Exception: | |
| audio_path = None | |
| if not audio_path: | |
| return "Unsupported audio format from Gradio", "" | |
| try: | |
| model_manager.load() | |
| whisper_lang, nllb_tag = LANG_MAP.get(language.lower(), (None, None)) | |
| transcription = model_manager.transcribe(audio_path, whisper_language_arg=whisper_lang) | |
| translation = model_manager.translate_to_english(transcription, nllb_tag) | |
| return transcription, translation | |
| finally: | |
| try: | |
| if audio_path and Path(audio_path).exists() and "/tmp" in str(audio_path): | |
| os.remove(audio_path) | |
| except Exception: | |
| pass | |
| demo = None | |
| # Create components robustly across gradio versions | |
| audio_component = None | |
| dropdown_component = None | |
| textbox_out1 = None | |
| textbox_out2 = None | |
| # Option A: modern simple API (gr.Audio) | |
| try: | |
| if hasattr(gr, "Audio"): | |
| audio_component = gr.Audio(source="microphone", type="filepath") | |
| elif hasattr(gr, "components") and hasattr(gr.components, "Audio"): | |
| audio_component = gr.components.Audio(source="microphone", type="filepath") | |
| except Exception: | |
| audio_component = None | |
| # Dropdown | |
| try: | |
| if hasattr(gr, "Dropdown"): | |
| dropdown_component = gr.Dropdown(choices=list(LANG_MAP.keys()), value="english", label="Language") | |
| elif hasattr(gr, "components") and hasattr(gr.components, "Dropdown"): | |
| dropdown_component = gr.components.Dropdown(choices=list(LANG_MAP.keys()), value="english", label="Language") | |
| except Exception: | |
| dropdown_component = None | |
| # Output textboxes | |
| try: | |
| if hasattr(gr, "Textbox"): | |
| textbox_out1 = gr.Textbox(label="Transcription") | |
| textbox_out2 = gr.Textbox(label="Translation (English)") | |
| elif hasattr(gr, "components") and hasattr(gr.components, "Textbox"): | |
| textbox_out1 = gr.components.Textbox(label="Transcription") | |
| textbox_out2 = gr.components.Textbox(label="Translation (English)") | |
| except Exception: | |
| textbox_out1 = textbox_out2 = None | |
| # If any component missing, try old 'inputs/outputs' API as final fallback | |
| if audio_component is None or dropdown_component is None or textbox_out1 is None: | |
| try: | |
| if hasattr(gr, "inputs") and hasattr(gr, "inputs",): | |
| audio_component = getattr(gr.inputs, "Audio")(source="microphone", type="filepath") | |
| dropdown_component = getattr(gr.inputs, "Dropdown")(choices=list(LANG_MAP.keys()), default="english") | |
| textbox_out1 = getattr(gr.outputs, "Textbox")() | |
| textbox_out2 = getattr(gr.outputs, "Textbox")() | |
| except Exception: | |
| pass | |
| # If we have required components, create the Interface | |
| if audio_component is not None and dropdown_component is not None and textbox_out1 is not None: | |
| try: | |
| demo = gr.Interface( | |
| fn=_ui_transcribe, | |
| inputs=[audio_component, dropdown_component], | |
| outputs=[textbox_out1, textbox_out2], | |
| title="Multilingual Transcriber (server)" | |
| ) | |
| except Exception as e: | |
| print("Failed to create gr.Interface:", e) | |
| demo = None | |
| if demo is not None: | |
| try: | |
| app = gr.mount_gradio_app(app, demo, path="/ui") | |
| gradio_mounted = True | |
| print("Gradio mounted at /ui") | |
| except Exception as e: | |
| print("Failed to mount Gradio app:", e) | |
| gradio_mounted = False | |
| else: | |
| print("Gradio demo not created; continuing without mounted UI.") | |
| except Exception as e: | |
| print("Gradio UI unavailable or failed to mount:", e) | |
| gradio_mounted = False | |
| else: | |
| print("Gradio mounting disabled via DISABLE_GRADIO=1") | |
| gradio_mounted = False | |
| # Root endpoint | |
| def index(): | |
| if gradio_mounted: | |
| return redirect("/ui") | |
| return Response("Server running. REST endpoint available at /transcribe", status=200, mimetype="text/plain") | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port, debug=False) | |