Translator / app.py
Gamortsey's picture
Update app.py
ddb6bb8 verified
# app.py
import os
import tempfile
from pathlib import Path
from flask import Flask, request, Response, redirect
from flask_cors import CORS
import torch
import torchaudio
# Transformers imports (lazy loaded in ModelManager.load to reduce startup overhead)
from transformers import (
AutoProcessor,
AutoModelForSpeechSeq2Seq,
AutoTokenizer,
AutoModelForSeq2SeqLM
)
# ---------- Configuration ----------
WHISPER_MODEL = os.environ.get("WHISPER_MODEL", "openai/whisper-small")
NLLB_MODEL = os.environ.get("NLLB_MODEL", "facebook/nllb-200-distilled-600M")
LANG_MAP = {
"akan": (None, "aka_Latn"),
"hausa": ("ha", "hau_Latn"),
"swahili": ("sw", "swh_Latn"),
"french": ("fr", "fra_Latn"),
"arabic": ("ar", "arb_Arab"),
"english": ("en", None),
}
DEVICE = torch.device("cpu") # Free HF Spaces = CPU
app = Flask(__name__)
CORS(app)
# ---------- Model manager ----------
class ModelManager:
def __init__(self):
self.whisper_processor = None
self.whisper_model = None
self.nllb_tokenizer = None
self.nllb_model = None
self._loaded = False
def load(self):
if self._loaded:
return
print(f"Loading Whisper model: {WHISPER_MODEL}")
try:
self.whisper_processor = AutoProcessor.from_pretrained(WHISPER_MODEL)
self.whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(WHISPER_MODEL).to(DEVICE)
except Exception as e:
raise RuntimeError(f"Failed to load Whisper model ({WHISPER_MODEL}): {e}")
print(f"Loading NLLB tokenizer/model: {NLLB_MODEL}")
try:
self.nllb_tokenizer = AutoTokenizer.from_pretrained(NLLB_MODEL)
self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(NLLB_MODEL).to(DEVICE)
except Exception as e:
raise RuntimeError(f"Failed to load NLLB model ({NLLB_MODEL}): {e}")
self._loaded = True
print("Models loaded successfully.")
def transcribe(self, audio_path, whisper_language_arg=None):
if self.whisper_processor is None or self.whisper_model is None:
raise RuntimeError("Whisper model not loaded")
waveform, sr = torchaudio.load(audio_path)
if sr != 16000:
waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform)
sr = 16000
inputs = self.whisper_processor(
waveform.squeeze().numpy(),
sampling_rate=sr,
return_tensors="pt",
return_attention_mask=True,
language=whisper_language_arg,
task="transcribe"
).to(DEVICE)
with torch.no_grad():
generated_ids = self.whisper_model.generate(
input_features=inputs["input_features"],
attention_mask=inputs.get("attention_mask"),
)
decoded = self.whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)
return decoded[0].strip()
def translate_to_english(self, src_text, nllb_src_lang_tag):
if not src_text:
return ""
if not nllb_src_lang_tag:
return src_text
if self.nllb_tokenizer is None or self.nllb_model is None:
raise RuntimeError("NLLB model not loaded")
try:
self.nllb_tokenizer.src_lang = nllb_src_lang_tag
except Exception:
pass
inputs = self.nllb_tokenizer(src_text, return_tensors="pt").to(DEVICE)
forced_bos = None
try:
forced_bos = self.nllb_tokenizer.convert_tokens_to_ids("eng_Latn")
except Exception:
forced_bos = None
gen_kwargs = {
"max_length": 512,
"num_beams": 4,
"no_repeat_ngram_size": 2,
"early_stopping": True
}
if forced_bos is not None:
gen_kwargs["forced_bos_token_id"] = forced_bos
with torch.no_grad():
translated_tokens = self.nllb_model.generate(**inputs, **gen_kwargs)
translated = self.nllb_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
return translated.strip()
model_manager = ModelManager()
# ---------- REST endpoint ----------
@app.route("/transcribe", methods=["POST"])
def transcribe_endpoint():
if "audio" not in request.files:
return Response("No audio file provided", status=400, mimetype="text/plain")
audio_file = request.files["audio"]
language = (request.form.get("language") or request.args.get("language") or "english").lower()
if language not in LANG_MAP:
return Response(f"Unsupported language: {language}", status=400, mimetype="text/plain")
whisper_lang_arg, nllb_src_tag = LANG_MAP[language]
try:
model_manager.load()
except Exception as e:
return Response(f"Model loading failed: {e}", status=500, mimetype="text/plain")
tmp_fd, tmp_path = tempfile.mkstemp(suffix=Path(audio_file.filename).suffix or ".wav")
os.close(tmp_fd)
audio_file.save(tmp_path)
try:
transcription = model_manager.transcribe(tmp_path, whisper_language_arg=whisper_lang_arg)
if not transcription:
return Response("", status=204, mimetype="text/plain")
translation = model_manager.translate_to_english(transcription, nllb_src_tag)
return Response(translation, status=200, mimetype="text/plain")
except Exception as e:
return Response(f"Processing failed: {e}", status=500, mimetype="text/plain")
finally:
try:
os.remove(tmp_path)
except Exception:
pass
# ---------- Robust Gradio UI mount ----------
gradio_mounted = False
if os.environ.get("DISABLE_GRADIO", "0") != "1":
try:
import gradio as gr
import soundfile as sf
import numpy as np
def _ui_transcribe(audio, language):
if audio is None:
return "No audio", ""
audio_path = None
if isinstance(audio, str) and Path(audio).exists():
audio_path = audio
elif isinstance(audio, (tuple, list)) and len(audio) >= 2:
sr, data = audio[0], audio[1]
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
sf.write(tmp.name, data, sr)
audio_path = tmp.name
elif isinstance(audio, (np.ndarray,)) or hasattr(audio, "shape"):
sr = 16000
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
sf.write(tmp.name, audio, sr)
audio_path = tmp.name
else:
try:
audio_path = getattr(audio, "name", None)
except Exception:
audio_path = None
if not audio_path:
return "Unsupported audio format from Gradio", ""
try:
model_manager.load()
whisper_lang, nllb_tag = LANG_MAP.get(language.lower(), (None, None))
transcription = model_manager.transcribe(audio_path, whisper_language_arg=whisper_lang)
translation = model_manager.translate_to_english(transcription, nllb_tag)
return transcription, translation
finally:
try:
if audio_path and Path(audio_path).exists() and "/tmp" in str(audio_path):
os.remove(audio_path)
except Exception:
pass
demo = None
# Create components robustly across gradio versions
audio_component = None
dropdown_component = None
textbox_out1 = None
textbox_out2 = None
# Option A: modern simple API (gr.Audio)
try:
if hasattr(gr, "Audio"):
audio_component = gr.Audio(source="microphone", type="filepath")
elif hasattr(gr, "components") and hasattr(gr.components, "Audio"):
audio_component = gr.components.Audio(source="microphone", type="filepath")
except Exception:
audio_component = None
# Dropdown
try:
if hasattr(gr, "Dropdown"):
dropdown_component = gr.Dropdown(choices=list(LANG_MAP.keys()), value="english", label="Language")
elif hasattr(gr, "components") and hasattr(gr.components, "Dropdown"):
dropdown_component = gr.components.Dropdown(choices=list(LANG_MAP.keys()), value="english", label="Language")
except Exception:
dropdown_component = None
# Output textboxes
try:
if hasattr(gr, "Textbox"):
textbox_out1 = gr.Textbox(label="Transcription")
textbox_out2 = gr.Textbox(label="Translation (English)")
elif hasattr(gr, "components") and hasattr(gr.components, "Textbox"):
textbox_out1 = gr.components.Textbox(label="Transcription")
textbox_out2 = gr.components.Textbox(label="Translation (English)")
except Exception:
textbox_out1 = textbox_out2 = None
# If any component missing, try old 'inputs/outputs' API as final fallback
if audio_component is None or dropdown_component is None or textbox_out1 is None:
try:
if hasattr(gr, "inputs") and hasattr(gr, "inputs",):
audio_component = getattr(gr.inputs, "Audio")(source="microphone", type="filepath")
dropdown_component = getattr(gr.inputs, "Dropdown")(choices=list(LANG_MAP.keys()), default="english")
textbox_out1 = getattr(gr.outputs, "Textbox")()
textbox_out2 = getattr(gr.outputs, "Textbox")()
except Exception:
pass
# If we have required components, create the Interface
if audio_component is not None and dropdown_component is not None and textbox_out1 is not None:
try:
demo = gr.Interface(
fn=_ui_transcribe,
inputs=[audio_component, dropdown_component],
outputs=[textbox_out1, textbox_out2],
title="Multilingual Transcriber (server)"
)
except Exception as e:
print("Failed to create gr.Interface:", e)
demo = None
if demo is not None:
try:
app = gr.mount_gradio_app(app, demo, path="/ui")
gradio_mounted = True
print("Gradio mounted at /ui")
except Exception as e:
print("Failed to mount Gradio app:", e)
gradio_mounted = False
else:
print("Gradio demo not created; continuing without mounted UI.")
except Exception as e:
print("Gradio UI unavailable or failed to mount:", e)
gradio_mounted = False
else:
print("Gradio mounting disabled via DISABLE_GRADIO=1")
gradio_mounted = False
# Root endpoint
@app.route("/")
def index():
if gradio_mounted:
return redirect("/ui")
return Response("Server running. REST endpoint available at /transcribe", status=200, mimetype="text/plain")
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=False)