Update app.py
Browse files
app.py
CHANGED
|
@@ -8,21 +8,27 @@ import soundfile as sf
|
|
| 8 |
# --------------------------
|
| 9 |
asr = pipeline(
|
| 10 |
task="automatic-speech-recognition",
|
| 11 |
-
model="openai/whisper-small",
|
| 12 |
-
device=-1
|
| 13 |
)
|
| 14 |
|
| 15 |
# --------------------------
|
| 16 |
-
# 2. Language Model (LLM) -
|
| 17 |
# --------------------------
|
| 18 |
-
llm_model_id = "google/flan-t5-
|
| 19 |
tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
|
| 20 |
llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_id).to("cpu")
|
| 21 |
|
| 22 |
-
def ask_llm(prompt, max_new_tokens=
|
| 23 |
inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
|
| 24 |
with torch.no_grad():
|
| 25 |
-
outputs = llm_model.generate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 27 |
|
| 28 |
# --------------------------
|
|
@@ -31,8 +37,6 @@ def ask_llm(prompt, max_new_tokens=100):
|
|
| 31 |
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
|
| 32 |
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
|
| 33 |
|
| 34 |
-
# fixed dummy speaker embedding (instead of dataset)
|
| 35 |
-
# dimension must match SpeechT5 (512)
|
| 36 |
speaker_embedding = torch.randn(1, 512)
|
| 37 |
|
| 38 |
def text_to_speech(text, out_path="output.wav"):
|
|
@@ -56,7 +60,7 @@ def full_pipeline(audio_file):
|
|
| 56 |
user_text = result.get("text", "")
|
| 57 |
|
| 58 |
try:
|
| 59 |
-
llm_response = ask_llm(user_text)
|
| 60 |
except Exception as e:
|
| 61 |
return f"Assistant generation error: {e}", None
|
| 62 |
|
|
@@ -74,8 +78,8 @@ iface = gr.Interface(
|
|
| 74 |
fn=full_pipeline,
|
| 75 |
inputs=gr.Audio(type="filepath", label="Record or upload audio"),
|
| 76 |
outputs=[gr.Textbox(label="Conversation"), gr.Audio(label="TTS Response")],
|
| 77 |
-
title="Persian Voice Assistant (
|
| 78 |
-
description="ASR →
|
| 79 |
)
|
| 80 |
|
| 81 |
if __name__ == "__main__":
|
|
|
|
| 8 |
# --------------------------
|
| 9 |
asr = pipeline(
|
| 10 |
task="automatic-speech-recognition",
|
| 11 |
+
model="openai/whisper-small",
|
| 12 |
+
device=-1
|
| 13 |
)
|
| 14 |
|
| 15 |
# --------------------------
|
| 16 |
+
# 2. Language Model (LLM) - more reliable
|
| 17 |
# --------------------------
|
| 18 |
+
llm_model_id = "google/flan-t5-base"
|
| 19 |
tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
|
| 20 |
llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_id).to("cpu")
|
| 21 |
|
| 22 |
+
def ask_llm(prompt, max_new_tokens=200):
|
| 23 |
inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
|
| 24 |
with torch.no_grad():
|
| 25 |
+
outputs = llm_model.generate(
|
| 26 |
+
**inputs,
|
| 27 |
+
max_new_tokens=max_new_tokens,
|
| 28 |
+
do_sample=True,
|
| 29 |
+
top_k=50,
|
| 30 |
+
top_p=0.95
|
| 31 |
+
)
|
| 32 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 33 |
|
| 34 |
# --------------------------
|
|
|
|
| 37 |
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
|
| 38 |
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
|
| 39 |
|
|
|
|
|
|
|
| 40 |
speaker_embedding = torch.randn(1, 512)
|
| 41 |
|
| 42 |
def text_to_speech(text, out_path="output.wav"):
|
|
|
|
| 60 |
user_text = result.get("text", "")
|
| 61 |
|
| 62 |
try:
|
| 63 |
+
llm_response = ask_llm(f"پاسخ بده به زبان ساده: {user_text}")
|
| 64 |
except Exception as e:
|
| 65 |
return f"Assistant generation error: {e}", None
|
| 66 |
|
|
|
|
| 78 |
fn=full_pipeline,
|
| 79 |
inputs=gr.Audio(type="filepath", label="Record or upload audio"),
|
| 80 |
outputs=[gr.Textbox(label="Conversation"), gr.Audio(label="TTS Response")],
|
| 81 |
+
title="Persian Voice Assistant (Reliable LLM)",
|
| 82 |
+
description="ASR → Flan-T5-Base → TTS"
|
| 83 |
)
|
| 84 |
|
| 85 |
if __name__ == "__main__":
|