Spaces:
Running
Running
| """Speech-to-expressive-speech agent pipeline with full W&B Weave tracing. | |
| Pipeline: | |
| 1. Audio input -> Voxtral transcription (with ElevenLabs v3 expressive tags) | |
| 2. Tagged text -> ElevenLabs v3 TTS -> expressive audio output | |
| Every step is traced with @weave.op() decorators for observability. | |
| """ | |
| import os | |
| import re | |
| import time | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Optional | |
| import httpx | |
| import torch | |
| import weave | |
| import librosa | |
| from dotenv import load_dotenv | |
| from transformers import VoxtralForConditionalGeneration, AutoProcessor | |
| from peft import PeftModel | |
| load_dotenv() | |
| MODEL_ID = "mistralai/Voxtral-Mini-3B-2507" | |
| ELEVENLABS_MODEL_ID = "eleven_v3" | |
| TRANSCRIPTION_PROMPT = "Transcribe this audio with expressive tags." | |
| VOICE_POOL = [ | |
| {"id": "CwhRBWXzGAHq8TQ4Fs17", "name": "Roger", "gender": "male", "age": "middle_aged"}, | |
| {"id": "cjVigY5qzO86Huf0OWal", "name": "Eric", "gender": "male", "age": "middle_aged"}, | |
| {"id": "EXAVITQu4vr4xnSDxMaL", "name": "Sarah", "gender": "female", "age": "young"}, | |
| {"id": "XrExE9yKIg1WjnnlVkGX", "name": "Matilda", "gender": "female", "age": "middle_aged"}, | |
| {"id": "TX3LPaxmHKxFdv7VOQHJ", "name": "Liam", "gender": "male", "age": "young"}, | |
| {"id": "cgSgspJ2msm6clMCkdW9", "name": "Jessica", "gender": "female", "age": "young"}, | |
| {"id": "JBFqnCBsd6RMkjVDRZzb", "name": "George", "gender": "male", "age": "middle_aged"}, | |
| {"id": "pFZP5JQG7iQjIQuC4Bku", "name": "Lily", "gender": "female", "age": "middle_aged"}, | |
| ] | |
| TAG_PATTERN = re.compile(r"\[([^\]]+)\]") | |
| def get_api_keys() -> list[str]: | |
| """Load all ElevenLabs API keys from environment.""" | |
| keys = [] | |
| for key_name, value in sorted(os.environ.items()): | |
| if key_name.startswith("ELEVENLABS_API_KEY"): | |
| keys.append(value) | |
| if not keys: | |
| raise ValueError("No ELEVENLABS_API_KEY* found in environment") | |
| return keys | |
| def strip_tags(tagged_text: str) -> str: | |
| """Remove all bracket tags from text, returning plain transcription.""" | |
| return TAG_PATTERN.sub("", tagged_text).strip() | |
| def extract_tags(tagged_text: str) -> list[str]: | |
| """Extract all bracket tags from tagged text.""" | |
| return TAG_PATTERN.findall(tagged_text) | |
| # --------------------------------------------------------------------------- | |
| # Initialize Weave project | |
| # --------------------------------------------------------------------------- | |
| weave.init("evoxtral") | |
| class SpeechAgent(weave.Model): | |
| """End-to-end speech-to-expressive-speech agent. | |
| Transcribes audio with Voxtral (optionally with a LoRA adapter) to produce | |
| ElevenLabs v3 tagged text, then synthesizes expressive audio via ElevenLabs TTS. | |
| """ | |
| adapter_path: Optional[str] = None | |
| default_voice_id: str = VOICE_POOL[0]["id"] | |
| # Private attributes (not serialized by weave.Model / Pydantic) | |
| _model: object = None | |
| _processor: object = None | |
| _device: object = None | |
| _api_keys: list = [] | |
| def model_post_init(self, __context): | |
| """Load model, processor, and API keys after initialization.""" | |
| self._load_model() | |
| self._api_keys = get_api_keys() | |
| def _load_model(self): | |
| """Load Voxtral model (with optional LoRA adapter).""" | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| device_map = "auto" if torch.cuda.is_available() else "cpu" | |
| self._processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| base_model = VoxtralForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=dtype, | |
| device_map=device_map, | |
| ) | |
| if self.adapter_path is not None: | |
| print(f"Loading LoRA adapter from {self.adapter_path}") | |
| self._model = PeftModel.from_pretrained(base_model, self.adapter_path) | |
| else: | |
| self._model = base_model | |
| self._model.eval() | |
| self._device = next(self._model.parameters()).device | |
| print(f"Model loaded on {self._device} (dtype={dtype})") | |
| def transcribe(self, audio_path: str) -> dict: | |
| """Transcribe audio to tagged text using Voxtral. | |
| Args: | |
| audio_path: Path to the input audio file. | |
| Returns: | |
| dict with keys: tagged_text, plain_text, tags | |
| """ | |
| # Load audio | |
| audio_array, sr = librosa.load(audio_path, sr=16000) | |
| # Build conversation for the processor | |
| conversation = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "audio", "audio": audio_array}, | |
| {"type": "text", "text": TRANSCRIPTION_PROMPT}, | |
| ], | |
| }, | |
| ] | |
| inputs = self._processor.apply_chat_template( | |
| conversation, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=2048, | |
| ) | |
| # Move inputs to device | |
| if isinstance(inputs, dict): | |
| inputs = {k: v.to(self._device) if hasattr(v, "to") else v for k, v in inputs.items()} | |
| else: | |
| inputs = inputs.to(self._device) | |
| # Generate | |
| with torch.no_grad(): | |
| if isinstance(inputs, dict): | |
| output_ids = self._model.generate(**inputs, max_new_tokens=512) | |
| else: | |
| output_ids = self._model.generate(inputs, max_new_tokens=512) | |
| # Decode — skip the input tokens to get only the generated response | |
| if isinstance(inputs, dict): | |
| input_len = inputs["input_ids"].shape[-1] | |
| else: | |
| input_len = inputs.shape[-1] | |
| generated_ids = output_ids[:, input_len:] | |
| tagged_text = self._processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip() | |
| plain_text = strip_tags(tagged_text) | |
| tags = extract_tags(tagged_text) | |
| return { | |
| "tagged_text": tagged_text, | |
| "plain_text": plain_text, | |
| "tags": tags, | |
| } | |
| def synthesize(self, tagged_text: str, voice_id: str | None = None) -> dict: | |
| """Synthesize tagged text to expressive audio via ElevenLabs v3 TTS. | |
| Args: | |
| tagged_text: Text with ElevenLabs v3 bracket tags. | |
| voice_id: ElevenLabs voice ID. Defaults to the agent's default voice. | |
| Returns: | |
| dict with keys: audio_path, duration_ms | |
| """ | |
| if voice_id is None: | |
| voice_id = self.default_voice_id | |
| api_key = self._api_keys[0] | |
| start = time.time() | |
| with httpx.Client(timeout=60.0) as client: | |
| response = client.post( | |
| f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}", | |
| headers={ | |
| "xi-api-key": api_key, | |
| "Content-Type": "application/json", | |
| }, | |
| json={ | |
| "text": tagged_text, | |
| "model_id": ELEVENLABS_MODEL_ID, | |
| "output_format": "mp3_44100_128", | |
| }, | |
| ) | |
| elapsed_ms = (time.time() - start) * 1000 | |
| if response.status_code != 200: | |
| raise RuntimeError( | |
| f"ElevenLabs TTS failed (HTTP {response.status_code}): {response.text[:200]}" | |
| ) | |
| # Write output audio to a temp file | |
| output_dir = Path(tempfile.gettempdir()) / "evoxtral_output" | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| output_path = str(output_dir / f"synth_{int(time.time() * 1000)}.mp3") | |
| with open(output_path, "wb") as f: | |
| f.write(response.content) | |
| return { | |
| "audio_path": output_path, | |
| "duration_ms": round(elapsed_ms, 2), | |
| } | |
| def predict(self, audio_path: str) -> dict: | |
| """Full pipeline: transcribe audio then synthesize expressive speech. | |
| Args: | |
| audio_path: Path to the input audio file. | |
| Returns: | |
| dict with keys: transcription (dict), synthesis (dict) | |
| """ | |
| transcription = self.transcribe(audio_path) | |
| synthesis = self.synthesize(transcription["tagged_text"]) | |
| return { | |
| "transcription": transcription, | |
| "synthesis": synthesis, | |
| } | |
| if __name__ == "__main__": | |
| import sys | |
| agent = SpeechAgent(adapter_path=sys.argv[1] if len(sys.argv) > 1 else None) | |
| result = agent.predict(sys.argv[2] if len(sys.argv) > 2 else "test.mp3") | |
| print(result) | |