STTR commited on
Commit
87733fb
Β·
1 Parent(s): 448a6e3

Add SeamlessM4T v2 Large STT + NLLB-200

Browse files
Files changed (2) hide show
  1. app.py +120 -30
  2. requirements.txt +5 -2
app.py CHANGED
@@ -1,18 +1,36 @@
1
  import gradio as gr
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import torch
 
 
 
 
 
4
 
5
- # Load NLLB-200 (distilled for speed)
6
- MODEL_NAME = "facebook/nllb-200-distilled-600M"
7
- print(f"Loading {MODEL_NAME}...")
8
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
- model = model.to(device)
12
- print(f"Model loaded on {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Language codes
15
- LANGS = {
 
 
 
 
 
16
  "English": "eng_Latn",
17
  "French": "fra_Latn",
18
  "Arabic": "arb_Arab",
@@ -30,33 +48,105 @@ LANGS = {
30
  "Hindi": "hin_Deva",
31
  }
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def translate(text, src_lang, tgt_lang):
 
34
  if not text.strip():
35
  return ""
36
 
37
- src_code = LANGS.get(src_lang, "eng_Latn")
38
- tgt_code = LANGS.get(tgt_lang, "fra_Latn")
39
 
40
- tokenizer.src_lang = src_code
41
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
42
 
43
- forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code)
44
 
45
  with torch.no_grad():
46
- outputs = model.generate(**inputs, forced_bos_token_id=forced_bos_token_id, max_length=512, num_beams=5)
47
-
48
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
49
-
50
- demo = gr.Interface(
51
- fn=translate,
52
- inputs=[
53
- gr.Textbox(label="Text to translate", lines=3),
54
- gr.Dropdown(list(LANGS.keys()), label="Source Language", value="English"),
55
- gr.Dropdown(list(LANGS.keys()), label="Target Language", value="French"),
56
- ],
57
- outputs=gr.Textbox(label="Translation", lines=3),
58
- title="NLLB-200 Translation API",
59
- description="200 languages including Moroccan Arabic!",
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, SeamlessM4Tv2ForSpeechToText, AutoModelForSeq2SeqLM, AutoTokenizer
3
  import torch
4
+ import numpy as np
5
+
6
+ # ============================================================
7
+ # πŸš€ Load Models
8
+ # ============================================================
9
 
 
 
 
 
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ print(f"πŸ–₯️ Device: {device}")
12
+
13
+ # SeamlessM4T v2 Large for STT
14
+ print("πŸ“₯ Loading SeamlessM4T v2 Large...")
15
+ stt_model_name = "facebook/seamless-m4t-v2-large"
16
+ stt_processor = AutoProcessor.from_pretrained(stt_model_name)
17
+ stt_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(stt_model_name).to(device)
18
+ print("βœ… SeamlessM4T v2 Large loaded")
19
+
20
+ # NLLB-200 for Translation
21
+ print("πŸ“₯ Loading NLLB-200...")
22
+ nllb_model_name = "facebook/nllb-200-distilled-600M"
23
+ nllb_tokenizer = AutoTokenizer.from_pretrained(nllb_model_name)
24
+ nllb_model = AutoModelForSeq2SeqLM.from_pretrained(nllb_model_name).to(device)
25
+ print("βœ… NLLB-200 loaded")
26
 
27
+ print("πŸŽ‰ All models ready!")
28
+
29
+ # ============================================================
30
+ # Language Codes
31
+ # ============================================================
32
+
33
+ NLLB_LANGS = {
34
  "English": "eng_Latn",
35
  "French": "fra_Latn",
36
  "Arabic": "arb_Arab",
 
48
  "Hindi": "hin_Deva",
49
  }
50
 
51
+ STT_LANGS = {
52
+ "English": "eng",
53
+ "French": "fra",
54
+ "Arabic": "arb",
55
+ "Spanish": "spa",
56
+ "German": "deu",
57
+ "Italian": "ita",
58
+ "Portuguese": "por",
59
+ "Chinese": "cmn",
60
+ "Japanese": "jpn",
61
+ "Korean": "kor",
62
+ "Russian": "rus",
63
+ "Turkish": "tur",
64
+ "Dutch": "nld",
65
+ "Hindi": "hin",
66
+ }
67
+
68
+ # ============================================================
69
+ # STT Function (SeamlessM4T v2 Large)
70
+ # ============================================================
71
+
72
+ def stt(audio, src_lang):
73
+ """Speech-to-Text using SeamlessM4T v2 Large"""
74
+ if audio is None:
75
+ return ""
76
+
77
+ # Handle tuple input from Gradio
78
+ if isinstance(audio, tuple):
79
+ sample_rate, audio_data = audio
80
+ audio_data = audio_data.astype(np.float32)
81
+ if audio_data.max() > 1.0:
82
+ audio_data = audio_data / 32768.0
83
+ else:
84
+ return "Error: Invalid audio format"
85
+
86
+ src_code = STT_LANGS.get(src_lang, "eng")
87
+
88
+ inputs = stt_processor(
89
+ audios=audio_data,
90
+ sampling_rate=sample_rate,
91
+ return_tensors="pt"
92
+ ).to(device)
93
+
94
+ with torch.no_grad():
95
+ output_tokens = stt_model.generate(
96
+ **inputs,
97
+ tgt_lang=src_code,
98
+ generate_speech=False
99
+ )
100
+
101
+ text = stt_processor.decode(output_tokens[0], skip_special_tokens=True)
102
+ return text
103
+
104
+ # ============================================================
105
+ # Translation Function (NLLB-200)
106
+ # ============================================================
107
+
108
  def translate(text, src_lang, tgt_lang):
109
+ """Translation using NLLB-200"""
110
  if not text.strip():
111
  return ""
112
 
113
+ src_code = NLLB_LANGS.get(src_lang, "eng_Latn")
114
+ tgt_code = NLLB_LANGS.get(tgt_lang, "fra_Latn")
115
 
116
+ nllb_tokenizer.src_lang = src_code
117
+ inputs = nllb_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
118
 
119
+ forced_bos_token_id = nllb_tokenizer.convert_tokens_to_ids(tgt_code)
120
 
121
  with torch.no_grad():
122
+ outputs = nllb_model.generate(**inputs, forced_bos_token_id=forced_bos_token_id, max_length=512, num_beams=5)
123
+
124
+ return nllb_tokenizer.decode(outputs[0], skip_special_tokens=True)
125
+
126
+ # ============================================================
127
+ # Gradio Interface
128
+ # ============================================================
129
+
130
+ with gr.Blocks(title="STTR - Speech & Translation API") as demo:
131
+ gr.Markdown("# 🌍 STTR - Speech-to-Text & Translation API")
132
+ gr.Markdown("**SeamlessM4T v2 Large** for STT + **NLLB-200** for Translation")
133
+
134
+ with gr.Tab("🎀 STT (Speech-to-Text)"):
135
+ with gr.Row():
136
+ stt_audio = gr.Audio(label="Record/Upload Audio", type="numpy")
137
+ stt_lang = gr.Dropdown(list(STT_LANGS.keys()), label="Language", value="English")
138
+ stt_output = gr.Textbox(label="Transcription", lines=3)
139
+ stt_btn = gr.Button("🎀 Transcribe", variant="primary")
140
+ stt_btn.click(stt, inputs=[stt_audio, stt_lang], outputs=stt_output, api_name="stt")
141
+
142
+ with gr.Tab("🌍 Translation"):
143
+ with gr.Row():
144
+ trans_text = gr.Textbox(label="Text to translate", lines=3)
145
+ with gr.Row():
146
+ trans_src = gr.Dropdown(list(NLLB_LANGS.keys()), label="Source", value="English")
147
+ trans_tgt = gr.Dropdown(list(NLLB_LANGS.keys()), label="Target", value="French")
148
+ trans_output = gr.Textbox(label="Translation", lines=3)
149
+ trans_btn = gr.Button("🌍 Translate", variant="primary")
150
+ trans_btn.click(translate, inputs=[trans_text, trans_src, trans_tgt], outputs=trans_output, api_name="translate")
151
 
152
  demo.launch()
requirements.txt CHANGED
@@ -1,5 +1,8 @@
1
- transformers>=4.30.0
2
  torch>=2.0.0
3
  sentencepiece
4
  protobuf
5
- gradio
 
 
 
 
1
+ transformers>=4.40.0
2
  torch>=2.0.0
3
  sentencepiece
4
  protobuf
5
+ gradio>=4.0.0
6
+ numpy
7
+ scipy
8
+ torchaudio