Vurtnec commited on
Commit
79c6cb6
·
verified ·
1 Parent(s): f8bdcc7

Upload compare_eot_models.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. compare_eot_models.py +32 -57
compare_eot_models.py CHANGED
@@ -33,7 +33,7 @@ import torch
33
  # ============================================================
34
 
35
  def normalize_text_multilingual(text: str) -> str:
36
- """Normalize text for multilingual model (from base.py:56-67)"""
37
  if not text:
38
  return ""
39
  text = unicodedata.normalize("NFKC", text.lower())
@@ -45,7 +45,7 @@ def normalize_text_multilingual(text: str) -> str:
45
  return text
46
 
47
  def format_chat_for_livekit(messages: list, tokenizer) -> str:
48
- """Format chat context for LiveKit model (from base.py:69-93)"""
49
  new_chat_ctx = []
50
  last_msg = None
51
 
@@ -55,7 +55,6 @@ def format_chat_for_livekit(messages: list, tokenizer) -> str:
55
 
56
  content = normalize_text_multilingual(msg["content"])
57
 
58
- # Combine adjacent turns
59
  if last_msg and last_msg["role"] == msg["role"]:
60
  last_msg["content"] += f" {content}"
61
  else:
@@ -70,7 +69,6 @@ def format_chat_for_livekit(messages: list, tokenizer) -> str:
70
  tokenize=False
71
  )
72
 
73
- # Remove the EOU token from current utterance
74
  ix = convo_text.rfind("<|im_end|>")
75
  text = convo_text[:ix]
76
  return text
@@ -83,7 +81,7 @@ def predict_livekit(session, tokenizer, messages: list) -> float:
83
  text,
84
  add_special_tokens=False,
85
  return_tensors="np",
86
- max_length=128, # MAX_HISTORY_TOKENS from base.py
87
  truncation=True,
88
  )
89
 
@@ -97,7 +95,6 @@ def predict_livekit(session, tokenizer, messages: list) -> float:
97
 
98
  def predict_finetuned(model, tokenizer, messages: list, device: str) -> float:
99
  """Run inference with fine-tuned model"""
100
- # Format as ChatML
101
  formatted = ""
102
  for msg in messages:
103
  role = msg["role"]
@@ -117,14 +114,13 @@ def predict_finetuned(model, tokenizer, messages: list, device: str) -> float:
117
 
118
  generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
119
 
120
- # Parse prediction
121
  generated_lower = generated.strip().lower()
122
  if "<|eot|>" in generated_lower or "eot" in generated_lower:
123
- return 1.0 # Complete turn
124
  elif "<|continue|>" in generated_lower or "continue" in generated_lower:
125
- return 0.0 # Incomplete turn
126
  else:
127
- return 0.5 # Uncertain
128
 
129
  # ============================================================
130
  # Main Evaluation
@@ -139,6 +135,7 @@ def main():
139
  print("\n[1/4] Loading test dataset...")
140
  dataset = load_dataset("Vurtnec/eot-detection-testset", split="train")
141
  print(f" Loaded {len(dataset)} test samples")
 
142
 
143
  # Load fine-tuned model
144
  print("\n[2/4] Loading fine-tuned model (Vurtnec/eot-detector-smollm2)...")
@@ -155,8 +152,7 @@ def main():
155
  # Load LiveKit model
156
  print("\n[3/4] Loading LiveKit model (livekit/turn-detector)...")
157
 
