Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| import numpy as np | |
| # ============================================================================ | |
| # STT Module | |
| # ============================================================================ | |
| class STTModule: | |
| def __init__(self): | |
| self.model_options = { | |
| "Whisper Tiny": "openai/whisper-tiny", | |
| "Whisper Base": "openai/whisper-base", | |
| "Whisper Small": "openai/whisper-small" | |
| } | |
| self.current_model = None | |
| self.pipe = None | |
| def load_model(self, model_name): | |
| try: | |
| model_id = self.model_options[model_name] | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=model_id, | |
| device=device | |
| ) | |
| self.current_model = model_name | |
| return f"β Loaded {model_name} on {device}" | |
| except Exception as e: | |
| return f"β Error loading model: {str(e)}" | |
| def transcribe(self, audio_path): | |
| if self.pipe is None: | |
| return "β Please load a model first" | |
| try: | |
| result = self.pipe(audio_path) | |
| return result["text"] | |
| except Exception as e: | |
| return f"β Error transcribing: {str(e)}" | |
| def create_interface(self): | |
| with gr.Column() as interface: | |
| gr.Markdown("## π€ Speech-to-Text Testing") | |
| with gr.Row(): | |
| model_selector = gr.Dropdown( | |
| choices=list(self.model_options.keys()), | |
| value="Whisper Base", | |
| label="Select STT Model" | |
| ) | |
| load_btn = gr.Button("Load Model", variant="primary") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| gr.Markdown("### Test Transcription") | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| label="π€ Record or Upload Audio" | |
| ) | |
| transcribe_btn = gr.Button("Transcribe", variant="secondary") | |
| transcription_output = gr.Textbox(label="Transcription", lines=5) | |
| load_btn.click(fn=self.load_model, inputs=[model_selector], outputs=[status]) | |
| transcribe_btn.click(fn=self.transcribe, inputs=[audio_input], outputs=[transcription_output]) | |
| return interface | |
| # ============================================================================ | |
| # TTS Module | |
| # ============================================================================ | |
| class TTSModule: | |
| def __init__(self): | |
| self.model_options = { | |
| "SpeechT5": "microsoft/speecht5_tts", | |
| "FastSpeech2": "facebook/fastspeech2-en-ljspeech" | |
| } | |
| self.current_model = None | |
| self.synthesiser = None | |
| def load_model(self, model_name): | |
| try: | |
| model_id = self.model_options.get(model_name, self.model_options["SpeechT5"]) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.synthesiser = pipeline("text-to-speech", model=model_id, device=device) | |
| self.current_model = model_name | |
| return f"β Loaded {model_name} on {device}" | |
| except Exception as e: | |
| return f"β Error loading model: {str(e)}" | |
| def synthesize(self, text): | |
| if self.synthesiser is None: | |
| return None, "β Please load a model first" | |
| if not text.strip(): | |
| return None, "β Please enter some text" | |
| try: | |
| speech = self.synthesiser(text) | |
| audio_data = speech["audio"] | |
| sampling_rate = speech["sampling_rate"] | |
| if audio_data.dtype != np.float32: | |
| audio_data = audio_data.astype(np.float32) | |
| return (sampling_rate, audio_data), f"β Generated {len(audio_data)/sampling_rate:.2f}s of audio" | |
| except Exception as e: | |
| return None, f"β Error synthesizing: {str(e)}" | |
| def create_interface(self): | |
| with gr.Column() as interface: | |
| gr.Markdown("## π Text-to-Speech Testing") | |
| with gr.Row(): | |
| model_selector = gr.Dropdown( | |
| choices=list(self.model_options.keys()), | |
| value="SpeechT5", | |
| label="Select TTS Model" | |
| ) | |
| load_btn = gr.Button("Load Model", variant="primary") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| gr.Markdown("### Test Synthesis") | |
| text_input = gr.Textbox( | |
| label="Enter Text", | |
| placeholder="Type something to convert to speech...", | |
| lines=3 | |
| ) | |
| synthesize_btn = gr.Button("Generate Speech", variant="secondary") | |
| audio_output = gr.Audio(label="Generated Audio", type="numpy") | |
| synthesis_status = gr.Textbox(label="Synthesis Status", interactive=False) | |
| load_btn.click(fn=self.load_model, inputs=[model_selector], outputs=[status]) | |
| synthesize_btn.click(fn=self.synthesize, inputs=[text_input], outputs=[audio_output, synthesis_status]) | |
| return interface | |
| # ============================================================================ | |
| # LLM Module | |
| # ============================================================================ | |
| class LLMModule: | |
| def __init__(self): | |
| self.model_options = { | |
| "TinyLlama": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| "Phi-2": "microsoft/phi-2", | |
| "Qwen 0.5B": "Qwen/Qwen2.5-0.5B-Instruct" | |
| } | |
| self.current_model = None | |
| self.pipe = None | |
| self.chat_history = [] | |
| def load_model(self, model_name): | |
| try: | |
| model_id = self.model_options[model_name] | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.pipe = pipeline( | |
| "text-generation", | |
| model=model_id, | |
| device=device, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
| ) | |
| self.current_model = model_name | |
| self.chat_history = [] | |
| return f"β Loaded {model_name} on {device}" | |
| except Exception as e: | |
| return f"β Error loading model: {str(e)}" | |
| def generate_response(self, message, max_tokens, temperature): | |
| if self.pipe is None: | |
| return "β Please load a model first", [] | |
| if not message.strip(): | |
| return "β Please enter a message", self.chat_history | |
| try: | |
| self.chat_history.append({"role": "user", "content": message}) | |
| response = self.pipe( | |
| message, | |
| max_new_tokens=int(max_tokens), | |
| temperature=float(temperature), | |
| do_sample=True, | |
| top_p=0.9 | |
| ) | |
| assistant_message = response[0]["generated_text"] | |
| if assistant_message.startswith(message): | |
| assistant_message = assistant_message[len(message):].strip() | |
| self.chat_history.append({"role": "assistant", "content": assistant_message}) | |
| chat_display = [(h["content"], self.chat_history[i+1]["content"]) | |
| for i, h in enumerate(self.chat_history[::2]) | |
| if i*2+1 < len(self.chat_history)] | |
| return "", chat_display | |
| except Exception as e: | |
| return f"β Error generating response: {str(e)}", self.chat_history | |
| def clear_history(self): | |
| self.chat_history = [] | |
| return [], "" | |
| def create_interface(self): | |
| with gr.Column() as interface: | |
| gr.Markdown("## π€ LLM Testing") | |
| with gr.Row(): | |
| model_selector = gr.Dropdown( | |
| choices=list(self.model_options.keys()), | |
| value="Qwen 0.5B", | |
| label="Select LLM Model" | |
| ) | |
| load_btn = gr.Button("Load Model", variant="primary") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| gr.Markdown("### Chat Interface") | |
| chatbot = gr.Chatbot(label="Conversation", height=400) | |
| with gr.Row(): | |
| message_input = gr.Textbox(label="Message", placeholder="Type your message...", scale=4) | |
| send_btn = gr.Button("Send", variant="secondary", scale=1) | |
| with gr.Row(): | |
| max_tokens = gr.Slider(minimum=50, maximum=500, value=150, step=10, label="Max Tokens") | |
| temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature") | |
| clear_btn = gr.Button("Clear Chat", variant="stop") | |
| load_btn.click(fn=self.load_model, inputs=[model_selector], outputs=[status]) | |
| send_btn.click(fn=self.generate_response, inputs=[message_input, max_tokens, temperature], outputs=[message_input, chatbot]) | |
| message_input.submit(fn=self.generate_response, inputs=[message_input, max_tokens, temperature], outputs=[message_input, chatbot]) | |
| clear_btn.click(fn=self.clear_history, outputs=[chatbot, message_input]) | |
| return interface | |
| # ============================================================================ | |
| # Pipeline Module | |
| # ============================================================================ | |
| class VoiceAgentPipeline: | |
| def __init__(self): | |
| self.stt = STTModule() | |
| self.tts = TTSModule() | |
| self.llm = LLMModule() | |
| self.conversation_history = [] | |
| def load_models(self, stt_model, tts_model, llm_model): | |
| results = [] | |
| results.append(self.stt.load_model(stt_model)) | |
| results.append(self.tts.load_model(tts_model)) | |
| results.append(self.llm.load_model(llm_model)) | |
| return "\n".join(results) | |
| def process_voice_input(self, audio_path, max_tokens, temperature): | |
| if not audio_path: | |
| return None, "β Please provide audio input", [] | |
| if self.stt.pipe is None or self.tts.synthesiser is None or self.llm.pipe is None: | |
| return None, "β Please load all models first", [] | |
| try: | |
| transcription = self.stt.transcribe(audio_path) | |
| if transcription.startswith("β") or transcription.startswith("β "): | |
| return None, transcription, [] | |
| self.conversation_history.append({"role": "user", "content": transcription}) | |
| response = self.llm.pipe( | |
| transcription, | |
| max_new_tokens=int(max_tokens), | |
| temperature=float(temperature), | |
| do_sample=True, | |
| top_p=0.9 | |
| ) | |
| assistant_message = response[0]["generated_text"] | |
| if assistant_message.startswith(transcription): | |
| assistant_message = assistant_message[len(transcription):].strip() | |
| self.conversation_history.append({"role": "assistant", "content": assistant_message}) | |
| audio_output, tts_status = self.tts.synthesize(assistant_message) | |
| chat_display = [(self.conversation_history[i]["content"], | |
| self.conversation_history[i+1]["content"]) | |
| for i in range(0, len(self.conversation_history)-1, 2)] | |
| status_message = f"User: {transcription}\n\nAssistant: {assistant_message}\n\n{tts_status}" | |
| return audio_output, status_message, chat_display | |
| except Exception as e: | |
| return None, f"β Pipeline error: {str(e)}", [] | |
| def clear_conversation(self): | |
| self.conversation_history = [] | |
| return None, "", [] | |
| def create_interface(self): | |
| with gr.Column() as interface: | |
| gr.Markdown("## ποΈ Full Voice Agent Pipeline") | |
| gr.Markdown("Test the complete flow: **Voice Input β STT β LLM β TTS β Voice Output**") | |
| gr.Markdown("### 1. Load Models") | |
| with gr.Row(): | |
| stt_selector = gr.Dropdown(choices=list(self.stt.model_options.keys()), value="Whisper Base", label="STT Model") | |
| llm_selector = gr.Dropdown(choices=list(self.llm.model_options.keys()), value="Qwen 0.5B", label="LLM Model") | |
| tts_selector = gr.Dropdown(choices=list(self.tts.model_options.keys()), value="SpeechT5", label="TTS Model") | |
| load_all_btn = gr.Button("Load All Models", variant="primary", size="lg") | |
| load_status = gr.Textbox(label="Status", interactive=False, lines=3) | |
| gr.Markdown("### 2. Voice Conversation") | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| label="π€ Speak or Upload Audio" | |
| ) | |
| with gr.Row(): | |
| max_tokens = gr.Slider(minimum=50, maximum=300, value=100, step=10, label="Max Response Tokens") | |
| temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature") | |
| process_btn = gr.Button("Process Voice Input", variant="secondary", size="lg") | |
| audio_output = gr.Audio(label="AI Response (Audio)", type="numpy") | |
| process_status = gr.Textbox(label="Pipeline Output", interactive=False, lines=4) | |
| gr.Markdown("### Conversation History") | |
| conversation_display = gr.Chatbot(label="Conversation", height=300) | |
| clear_btn = gr.Button("Clear Conversation", variant="stop") | |
| load_all_btn.click(fn=self.load_models, inputs=[stt_selector, tts_selector, llm_selector], outputs=[load_status]) | |
| process_btn.click(fn=self.process_voice_input, inputs=[audio_input, max_tokens, temperature], outputs=[audio_output, process_status, conversation_display]) | |
| clear_btn.click(fn=self.clear_conversation, outputs=[audio_output, process_status, conversation_display]) | |
| return interface | |
| # ============================================================================ | |
| # Main App | |
| # ============================================================================ | |
| stt_module = STTModule() | |
| tts_module = TTSModule() | |
| llm_module = LLMModule() | |
| pipeline_module = VoiceAgentPipeline() | |
| with gr.Blocks(title="Voice Agent Modular Tester", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # ποΈ Voice Agent Modular Testing Suite | |
| Test individual components or the full voice agent pipeline. Each tab allows you to: | |
| - **STT Tab**: Test speech-to-text models independently | |
| - **TTS Tab**: Test text-to-speech models independently | |
| - **LLM Tab**: Test language models independently | |
| - **Pipeline Tab**: Test the complete voice agent flow (STT β LLM β TTS) | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("π€ STT Module"): | |
| stt_module.create_interface() | |
| with gr.Tab("π TTS Module"): | |
| tts_module.create_interface() | |
| with gr.Tab("π€ LLM Module"): | |
| llm_module.create_interface() | |
| with gr.Tab("ποΈ Full Pipeline"): | |
| pipeline_module.create_interface() | |
| gr.Markdown(""" | |
| --- | |
| ### π Usage Tips | |
| - **Load models first**: Click "Load Model" buttons before testing | |
| - **Recording audio**: Click the microphone icon π€ to start recording, click again to stop | |
| - **Upload audio**: Or drag & drop an audio file | |
| - **GPU acceleration**: Models run on GPU if available, otherwise CPU | |
| - **Pipeline mode**: Combines all modules for end-to-end voice interaction | |
| - **Performance**: Use smaller models (Whisper Base, Qwen 0.5B) for faster performance on CPU | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |