STTR commited on
Commit
df4ae9b
Β·
1 Parent(s): 87733fb

Add SeamlessM4T v2 Large STT + NLLB-200 with T4 GPU

Browse files
Files changed (3) hide show
  1. README.md +19 -7
  2. app.py +95 -94
  3. requirements.txt +2 -1
README.md CHANGED
@@ -1,12 +1,24 @@
1
  ---
2
- title: STTR
3
- emoji: πŸ‘
4
- colorFrom: gray
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 6.2.0
8
  app_file: app.py
9
- pinned: false
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: STTR - Speech Translation
3
+ emoji: 🌍
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: "4.44.0"
8
  app_file: app.py
9
+ pinned: true
10
+ license: mit
11
+ hardware: t4-small
12
  ---
13
 
14
+ # 🌍 STTR - Speech-to-Text & Translation API
15
+
16
+ **Meta AI Models:**
17
+ - 🎀 **SeamlessM4T v2 Large** - STT (101 languages)
18
+ - 🌍 **NLLB-200** - Translation (200 languages + Darija!)
19
+ - 🎭 **SeamlessExpressive** - Expressive Speech Translation
20
+
21
+ **API Endpoints:**
22
+ - `/stt` - Speech-to-Text
23
+ - `/translate` - Text Translation
24
+ - `/expressive` - Expressive Speech Translation
app.py CHANGED
@@ -1,28 +1,41 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, SeamlessM4Tv2ForSpeechToText, AutoModelForSeq2SeqLM, AutoTokenizer
 
 
 
 
 
3
  import torch
4
  import numpy as np
5
 
6
  # ============================================================
7
- # πŸš€ Load Models
8
  # ============================================================
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  print(f"πŸ–₯️ Device: {device}")
12
 
 
 
 
 
13
  # SeamlessM4T v2 Large for STT
14
  print("πŸ“₯ Loading SeamlessM4T v2 Large...")
15
- stt_model_name = "facebook/seamless-m4t-v2-large"
16
- stt_processor = AutoProcessor.from_pretrained(stt_model_name)
17
- stt_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(stt_model_name).to(device)
18
- print("βœ… SeamlessM4T v2 Large loaded")
 
 
19
 
20
  # NLLB-200 for Translation
21
  print("πŸ“₯ Loading NLLB-200...")
22
- nllb_model_name = "facebook/nllb-200-distilled-600M"
23
- nllb_tokenizer = AutoTokenizer.from_pretrained(nllb_model_name)
24
- nllb_model = AutoModelForSeq2SeqLM.from_pretrained(nllb_model_name).to(device)
25
- print("βœ… NLLB-200 loaded")
 
 
26
 
27
  print("πŸŽ‰ All models ready!")
28
 
@@ -31,122 +44,110 @@ print("πŸŽ‰ All models ready!")
31
  # ============================================================
32
 
33
  NLLB_LANGS = {
34
- "English": "eng_Latn",
35
- "French": "fra_Latn",
36
- "Arabic": "arb_Arab",
37
- "Moroccan Arabic": "ary_Arab",
38
- "Spanish": "spa_Latn",
39
- "German": "deu_Latn",
40
- "Italian": "ita_Latn",
41
- "Portuguese": "por_Latn",
42
- "Chinese": "zho_Hans",
43
- "Japanese": "jpn_Jpan",
44
- "Korean": "kor_Hang",
45
- "Russian": "rus_Cyrl",
46
- "Turkish": "tur_Latn",
47
- "Dutch": "nld_Latn",
48
- "Hindi": "hin_Deva",
49
  }
50
 
51
  STT_LANGS = {
52
- "English": "eng",
53
- "French": "fra",
54
- "Arabic": "arb",
55
- "Spanish": "spa",
56
- "German": "deu",
57
- "Italian": "ita",
58
- "Portuguese": "por",
59
- "Chinese": "cmn",
60
- "Japanese": "jpn",
61
- "Korean": "kor",
62
- "Russian": "rus",
63
- "Turkish": "tur",
64
- "Dutch": "nld",
65
- "Hindi": "hin",
66
  }
67
 
68
  # ============================================================
69
- # STT Function (SeamlessM4T v2 Large)
70
  # ============================================================
71
 
72
  def stt(audio, src_lang):
73
  """Speech-to-Text using SeamlessM4T v2 Large"""
74
  if audio is None:
75
- return ""
76
-
77
- # Handle tuple input from Gradio
78
- if isinstance(audio, tuple):
79
- sample_rate, audio_data = audio
80
- audio_data = audio_data.astype(np.float32)
81
- if audio_data.max() > 1.0:
82
- audio_data = audio_data / 32768.0
83
- else:
84
- return "Error: Invalid audio format"
85
-
86
- src_code = STT_LANGS.get(src_lang, "eng")
87
-
88
- inputs = stt_processor(
89
- audios=audio_data,
90
- sampling_rate=sample_rate,
91
- return_tensors="pt"
92
- ).to(device)
93
 
