Spaces:
Running
Running
| """Evoxtral-Bench: 3-layer evaluation with W&B Weave tracing. | |
| Layer 1: Text quality (WER on plain text) | |
| Layer 2: Tag classification (F1 on tag extraction + emphasis) | |
| Layer 3: Round-trip TTS quality (optional) | |
| All evaluations are traced via W&B Weave for full reproducibility. | |
| """ | |
| import asyncio | |
| import json | |
| import torch | |
| import wandb | |
| import weave | |
| from pathlib import Path | |
| from datasets import load_from_disk, Audio | |
| from jiwer import wer as compute_wer | |
| from dotenv import load_dotenv | |
| from .tag_metrics import tag_f1, emphasis_f1, tag_hallucination_rate, strip_tags | |
| load_dotenv() | |
| # ββ Weave Scorer Functions ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def wer_scorer(output: dict, expected: str) -> dict: | |
| """Layer 1: Word Error Rate on plain text (tags stripped).""" | |
| pred_plain = strip_tags(output["prediction"]) | |
| ref_plain = strip_tags(expected) | |
| score = compute_wer(ref_plain, pred_plain) if ref_plain.strip() else 0.0 | |
| return {"wer": score} | |
| def tag_f1_scorer(output: dict, expected: str) -> dict: | |
| """Layer 2a: Tag extraction F1.""" | |
| return tag_f1(output["prediction"], expected) | |
| def emphasis_f1_scorer(output: dict, expected: str) -> dict: | |
| """Layer 2b: Emphasis (CAPS) F1.""" | |
| return emphasis_f1(output["prediction"], expected) | |
| def hallucination_scorer(output: dict, expected: str) -> dict: | |
| """Layer 2c: Tag hallucination rate.""" | |
| rate = tag_hallucination_rate(output["prediction"], expected) | |
| return {"tag_hallucination_rate": rate} | |
| # ββ Model Wrapper βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class EvoxtralModel(weave.Model): | |
| """Wraps Voxtral inference for Weave evaluation.""" | |
| model_id: str = "mistralai/Voxtral-Mini-3B-2507" | |
| adapter_path: str | None = None | |
| _model: object = None | |
| _processor: object = None | |
| class Config: | |
| arbitrary_types_allowed = True | |
| def _load(self): | |
| if self._model is not None: | |
| return | |
| from transformers import VoxtralForConditionalGeneration, AutoProcessor | |
| self._processor = AutoProcessor.from_pretrained(self.model_id) | |
| if self.adapter_path: | |
| from peft import PeftModel | |
| base = VoxtralForConditionalGeneration.from_pretrained( | |
| self.model_id, torch_dtype=torch.bfloat16, device_map="auto", | |
| ) | |
| self._model = PeftModel.from_pretrained(base, self.adapter_path) | |
| else: | |
| self._model = VoxtralForConditionalGeneration.from_pretrained( | |
| self.model_id, torch_dtype=torch.bfloat16, device_map="auto", | |
| ) | |
| self._model.eval() | |
| def predict(self, audio_path: str) -> dict: | |
| """Run inference on a single audio file.""" | |
| self._load() | |
| import librosa | |
| audio_array, sr = librosa.load(audio_path, sr=16000) | |
| conversation = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "audio", "audio": audio_array}, | |
| {"type": "text", "text": "Transcribe this audio with expressive tags."}, | |
| ], | |
| }, | |
| ] | |
| inputs = self._processor.apply_chat_template( | |
| [conversation], | |
| return_tensors="pt", | |
| ) | |
| inputs = {k: v.to(self._model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output_ids = self._model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=False, | |
| ) | |
| # Decode only the generated tokens (skip input) | |
| input_len = inputs["input_ids"].shape[1] | |
| prediction = self._processor.tokenizer.decode( | |
| output_ids[0][input_len:], skip_special_tokens=True, | |
| ) | |
| return {"prediction": prediction} | |
| # ββ Evaluation Runner βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_eval_dataset( | |
| dataset_path: str = "data/processed", | |
| split: str = "test", | |
| manifest_path: str = "data/audio/manifest.json", | |
| ) -> list[dict]: | |
| """Build evaluation examples from the processed dataset.""" | |
| ds = load_from_disk(dataset_path) | |
| test_ds = ds[split] | |
| with open(manifest_path) as f: | |
| manifest = json.load(f) | |
| # Build index from tagged_text to audio_path | |
| audio_lookup = {m["tagged_text"]: m["audio_path"] for m in manifest if m.get("synthesis_success")} | |
| examples = [] | |
| for i in range(len(test_ds)): | |
| row = test_ds[i] | |
| tagged_text = row["tagged_text"] | |
| audio_path = audio_lookup.get(tagged_text) | |
| if audio_path and Path(audio_path).exists(): | |
| examples.append({ | |
| "audio_path": audio_path, | |
| "expected": tagged_text, | |
| }) | |
| return examples | |
| def run_eval( | |
| adapter_path: str | None = None, | |
| dataset_path: str = "data/processed", | |
| split: str = "test", | |
| wandb_project: str = "evoxtral", | |
| ): | |
| """Run full Evoxtral-Bench evaluation with Weave tracing.""" | |
| weave.init(f"{wandb_project}") | |
| # Build eval dataset | |
| examples = build_eval_dataset(dataset_path, split) | |
| print(f"Evaluation examples: {len(examples)}") | |
| eval_dataset = weave.Dataset( | |
| name=f"evoxtral-bench-{split}", | |
| rows=examples, | |
| ) | |
| # Create model (base or finetuned) | |
| model_name = "evoxtral-finetuned" if adapter_path else "voxtral-base" | |
| model = EvoxtralModel( | |
| adapter_path=adapter_path, | |
| ) | |
| # Run evaluation | |
| evaluation = weave.Evaluation( | |
| dataset=eval_dataset, | |
| scorers=[wer_scorer, tag_f1_scorer, emphasis_f1_scorer, hallucination_scorer], | |
| evaluation_name=f"evoxtral-bench-{model_name}", | |
| ) | |
| results = asyncio.run(evaluation.evaluate(model)) | |
| # Also log summary to W&B run for easy comparison | |
| run = wandb.init( | |
| project=wandb_project, | |
| name=f"eval-{model_name}", | |
| job_type="evaluation", | |
| tags=["eval", model_name], | |
| ) | |
| wandb.log({ | |
| "eval/model": model_name, | |
| "eval/adapter_path": adapter_path or "none", | |
| "eval/num_examples": len(examples), | |
| "eval/split": split, | |
| }) | |
| # Log per-metric summary | |
| if isinstance(results, dict): | |
| for key, value in results.items(): | |
| if isinstance(value, (int, float)): | |
| wandb.log({f"eval/{key}": value}) | |
| wandb.finish() | |
| print(f"Evaluation complete for {model_name}") | |
| return results | |
| def compare_base_vs_finetuned( | |
| adapter_path: str = "model/evoxtral-lora", | |
| dataset_path: str = "data/processed", | |
| wandb_project: str = "evoxtral", | |
| ): | |
| """Run eval on both base and finetuned model for comparison.""" | |
| print("=" * 50) | |
| print("EVALUATING BASE MODEL") | |
| print("=" * 50) | |
| base_results = run_eval( | |
| adapter_path=None, | |
| dataset_path=dataset_path, | |
| wandb_project=wandb_project, | |
| ) | |
| print("\n" + "=" * 50) | |
| print("EVALUATING FINETUNED MODEL") | |
| print("=" * 50) | |
| ft_results = run_eval( | |
| adapter_path=adapter_path, | |
| dataset_path=dataset_path, | |
| wandb_project=wandb_project, | |
| ) | |
| return {"base": base_results, "finetuned": ft_results} | |
| if __name__ == "__main__": | |
| import sys | |
| adapter = sys.argv[1] if len(sys.argv) > 1 else None | |
| if adapter == "compare": | |
| compare_base_vs_finetuned() | |
| else: | |
| run_eval(adapter_path=adapter) | |