File size: 11,388 Bytes
d7c90a6
 
 
e6584a7
 
d7c90a6
 
 
e6584a7
 
d7c90a6
 
 
 
 
 
 
 
e6584a7
 
d7c90a6
 
e6584a7
d7c90a6
 
 
e6584a7
d7c90a6
 
 
ddb6bb8
d7c90a6
 
 
 
ddb6bb8
d7c90a6
 
 
 
 
 
 
 
 
 
 
e6584a7
 
 
 
 
 
d7c90a6
e6584a7
 
 
 
 
 
d7c90a6
 
e6584a7
d7c90a6
 
 
e6584a7
d7c90a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6584a7
 
d7c90a6
 
e6584a7
d7c90a6
e6584a7
d7c90a6
 
 
 
 
 
 
e6584a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7c90a6
 
e6584a7
 
 
d7c90a6
 
 
 
 
 
 
 
 
 
 
e6584a7
d7c90a6
 
 
 
 
 
 
 
 
 
e6584a7
d7c90a6
 
 
 
 
 
e6584a7
d7c90a6
 
 
 
 
 
 
 
 
 
 
ddb6bb8
e6584a7
ddb6bb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6584a7
ddb6bb8
 
 
 
 
 
 
 
 
 
 
e6584a7
ddb6bb8
 
 
 
 
 
e6584a7
ddb6bb8
e6584a7
ddb6bb8
 
 
 
 
 
e6584a7
ddb6bb8
e6584a7
ddb6bb8
 
 
 
 
 
e6584a7
ddb6bb8
e6584a7
ddb6bb8
 
 
 
 
 
 
 
e6584a7
ddb6bb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6584a7
d7c90a6
ddb6bb8
e6584a7
 
 
 
 
d7c90a6
 
e6584a7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
# app.py
import os
import tempfile
from pathlib import Path
from flask import Flask, request, Response, redirect
from flask_cors import CORS
import torch
import torchaudio

# Transformers imports (lazy loaded in ModelManager.load to reduce startup overhead)
from transformers import (
    AutoProcessor,
    AutoModelForSpeechSeq2Seq,
    AutoTokenizer,
    AutoModelForSeq2SeqLM
)

# ---------- Configuration ----------
WHISPER_MODEL = os.environ.get("WHISPER_MODEL", "openai/whisper-small")
NLLB_MODEL = os.environ.get("NLLB_MODEL", "facebook/nllb-200-distilled-600M")

LANG_MAP = {
    "akan": (None, "aka_Latn"),
    "hausa": ("ha", "hau_Latn"),
    "swahili": ("sw", "swh_Latn"),
    "french": ("fr", "fra_Latn"),
    "arabic": ("ar", "arb_Arab"),
    "english": ("en", None),
}

DEVICE = torch.device("cpu")  # Free HF Spaces = CPU

app = Flask(__name__)
CORS(app)