94
- with torch.no_grad():
95
- output_tokens = stt_model.generate(
96
- **inputs,
97
- tgt_lang=src_code,
98
- generate_speech=False
99
- )
100
-
101
- text = stt_processor.decode(output_tokens[0], skip_special_tokens=True)
102
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  # ============================================================
105
- # Translation Function (NLLB-200)
106
  # ============================================================
107
 
108
  def translate(text, src_lang, tgt_lang):
109
  """Translation using NLLB-200"""
110
- if not text.strip():
111
  return ""
112
 
113
- src_code = NLLB_LANGS.get(src_lang, "eng_Latn")
114
- tgt_code = NLLB_LANGS.get(tgt_lang, "fra_Latn")
115
-
116
- nllb_tokenizer.src_lang = src_code
117
- inputs = nllb_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
118
-
119
- forced_bos_token_id = nllb_tokenizer.convert_tokens_to_ids(tgt_code)
120
-
121
- with torch.no_grad():
122
- outputs = nllb_model.generate(**inputs, forced_bos_token_id=forced_bos_token_id, max_length=512, num_beams=5)
123
-
124
- return nllb_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
125
 
126
  # ============================================================
127
  # Gradio Interface
128
  # ============================================================
129
 
130
- with gr.Blocks(title="STTR - Speech & Translation API") as demo:
131
- gr.Markdown("# 🌍 STTR - Speech-to-Text & Translation API")
132
- gr.Markdown("**SeamlessM4T v2 Large** for STT + **NLLB-200** for Translation")
133
 
134
- with gr.Tab("🎀 STT (Speech-to-Text)"):
135
- with gr.Row():
136
- stt_audio = gr.Audio(label="Record/Upload Audio", type="numpy")
137
- stt_lang = gr.Dropdown(list(STT_LANGS.keys()), label="Language", value="English")
138
  stt_output = gr.Textbox(label="Transcription", lines=3)
139
  stt_btn = gr.Button("🎀 Transcribe", variant="primary")
140
- stt_btn.click(stt, inputs=[stt_audio, stt_lang], outputs=stt_output, api_name="stt")
141
 
142
  with gr.Tab("🌍 Translation"):
 
143
  with gr.Row():
144
- trans_text = gr.Textbox(label="Text to translate", lines=3)
145
- with gr.Row():
146
- trans_src = gr.Dropdown(list(NLLB_LANGS.keys()), label="Source", value="English")
147
- trans_tgt = gr.Dropdown(list(NLLB_LANGS.keys()), label="Target", value="French")
148
  trans_output = gr.Textbox(label="Translation", lines=3)
149
  trans_btn = gr.Button("🌍 Translate", variant="primary")
150
- trans_btn.click(translate, inputs=[trans_text, trans_src, trans_tgt], outputs=trans_output, api_name="translate")
151
 
152
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import (
3
+ AutoProcessor,
4
+ SeamlessM4Tv2ForSpeechToText,
5
+ AutoModelForSeq2SeqLM,
6
+ AutoTokenizer
7
+ )
8
  import torch
9
  import numpy as np
10
 
11
  # ============================================================
12
+ # πŸš€ Device Setup
13
  # ============================================================
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  print(f"πŸ–₯️ Device: {device}")
17
 
18
+ # ============================================================
19
+ # πŸ“₯ Load Models
20
+ # ============================================================
21
+
22
  # SeamlessM4T v2 Large for STT
23
  print("πŸ“₯ Loading SeamlessM4T v2 Large...")
24
+ STT_MODEL = "facebook/seamless-m4t-v2-large"
25
+ stt_processor = AutoProcessor.from_pretrained(STT_MODEL)
26
+ stt_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(STT_MODEL)
27
+ stt_model = stt_model.to(device)
28
+ stt_model.eval()
29
+ print("βœ… SeamlessM4T v2 Large loaded!")
30
 
31
  # NLLB-200 for Translation
32
  print("πŸ“₯ Loading NLLB-200...")
33
+ NLLB_MODEL = "facebook/nllb-200-distilled-600M"
34
+ nllb_tokenizer = AutoTokenizer.from_pretrained(NLLB_MODEL)
35
+ nllb_model = AutoModelForSeq2SeqLM.from_pretrained(NLLB_MODEL)
36
+ nllb_model = nllb_model.to(device)
37
+ nllb_model.eval()
38
+ print("βœ… NLLB-200 loaded!")
39
 
40
  print("πŸŽ‰ All models ready!")
41
 
 
44
  # ============================================================
45
 
46
  NLLB_LANGS = {
47
+ "English": "eng_Latn", "French": "fra_Latn", "Arabic": "arb_Arab",
48
+ "Moroccan Arabic": "ary_Arab", "Spanish": "spa_Latn", "German": "deu_Latn",
49
+ "Italian": "ita_Latn", "Portuguese": "por_Latn", "Chinese": "zho_Hans",
50
+ "Japanese": "jpn_Jpan", "Korean": "kor_Hang", "Russian": "rus_Cyrl",
51
+ "Turkish": "tur_Latn", "Dutch": "nld_Latn", "Hindi": "hin_Deva",
 
 
 
 
 
 
 
 
 
 
52
  }
