Senath commited on
Commit
9acc204
·
verified ·
1 Parent(s): 7ce9df0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -37
app.py CHANGED
@@ -2,58 +2,47 @@ import os
2
  import torch
3
  import torchaudio
4
  import gradio as gr
5
- from transformers import AutoProcessor, SeamlessM4TModel
 
 
 
 
 
6
 
7
  # Constants
8
  MODEL_NAME = "facebook/hf-seamless-m4t-medium"
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
- # Load model and processor
12
- processor = AutoProcessor.from_pretrained(MODEL_NAME)
13
- model = SeamlessM4TModel.from_pretrained(MODEL_NAME).to(device).eval()
 
14
 
15
  # Main translation function
16
- def translate(text_input, audio_input, source_lang, target_lang, auto_detect):
17
- src = None if auto_detect else source_lang
18
- translated_text = None
19
- translated_audio = None
20
-
21
- # If text input is provided
22
- if text_input:
23
- inputs = processor(text=text_input, src_lang=src, return_tensors="pt").to(device)
24
 
25
- # Generate translated speech
26
- speech = model.generate(**inputs, tgt_lang=target_lang)[0].cpu().numpy().squeeze()
27
- translated_audio = (16000, speech)
28
-
29
- # Generate translated text
30
- text_tokens = model.generate(**inputs, tgt_lang=target_lang, generate_speech=False)
31
- translated_text = processor.decode(text_tokens[0].tolist(), skip_special_tokens=True)
32
 
33
- # If audio input is provided
34
- elif audio_input:
35
- waveform, sr = torchaudio.load(audio_input)
36
- waveform = torchaudio.functional.resample(waveform, sr, 16000)
37
- inputs = processor(audios=waveform, src_lang=src, return_tensors="pt").to(device)
38
 
39
- # Generate translated speech
40
- speech = model.generate(**inputs, tgt_lang=target_lang)[0].cpu().numpy().squeeze()
41
- translated_audio = (16000, speech)
42
 
43
- # Generate translated text
44
- text_tokens = model.generate(**inputs, tgt_lang=target_lang, generate_speech=False)
45
- translated_text = processor.decode(text_tokens[0].tolist(), skip_special_tokens=True)
46
 
47
- if translated_text or translated_audio:
48
- return translated_text or "", translated_audio
49
- return "No input provided.", None
50
 
51
  # Gradio Interface
52
  iface = gr.Interface(
53
  fn=translate,
54
  inputs=[
55
- gr.Textbox(label="Input Text (optional)"),
56
- gr.Audio(type="filepath", label="Input Audio (optional)"),
57
  gr.Textbox(label="Source Language (e.g. eng)"),
58
  gr.Textbox(label="Target Language (e.g. fra)"),
59
  gr.Checkbox(label="Auto-detect source language")
@@ -62,9 +51,9 @@ iface = gr.Interface(
62
  gr.Textbox(label="Translated Text"),
63
  gr.Audio(label="Translated Speech")
64
  ],
65
- title="iVoice Translate (Text + Speech)"
66
  ).queue()
67
 
68
- # Launch app
69
  if __name__ == "__main__":
70
  iface.launch(server_name="0.0.0.0", share=True, server_port=int(os.environ.get("PORT", 7860)))
 
2
  import torch
3
  import torchaudio
4
  import gradio as gr
5
+ from transformers import (
6
+ AutoProcessor,
7
+ SeamlessM4TProcessor,
8
+ SeamlessM4TForTextToText,
9
+ SeamlessM4TForTextToSpeech
10
+ )
11
 
12
  # Constants
13
  MODEL_NAME = "facebook/hf-seamless-m4t-medium"
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
+ # Load processor and models
17
+ processor = SeamlessM4TProcessor.from_pretrained(MODEL_NAME)
18
+ t2t_model = SeamlessM4TForTextToText.from_pretrained(MODEL_NAME).to(device).eval()
19
+ t2s_model = SeamlessM4TForTextToSpeech.from_pretrained(MODEL_NAME).to(device).eval()
20
 
21
  # Main translation function
22
+ def translate(text_input, source_lang, target_lang, auto_detect):
23
+ if not text_input:
24
+ return "No input text provided.", None
 
 
 
 
 
25
 
26
+ src = None if auto_detect else source_lang
 
 
 
 
 
 
27
 
28
+ # Prepare input
29
+ inputs = processor(text=text_input, src_lang=src, return_tensors="pt").to(device)
 
 
 
30
 
31
+ # Text-to-Text
32
+ text_tokens = t2t_model.generate(**inputs, tgt_lang=target_lang)
33
+ translated_text = processor.decode(text_tokens[0].tolist(), skip_special_tokens=True)
34
 
35
+ # Text-to-Speech
36
+ speech_waveform = t2s_model.generate(**inputs, tgt_lang=target_lang)[0].cpu().numpy().squeeze()
37
+ translated_audio = (16000, speech_waveform)
38
 
39
+ return translated_text, translated_audio
 
 
40
 
41
  # Gradio Interface
42
  iface = gr.Interface(
43
  fn=translate,
44
  inputs=[
45
+ gr.Textbox(label="Input Text"),
 
46
  gr.Textbox(label="Source Language (e.g. eng)"),
47
  gr.Textbox(label="Target Language (e.g. fra)"),
48
  gr.Checkbox(label="Auto-detect source language")
 
51
  gr.Textbox(label="Translated Text"),
52
  gr.Audio(label="Translated Speech")
53
  ],
54
+ title="iVoice Translate (T2T + T2S)"
55
  ).queue()
56
 
57
+ # Launch
58
  if __name__ == "__main__":
59
  iface.launch(server_name="0.0.0.0", share=True, server_port=int(os.environ.get("PORT", 7860)))