Senath commited on
Commit
8ab6697
·
verified ·
1 Parent(s): dbf8c6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -30
app.py CHANGED
@@ -1,48 +1,58 @@
1
- import gradio as gr
2
- import torchaudio
3
  import torch
 
 
4
  from transformers import AutoProcessor, SeamlessM4TModel
5
 
6
- # Load model and processor
7
- model = SeamlessM4TModel.from_pretrained("facebook/hf-seamless-m4t-medium")
8
- processor = AutoProcessor.from_pretrained("facebook/hf-seamless-m4t-medium")
 
 
9
 
10
- def translate(text_input, audio_file, target_lang):
11
- results = []
 
12
 
 
13
  if text_input:
14
- text_inputs = processor(text=text_input, return_tensors="pt")
15
- audio_out = model.generate(**text_inputs, tgt_lang=target_lang)[0].cpu().numpy().squeeze()
16
- results.append(("Translated from text", audio_out))
17
-
18
- if audio_file:
19
- audio_waveform, sr = torchaudio.load(audio_file)
20
- audio_waveform = torchaudio.functional.resample(audio_waveform, sr, 16000)
21
- audio_inputs = processor(audios=audio_waveform, return_tensors="pt")
22
- audio_out = model.generate(**audio_inputs, tgt_lang=target_lang)[0].cpu().numpy().squeeze()
23
- results.append(("Translated from audio", audio_out))
24
-
25
- if results:
26
- combined_text = "\n".join([r[0] for r in results])
27
- combined_audio = results[0][1]
28
- return combined_text, (16000, combined_audio)
 
 
 
 
29
 
30
  return "No input provided.", None
31
 
32
- demo = gr.Interface(
33
  fn=translate,
34
  inputs=[
35
- gr.Textbox(label="Input Text", placeholder="Enter text to translate (optional)"),
36
  gr.Audio(type="filepath", label="Input Audio (optional)"),
37
- gr.Dropdown(choices=["eng", "hin", "spa", "fra", "por"], label="Target Language", value="hin")
 
 
38
  ],
39
  outputs=[
40
- gr.Textbox(label="Translation Info"),
41
  gr.Audio(label="Translated Speech")
42
  ],
43
- title="SeamlessM4T Translation (Text & Audio)",
44
- description="Upload audio or enter text, pick a target language, and get translated text + speech."
45
- )
46
 
47
  if __name__ == "__main__":
48
- demo.launch()
 
1
+ import os
 
2
  import torch
3
+ import torchaudio
4
+ import gradio as gr
5
  from transformers import AutoProcessor, SeamlessM4TModel
6
 
7
+ MODEL_NAME = "facebook/hf-seamless-m4t-medium"
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ processor = AutoProcessor.from_pretrained(MODEL_NAME)
11
+ model = SeamlessM4TModel.from_pretrained(MODEL_NAME).to(device).eval()
12
 
13
+ def translate(text_input, audio_input, source_lang, target_lang, auto_detect):
14
+ outputs = []
15
+ src = None if auto_detect else source_lang
16
 
17
+ # From text input
18
  if text_input:
19
+ inputs = processor(text=text_input, src_lang=src, return_tensors="pt").to(device)
20
+ output = model.generate(**inputs, tgt_lang=target_lang)
21
+ text_out = processor.decode(output[0].tolist(), skip_special_tokens=True)
22
+ speech_out = output[1].cpu().numpy().squeeze()
23
+ outputs.append((f"Text translated", text_out, (16000, speech_out)))
24
+
25
+ # From audio input
26
+ elif audio_input:
27
+ waveform, sr = torchaudio.load(audio_input)
28
+ waveform = torchaudio.functional.resample(waveform, sr, 16000)
29
+ inputs = processor(audios=waveform, src_lang=src, return_tensors="pt").to(device)
30
+ output = model.generate(**inputs, tgt_lang=target_lang)
31
+ text_out = processor.decode(output[0].tolist(), skip_special_tokens=True)
32
+ speech_out = output[1].cpu().numpy().squeeze()
33
+ outputs.append((f"Audio translated", text_out, (16000, speech_out)))
34
+
35
+ if outputs:
36
+ _, txt, aud = outputs[0]
37
+ return txt, aud
38
 
39
  return "No input provided.", None
40
 
41
+ iface = gr.Interface(
42
  fn=translate,
43
  inputs=[
44
+ gr.Textbox(label="Input Text (optional)"),
45
  gr.Audio(type="filepath", label="Input Audio (optional)"),
46
+ gr.Textbox(label="Source Language (e.g. eng)"),
47
+ gr.Textbox(label="Target Language (e.g. fra)"),
48
+ gr.Checkbox(label="Auto-detect source language")
49
  ],
50
  outputs=[
51
+ gr.Textbox(label="Translated Text"),
52
  gr.Audio(label="Translated Speech")
53
  ],
54
+ title="iVoice Translate (Text + Speech)"
55
+ ).queue()
 
56
 
57
  if __name__ == "__main__":
58
+ iface.launch(server_name="0.0.0.0", share=True, server_port=int(os.environ.get("PORT", 7860)))