tiny-scribe / benchmark_all_embeddings.py
Luigi's picture
Keep only Granite-107M embedding model
de2e4cb
#!/usr/bin/env python3
"""
Benchmark all 4 embedding models for Chinese deduplication
Tests: granite-107m, granite-278m, gemma-300m, qwen-600m
"""
import sys
import os
import time
import json
import numpy as np
# Add project path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from meeting_summarizer.extraction import EmbeddingModel, EMBEDDING_MODELS
# Test pairs: Chinese text that should/shouldn't match
TEST_PAIRS = [
# Exact duplicates (should match)
{"text1": "與三星討論Q3產能分配", "text2": "與三星討論Q3產能分配", "should_match": True, "type": "exact"},
{"text1": "優先供應大客戶浪潮", "text2": "優先供應大客戶浪潮", "should_match": True, "type": "exact"},
{"text1": "DDR4缺貨持續到2028年", "text2": "DDR4缺貨持續到2028年", "should_match": True, "type": "exact"},
# Different items (should NOT match)
{"text1": "與三星討論Q3產能分配", "text2": "確認LPDDR4供應數量", "should_match": False, "type": "different"},
{"text1": "優先供應大客戶浪潮", "text2": "與浪潮討論大客戶付款能力", "should_match": False, "type": "related"},
{"text1": "DDR4缺貨持續到2028年", "text2": "AI需求占全球產能45%", "should_match": False, "type": "different"},
{"text1": "Q2價格漲幅預估", "text2": "深圳測試場良率確認", "should_match": False, "type": "different"},
# Edge cases
{"text1": "ModuleHouse為嵌入式產品", "text2": "中興、創惟啟興也是重要客戶", "should_match": False, "type": "different"},
{"text1": "與三星討論Q3產能分配", "text2": "與三星討論Q3產能分配及價格", "should_match": False, "type": "extended"},
]
def test_embedding_model(model_key, threshold=0.85):
"""Test a single embedding model"""
config = EMBEDDING_MODELS[model_key]
print(f"\n{'='*70}")
print(f"Testing: {config['name']}")
print(f"Repo: {config['repo_id']}")
print(f"Dimensions: {config['embedding_dim']}")
print(f"{'='*70}")
try:
# Load model
start = time.time()
model = EmbeddingModel(model_key, n_threads=2)
msg = model.load()
load_time = time.time() - start
print(f"✓ Loaded in {load_time:.2f}s")
print(f" Status: {msg}")
results = {
"model_key": model_key,
"model_name": config['name'],
"dimensions": config['embedding_dim'],
"load_time": load_time,
"threshold": threshold,
"tests": [],
"correct": 0,
"false_positives": 0,
"false_negatives": 0,
}
# Test each pair
for i, test in enumerate(TEST_PAIRS, 1):
# Get embeddings
emb1 = model.embed(test['text1'])
emb2 = model.embed(test['text2'])
# Calculate cosine similarity (vectors are already normalized)
similarity = float(np.dot(emb1, emb2))
predicted = similarity >= threshold
# Check accuracy
is_correct = predicted == test['should_match']
if is_correct:
results['correct'] += 1
elif predicted and not test['should_match']:
results['false_positives'] += 1
else:
results['false_negatives'] += 1
# Store result
results['tests'].append({
"id": i,
"type": test['type'],
"similarity": float(similarity),
"predicted": predicted,
"expected": test['should_match'],
"correct": is_correct
})
status = "✅" if is_correct else "❌"
print(f"{status} Test {i} ({test['type'][:10]:<10}): sim={similarity:.3f}, "
f"match={predicted}, expected={test['should_match']}")
# Calculate accuracy
total = len(TEST_PAIRS)
results['accuracy'] = results['correct'] / total
print(f"\n📊 {config['name']} Results:")
print(f" Accuracy: {results['accuracy']:.1%} ({results['correct']}/{total})")
print(f" False Positives: {results['false_positives']}")
print(f" False Negatives: {results['false_negatives']}")
# Cleanup
model.unload()
return results
except Exception as e:
print(f"❌ Error: {str(e)}")
import traceback
traceback.print_exc()
return None
def main():
print("="*70)
print("EMBEDDING MODEL BENCHMARK - All 4 Models")
print("Chinese Transcript Deduplication")
print("="*70)
print(f"\nTest pairs: {len(TEST_PAIRS)}")
print(f"Similarity threshold: 0.85")
print(f"\nModels to test: {len(EMBEDDING_MODELS)}")
for key, cfg in EMBEDDING_MODELS.items():
print(f" • {cfg['name']} ({cfg['embedding_dim']}d)")
all_results = []
for model_key in EMBEDDING_MODELS.keys():
result = test_embedding_model(model_key)
if result:
all_results.append(result)
# Summary
print("\n" + "="*70)
print("FINAL COMPARISON")
print("="*70)
if all_results:
# Sort by accuracy
all_results.sort(key=lambda x: x['accuracy'], reverse=True)
print(f"\n{'Rank':<6}{'Model':<30}{'Dims':<8}{'Accuracy':<12}{'Load(s)':<10}")
print("-"*70)
for i, r in enumerate(all_results, 1):
print(f"{i:<6}{r['model_name']:<30}{r['dimensions']:<8}{r['accuracy']:.1%} {r['load_time']:.2f}")
# Save results
output_file = "embedding_benchmark_all_models.json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump({
"benchmark_info": {
"test_pairs": len(TEST_PAIRS),
"threshold": 0.85,
"models_tested": len(all_results)
},
"results": all_results
}, f, indent=2, ensure_ascii=False)
print(f"\n💾 Results saved to: {output_file}")
# Best model
best = all_results[0]
print(f"\n🏆 Best: {best['model_name']}")
print(f" Accuracy: {best['accuracy']:.1%}")
print(f" Dimensions: {best['dimensions']}")
else:
print("❌ No models successfully tested")
if __name__ == "__main__":
main()