53
 
54
  STT_LANGS = {
55
+ "English": "eng", "French": "fra", "Arabic": "arb", "Spanish": "spa",
56
+ "German": "deu", "Italian": "ita", "Portuguese": "por", "Chinese": "cmn",
57
+ "Japanese": "jpn", "Korean": "kor", "Russian": "rus", "Turkish": "tur",
58
+ "Dutch": "nld", "Hindi": "hin",
 
 
 
 
 
 
 
 
 
 
59
  }
60
 
61
  # ============================================================
62
+ # STT Function
63
  # ============================================================
64
 
65
  def stt(audio, src_lang):
66
  """Speech-to-Text using SeamlessM4T v2 Large"""
67
  if audio is None:
68
+ return "No audio provided"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ try:
71
+ if isinstance(audio, tuple):
72
+ sample_rate, audio_data = audio
73
+ audio_data = audio_data.astype(np.float32)
74
+ if np.abs(audio_data).max() > 1.0:
75
+ audio_data = audio_data / 32768.0
76
+ else:
77
+ return "Invalid audio format"
78
+
79
+ src_code = STT_LANGS.get(src_lang, "eng")
80
+
81
+ inputs = stt_processor(
82
+ audios=audio_data,
83
+ sampling_rate=sample_rate,
84
+ return_tensors="pt"
85
+ ).to(device)
86
+
87
+ with torch.no_grad():
88
+ output_tokens = stt_model.generate(
89
+ **inputs,
90
+ tgt_lang=src_code,
91
+ generate_speech=False
92
+ )
93
+
94
+ text = stt_processor.decode(output_tokens[0].tolist(), skip_special_tokens=True)
95
+ return text
96
+ except Exception as e:
97
+ return f"Error: {str(e)}"
98
 
99
  # ============================================================
100
+ # Translation Function
101
  # ============================================================
102
 
103
  def translate(text, src_lang, tgt_lang):
104
  """Translation using NLLB-200"""
105
+ if not text or not text.strip():
106
  return ""
107
 
108
+ try:
109
+ src_code = NLLB_LANGS.get(src_lang, "eng_Latn")
110
+ tgt_code = NLLB_LANGS.get(tgt_lang, "fra_Latn")
111
+
112
+ nllb_tokenizer.src_lang = src_code
113
+ inputs = nllb_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
114
+
115
+ forced_bos_token_id = nllb_tokenizer.convert_tokens_to_ids(tgt_code)
116
+
117
+ with torch.no_grad():
118
+ outputs = nllb_model.generate(
119
+ **inputs,
120
+ forced_bos_token_id=forced_bos_token_id,
121
+ max_length=512,
122
+ num_beams=5
123
+ )
124
+
125
+ return nllb_tokenizer.decode(outputs[0], skip_special_tokens=True)
126
+ except Exception as e:
127
+ return f"Error: {str(e)}"
128
 
129
  # ============================================================
130
  # Gradio Interface
131
  # ============================================================
132
 
133
+ with gr.Blocks(title="STTR API", theme=gr.themes.Soft()) as demo:
134
+ gr.Markdown("# 🌍 STTR - Speech & Translation API")
135
+ gr.Markdown("**SeamlessM4T v2 Large** + **NLLB-200** (200 languages + Darija!)")
136
 
137
+ with gr.Tab("🎀 Speech-to-Text"):
138
+ stt_audio = gr.Audio(label="Audio", type="numpy")
139
+ stt_lang = gr.Dropdown(list(STT_LANGS.keys()), label="Language", value="English")
 
140
  stt_output = gr.Textbox(label="Transcription", lines=3)
141
  stt_btn = gr.Button("🎀 Transcribe", variant="primary")
142
+ stt_btn.click(stt, [stt_audio, stt_lang], stt_output, api_name="stt")
143
 
144
  with gr.Tab("🌍 Translation"):
145
+ trans_text = gr.Textbox(label="Text", lines=3)
146
  with gr.Row():
147
+ trans_src = gr.Dropdown(list(NLLB_LANGS.keys()), label="From", value="English")
148
+ trans_tgt = gr.Dropdown(list(NLLB_LANGS.keys()), label="To", value="French")
 
 
149
  trans_output = gr.Textbox(label="Translation", lines=3)
150
  trans_btn = gr.Button("🌍 Translate", variant="primary")
151
+ trans_btn.click(translate, [trans_text, trans_src, trans_tgt], trans_output, api_name="translate")
152
 
153
  demo.launch()
requirements.txt CHANGED
@@ -1,8 +1,9 @@
1
  transformers>=4.40.0
2
  torch>=2.0.0
 
3
  sentencepiece
4
  protobuf
5
  gradio>=4.0.0
6
  numpy
7
  scipy
8
- torchaudio
 
1
  transformers>=4.40.0
2
  torch>=2.0.0
3
+ torchaudio
4
  sentencepiece
5
  protobuf
6
  gradio>=4.0.0
7
  numpy
8
  scipy
9
+ accelerate