| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Evaluate Fine-tuned LiveKit model vs Original LiveKit model |
| """ |
|
|
| import unicodedata |
| import re |
| from datasets import load_dataset |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import PeftModel |
| from huggingface_hub import hf_hub_download |
| import onnxruntime as ort |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix |
| import torch |
| from tabulate import tabulate |
|
|
|
|
| def normalize_text(text: str) -> str: |
| if not text: |
| return "" |
| text = unicodedata.normalize("NFKC", text.lower()) |
| text = "".join( |
| ch for ch in text |
| if not (unicodedata.category(ch).startswith("P") and ch not in ["'", "-"]) |
| ) |
| text = re.sub(r"\s+", " ", text).strip() |
| return text |
|
|
|
|
| def format_chat_for_livekit(messages: list, tokenizer) -> str: |
| """Format for LiveKit model inference""" |
| new_chat_ctx = [] |
| last_msg = None |
|
|
| for msg in messages: |
| if not msg.get("content"): |
| continue |
| content = normalize_text(msg["content"]) |
| if last_msg and last_msg["role"] == msg["role"]: |
| last_msg["content"] += f" {content}" |
| else: |
| new_msg = {"role": msg["role"], "content": content} |
| new_chat_ctx.append(new_msg) |
| last_msg = new_msg |
|
|
| convo_text = tokenizer.apply_chat_template( |
| new_chat_ctx, |
| add_generation_prompt=False, |
| add_special_tokens=False, |
| tokenize=False |
| ) |
|
|
| ix = convo_text.rfind("<|im_end|>") |
| return convo_text[:ix] |
|
|
|
|
| def predict_onnx(session, tokenizer, messages: list) -> float: |
| """Original LiveKit ONNX model prediction""" |
| text = format_chat_for_livekit(messages, tokenizer) |
| inputs = tokenizer( |
| text, |
| add_special_tokens=False, |
| return_tensors="np", |
| max_length=128, |
| truncation=True, |
| ) |
| outputs = session.run(None, {"input_ids": inputs["input_ids"].astype("int64")}) |
| return float(outputs[0].flatten()[-1]) |
|
|
|
|
| def predict_finetuned(model, tokenizer, messages: list, device: str) -> float: |
| """Fine-tuned LiveKit model prediction""" |
| text = format_chat_for_livekit(messages, tokenizer) |
| inputs = tokenizer( |
| text, |
| add_special_tokens=False, |
| return_tensors="pt", |
| max_length=128, |
| truncation=True, |
| ) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = model(**inputs) |
| |
| logits = outputs.logits[:, -1, :] |
| probs = torch.softmax(logits, dim=-1) |
|
|
| |
| im_end_id = tokenizer.encode("<|im_end|>", add_special_tokens=False) |
| if im_end_id: |
| eot_prob = probs[0, im_end_id[0]].item() |
| else: |
| eot_prob = 0.5 |
|
|
| return eot_prob |
|
|
|
|
| def main(): |
| print("=" * 60) |
| print("Evaluation: Fine-tuned vs Original LiveKit model") |
| print("=" * 60) |
|
|
| |
| print("\n[1/4] Loading test dataset...") |
| dataset = load_dataset("Vurtnec/eot-detection-testset", split="train") |
| print(f" Loaded {len(dataset)} test samples") |
|
|
| |
| print("\n[2/4] Loading original LiveKit model (ONNX)...") |
| revision = "v0.4.1-intl" |
| onnx_path = hf_hub_download( |
| repo_id="livekit/turn-detector", |
| filename="model_q8.onnx", |
| subfolder="onnx", |
| revision=revision |
| ) |
| orig_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) |
| orig_tokenizer = AutoTokenizer.from_pretrained( |
| "livekit/turn-detector", |
| revision=revision, |
| truncation_side="left" |
| ) |
| print(" Original model loaded!") |
|
|
| |
| print("\n[3/4] Loading fine-tuned LiveKit model...") |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f" Using device: {device}") |
|
|
| base_model = AutoModelForCausalLM.from_pretrained("livekit/turn-detector") |
| ft_model = PeftModel.from_pretrained(base_model, "Vurtnec/livekit-eot-finetuned") |
| ft_model = ft_model.to(device) |
| ft_model.eval() |
| ft_tokenizer = AutoTokenizer.from_pretrained("livekit/turn-detector") |
| print(" Fine-tuned model loaded!") |
|
|
| |
| print("\n[4/4] Running evaluation...") |
|
|
| orig_predictions = [] |
| ft_predictions = [] |
| ground_truth = [] |
|
|
| orig_probs = [] |
| ft_probs = [] |
|
|
| for i, sample in enumerate(dataset): |
| messages = sample["messages"] |
| is_complete = sample["is_complete"] |
|
|
| label = 1 if is_complete else 0 |
| ground_truth.append(label) |
|
|
| |
| try: |
| orig_prob = predict_onnx(orig_session, orig_tokenizer, messages) |
| orig_pred = 1 if orig_prob >= 0.5 else 0 |
| orig_probs.append(orig_prob) |
| except Exception as e: |
| print(f" Warning: Original model error on sample {i}: {e}") |
| orig_pred = 1 |
| orig_probs.append(0.5) |
| orig_predictions.append(orig_pred) |
|
|
| |
| try: |
| ft_prob = predict_finetuned(ft_model, ft_tokenizer, messages, device) |
| ft_pred = 1 if ft_prob >= 0.5 else 0 |
| ft_probs.append(ft_prob) |
| except Exception as e: |
| print(f" Warning: Fine-tuned model error on sample {i}: {e}") |
| ft_pred = 1 |
| ft_probs.append(0.5) |
| ft_predictions.append(ft_pred) |
|
|
| if (i + 1) % 10 == 0: |
| print(f" Processed {i + 1}/{len(dataset)} samples...") |
|
|
| |
| print("\n" + "=" * 60) |
| print("RESULTS") |
| print("=" * 60) |
|
|
| def calc_metrics(preds, labels, name): |
| return { |
| "Model": name, |
| "Accuracy": f"{accuracy_score(labels, preds) * 100:.2f}%", |
| "Precision": f"{precision_score(labels, preds, zero_division=0) * 100:.2f}%", |
| "Recall": f"{recall_score(labels, preds, zero_division=0) * 100:.2f}%", |
| "F1 Score": f"{f1_score(labels, preds, zero_division=0) * 100:.2f}%", |
| } |
|
|
| orig_metrics = calc_metrics(orig_predictions, ground_truth, "Original LiveKit") |
| ft_metrics = calc_metrics(ft_predictions, ground_truth, "Fine-tuned LiveKit") |
|
|
| table = [orig_metrics, ft_metrics] |
| print("\n" + tabulate(table, headers="keys", tablefmt="grid")) |
|
|
| |
| print("\n--- Confusion Matrices ---") |
| print("\nOriginal Model:") |
| print(" Predicted") |
| print(" Incomplete Complete") |
| orig_cm = confusion_matrix(ground_truth, orig_predictions) |
| print(f"Actual Incomplete {orig_cm[0][0]:3d} {orig_cm[0][1]:3d}") |
| print(f" Complete {orig_cm[1][0]:3d} {orig_cm[1][1]:3d}") |
|
|
| print("\nFine-tuned Model:") |
| print(" Predicted") |
| print(" Incomplete Complete") |
| ft_cm = confusion_matrix(ground_truth, ft_predictions) |
| print(f"Actual Incomplete {ft_cm[0][0]:3d} {ft_cm[0][1]:3d}") |
| print(f" Complete {ft_cm[1][0]:3d} {ft_cm[1][1]:3d}") |
|
|
| |
| print("\n--- Probability Analysis ---") |
| print(f"Original model avg prob: {sum(orig_probs)/len(orig_probs):.4f}") |
| print(f"Fine-tuned model avg prob: {sum(ft_probs)/len(ft_probs):.4f}") |
|
|
| |
| print("\n" + "=" * 60) |
| print("COMPARISON SUMMARY") |
| print("=" * 60) |
|
|
| orig_acc = accuracy_score(ground_truth, orig_predictions) |
| ft_acc = accuracy_score(ground_truth, ft_predictions) |
| orig_f1 = f1_score(ground_truth, orig_predictions, zero_division=0) |
| ft_f1 = f1_score(ground_truth, ft_predictions, zero_division=0) |
|
|
| print(f"\nOriginal LiveKit: Accuracy={orig_acc*100:.2f}%, F1={orig_f1*100:.2f}%") |
| print(f"Fine-tuned: Accuracy={ft_acc*100:.2f}%, F1={ft_f1*100:.2f}%") |
|
|
| diff_acc = (ft_acc - orig_acc) * 100 |
| diff_f1 = (ft_f1 - orig_f1) * 100 |
|
|
| print(f"\nImprovement (Fine-tuned - Original):") |
| print(f" Accuracy: {'+' if diff_acc >= 0 else ''}{diff_acc:.2f}%") |
| print(f" F1 Score: {'+' if diff_f1 >= 0 else ''}{diff_f1:.2f}%") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|