pradeepsengarr commited on
Commit
80ffd7b
·
verified ·
1 Parent(s): dfaf587

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -12
app.py CHANGED
@@ -5,17 +5,17 @@ from gtts import gTTS
5
  from pydub import AudioSegment
6
  import tempfile
7
  import os
8
- from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
9
 
10
  # Load Whisper model
11
  whisper_model = whisper.load_model("base")
12
 
13
- # Load mBART for multilingual response
14
  model_name = "facebook/mbart-large-50-many-to-many-mmt"
15
- tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
16
  model = MBartForConditionalGeneration.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
- # Default target language (can be dynamic)
19
  TARGET_LANG = "hi_IN" # Hindi
20
 
21
  def respond(prompt_text, audio_file):
@@ -32,13 +32,13 @@ def respond(prompt_text, audio_file):
32
  else:
33
  return "No prompt provided", "", None
34
 
35
- # Tokenize and generate
36
  tokenizer.src_lang = "en_XX"
37
  encoded = tokenizer(final_prompt, return_tensors="pt").to(model.device)
38
  generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.lang_code_to_id[TARGET_LANG], max_new_tokens=100)
39
  translated = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
40
 
41
- # Text to speech
42
  tts = gTTS(translated, lang='hi')
43
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
44
  tts.save(fp.name)
@@ -53,12 +53,6 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Chat with Vidhya") as demo:
53
  gr.Markdown("""
54
  # 🧠 Chat with Vidhya
55
  **An AI assistant that listens to your voice or reads your text, and responds in your language.**
56
-
57
- 💡 Try prompts about:
58
- - Technology
59
- - Bikes
60
- - Money
61
- - Games
62
  """)
63
 
64
  with gr.Row():
 
5
  from pydub import AudioSegment
6
  import tempfile
7
  import os
8
+ from transformers import MBartForConditionalGeneration, MBart50Tokenizer
9
 
10
  # Load Whisper model
11
  whisper_model = whisper.load_model("base")
12
 
13
+ # Load mBART
14
  model_name = "facebook/mbart-large-50-many-to-many-mmt"
15
+ tokenizer = MBart50Tokenizer.from_pretrained(model_name)
16
  model = MBartForConditionalGeneration.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
+ # Target language
19
  TARGET_LANG = "hi_IN" # Hindi
20
 
21
  def respond(prompt_text, audio_file):
 
32
  else:
33
  return "No prompt provided", "", None
34
 
35
+ # Generate response
36
  tokenizer.src_lang = "en_XX"
37
  encoded = tokenizer(final_prompt, return_tensors="pt").to(model.device)
38
  generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.lang_code_to_id[TARGET_LANG], max_new_tokens=100)
39
  translated = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
40
 
41
+ # TTS
42
  tts = gTTS(translated, lang='hi')
43
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
44
  tts.save(fp.name)
 
53
  gr.Markdown("""
54
  # 🧠 Chat with Vidhya
55
  **An AI assistant that listens to your voice or reads your text, and responds in your language.**
 
 
 
 
 
 
56
  """)
57
 
58
  with gr.Row():