File size: 2,769 Bytes
a9467a5
9f08613
0ec3d43
ce98ad4
a9467a5
ce98ad4
 
 
a9467a5
ce98ad4
66951ec
 
a9467a5
 
ce98ad4
66951ec
ce98ad4
66951ec
a9467a5
9f08613
a9467a5
66951ec
a9467a5
 
66951ec
 
 
 
 
 
 
a9467a5
 
ce98ad4
 
 
 
 
0ec3d43
f13240a
0ec3d43
ce98ad4
 
9f08613
ce98ad4
a9467a5
 
ce98ad4
 
 
 
a9467a5
 
 
 
ce98ad4
a9467a5
 
 
ce98ad4
a9467a5
0ec3d43
66951ec
0ec3d43
ce98ad4
0ec3d43
 
ce98ad4
0ec3d43
ce98ad4
a9467a5
ce98ad4
a9467a5
ce98ad4
 
 
a9467a5
 
 
 
66951ec
 
a9467a5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, SpeechT5Processor, SpeechT5ForTextToSpeech
import torch
import soundfile as sf

# --------------------------
# 1. ASR (speech to text)
# --------------------------
asr = pipeline(
    task="automatic-speech-recognition",
    model="openai/whisper-small",
    device=-1
)

# --------------------------
# 2. Language Model (LLM) - more reliable
# --------------------------
llm_model_id = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_id).to("cpu")

def ask_llm(prompt, max_new_tokens=200):
    inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
    with torch.no_grad():
        outputs = llm_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_k=50,
            top_p=0.95
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# --------------------------
# 3. TTS (text-to-speech) using SpeechT5
# --------------------------
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")

speaker_embedding = torch.randn(1, 512)

def text_to_speech(text, out_path="output.wav"):
    inputs = processor(text=text, return_tensors="pt")
    speech = tts_model.generate_speech(inputs["input_ids"], speaker_embedding)
    sf.write(out_path, speech.numpy(), 16000)
    return out_path

# --------------------------
# 4. Full pipeline function
# --------------------------
def full_pipeline(audio_file):
    if not audio_file:
        return "No audio input detected.", None

    try:
        result = asr(audio_file, chunk_length_s=30, stride_length_s=[5, 5])
    except Exception as e:
        return f"ASR error: {e}", None

    user_text = result.get("text", "")

    try:
        llm_response = ask_llm(f"پاسخ بده به زبان ساده: {user_text}")
    except Exception as e:
        return f"Assistant generation error: {e}", None

    try:
        audio_path = text_to_speech(llm_response, "response.wav")
    except Exception as e:
        return f"TTS error: {e}", None

    return f"User said: {user_text}\nAssistant: {llm_response}", audio_path

# --------------------------
# 5. Gradio Interface
# --------------------------
iface = gr.Interface(
    fn=full_pipeline,
    inputs=gr.Audio(type="filepath", label="Record or upload audio"),
    outputs=[gr.Textbox(label="Conversation"), gr.Audio(label="TTS Response")],
    title="Persian Voice Assistant (Reliable LLM)",
    description="ASR → Flan-T5-Base → TTS"
)

if __name__ == "__main__":
    iface.launch()