158
- # Download ONNX model (using multilingual version)
159
- revision = "v0.4.1-intl" # multilingual
160
  onnx_path = hf_hub_download(
161
  repo_id="livekit/turn-detector",
162
  filename="model_q8.onnx",
@@ -180,34 +176,14 @@ def main():
180
  ground_truth = []
181
 
182
  for i, sample in enumerate(dataset):
183
- text = sample["text"]
 
 
184
 
185
- # Parse ground truth from text
186
- if "<|eot|>" in text:
187
- label = 1 # Complete
188
- else:
189
- label = 0 # Incomplete
190
  ground_truth.append(label)
191
 
192
- # Extract conversation from text
193
- messages = []
194
- parts = text.split("<|im_end|>")
195
- for part in parts[:-1]: # Skip the label part
196
- if "<|im_start|>" in part:
197
- idx = part.find("<|im_start|>")
198
- content_part = part[idx + len("<|im_start|>"):]
199
- if "\n" in content_part:
200
- role, content = content_part.split("\n", 1)
201
- role = role.strip()
202
- content = content.strip()
203
- if role in ["user", "assistant"] and content:
204
- messages.append({"role": role, "content": content})
205
-
206
- if not messages:
207
- # Fallback: treat as user message
208
- clean_text = text.split("<|eot|>")[0].split("<|continue|>")[0].strip()
209
- messages = [{"role": "user", "content": clean_text}]
210
-
211
  # Fine-tuned prediction
212
  try:
213
  ft_prob = predict_finetuned(ft_model, ft_tokenizer, messages, device)
@@ -220,11 +196,10 @@ def main():
220
  # LiveKit prediction
221
  try:
222
  lk_prob = predict_livekit(lk_session, lk_tokenizer, messages)
223
- # LiveKit uses 0.5 as default threshold
224
  lk_pred = 1 if lk_prob >= 0.5 else 0
225
  except Exception as e:
226
  print(f" Warning: LiveKit model error on sample {i}: {e}")
227
- lk_pred = 1 # Default to complete on error
228
  lk_predictions.append(lk_pred)
229
 
230
  if (i + 1) % 10 == 0:
@@ -268,25 +243,25 @@ def main():
268
  print(f"Actual Incomplete {lk_cm[0][0]:3d} {lk_cm[0][1]:3d}")
269
  print(f" Complete {lk_cm[1][0]:3d} {lk_cm[1][1]:3d}")
270
 
271
- # Save results
272
- results = {
273
- "fine_tuned": {
274
- "accuracy": accuracy_score(ground_truth, ft_predictions),
275
- "precision": precision_score(ground_truth, ft_predictions, zero_division=0),
276
- "recall": recall_score(ground_truth, ft_predictions, zero_division=0),
277
- "f1": f1_score(ground_truth, ft_predictions, zero_division=0),
278
- },
279
- "livekit": {
280
- "accuracy": accuracy_score(ground_truth, lk_predictions),
281
- "precision": precision_score(ground_truth, lk_predictions, zero_division=0),
282
- "recall": recall_score(ground_truth, lk_predictions, zero_division=0),
283
- "f1": f1_score(ground_truth, lk_predictions, zero_division=0),
284
- }
285
- }
286
 
287
- with open("comparison_results.json", "w") as f:
288
- json.dump(results, f, indent=2)
289
- print("\nResults saved to comparison_results.json")
290
 
291
  if __name__ == "__main__":
292
  main()
 
33
  # ============================================================
34
 
35
  def normalize_text_multilingual(text: str) -> str:
36
+ """Normalize text for multilingual model"""
37
  if not text:
38
  return ""
39
  text = unicodedata.normalize("NFKC", text.lower())
 
45
  return text
46
 
47
  def format_chat_for_livekit(messages: list, tokenizer) -> str:
48
+ """Format chat context for LiveKit model"""
49
  new_chat_ctx = []
50
  last_msg = None
51
 
 
55
 
56
  content = normalize_text_multilingual(msg["content"])
57
 
 
58
  if last_msg and last_msg["role"] == msg["role"]:
59
  last_msg["content"] += f" {content}"
60
  else:
 
69
  tokenize=False
70
  )
71
 
 
72
  ix = convo_text.rfind("<|im_end|>")
73
  text = convo_text[:ix]
74
  return text
 
81
  text,
82
  add_special_tokens=False,
83
  return_tensors="np",
84
+ max_length=128,
85
  truncation=True,
86
  )
87
 
 
95
 
