mkfallah commited on
Commit
9f08613
·
verified ·
1 Parent(s): ce98ad4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -16
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, SpeechT5Processor, SpeechT5ForTextToSpeech
3
  import torch
4
  import soundfile as sf
5
 
@@ -8,21 +8,18 @@ import soundfile as sf
8
  # --------------------------
9
  asr = pipeline(
10
  task="automatic-speech-recognition",
11
- model="vhdm/whisper-large-fa-v1",
12
- device=-1
13
  )
14
 
15
  # --------------------------
16
- # 2. Language Model (LLM)
17
  # --------------------------
18
- llm_model_id = "tiiuae/falcon-rw-1b"
19
  tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
20
- llm_model = AutoModelForCausalLM.from_pretrained(
21
- llm_model_id,
22
- torch_dtype=torch.float32
23
- ).to("cpu")
24
 
25
- def ask_llm(prompt, max_new_tokens=200):
26
  inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
27
  with torch.no_grad():
28
  outputs = llm_model.generate(**inputs, max_new_tokens=max_new_tokens)
@@ -31,16 +28,20 @@ def ask_llm(prompt, max_new_tokens=200):
31
  # --------------------------
32
  # 3. TTS (text-to-speech) using SpeechT5
33
  # --------------------------
 
34
  processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
35
  tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
36
 
37
- # Random speaker embedding (can be replaced with a fixed one for consistency)
38
- speaker_embedding = torch.randn(1, 512)
 
 
 
 
39
 
40
  def text_to_speech(text, out_path="output.wav"):
41
  inputs = processor(text=text, return_tensors="pt")
42
- with torch.no_grad():
43
- speech = tts_model.generate_speech(inputs["input_ids"], speaker_embedding)
44
  sf.write(out_path, speech.numpy(), 16000)
45
  return out_path
46
 
@@ -77,8 +78,8 @@ iface = gr.Interface(
77
  fn=full_pipeline,
78
  inputs=gr.Audio(type="filepath", label="Record or upload audio"),
79
  outputs=[gr.Textbox(label="Conversation"), gr.Audio(label="TTS Response")],
80
- title="Persian Voice Assistant",
81
- description="ASR → LLM → TTS"
82
  )
83
 
84
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, SpeechT5Processor, SpeechT5ForTextToSpeech
3
  import torch
4
  import soundfile as sf
5
 
 
8
  # --------------------------
9
  asr = pipeline(
10
  task="automatic-speech-recognition",
11
+ model="openai/whisper-small", # smaller model = faster
12
+ device=-1 # set to 0 for GPU
13
  )
14
 
15
  # --------------------------
16
+ # 2. Language Model (LLM) - lightweight
17
  # --------------------------
18
+ llm_model_id = "google/flan-t5-small"
19
  tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
20
+ llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_id).to("cpu")
 
 
 
21
 
22
+ def ask_llm(prompt, max_new_tokens=100):
23
  inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
24
  with torch.no_grad():
25
  outputs = llm_model.generate(**inputs, max_new_tokens=max_new_tokens)
 
28
  # --------------------------
29
  # 3. TTS (text-to-speech) using SpeechT5
30
  # --------------------------
31
+ from datasets import load_dataset
32
  processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
33
  tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
34
 
35
+ # use a fixed speaker embedding (pre-extracted)
36
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", streaming=True)
37
+ for i, example in enumerate(embeddings_dataset):
38
+ if i == 0: # just take the first speaker embedding
39
+ speaker_embedding = torch.tensor(example["xvector"]).unsqueeze(0)
40
+ break
41
 
42
  def text_to_speech(text, out_path="output.wav"):
43
  inputs = processor(text=text, return_tensors="pt")
44
+ speech = tts_model.generate_speech(inputs["input_ids"], speaker_embedding)
 
45
  sf.write(out_path, speech.numpy(), 16000)
46
  return out_path
47
 
 
78
  fn=full_pipeline,
79
  inputs=gr.Audio(type="filepath", label="Record or upload audio"),
80
  outputs=[gr.Textbox(label="Conversation"), gr.Audio(label="TTS Response")],
81
+ title="Persian Voice Assistant (Fast LLM)",
82
+ description="ASR → Lightweight LLM → TTS"
83
  )
84
 
85
  if __name__ == "__main__":