# app.py import os import tempfile from pathlib import Path from flask import Flask, request, Response, redirect from flask_cors import CORS import torch import torchaudio # Transformers imports (lazy loaded in ModelManager.load to reduce startup overhead) from transformers import ( AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, AutoModelForSeq2SeqLM ) # ---------- Configuration ---------- WHISPER_MODEL = os.environ.get("WHISPER_MODEL", "openai/whisper-small") NLLB_MODEL = os.environ.get("NLLB_MODEL", "facebook/nllb-200-distilled-600M") LANG_MAP = { "akan": (None, "aka_Latn"), "hausa": ("ha", "hau_Latn"), "swahili": ("sw", "swh_Latn"), "french": ("fr", "fra_Latn"), "arabic": ("ar", "arb_Arab"), "english": ("en", None), } DEVICE = torch.device("cpu") # Free HF Spaces = CPU app = Flask(__name__) CORS(app) # ---------- Model manager ---------- class ModelManager: def __init__(self): self.whisper_processor = None self.whisper_model = None self.nllb_tokenizer = None self.nllb_model = None self._loaded = False def load(self): if self._loaded: return print(f"Loading Whisper model: {WHISPER_MODEL}") try: self.whisper_processor = AutoProcessor.from_pretrained(WHISPER_MODEL) self.whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(WHISPER_MODEL).to(DEVICE) except Exception as e: raise RuntimeError(f"Failed to load Whisper model ({WHISPER_MODEL}): {e}") print(f"Loading NLLB tokenizer/model: {NLLB_MODEL}") try: self.nllb_tokenizer = AutoTokenizer.from_pretrained(NLLB_MODEL) self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(NLLB_MODEL).to(DEVICE) except Exception as e: raise RuntimeError(f"Failed to load NLLB model ({NLLB_MODEL}): {e}") self._loaded = True print("Models loaded successfully.") def transcribe(self, audio_path, whisper_language_arg=None): if self.whisper_processor is None or self.whisper_model is None: raise RuntimeError("Whisper model not loaded") waveform, sr = torchaudio.load(audio_path) if sr != 16000: waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform) sr = 16000 inputs = self.whisper_processor( waveform.squeeze().numpy(), sampling_rate=sr, return_tensors="pt", return_attention_mask=True, language=whisper_language_arg, task="transcribe" ).to(DEVICE) with torch.no_grad(): generated_ids = self.whisper_model.generate( input_features=inputs["input_features"], attention_mask=inputs.get("attention_mask"), ) decoded = self.whisper_processor.batch_decode(generated_ids, skip_special_tokens=True) return decoded[0].strip() def translate_to_english(self, src_text, nllb_src_lang_tag): if not src_text: return "" if not nllb_src_lang_tag: return src_text if self.nllb_tokenizer is None or self.nllb_model is None: raise RuntimeError("NLLB model not loaded") try: self.nllb_tokenizer.src_lang = nllb_src_lang_tag except Exception: pass inputs = self.nllb_tokenizer(src_text, return_tensors="pt").to(DEVICE) forced_bos = None try: forced_bos = self.nllb_tokenizer.convert_tokens_to_ids("eng_Latn") except Exception: forced_bos = None gen_kwargs = { "max_length": 512, "num_beams": 4, "no_repeat_ngram_size": 2, "early_stopping": True } if forced_bos is not None: gen_kwargs["forced_bos_token_id"] = forced_bos with torch.no_grad(): translated_tokens = self.nllb_model.generate(**inputs, **gen_kwargs) translated = self.nllb_tokenizer.decode(translated_tokens[0], skip_special_tokens=True) return translated.strip() model_manager = ModelManager() # ---------- REST endpoint ---------- @app.route("/transcribe", methods=["POST"]) def transcribe_endpoint(): if "audio" not in request.files: return Response("No audio file provided", status=400, mimetype="text/plain") audio_file = request.files["audio"] language = (request.form.get("language") or request.args.get("language") or "english").lower() if language not in LANG_MAP: return Response(f"Unsupported language: {language}", status=400, mimetype="text/plain") whisper_lang_arg, nllb_src_tag = LANG_MAP[language] try: model_manager.load() except Exception as e: return Response(f"Model loading failed: {e}", status=500, mimetype="text/plain") tmp_fd, tmp_path = tempfile.mkstemp(suffix=Path(audio_file.filename).suffix or ".wav") os.close(tmp_fd) audio_file.save(tmp_path) try: transcription = model_manager.transcribe(tmp_path, whisper_language_arg=whisper_lang_arg) if not transcription: return Response("", status=204, mimetype="text/plain") translation = model_manager.translate_to_english(transcription, nllb_src_tag) return Response(translation, status=200, mimetype="text/plain") except Exception as e: return Response(f"Processing failed: {e}", status=500, mimetype="text/plain") finally: try: os.remove(tmp_path) except Exception: pass # ---------- Robust Gradio UI mount ---------- gradio_mounted = False if os.environ.get("DISABLE_GRADIO", "0") != "1": try: import gradio as gr import soundfile as sf import numpy as np def _ui_transcribe(audio, language): if audio is None: return "No audio", "" audio_path = None if isinstance(audio, str) and Path(audio).exists(): audio_path = audio elif isinstance(audio, (tuple, list)) and len(audio) >= 2: sr, data = audio[0], audio[1] tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) sf.write(tmp.name, data, sr) audio_path = tmp.name elif isinstance(audio, (np.ndarray,)) or hasattr(audio, "shape"): sr = 16000 tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) sf.write(tmp.name, audio, sr) audio_path = tmp.name else: try: audio_path = getattr(audio, "name", None) except Exception: audio_path = None if not audio_path: return "Unsupported audio format from Gradio", "" try: model_manager.load() whisper_lang, nllb_tag = LANG_MAP.get(language.lower(), (None, None)) transcription = model_manager.transcribe(audio_path, whisper_language_arg=whisper_lang) translation = model_manager.translate_to_english(transcription, nllb_tag) return transcription, translation finally: try: if audio_path and Path(audio_path).exists() and "/tmp" in str(audio_path): os.remove(audio_path) except Exception: pass demo = None # Create components robustly across gradio versions audio_component = None dropdown_component = None textbox_out1 = None textbox_out2 = None # Option A: modern simple API (gr.Audio) try: if hasattr(gr, "Audio"): audio_component = gr.Audio(source="microphone", type="filepath") elif hasattr(gr, "components") and hasattr(gr.components, "Audio"): audio_component = gr.components.Audio(source="microphone", type="filepath") except Exception: audio_component = None # Dropdown try: if hasattr(gr, "Dropdown"): dropdown_component = gr.Dropdown(choices=list(LANG_MAP.keys()), value="english", label="Language") elif hasattr(gr, "components") and hasattr(gr.components, "Dropdown"): dropdown_component = gr.components.Dropdown(choices=list(LANG_MAP.keys()), value="english", label="Language") except Exception: dropdown_component = None # Output textboxes try: if hasattr(gr, "Textbox"): textbox_out1 = gr.Textbox(label="Transcription") textbox_out2 = gr.Textbox(label="Translation (English)") elif hasattr(gr, "components") and hasattr(gr.components, "Textbox"): textbox_out1 = gr.components.Textbox(label="Transcription") textbox_out2 = gr.components.Textbox(label="Translation (English)") except Exception: textbox_out1 = textbox_out2 = None # If any component missing, try old 'inputs/outputs' API as final fallback if audio_component is None or dropdown_component is None or textbox_out1 is None: try: if hasattr(gr, "inputs") and hasattr(gr, "inputs",): audio_component = getattr(gr.inputs, "Audio")(source="microphone", type="filepath") dropdown_component = getattr(gr.inputs, "Dropdown")(choices=list(LANG_MAP.keys()), default="english") textbox_out1 = getattr(gr.outputs, "Textbox")() textbox_out2 = getattr(gr.outputs, "Textbox")() except Exception: pass # If we have required components, create the Interface if audio_component is not None and dropdown_component is not None and textbox_out1 is not None: try: demo = gr.Interface( fn=_ui_transcribe, inputs=[audio_component, dropdown_component], outputs=[textbox_out1, textbox_out2], title="Multilingual Transcriber (server)" ) except Exception as e: print("Failed to create gr.Interface:", e) demo = None if demo is not None: try: app = gr.mount_gradio_app(app, demo, path="/ui") gradio_mounted = True print("Gradio mounted at /ui") except Exception as e: print("Failed to mount Gradio app:", e) gradio_mounted = False else: print("Gradio demo not created; continuing without mounted UI.") except Exception as e: print("Gradio UI unavailable or failed to mount:", e) gradio_mounted = False else: print("Gradio mounting disabled via DISABLE_GRADIO=1") gradio_mounted = False # Root endpoint @app.route("/") def index(): if gradio_mounted: return redirect("/ui") return Response("Server running. REST endpoint available at /transcribe", status=200, mimetype="text/plain") if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) app.run(host="0.0.0.0", port=port, debug=False)