96
  def predict_finetuned(model, tokenizer, messages: list, device: str) -> float:
97
  """Run inference with fine-tuned model"""
 
98
  formatted = ""
99
  for msg in messages:
100
  role = msg["role"]
 
114
 
115
  generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
116
 
 
117
  generated_lower = generated.strip().lower()
118
  if "<|eot|>" in generated_lower or "eot" in generated_lower:
119
+ return 1.0
120
  elif "<|continue|>" in generated_lower or "continue" in generated_lower:
121
+ return 0.0
122
  else:
123
+ return 0.5
124
 
125
  # ============================================================
126
  # Main Evaluation
 
135
  print("\n[1/4] Loading test dataset...")
136
  dataset = load_dataset("Vurtnec/eot-detection-testset", split="train")
137
  print(f" Loaded {len(dataset)} test samples")
138
+ print(f" Columns: {dataset.column_names}")
139
 
140
  # Load fine-tuned model
141
  print("\n[2/4] Loading fine-tuned model (Vurtnec/eot-detector-smollm2)...")
 
152
  # Load LiveKit model
153
  print("\n[3/4] Loading LiveKit model (livekit/turn-detector)...")
154
 
155
+ revision = "v0.4.1-intl"
 
156
  onnx_path = hf_hub_download(
157
  repo_id="livekit/turn-detector",
158
  filename="model_q8.onnx",
 
176
  ground_truth = []
177
 
178
  for i, sample in enumerate(dataset):
179
+ # Dataset structure: messages (list), is_complete (bool)
180
+ messages = sample["messages"]
181
+ is_complete = sample["is_complete"]
182
 
183
+ # Ground truth: 1 = complete, 0 = incomplete
184
+ label = 1 if is_complete else 0
 
 
 
185
  ground_truth.append(label)
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  # Fine-tuned prediction
188
  try:
189
  ft_prob = predict_finetuned(ft_model, ft_tokenizer, messages, device)
 
196
  # LiveKit prediction
197
  try:
198
  lk_prob = predict_livekit(lk_session, lk_tokenizer, messages)
 
199
  lk_pred = 1 if lk_prob >= 0.5 else 0
200
  except Exception as e:
201
  print(f" Warning: LiveKit model error on sample {i}: {e}")
202
+ lk_pred = 1
203
  lk_predictions.append(lk_pred)
204
 
205
  if (i + 1) % 10 == 0:
 
243
  print(f"Actual Incomplete {lk_cm[0][0]:3d} {lk_cm[0][1]:3d}")
244
  print(f" Complete {lk_cm[1][0]:3d} {lk_cm[1][1]:3d}")
245
 
246
+ # Final summary
247
+ print("\n" + "=" * 60)
248
+ print("FINAL COMPARISON SUMMARY")
249
+ print("=" * 60)
250
+
251
+ ft_acc = accuracy_score(ground_truth, ft_predictions)
252
+ ft_f1 = f1_score(ground_truth, ft_predictions, zero_division=0)
253
+ lk_acc = accuracy_score(ground_truth, lk_predictions)
254
+ lk_f1 = f1_score(ground_truth, lk_predictions, zero_division=0)
255
+
256
+ print(f"\nFine-tuned Model: Accuracy={ft_acc*100:.2f}%, F1={ft_f1*100:.2f}%")
257
+ print(f"LiveKit Official: Accuracy={lk_acc*100:.2f}%, F1={lk_f1*100:.2f}%")
258
+
259
+ diff_acc = (lk_acc - ft_acc) * 100
260
+ diff_f1 = (lk_f1 - ft_f1) * 100
261
 
262
+ print(f"\nDifference (LiveKit - Fine-tuned):")
263
+ print(f" Accuracy: {'+' if diff_acc >= 0 else ''}{diff_acc:.2f}%")
264
+ print(f" F1 Score: {'+' if diff_f1 >= 0 else ''}{diff_f1:.2f}%")
265
 
266
  if __name__ == "__main__":
267
  main()