File size: 2,769 Bytes
a9467a5 9f08613 0ec3d43 ce98ad4 a9467a5 ce98ad4 a9467a5 ce98ad4 66951ec a9467a5 ce98ad4 66951ec ce98ad4 66951ec a9467a5 9f08613 a9467a5 66951ec a9467a5 66951ec a9467a5 ce98ad4 0ec3d43 f13240a 0ec3d43 ce98ad4 9f08613 ce98ad4 a9467a5 ce98ad4 a9467a5 ce98ad4 a9467a5 ce98ad4 a9467a5 0ec3d43 66951ec 0ec3d43 ce98ad4 0ec3d43 ce98ad4 0ec3d43 ce98ad4 a9467a5 ce98ad4 a9467a5 ce98ad4 a9467a5 66951ec a9467a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, SpeechT5Processor, SpeechT5ForTextToSpeech
import torch
import soundfile as sf
# --------------------------
# 1. ASR (speech to text)
# --------------------------
asr = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-small",
device=-1
)
# --------------------------
# 2. Language Model (LLM) - more reliable
# --------------------------
llm_model_id = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_id).to("cpu")
def ask_llm(prompt, max_new_tokens=200):
inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
with torch.no_grad():
outputs = llm_model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
top_k=50,
top_p=0.95
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# --------------------------
# 3. TTS (text-to-speech) using SpeechT5
# --------------------------
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
speaker_embedding = torch.randn(1, 512)
def text_to_speech(text, out_path="output.wav"):
inputs = processor(text=text, return_tensors="pt")
speech = tts_model.generate_speech(inputs["input_ids"], speaker_embedding)
sf.write(out_path, speech.numpy(), 16000)
return out_path
# --------------------------
# 4. Full pipeline function
# --------------------------
def full_pipeline(audio_file):
if not audio_file:
return "No audio input detected.", None
try:
result = asr(audio_file, chunk_length_s=30, stride_length_s=[5, 5])
except Exception as e:
return f"ASR error: {e}", None
user_text = result.get("text", "")
try:
llm_response = ask_llm(f"پاسخ بده به زبان ساده: {user_text}")
except Exception as e:
return f"Assistant generation error: {e}", None
try:
audio_path = text_to_speech(llm_response, "response.wav")
except Exception as e:
return f"TTS error: {e}", None
return f"User said: {user_text}\nAssistant: {llm_response}", audio_path
# --------------------------
# 5. Gradio Interface
# --------------------------
iface = gr.Interface(
fn=full_pipeline,
inputs=gr.Audio(type="filepath", label="Record or upload audio"),
outputs=[gr.Textbox(label="Conversation"), gr.Audio(label="TTS Response")],
title="Persian Voice Assistant (Reliable LLM)",
description="ASR → Flan-T5-Base → TTS"
)
if __name__ == "__main__":
iface.launch()
|