Torstens_agent / app.py
eduard76's picture
Update app.py
ed2b946 verified
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
# ============================================================================
# STT Module
# ============================================================================
class STTModule:
def __init__(self):
self.model_options = {
"Whisper Tiny": "openai/whisper-tiny",
"Whisper Base": "openai/whisper-base",
"Whisper Small": "openai/whisper-small"
}
self.current_model = None
self.pipe = None
def load_model(self, model_name):
try:
model_id = self.model_options[model_name]
device = "cuda" if torch.cuda.is_available() else "cpu"
self.pipe = pipeline(
"automatic-speech-recognition",
model=model_id,
device=device
)
self.current_model = model_name
return f"βœ“ Loaded {model_name} on {device}"
except Exception as e:
return f"βœ— Error loading model: {str(e)}"
def transcribe(self, audio_path):
if self.pipe is None:
return "⚠ Please load a model first"
try:
result = self.pipe(audio_path)
return result["text"]
except Exception as e:
return f"βœ— Error transcribing: {str(e)}"
def create_interface(self):
with gr.Column() as interface:
gr.Markdown("## 🎀 Speech-to-Text Testing")
with gr.Row():
model_selector = gr.Dropdown(
choices=list(self.model_options.keys()),
value="Whisper Base",
label="Select STT Model"
)
load_btn = gr.Button("Load Model", variant="primary")
status = gr.Textbox(label="Status", interactive=False)
gr.Markdown("### Test Transcription")
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="🎀 Record or Upload Audio"
)
transcribe_btn = gr.Button("Transcribe", variant="secondary")
transcription_output = gr.Textbox(label="Transcription", lines=5)
load_btn.click(fn=self.load_model, inputs=[model_selector], outputs=[status])
transcribe_btn.click(fn=self.transcribe, inputs=[audio_input], outputs=[transcription_output])
return interface
# ============================================================================
# TTS Module
# ============================================================================
class TTSModule:
def __init__(self):
self.model_options = {
"SpeechT5": "microsoft/speecht5_tts",
"FastSpeech2": "facebook/fastspeech2-en-ljspeech"
}
self.current_model = None
self.synthesiser = None
def load_model(self, model_name):
try:
model_id = self.model_options.get(model_name, self.model_options["SpeechT5"])
device = "cuda" if torch.cuda.is_available() else "cpu"
self.synthesiser = pipeline("text-to-speech", model=model_id, device=device)
self.current_model = model_name
return f"βœ“ Loaded {model_name} on {device}"
except Exception as e:
return f"βœ— Error loading model: {str(e)}"
def synthesize(self, text):
if self.synthesiser is None:
return None, "⚠ Please load a model first"
if not text.strip():
return None, "⚠ Please enter some text"
try:
speech = self.synthesiser(text)
audio_data = speech["audio"]
sampling_rate = speech["sampling_rate"]
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
return (sampling_rate, audio_data), f"βœ“ Generated {len(audio_data)/sampling_rate:.2f}s of audio"
except Exception as e:
return None, f"βœ— Error synthesizing: {str(e)}"
def create_interface(self):
with gr.Column() as interface:
gr.Markdown("## πŸ”Š Text-to-Speech Testing")
with gr.Row():
model_selector = gr.Dropdown(
choices=list(self.model_options.keys()),
value="SpeechT5",
label="Select TTS Model"
)
load_btn = gr.Button("Load Model", variant="primary")
status = gr.Textbox(label="Status", interactive=False)
gr.Markdown("### Test Synthesis")
text_input = gr.Textbox(
label="Enter Text",
placeholder="Type something to convert to speech...",
lines=3
)
synthesize_btn = gr.Button("Generate Speech", variant="secondary")
audio_output = gr.Audio(label="Generated Audio", type="numpy")
synthesis_status = gr.Textbox(label="Synthesis Status", interactive=False)
load_btn.click(fn=self.load_model, inputs=[model_selector], outputs=[status])
synthesize_btn.click(fn=self.synthesize, inputs=[text_input], outputs=[audio_output, synthesis_status])
return interface
# ============================================================================
# LLM Module
# ============================================================================
class LLMModule:
def __init__(self):
self.model_options = {
"TinyLlama": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"Phi-2": "microsoft/phi-2",
"Qwen 0.5B": "Qwen/Qwen2.5-0.5B-Instruct"
}
self.current_model = None
self.pipe = None
self.chat_history = []
def load_model(self, model_name):
try:
model_id = self.model_options[model_name]
device = "cuda" if torch.cuda.is_available() else "cpu"
self.pipe = pipeline(
"text-generation",
model=model_id,
device=device,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
)
self.current_model = model_name
self.chat_history = []
return f"βœ“ Loaded {model_name} on {device}"
except Exception as e:
return f"βœ— Error loading model: {str(e)}"
def generate_response(self, message, max_tokens, temperature):
if self.pipe is None:
return "⚠ Please load a model first", []
if not message.strip():
return "⚠ Please enter a message", self.chat_history
try:
self.chat_history.append({"role": "user", "content": message})
response = self.pipe(
message,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
do_sample=True,
top_p=0.9
)
assistant_message = response[0]["generated_text"]
if assistant_message.startswith(message):
assistant_message = assistant_message[len(message):].strip()
self.chat_history.append({"role": "assistant", "content": assistant_message})
chat_display = [(h["content"], self.chat_history[i+1]["content"])
for i, h in enumerate(self.chat_history[::2])
if i*2+1 < len(self.chat_history)]
return "", chat_display
except Exception as e:
return f"βœ— Error generating response: {str(e)}", self.chat_history
def clear_history(self):
self.chat_history = []
return [], ""
def create_interface(self):
with gr.Column() as interface:
gr.Markdown("## πŸ€– LLM Testing")
with gr.Row():
model_selector = gr.Dropdown(
choices=list(self.model_options.keys()),
value="Qwen 0.5B",
label="Select LLM Model"
)
load_btn = gr.Button("Load Model", variant="primary")
status = gr.Textbox(label="Status", interactive=False)
gr.Markdown("### Chat Interface")
chatbot = gr.Chatbot(label="Conversation", height=400)
with gr.Row():
message_input = gr.Textbox(label="Message", placeholder="Type your message...", scale=4)
send_btn = gr.Button("Send", variant="secondary", scale=1)
with gr.Row():
max_tokens = gr.Slider(minimum=50, maximum=500, value=150, step=10, label="Max Tokens")
temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature")
clear_btn = gr.Button("Clear Chat", variant="stop")
load_btn.click(fn=self.load_model, inputs=[model_selector], outputs=[status])
send_btn.click(fn=self.generate_response, inputs=[message_input, max_tokens, temperature], outputs=[message_input, chatbot])
message_input.submit(fn=self.generate_response, inputs=[message_input, max_tokens, temperature], outputs=[message_input, chatbot])
clear_btn.click(fn=self.clear_history, outputs=[chatbot, message_input])
return interface
# ============================================================================
# Pipeline Module
# ============================================================================
class VoiceAgentPipeline:
def __init__(self):
self.stt = STTModule()
self.tts = TTSModule()
self.llm = LLMModule()
self.conversation_history = []
def load_models(self, stt_model, tts_model, llm_model):
results = []
results.append(self.stt.load_model(stt_model))
results.append(self.tts.load_model(tts_model))
results.append(self.llm.load_model(llm_model))
return "\n".join(results)
def process_voice_input(self, audio_path, max_tokens, temperature):
if not audio_path:
return None, "⚠ Please provide audio input", []
if self.stt.pipe is None or self.tts.synthesiser is None or self.llm.pipe is None:
return None, "⚠ Please load all models first", []
try:
transcription = self.stt.transcribe(audio_path)
if transcription.startswith("βœ—") or transcription.startswith("⚠"):
return None, transcription, []
self.conversation_history.append({"role": "user", "content": transcription})
response = self.llm.pipe(
transcription,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
do_sample=True,
top_p=0.9
)
assistant_message = response[0]["generated_text"]
if assistant_message.startswith(transcription):
assistant_message = assistant_message[len(transcription):].strip()
self.conversation_history.append({"role": "assistant", "content": assistant_message})
audio_output, tts_status = self.tts.synthesize(assistant_message)
chat_display = [(self.conversation_history[i]["content"],
self.conversation_history[i+1]["content"])
for i in range(0, len(self.conversation_history)-1, 2)]
status_message = f"User: {transcription}\n\nAssistant: {assistant_message}\n\n{tts_status}"
return audio_output, status_message, chat_display
except Exception as e:
return None, f"βœ— Pipeline error: {str(e)}", []
def clear_conversation(self):
self.conversation_history = []
return None, "", []
def create_interface(self):
with gr.Column() as interface:
gr.Markdown("## πŸŽ™οΈ Full Voice Agent Pipeline")
gr.Markdown("Test the complete flow: **Voice Input β†’ STT β†’ LLM β†’ TTS β†’ Voice Output**")
gr.Markdown("### 1. Load Models")
with gr.Row():
stt_selector = gr.Dropdown(choices=list(self.stt.model_options.keys()), value="Whisper Base", label="STT Model")
llm_selector = gr.Dropdown(choices=list(self.llm.model_options.keys()), value="Qwen 0.5B", label="LLM Model")
tts_selector = gr.Dropdown(choices=list(self.tts.model_options.keys()), value="SpeechT5", label="TTS Model")
load_all_btn = gr.Button("Load All Models", variant="primary", size="lg")
load_status = gr.Textbox(label="Status", interactive=False, lines=3)
gr.Markdown("### 2. Voice Conversation")
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="🎀 Speak or Upload Audio"
)
with gr.Row():
max_tokens = gr.Slider(minimum=50, maximum=300, value=100, step=10, label="Max Response Tokens")
temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature")
process_btn = gr.Button("Process Voice Input", variant="secondary", size="lg")
audio_output = gr.Audio(label="AI Response (Audio)", type="numpy")
process_status = gr.Textbox(label="Pipeline Output", interactive=False, lines=4)
gr.Markdown("### Conversation History")
conversation_display = gr.Chatbot(label="Conversation", height=300)
clear_btn = gr.Button("Clear Conversation", variant="stop")
load_all_btn.click(fn=self.load_models, inputs=[stt_selector, tts_selector, llm_selector], outputs=[load_status])
process_btn.click(fn=self.process_voice_input, inputs=[audio_input, max_tokens, temperature], outputs=[audio_output, process_status, conversation_display])
clear_btn.click(fn=self.clear_conversation, outputs=[audio_output, process_status, conversation_display])
return interface
# ============================================================================
# Main App
# ============================================================================
stt_module = STTModule()
tts_module = TTSModule()
llm_module = LLMModule()
pipeline_module = VoiceAgentPipeline()
with gr.Blocks(title="Voice Agent Modular Tester", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸŽ™οΈ Voice Agent Modular Testing Suite
Test individual components or the full voice agent pipeline. Each tab allows you to:
- **STT Tab**: Test speech-to-text models independently
- **TTS Tab**: Test text-to-speech models independently
- **LLM Tab**: Test language models independently
- **Pipeline Tab**: Test the complete voice agent flow (STT β†’ LLM β†’ TTS)
""")
with gr.Tabs():
with gr.Tab("🎀 STT Module"):
stt_module.create_interface()
with gr.Tab("πŸ”Š TTS Module"):
tts_module.create_interface()
with gr.Tab("πŸ€– LLM Module"):
llm_module.create_interface()
with gr.Tab("πŸŽ™οΈ Full Pipeline"):
pipeline_module.create_interface()
gr.Markdown("""
---
### πŸ“ Usage Tips
- **Load models first**: Click "Load Model" buttons before testing
- **Recording audio**: Click the microphone icon 🎀 to start recording, click again to stop
- **Upload audio**: Or drag & drop an audio file
- **GPU acceleration**: Models run on GPU if available, otherwise CPU
- **Pipeline mode**: Combines all modules for end-to-end voice interaction
- **Performance**: Use smaller models (Whisper Base, Qwen 0.5B) for faster performance on CPU
""")
if __name__ == "__main__":
demo.launch()