mkfallah commited on
Commit
66951ec
·
verified ·
1 Parent(s): f13240a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -11
app.py CHANGED
@@ -8,21 +8,27 @@ import soundfile as sf
8
  # --------------------------
9
  asr = pipeline(
10
  task="automatic-speech-recognition",
11
- model="openai/whisper-small", # smaller model = faster
12
- device=-1 # set to 0 for GPU
13
  )
14
 
15
  # --------------------------
16
- # 2. Language Model (LLM) - lightweight
17
  # --------------------------
18
- llm_model_id = "google/flan-t5-small"
19
  tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
20
  llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_id).to("cpu")
21
 
22
- def ask_llm(prompt, max_new_tokens=100):
23
  inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
24
  with torch.no_grad():
25
- outputs = llm_model.generate(**inputs, max_new_tokens=max_new_tokens)
 
 
 
 
 
 
26
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
27
 
28
  # --------------------------
@@ -31,8 +37,6 @@ def ask_llm(prompt, max_new_tokens=100):
31
  processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
32
  tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
33
 
34
- # fixed dummy speaker embedding (instead of dataset)
35
- # dimension must match SpeechT5 (512)
36
  speaker_embedding = torch.randn(1, 512)
37
 
38
  def text_to_speech(text, out_path="output.wav"):
@@ -56,7 +60,7 @@ def full_pipeline(audio_file):
56
  user_text = result.get("text", "")
57
 
58
  try:
59
- llm_response = ask_llm(user_text)
60
  except Exception as e:
61
  return f"Assistant generation error: {e}", None
62
 
@@ -74,8 +78,8 @@ iface = gr.Interface(
74
  fn=full_pipeline,
75
  inputs=gr.Audio(type="filepath", label="Record or upload audio"),
76
  outputs=[gr.Textbox(label="Conversation"), gr.Audio(label="TTS Response")],
77
- title="Persian Voice Assistant (Fast LLM)",
78
- description="ASR → Lightweight LLM → TTS"
79
  )
80
 
81
  if __name__ == "__main__":
 
8
  # --------------------------
9
  asr = pipeline(
10
  task="automatic-speech-recognition",
11
+ model="openai/whisper-small",
12
+ device=-1
13
  )
14
 
15
  # --------------------------
16
+ # 2. Language Model (LLM) - more reliable
17
  # --------------------------
18
+ llm_model_id = "google/flan-t5-base"
19
  tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
20
  llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_id).to("cpu")
21
 
22
+ def ask_llm(prompt, max_new_tokens=200):
23
  inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
24
  with torch.no_grad():
25
+ outputs = llm_model.generate(
26
+ **inputs,
27
+ max_new_tokens=max_new_tokens,
28
+ do_sample=True,
29
+ top_k=50,
30
+ top_p=0.95
31
+ )
32
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
33
 
34
  # --------------------------
 
37
  processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
38
  tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
39
 
 
 
40
  speaker_embedding = torch.randn(1, 512)
41
 
42
  def text_to_speech(text, out_path="output.wav"):
 
60
  user_text = result.get("text", "")
61
 
62
  try:
63
+ llm_response = ask_llm(f"پاسخ بده به زبان ساده: {user_text}")
64
  except Exception as e:
65
  return f"Assistant generation error: {e}", None
66
 
 
78
  fn=full_pipeline,
79
  inputs=gr.Audio(type="filepath", label="Record or upload audio"),
80
  outputs=[gr.Textbox(label="Conversation"), gr.Audio(label="TTS Response")],
81
+ title="Persian Voice Assistant (Reliable LLM)",
82
+ description="ASR → Flan-T5-Base → TTS"
83
  )
84
 
85
  if __name__ == "__main__":