eot-detector-smollm2 / eval_finetuned_livekit.py
Vurtnec's picture
Upload eval_finetuned_livekit.py with huggingface_hub
5972e29 verified
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "datasets>=3.1",
# "transformers>=4.46",
# "torch>=2.5",
# "peft>=0.13",
# "huggingface_hub>=0.26",
# "onnxruntime>=1.19",
# "scikit-learn>=1.5",
# "tabulate>=0.9",
# ]
# ///
"""
Evaluate Fine-tuned LiveKit model vs Original LiveKit model
"""
import unicodedata
import re
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from huggingface_hub import hf_hub_download
import onnxruntime as ort
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import torch
from tabulate import tabulate
def normalize_text(text: str) -> str:
if not text:
return ""
text = unicodedata.normalize("NFKC", text.lower())
text = "".join(
ch for ch in text
if not (unicodedata.category(ch).startswith("P") and ch not in ["'", "-"])
)
text = re.sub(r"\s+", " ", text).strip()
return text
def format_chat_for_livekit(messages: list, tokenizer) -> str:
"""Format for LiveKit model inference"""
new_chat_ctx = []
last_msg = None
for msg in messages:
if not msg.get("content"):
continue
content = normalize_text(msg["content"])
if last_msg and last_msg["role"] == msg["role"]:
last_msg["content"] += f" {content}"
else:
new_msg = {"role": msg["role"], "content": content}
new_chat_ctx.append(new_msg)
last_msg = new_msg
convo_text = tokenizer.apply_chat_template(
new_chat_ctx,
add_generation_prompt=False,
add_special_tokens=False,
tokenize=False
)
ix = convo_text.rfind("<|im_end|>")
return convo_text[:ix]
def predict_onnx(session, tokenizer, messages: list) -> float:
"""Original LiveKit ONNX model prediction"""
text = format_chat_for_livekit(messages, tokenizer)
inputs = tokenizer(
text,
add_special_tokens=False,
return_tensors="np",
max_length=128,
truncation=True,
)
outputs = session.run(None, {"input_ids": inputs["input_ids"].astype("int64")})
return float(outputs[0].flatten()[-1])
def predict_finetuned(model, tokenizer, messages: list, device: str) -> float:
"""Fine-tuned LiveKit model prediction"""
text = format_chat_for_livekit(messages, tokenizer)
inputs = tokenizer(
text,
add_special_tokens=False,
return_tensors="pt",
max_length=128,
truncation=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
# Get probability for next token being <|im_end|>
logits = outputs.logits[:, -1, :]
probs = torch.softmax(logits, dim=-1)
# Find <|im_end|> token id
im_end_id = tokenizer.encode("<|im_end|>", add_special_tokens=False)
if im_end_id:
eot_prob = probs[0, im_end_id[0]].item()
else:
eot_prob = 0.5
return eot_prob
def main():
print("=" * 60)
print("Evaluation: Fine-tuned vs Original LiveKit model")
print("=" * 60)
# Load test dataset
print("\n[1/4] Loading test dataset...")
dataset = load_dataset("Vurtnec/eot-detection-testset", split="train")
print(f" Loaded {len(dataset)} test samples")
# Load original LiveKit model (ONNX)
print("\n[2/4] Loading original LiveKit model (ONNX)...")
revision = "v0.4.1-intl"
onnx_path = hf_hub_download(
repo_id="livekit/turn-detector",
filename="model_q8.onnx",
subfolder="onnx",
revision=revision
)
orig_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
orig_tokenizer = AutoTokenizer.from_pretrained(
"livekit/turn-detector",
revision=revision,
truncation_side="left"
)
print(" Original model loaded!")
# Load fine-tuned model
print("\n[3/4] Loading fine-tuned LiveKit model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f" Using device: {device}")
base_model = AutoModelForCausalLM.from_pretrained("livekit/turn-detector")
ft_model = PeftModel.from_pretrained(base_model, "Vurtnec/livekit-eot-finetuned")
ft_model = ft_model.to(device)
ft_model.eval()
ft_tokenizer = AutoTokenizer.from_pretrained("livekit/turn-detector")
print(" Fine-tuned model loaded!")
# Run evaluation
print("\n[4/4] Running evaluation...")
orig_predictions = []
ft_predictions = []
ground_truth = []
orig_probs = []
ft_probs = []
for i, sample in enumerate(dataset):
messages = sample["messages"]
is_complete = sample["is_complete"]
label = 1 if is_complete else 0
ground_truth.append(label)
# Original model prediction
try:
orig_prob = predict_onnx(orig_session, orig_tokenizer, messages)
orig_pred = 1 if orig_prob >= 0.5 else 0
orig_probs.append(orig_prob)
except Exception as e:
print(f" Warning: Original model error on sample {i}: {e}")
orig_pred = 1
orig_probs.append(0.5)
orig_predictions.append(orig_pred)
# Fine-tuned model prediction
try:
ft_prob = predict_finetuned(ft_model, ft_tokenizer, messages, device)
ft_pred = 1 if ft_prob >= 0.5 else 0
ft_probs.append(ft_prob)
except Exception as e:
print(f" Warning: Fine-tuned model error on sample {i}: {e}")
ft_pred = 1
ft_probs.append(0.5)
ft_predictions.append(ft_pred)
if (i + 1) % 10 == 0:
print(f" Processed {i + 1}/{len(dataset)} samples...")
# Calculate metrics
print("\n" + "=" * 60)
print("RESULTS")
print("=" * 60)
def calc_metrics(preds, labels, name):
return {
"Model": name,
"Accuracy": f"{accuracy_score(labels, preds) * 100:.2f}%",
"Precision": f"{precision_score(labels, preds, zero_division=0) * 100:.2f}%",
"Recall": f"{recall_score(labels, preds, zero_division=0) * 100:.2f}%",
"F1 Score": f"{f1_score(labels, preds, zero_division=0) * 100:.2f}%",
}
orig_metrics = calc_metrics(orig_predictions, ground_truth, "Original LiveKit")
ft_metrics = calc_metrics(ft_predictions, ground_truth, "Fine-tuned LiveKit")
table = [orig_metrics, ft_metrics]
print("\n" + tabulate(table, headers="keys", tablefmt="grid"))
# Confusion matrices
print("\n--- Confusion Matrices ---")
print("\nOriginal Model:")
print(" Predicted")
print(" Incomplete Complete")
orig_cm = confusion_matrix(ground_truth, orig_predictions)
print(f"Actual Incomplete {orig_cm[0][0]:3d} {orig_cm[0][1]:3d}")
print(f" Complete {orig_cm[1][0]:3d} {orig_cm[1][1]:3d}")
print("\nFine-tuned Model:")
print(" Predicted")
print(" Incomplete Complete")
ft_cm = confusion_matrix(ground_truth, ft_predictions)
print(f"Actual Incomplete {ft_cm[0][0]:3d} {ft_cm[0][1]:3d}")
print(f" Complete {ft_cm[1][0]:3d} {ft_cm[1][1]:3d}")
# Probability analysis
print("\n--- Probability Analysis ---")
print(f"Original model avg prob: {sum(orig_probs)/len(orig_probs):.4f}")
print(f"Fine-tuned model avg prob: {sum(ft_probs)/len(ft_probs):.4f}")
# Summary
print("\n" + "=" * 60)
print("COMPARISON SUMMARY")
print("=" * 60)
orig_acc = accuracy_score(ground_truth, orig_predictions)
ft_acc = accuracy_score(ground_truth, ft_predictions)
orig_f1 = f1_score(ground_truth, orig_predictions, zero_division=0)
ft_f1 = f1_score(ground_truth, ft_predictions, zero_division=0)
print(f"\nOriginal LiveKit: Accuracy={orig_acc*100:.2f}%, F1={orig_f1*100:.2f}%")
print(f"Fine-tuned: Accuracy={ft_acc*100:.2f}%, F1={ft_f1*100:.2f}%")
diff_acc = (ft_acc - orig_acc) * 100
diff_f1 = (ft_f1 - orig_f1) * 100
print(f"\nImprovement (Fine-tuned - Original):")
print(f" Accuracy: {'+' if diff_acc >= 0 else ''}{diff_acc:.2f}%")
print(f" F1 Score: {'+' if diff_f1 >= 0 else ''}{diff_f1:.2f}%")
if __name__ == "__main__":
main()