# ---------- Model manager ----------
class ModelManager:
    def __init__(self):
        self.whisper_processor = None
        self.whisper_model = None
        self.nllb_tokenizer = None
        self.nllb_model = None
        self._loaded = False

    def load(self):
        if self._loaded:
            return
        print(f"Loading Whisper model: {WHISPER_MODEL}")
        try:
            self.whisper_processor = AutoProcessor.from_pretrained(WHISPER_MODEL)
            self.whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(WHISPER_MODEL).to(DEVICE)
        except Exception as e:
            raise RuntimeError(f"Failed to load Whisper model ({WHISPER_MODEL}): {e}")

        print(f"Loading NLLB tokenizer/model: {NLLB_MODEL}")
        try:
            self.nllb_tokenizer = AutoTokenizer.from_pretrained(NLLB_MODEL)
            self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(NLLB_MODEL).to(DEVICE)
        except Exception as e:
            raise RuntimeError(f"Failed to load NLLB model ({NLLB_MODEL}): {e}")

        self._loaded = True
        print("Models loaded successfully.")

    def transcribe(self, audio_path, whisper_language_arg=None):
        if self.whisper_processor is None or self.whisper_model is None:
            raise RuntimeError("Whisper model not loaded")

        waveform, sr = torchaudio.load(audio_path)
        if sr != 16000:
            waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform)
            sr = 16000

        inputs = self.whisper_processor(
            waveform.squeeze().numpy(),
            sampling_rate=sr,
            return_tensors="pt",
            return_attention_mask=True,
            language=whisper_language_arg,
            task="transcribe"
        ).to(DEVICE)

        with torch.no_grad():
            generated_ids = self.whisper_model.generate(
                input_features=inputs["input_features"],
                attention_mask=inputs.get("attention_mask"),
            )
        decoded = self.whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)
        return decoded[0].strip()

    def translate_to_english(self, src_text, nllb_src_lang_tag):
        if not src_text:
            return ""
        if not nllb_src_lang_tag:
            return src_text

        if self.nllb_tokenizer is None or self.nllb_model is None:
            raise RuntimeError("NLLB model not loaded")

        try:
            self.nllb_tokenizer.src_lang = nllb_src_lang_tag
        except Exception:
            pass

        inputs = self.nllb_tokenizer(src_text, return_tensors="pt").to(DEVICE)

        forced_bos = None
        try:
            forced_bos = self.nllb_tokenizer.convert_tokens_to_ids("eng_Latn")
        except Exception:
            forced_bos = None

        gen_kwargs = {
            "max_length": 512,
            "num_beams": 4,
            "no_repeat_ngram_size": 2,
            "early_stopping": True
        }
        if forced_bos is not None:
            gen_kwargs["forced_bos_token_id"] = forced_bos

        with torch.no_grad():
            translated_tokens = self.nllb_model.generate(**inputs, **gen_kwargs)
        translated = self.nllb_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
        return translated.strip()

model_manager = ModelManager()

# ---------- REST endpoint ----------
@app.route("/transcribe", methods=["POST"])
def transcribe_endpoint():
    if "audio" not in request.files:
        return Response("No audio file provided", status=400, mimetype="text/plain")

    audio_file = request.files["audio"]
    language = (request.form.get("language") or request.args.get("language") or "english").lower()

    if language not in LANG_MAP:
        return Response(f"Unsupported language: {language}", status=400, mimetype="text/plain")

    whisper_lang_arg, nllb_src_tag = LANG_MAP[language]

    try:
        model_manager.load()
    except Exception as e:
        return Response(f"Model loading failed: {e}", status=500, mimetype="text/plain")

    tmp_fd, tmp_path = tempfile.mkstemp(suffix=Path(audio_file.filename).suffix or ".wav")
    os.close(tmp_fd)
    audio_file.save(tmp_path)

    try:
        transcription = model_manager.transcribe(tmp_path, whisper_language_arg=whisper_lang_arg)
        if not transcription:
            return Response("", status=204, mimetype="text/plain")

        translation = model_manager.translate_to_english(transcription, nllb_src_tag)
        return Response(translation, status=200, mimetype="text/plain")
    except Exception as e:
        return Response(f"Processing failed: {e}", status=500, mimetype="text/plain")
    finally:
        try:
            os.remove(tmp_path)
        except Exception:
            pass

