Bhavibond commited on
Commit
b623c58
·
verified ·
1 Parent(s): 4a54047

use whisper-tiny and check processing speeds for low tiers

Browse files
Files changed (1) hide show
  1. app.py +33 -26
app.py CHANGED
@@ -1,56 +1,63 @@
1
  import gradio as gr
2
- from transformers import pipeline, SpeechT5Processor, SpeechT5ForTextToSpeech
3
  import torchaudio
4
  import torch
5
  from datasets import load_dataset
6
  import os
7
 
8
- # Load ASR and Translation models
9
- translator = pipeline("translation", model="Helsinki-NLP/opus-mt-en-mul", device=torch.device('cpu'))
10
- asr = pipeline("automatic-speech-recognition", model="openai/whisper-small", device=torch.device('cpu'))
11
 
12
- # Load TTS model and processor
 
 
 
 
 
 
 
 
13
  processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
14
- tts = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
15
 
16
- # Load speaker embeddings from dataset
17
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
18
- speaker_embeddings = embeddings_dataset[7306]["xvector"]
19
- speaker_embeddings = torch.tensor(speaker_embeddings).unsqueeze(0)
20
 
21
- # Ensure cache directory for output files
22
  os.makedirs("output", exist_ok=True)
23
 
24
- # Function to handle transcription, translation, TTS, and Braille file generation
25
  def process_audio(audio, target_language):
 
 
 
26
  try:
27
- if not audio:
28
- return "Error: No audio file provided.", None, None
29
-
30
  # Step 1: Transcribe the audio
31
  result = asr(audio)["text"]
 
32
  if not result:
33
  return "Error: Failed to transcribe audio.", None, None
34
 
35
  # Step 2: Translate the text
36
- translated_text = translator(result, tgt_lang=target_language)
37
- if isinstance(translated_text, list):
38
- translated_text = translated_text[0].get('translation_text', '')
39
- elif isinstance(translated_text, dict):
40
- translated_text = translated_text.get('translation_text', '')
41
  if not translated_text:
42
  return "Error: Translation failed.", None, None
43
 
44
  # Step 3: Generate speech from translated text
45
  inputs = processor(text=translated_text, return_tensors="pt")
46
- input_features = inputs.input_features
47
 
48
  with torch.no_grad():
49
  speech = tts.generate_speech(input_features, speaker_embeddings)
50
-
51
  # Save generated speech
52
  output_audio_path = "output/generated_speech.wav"
53
- torchaudio.save(output_audio_path, speech, 24000)
54
 
55
  # Step 4: Create Braille-compatible file
56
  braille_output_path = "output/braille.txt"
@@ -65,7 +72,6 @@ def process_audio(audio, target_language):
65
  # Define Gradio interface
66
  with gr.Blocks() as demo:
67
  gr.Markdown("# Multi-Language Voice Translator")
68
- gr.Markdown("Transcribe, translate, and generate speech in multiple languages with accessibility features.")
69
 
70
  with gr.Row():
71
  audio_input = gr.Audio(type="filepath", label="Upload Audio")
@@ -76,20 +82,21 @@ with gr.Blocks() as demo:
76
  )
77
 
78
  with gr.Row():
79
- submit_button = gr.Button("Submit")
80
  clear_button = gr.Button("Clear")
81
 
82
  with gr.Row():
83
  translated_text = gr.Textbox(label="Translated Text")
84
  generated_speech = gr.Audio(label="Generated Speech", interactive=False)
85
- braille_file = gr.File(label="Download Braille-Compatible File")
86
 
 
87
  submit_button.click(
88
  fn=process_audio,
89
  inputs=[audio_input, target_language],
90
  outputs=[translated_text, generated_speech, braille_file],
91
  )
92
-
93
  clear_button.click(
94
  fn=lambda: ("", None, None),
95
  inputs=[],
 
1
  import gradio as gr
2
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, AutoModelForSeq2SeqLM, AutoTokenizer
3
  import torchaudio
4
  import torch
5
  from datasets import load_dataset
6
  import os
7
 
8
+ # Load lightweight models
9
+ ASR_MODEL = "openai/whisper-tiny" # Faster ASR model
10
+ TRANSLATION_MODEL = "Helsinki-NLP/opus-mt-en-mul" # Lightweight translation model
11
 
12
+ # Load ASR model
13
+ from transformers import pipeline
14
+ asr = pipeline("automatic-speech-recognition", model=ASR_MODEL, device=0 if torch.cuda.is_available() else -1)
15
+
16
+ # Load translation model and tokenizer
17
+ translator_model = AutoModelForSeq2SeqLM.from_pretrained(TRANSLATION_MODEL)
18
+ translator_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL)
19
+
20
+ # Load TTS processor and model (use float16 for better speed)
21
  processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
22
+ tts = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(torch.float16)
23
 
24
+ # Cache speaker embeddings to avoid reloading every time
25
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
26
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(torch.float16)
 
27
 
28
+ # Ensure output directory exists
29
  os.makedirs("output", exist_ok=True)
30
 
31
+ # Processing function
32
  def process_audio(audio, target_language):
33
+ if not audio:
34
+ return "Error: No audio file provided.", None, None
35
+
36
  try:
 
 
 
37
  # Step 1: Transcribe the audio
38
  result = asr(audio)["text"]
39
+
40
  if not result:
41
  return "Error: Failed to transcribe audio.", None, None
42
 
43
  # Step 2: Translate the text
44
+ inputs = translator_tokenizer(result, return_tensors="pt", padding=True)
45
+ outputs = translator_model.generate(**inputs)
46
+ translated_text = translator_tokenizer.decode(outputs[0], skip_special_tokens=True)
47
+
 
48
  if not translated_text:
49
  return "Error: Translation failed.", None, None
50
 
51
  # Step 3: Generate speech from translated text
52
  inputs = processor(text=translated_text, return_tensors="pt")
53
+ input_features = inputs.input_features.to(torch.float16)
54
 
55
  with torch.no_grad():
56
  speech = tts.generate_speech(input_features, speaker_embeddings)
57
+
58
  # Save generated speech
59
  output_audio_path = "output/generated_speech.wav"
60
+ torchaudio.save(output_audio_path, speech.cpu(), 24000)
61
 
62
  # Step 4: Create Braille-compatible file
63
  braille_output_path = "output/braille.txt"
 
72
  # Define Gradio interface
73
  with gr.Blocks() as demo:
74
  gr.Markdown("# Multi-Language Voice Translator")
 
75
 
76
  with gr.Row():
77
  audio_input = gr.Audio(type="filepath", label="Upload Audio")
 
82
  )
83
 
84
  with gr.Row():
85
+ submit_button = gr.Button("Translate & Synthesize")
86
  clear_button = gr.Button("Clear")
87
 
88
  with gr.Row():
89
  translated_text = gr.Textbox(label="Translated Text")
90
  generated_speech = gr.Audio(label="Generated Speech", interactive=False)
91
+ braille_file = gr.File(label="Download Braille File")
92
 
93
+ # Link functions to buttons
94
  submit_button.click(
95
  fn=process_audio,
96
  inputs=[audio_input, target_language],
97
  outputs=[translated_text, generated_speech, braille_file],
98
  )
99
+
100
  clear_button.click(
101
  fn=lambda: ("", None, None),
102
  inputs=[],