ethos / training /eval /bench.py
Lior-0618's picture
refactor: restructure repo into api/ proxy/ web/ training/ docs/
a265585
"""Evoxtral-Bench: 3-layer evaluation with W&B Weave tracing.
Layer 1: Text quality (WER on plain text)
Layer 2: Tag classification (F1 on tag extraction + emphasis)
Layer 3: Round-trip TTS quality (optional)
All evaluations are traced via W&B Weave for full reproducibility.
"""
import asyncio
import json
import torch
import wandb
import weave
from pathlib import Path
from datasets import load_from_disk, Audio
from jiwer import wer as compute_wer
from dotenv import load_dotenv
from .tag_metrics import tag_f1, emphasis_f1, tag_hallucination_rate, strip_tags
load_dotenv()
# ── Weave Scorer Functions ──────────────────────────────────────────────
@weave.op()
def wer_scorer(output: dict, expected: str) -> dict:
"""Layer 1: Word Error Rate on plain text (tags stripped)."""
pred_plain = strip_tags(output["prediction"])
ref_plain = strip_tags(expected)
score = compute_wer(ref_plain, pred_plain) if ref_plain.strip() else 0.0
return {"wer": score}
@weave.op()
def tag_f1_scorer(output: dict, expected: str) -> dict:
"""Layer 2a: Tag extraction F1."""
return tag_f1(output["prediction"], expected)
@weave.op()
def emphasis_f1_scorer(output: dict, expected: str) -> dict:
"""Layer 2b: Emphasis (CAPS) F1."""
return emphasis_f1(output["prediction"], expected)
@weave.op()
def hallucination_scorer(output: dict, expected: str) -> dict:
"""Layer 2c: Tag hallucination rate."""
rate = tag_hallucination_rate(output["prediction"], expected)
return {"tag_hallucination_rate": rate}
# ── Model Wrapper ───────────────────────────────────────────────────────
class EvoxtralModel(weave.Model):
"""Wraps Voxtral inference for Weave evaluation."""
model_id: str = "mistralai/Voxtral-Mini-3B-2507"
adapter_path: str | None = None
_model: object = None
_processor: object = None
class Config:
arbitrary_types_allowed = True
def _load(self):
if self._model is not None:
return
from transformers import VoxtralForConditionalGeneration, AutoProcessor
self._processor = AutoProcessor.from_pretrained(self.model_id)
if self.adapter_path:
from peft import PeftModel
base = VoxtralForConditionalGeneration.from_pretrained(
self.model_id, torch_dtype=torch.bfloat16, device_map="auto",
)
self._model = PeftModel.from_pretrained(base, self.adapter_path)
else:
self._model = VoxtralForConditionalGeneration.from_pretrained(
self.model_id, torch_dtype=torch.bfloat16, device_map="auto",
)
self._model.eval()
@weave.op()
def predict(self, audio_path: str) -> dict:
"""Run inference on a single audio file."""
self._load()
import librosa
audio_array, sr = librosa.load(audio_path, sr=16000)
conversation = [
{
"role": "user",
"content": [
{"type": "audio", "audio": audio_array},
{"type": "text", "text": "Transcribe this audio with expressive tags."},
],
},
]
inputs = self._processor.apply_chat_template(
[conversation],
return_tensors="pt",
)
inputs = {k: v.to(self._model.device) for k, v in inputs.items()}
with torch.no_grad():
output_ids = self._model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
)
# Decode only the generated tokens (skip input)
input_len = inputs["input_ids"].shape[1]
prediction = self._processor.tokenizer.decode(
output_ids[0][input_len:], skip_special_tokens=True,
)
return {"prediction": prediction}
# ── Evaluation Runner ───────────────────────────────────────────────────
def build_eval_dataset(
dataset_path: str = "data/processed",
split: str = "test",
manifest_path: str = "data/audio/manifest.json",
) -> list[dict]:
"""Build evaluation examples from the processed dataset."""
ds = load_from_disk(dataset_path)
test_ds = ds[split]
with open(manifest_path) as f:
manifest = json.load(f)
# Build index from tagged_text to audio_path
audio_lookup = {m["tagged_text"]: m["audio_path"] for m in manifest if m.get("synthesis_success")}
examples = []
for i in range(len(test_ds)):
row = test_ds[i]
tagged_text = row["tagged_text"]
audio_path = audio_lookup.get(tagged_text)
if audio_path and Path(audio_path).exists():
examples.append({
"audio_path": audio_path,
"expected": tagged_text,
})
return examples
def run_eval(
adapter_path: str | None = None,
dataset_path: str = "data/processed",
split: str = "test",
wandb_project: str = "evoxtral",
):
"""Run full Evoxtral-Bench evaluation with Weave tracing."""
weave.init(f"{wandb_project}")
# Build eval dataset
examples = build_eval_dataset(dataset_path, split)
print(f"Evaluation examples: {len(examples)}")
eval_dataset = weave.Dataset(
name=f"evoxtral-bench-{split}",
rows=examples,
)
# Create model (base or finetuned)
model_name = "evoxtral-finetuned" if adapter_path else "voxtral-base"
model = EvoxtralModel(
adapter_path=adapter_path,
)
# Run evaluation
evaluation = weave.Evaluation(
dataset=eval_dataset,
scorers=[wer_scorer, tag_f1_scorer, emphasis_f1_scorer, hallucination_scorer],
evaluation_name=f"evoxtral-bench-{model_name}",
)
results = asyncio.run(evaluation.evaluate(model))
# Also log summary to W&B run for easy comparison
run = wandb.init(
project=wandb_project,
name=f"eval-{model_name}",
job_type="evaluation",
tags=["eval", model_name],
)
wandb.log({
"eval/model": model_name,
"eval/adapter_path": adapter_path or "none",
"eval/num_examples": len(examples),
"eval/split": split,
})
# Log per-metric summary
if isinstance(results, dict):
for key, value in results.items():
if isinstance(value, (int, float)):
wandb.log({f"eval/{key}": value})
wandb.finish()
print(f"Evaluation complete for {model_name}")
return results
def compare_base_vs_finetuned(
adapter_path: str = "model/evoxtral-lora",
dataset_path: str = "data/processed",
wandb_project: str = "evoxtral",
):
"""Run eval on both base and finetuned model for comparison."""
print("=" * 50)
print("EVALUATING BASE MODEL")
print("=" * 50)
base_results = run_eval(
adapter_path=None,
dataset_path=dataset_path,
wandb_project=wandb_project,
)
print("\n" + "=" * 50)
print("EVALUATING FINETUNED MODEL")
print("=" * 50)
ft_results = run_eval(
adapter_path=adapter_path,
dataset_path=dataset_path,
wandb_project=wandb_project,
)
return {"base": base_results, "finetuned": ft_results}
if __name__ == "__main__":
import sys
adapter = sys.argv[1] if len(sys.argv) > 1 else None
if adapter == "compare":
compare_base_vs_finetuned()
else:
run_eval(adapter_path=adapter)