# ---------- Robust Gradio UI mount ----------
gradio_mounted = False
if os.environ.get("DISABLE_GRADIO", "0") != "1":
    try:
        import gradio as gr
        import soundfile as sf
        import numpy as np

        def _ui_transcribe(audio, language):
            if audio is None:
                return "No audio", ""

            audio_path = None
            if isinstance(audio, str) and Path(audio).exists():
                audio_path = audio
            elif isinstance(audio, (tuple, list)) and len(audio) >= 2:
                sr, data = audio[0], audio[1]
                tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
                sf.write(tmp.name, data, sr)
                audio_path = tmp.name
            elif isinstance(audio, (np.ndarray,)) or hasattr(audio, "shape"):
                sr = 16000
                tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
                sf.write(tmp.name, audio, sr)
                audio_path = tmp.name
            else:
                try:
                    audio_path = getattr(audio, "name", None)
                except Exception:
                    audio_path = None

            if not audio_path:
                return "Unsupported audio format from Gradio", ""

            try:
                model_manager.load()
                whisper_lang, nllb_tag = LANG_MAP.get(language.lower(), (None, None))
                transcription = model_manager.transcribe(audio_path, whisper_language_arg=whisper_lang)
                translation = model_manager.translate_to_english(transcription, nllb_tag)
                return transcription, translation
            finally:
                try:
                    if audio_path and Path(audio_path).exists() and "/tmp" in str(audio_path):
                        os.remove(audio_path)
                except Exception:
                    pass

        demo = None
        # Create components robustly across gradio versions
        audio_component = None
        dropdown_component = None
        textbox_out1 = None
        textbox_out2 = None

        # Option A: modern simple API (gr.Audio)
        try:
            if hasattr(gr, "Audio"):
                audio_component = gr.Audio(source="microphone", type="filepath")
            elif hasattr(gr, "components") and hasattr(gr.components, "Audio"):
                audio_component = gr.components.Audio(source="microphone", type="filepath")
        except Exception:
            audio_component = None

        # Dropdown
        try:
            if hasattr(gr, "Dropdown"):
                dropdown_component = gr.Dropdown(choices=list(LANG_MAP.keys()), value="english", label="Language")
            elif hasattr(gr, "components") and hasattr(gr.components, "Dropdown"):
                dropdown_component = gr.components.Dropdown(choices=list(LANG_MAP.keys()), value="english", label="Language")
        except Exception:
            dropdown_component = None

        # Output textboxes
        try:
            if hasattr(gr, "Textbox"):
                textbox_out1 = gr.Textbox(label="Transcription")
                textbox_out2 = gr.Textbox(label="Translation (English)")
            elif hasattr(gr, "components") and hasattr(gr.components, "Textbox"):
                textbox_out1 = gr.components.Textbox(label="Transcription")
                textbox_out2 = gr.components.Textbox(label="Translation (English)")
        except Exception:
            textbox_out1 = textbox_out2 = None

        # If any component missing, try old 'inputs/outputs' API as final fallback
        if audio_component is None or dropdown_component is None or textbox_out1 is None:
            try:
                if hasattr(gr, "inputs") and hasattr(gr, "inputs",):
                    audio_component = getattr(gr.inputs, "Audio")(source="microphone", type="filepath")
                    dropdown_component = getattr(gr.inputs, "Dropdown")(choices=list(LANG_MAP.keys()), default="english")
                    textbox_out1 = getattr(gr.outputs, "Textbox")()
                    textbox_out2 = getattr(gr.outputs, "Textbox")()
            except Exception:
                pass

        # If we have required components, create the Interface
        if audio_component is not None and dropdown_component is not None and textbox_out1 is not None:
            try:
                demo = gr.Interface(
                    fn=_ui_transcribe,
                    inputs=[audio_component, dropdown_component],
                    outputs=[textbox_out1, textbox_out2],
                    title="Multilingual Transcriber (server)"
                )
            except Exception as e:
                print("Failed to create gr.Interface:", e)
                demo = None

        if demo is not None:
            try:
                app = gr.mount_gradio_app(app, demo, path="/ui")
                gradio_mounted = True
                print("Gradio mounted at /ui")
            except Exception as e:
                print("Failed to mount Gradio app:", e)
                gradio_mounted = False
        else:
            print("Gradio demo not created; continuing without mounted UI.")
    except Exception as e:
        print("Gradio UI unavailable or failed to mount:", e)
        gradio_mounted = False
else:
    print("Gradio mounting disabled via DISABLE_GRADIO=1")
    gradio_mounted = False

# Root endpoint
@app.route("/")
def index():
    if gradio_mounted:
        return redirect("/ui")
    return Response("Server running. REST endpoint available at /transcribe", status=200, mimetype="text/plain")

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860))
    app.run(host="0.0.0.0", port=port, debug=False)