Upload compare_eot_models.py with huggingface_hub
Browse files- compare_eot_models.py +32 -57
compare_eot_models.py
CHANGED
|
@@ -33,7 +33,7 @@ import torch
|
|
| 33 |
# ============================================================
|
| 34 |
|
| 35 |
def normalize_text_multilingual(text: str) -> str:
|
| 36 |
-
"""Normalize text for multilingual model
|
| 37 |
if not text:
|
| 38 |
return ""
|
| 39 |
text = unicodedata.normalize("NFKC", text.lower())
|
|
@@ -45,7 +45,7 @@ def normalize_text_multilingual(text: str) -> str:
|
|
| 45 |
return text
|
| 46 |
|
| 47 |
def format_chat_for_livekit(messages: list, tokenizer) -> str:
|
| 48 |
-
"""Format chat context for LiveKit model
|
| 49 |
new_chat_ctx = []
|
| 50 |
last_msg = None
|
| 51 |
|
|
@@ -55,7 +55,6 @@ def format_chat_for_livekit(messages: list, tokenizer) -> str:
|
|
| 55 |
|
| 56 |
content = normalize_text_multilingual(msg["content"])
|
| 57 |
|
| 58 |
-
# Combine adjacent turns
|
| 59 |
if last_msg and last_msg["role"] == msg["role"]:
|
| 60 |
last_msg["content"] += f" {content}"
|
| 61 |
else:
|
|
@@ -70,7 +69,6 @@ def format_chat_for_livekit(messages: list, tokenizer) -> str:
|
|
| 70 |
tokenize=False
|
| 71 |
)
|
| 72 |
|
| 73 |
-
# Remove the EOU token from current utterance
|
| 74 |
ix = convo_text.rfind("<|im_end|>")
|
| 75 |
text = convo_text[:ix]
|
| 76 |
return text
|
|
@@ -83,7 +81,7 @@ def predict_livekit(session, tokenizer, messages: list) -> float:
|
|
| 83 |
text,
|
| 84 |
add_special_tokens=False,
|
| 85 |
return_tensors="np",
|
| 86 |
-
max_length=128,
|
| 87 |
truncation=True,
|
| 88 |
)
|
| 89 |
|
|
@@ -97,7 +95,6 @@ def predict_livekit(session, tokenizer, messages: list) -> float:
|
|
| 97 |
|
| 98 |
def predict_finetuned(model, tokenizer, messages: list, device: str) -> float:
|
| 99 |
"""Run inference with fine-tuned model"""
|
| 100 |
-
# Format as ChatML
|
| 101 |
formatted = ""
|
| 102 |
for msg in messages:
|
| 103 |
role = msg["role"]
|
|
@@ -117,14 +114,13 @@ def predict_finetuned(model, tokenizer, messages: list, device: str) -> float:
|
|
| 117 |
|
| 118 |
generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 119 |
|
| 120 |
-
# Parse prediction
|
| 121 |
generated_lower = generated.strip().lower()
|
| 122 |
if "<|eot|>" in generated_lower or "eot" in generated_lower:
|
| 123 |
-
return 1.0
|
| 124 |
elif "<|continue|>" in generated_lower or "continue" in generated_lower:
|
| 125 |
-
return 0.0
|
| 126 |
else:
|
| 127 |
-
return 0.5
|
| 128 |
|
| 129 |
# ============================================================
|
| 130 |
# Main Evaluation
|
|
@@ -139,6 +135,7 @@ def main():
|
|
| 139 |
print("\n[1/4] Loading test dataset...")
|
| 140 |
dataset = load_dataset("Vurtnec/eot-detection-testset", split="train")
|
| 141 |
print(f" Loaded {len(dataset)} test samples")
|
|
|
|
| 142 |
|
| 143 |
# Load fine-tuned model
|
| 144 |
print("\n[2/4] Loading fine-tuned model (Vurtnec/eot-detector-smollm2)...")
|
|
@@ -155,8 +152,7 @@ def main():
|
|
| 155 |
# Load LiveKit model
|
| 156 |
print("\n[3/4] Loading LiveKit model (livekit/turn-detector)...")
|
| 157 |
|
| 158 |
-
|
| 159 |
-
revision = "v0.4.1-intl" # multilingual
|
| 160 |
onnx_path = hf_hub_download(
|
| 161 |
repo_id="livekit/turn-detector",
|
| 162 |
filename="model_q8.onnx",
|
|
@@ -180,34 +176,14 @@ def main():
|
|
| 180 |
ground_truth = []
|
| 181 |
|
| 182 |
for i, sample in enumerate(dataset):
|
| 183 |
-
|
|
|
|
|
|
|
| 184 |
|
| 185 |
-
#
|
| 186 |
-
if
|
| 187 |
-
label = 1 # Complete
|
| 188 |
-
else:
|
| 189 |
-
label = 0 # Incomplete
|
| 190 |
ground_truth.append(label)
|
| 191 |
|
| 192 |
-
# Extract conversation from text
|
| 193 |
-
messages = []
|
| 194 |
-
parts = text.split("<|im_end|>")
|
| 195 |
-
for part in parts[:-1]: # Skip the label part
|
| 196 |
-
if "<|im_start|>" in part:
|
| 197 |
-
idx = part.find("<|im_start|>")
|
| 198 |
-
content_part = part[idx + len("<|im_start|>"):]
|
| 199 |
-
if "\n" in content_part:
|
| 200 |
-
role, content = content_part.split("\n", 1)
|
| 201 |
-
role = role.strip()
|
| 202 |
-
content = content.strip()
|
| 203 |
-
if role in ["user", "assistant"] and content:
|
| 204 |
-
messages.append({"role": role, "content": content})
|
| 205 |
-
|
| 206 |
-
if not messages:
|
| 207 |
-
# Fallback: treat as user message
|
| 208 |
-
clean_text = text.split("<|eot|>")[0].split("<|continue|>")[0].strip()
|
| 209 |
-
messages = [{"role": "user", "content": clean_text}]
|
| 210 |
-
|
| 211 |
# Fine-tuned prediction
|
| 212 |
try:
|
| 213 |
ft_prob = predict_finetuned(ft_model, ft_tokenizer, messages, device)
|
|
@@ -220,11 +196,10 @@ def main():
|
|
| 220 |
# LiveKit prediction
|
| 221 |
try:
|
| 222 |
lk_prob = predict_livekit(lk_session, lk_tokenizer, messages)
|
| 223 |
-
# LiveKit uses 0.5 as default threshold
|
| 224 |
lk_pred = 1 if lk_prob >= 0.5 else 0
|
| 225 |
except Exception as e:
|
| 226 |
print(f" Warning: LiveKit model error on sample {i}: {e}")
|
| 227 |
-
lk_pred = 1
|
| 228 |
lk_predictions.append(lk_pred)
|
| 229 |
|
| 230 |
if (i + 1) % 10 == 0:
|
|
@@ -268,25 +243,25 @@ def main():
|
|
| 268 |
print(f"Actual Incomplete {lk_cm[0][0]:3d} {lk_cm[0][1]:3d}")
|
| 269 |
print(f" Complete {lk_cm[1][0]:3d} {lk_cm[1][1]:3d}")
|
| 270 |
|
| 271 |
-
#
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
print("
|
| 290 |
|
| 291 |
if __name__ == "__main__":
|
| 292 |
main()
|
|
|
|
| 33 |
# ============================================================
|
| 34 |
|
| 35 |
def normalize_text_multilingual(text: str) -> str:
|
| 36 |
+
"""Normalize text for multilingual model"""
|
| 37 |
if not text:
|
| 38 |
return ""
|
| 39 |
text = unicodedata.normalize("NFKC", text.lower())
|
|
|
|
| 45 |
return text
|
| 46 |
|
| 47 |
def format_chat_for_livekit(messages: list, tokenizer) -> str:
|
| 48 |
+
"""Format chat context for LiveKit model"""
|
| 49 |
new_chat_ctx = []
|
| 50 |
last_msg = None
|
| 51 |
|
|
|
|
| 55 |
|
| 56 |
content = normalize_text_multilingual(msg["content"])
|
| 57 |
|
|
|
|
| 58 |
if last_msg and last_msg["role"] == msg["role"]:
|
| 59 |
last_msg["content"] += f" {content}"
|
| 60 |
else:
|
|
|
|
| 69 |
tokenize=False
|
| 70 |
)
|
| 71 |
|
|
|
|
| 72 |
ix = convo_text.rfind("<|im_end|>")
|
| 73 |
text = convo_text[:ix]
|
| 74 |
return text
|
|
|
|
| 81 |
text,
|
| 82 |
add_special_tokens=False,
|
| 83 |
return_tensors="np",
|
| 84 |
+
max_length=128,
|
| 85 |
truncation=True,
|
| 86 |
)
|
| 87 |
|
|
|
|
| 95 |
|
| 96 |
def predict_finetuned(model, tokenizer, messages: list, device: str) -> float:
|
| 97 |
"""Run inference with fine-tuned model"""
|
|
|
|
| 98 |
formatted = ""
|
| 99 |
for msg in messages:
|
| 100 |
role = msg["role"]
|
|
|
|
| 114 |
|
| 115 |
generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 116 |
|
|
|
|
| 117 |
generated_lower = generated.strip().lower()
|
| 118 |
if "<|eot|>" in generated_lower or "eot" in generated_lower:
|
| 119 |
+
return 1.0
|
| 120 |
elif "<|continue|>" in generated_lower or "continue" in generated_lower:
|
| 121 |
+
return 0.0
|
| 122 |
else:
|
| 123 |
+
return 0.5
|
| 124 |
|
| 125 |
# ============================================================
|
| 126 |
# Main Evaluation
|
|
|
|
| 135 |
print("\n[1/4] Loading test dataset...")
|
| 136 |
dataset = load_dataset("Vurtnec/eot-detection-testset", split="train")
|
| 137 |
print(f" Loaded {len(dataset)} test samples")
|
| 138 |
+
print(f" Columns: {dataset.column_names}")
|
| 139 |
|
| 140 |
# Load fine-tuned model
|
| 141 |
print("\n[2/4] Loading fine-tuned model (Vurtnec/eot-detector-smollm2)...")
|
|
|
|
| 152 |
# Load LiveKit model
|
| 153 |
print("\n[3/4] Loading LiveKit model (livekit/turn-detector)...")
|
| 154 |
|
| 155 |
+
revision = "v0.4.1-intl"
|
|
|
|
| 156 |
onnx_path = hf_hub_download(
|
| 157 |
repo_id="livekit/turn-detector",
|
| 158 |
filename="model_q8.onnx",
|
|
|
|
| 176 |
ground_truth = []
|
| 177 |
|
| 178 |
for i, sample in enumerate(dataset):
|
| 179 |
+
# Dataset structure: messages (list), is_complete (bool)
|
| 180 |
+
messages = sample["messages"]
|
| 181 |
+
is_complete = sample["is_complete"]
|
| 182 |
|
| 183 |
+
# Ground truth: 1 = complete, 0 = incomplete
|
| 184 |
+
label = 1 if is_complete else 0
|
|
|
|
|
|
|
|
|
|
| 185 |
ground_truth.append(label)
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
# Fine-tuned prediction
|
| 188 |
try:
|
| 189 |
ft_prob = predict_finetuned(ft_model, ft_tokenizer, messages, device)
|
|
|
|
| 196 |
# LiveKit prediction
|
| 197 |
try:
|
| 198 |
lk_prob = predict_livekit(lk_session, lk_tokenizer, messages)
|
|
|
|
| 199 |
lk_pred = 1 if lk_prob >= 0.5 else 0
|
| 200 |
except Exception as e:
|
| 201 |
print(f" Warning: LiveKit model error on sample {i}: {e}")
|
| 202 |
+
lk_pred = 1
|
| 203 |
lk_predictions.append(lk_pred)
|
| 204 |
|
| 205 |
if (i + 1) % 10 == 0:
|
|
|
|
| 243 |
print(f"Actual Incomplete {lk_cm[0][0]:3d} {lk_cm[0][1]:3d}")
|
| 244 |
print(f" Complete {lk_cm[1][0]:3d} {lk_cm[1][1]:3d}")
|
| 245 |
|
| 246 |
+
# Final summary
|
| 247 |
+
print("\n" + "=" * 60)
|
| 248 |
+
print("FINAL COMPARISON SUMMARY")
|
| 249 |
+
print("=" * 60)
|
| 250 |
+
|
| 251 |
+
ft_acc = accuracy_score(ground_truth, ft_predictions)
|
| 252 |
+
ft_f1 = f1_score(ground_truth, ft_predictions, zero_division=0)
|
| 253 |
+
lk_acc = accuracy_score(ground_truth, lk_predictions)
|
| 254 |
+
lk_f1 = f1_score(ground_truth, lk_predictions, zero_division=0)
|
| 255 |
+
|
| 256 |
+
print(f"\nFine-tuned Model: Accuracy={ft_acc*100:.2f}%, F1={ft_f1*100:.2f}%")
|
| 257 |
+
print(f"LiveKit Official: Accuracy={lk_acc*100:.2f}%, F1={lk_f1*100:.2f}%")
|
| 258 |
+
|
| 259 |
+
diff_acc = (lk_acc - ft_acc) * 100
|
| 260 |
+
diff_f1 = (lk_f1 - ft_f1) * 100
|
| 261 |
|
| 262 |
+
print(f"\nDifference (LiveKit - Fine-tuned):")
|
| 263 |
+
print(f" Accuracy: {'+' if diff_acc >= 0 else ''}{diff_acc:.2f}%")
|
| 264 |
+
print(f" F1 Score: {'+' if diff_f1 >= 0 else ''}{diff_f1:.2f}%")
|
| 265 |
|
| 266 |
if __name__ == "__main__":
|
| 267 |
main()
|