tiny-scribe / benchmark_extraction.py
Luigi's picture
Add extraction benchmark and results for small models
126dfa5
#!/usr/bin/env python3
"""
Benchmark script for testing extraction models individually.
Tests each model on a single small window to verify extraction works.
"""
import json
import time
from typing import Dict, List, Tuple, Optional
import sys
sys.path.insert(0, '/home/luigi/tiny-scribe')
from meeting_summarizer.extraction import (
_build_schema_extraction_prompt,
_build_reasoning_extraction_prompt,
_try_parse_extraction_json,
)
from llama_cpp import Llama
# Test window - small excerpt from transcripts/full.txt
TEST_WINDOW = """SPEAKER_02: 三星在去年Q3的時候已經告訴,今年,它所有的產出50會在AI跟Service上面。25在Mobile20在PCM那模組廠就是PCMOthers這一塊。所以26年的供給已經會比25年的供給在PCMOthers這塊少了15那再加上現在的狀況。所以我們覺得看起來應該缺到了8年,再加上現在昨天我不知道昨天你們看到SanDisk有一個這不是只有DDRName也是這樣Name你知道。
SPEAKER_03: 我想請教一下,以現在來講第四三一,對於就是說三星他們減產,或是甚至於後面可能會停產的。這樣的狀況跟凱力士也差不多的情況。
SPEAKER_02: 對於這塊,你們怎麼應?該是這樣說他們就算減產或停產,vivo是不會停的,顆粒會停,它的成品會停,但vivo是不會停的。"""
# Small models to test (< 2B parameters)
TEST_MODELS = [
{
"name": "Falcon-H1 100M",
"repo_id": "tiiuae/Falcon-H1-100M-Base-GGUF",
"filename": "*Q8_0.gguf",
"temperature": 0.1,
"supports_reasoning": False,
},
{
"name": "Gemma-3 270M",
"repo_id": "google/gemma-3-270m-it-GGUF",
"filename": "*Q4_K_M.gguf",
"temperature": 0.1,
"supports_reasoning": False,
},
{
"name": "Granite-4.0 350M",
"repo_id": "unsloth/granite-4.0-h-350m-GGUF",
"filename": "*Q8_0.gguf",
"temperature": 0.1,
"supports_reasoning": False,
},
{
"name": "BitCPM4 0.5B",
"repo_id": "openbmb/BitCPM4-0.5B-GGUF",
"filename": "*q4_0.gguf",
"temperature": 0.1,
"supports_reasoning": False,
},
{
"name": "Qwen3 0.6B",
"repo_id": "unsloth/Qwen3-0.6B-GGUF",
"filename": "*Q4_0.gguf",
"temperature": 0.1,
"supports_reasoning": True,
},
{
"name": "Granite 3.1 1B",
"repo_id": "bartowski/granite-3.1-1b-a400m-instruct-GGUF",
"filename": "*Q8_0.gguf",
"temperature": 0.1,
"supports_reasoning": False,
},
{
"name": "Falcon-H1 1.5B",
"repo_id": "unsloth/Falcon-H1-1.5B-Deep-Instruct-GGUF",
"filename": "*Q4_K_M.gguf",
"temperature": 0.1,
"supports_reasoning": False,
},
{
"name": "Qwen3 1.7B",
"repo_id": "unsloth/Qwen3-1.7B-GGUF",
"filename": "*Q4_0.gguf",
"temperature": 0.1,
"supports_reasoning": True,
},
]
def test_model(model_config: Dict) -> Dict:
"""Test a single model on the test window."""
print(f"\n{'='*60}")
print(f"Testing: {model_config['name']}")
print(f"{'='*60}")
result = {
"model": model_config['name'],
"repo_id": model_config['repo_id'],
"success": False,
"items_extracted": 0,
"response": "",
"error": "",
"time_seconds": 0,
}
try:
# Load model
print(f"Loading {model_config['name']}...")
start_time = time.time()
llm = Llama.from_pretrained(
repo_id=model_config['repo_id'],
filename=model_config['filename'],
n_ctx=4096,
verbose=False,
)
# Build prompt
supports_reasoning = model_config.get('supports_reasoning', False)
if supports_reasoning:
system_prompt = _build_reasoning_extraction_prompt('zh-TW')
else:
system_prompt = _build_schema_extraction_prompt('zh-TW')
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Transcript:\n\n{TEST_WINDOW}"}
]
# Run extraction
print("Running extraction...")
response = llm.create_chat_completion(
messages=messages,
max_tokens=1024,
temperature=model_config['temperature'],
top_p=0.9,
top_k=30,
)
result['time_seconds'] = time.time() - start_time
# Get response text
full_response = response["choices"][0]["message"]["content"]
result['response'] = full_response[:500] + "..." if len(full_response) > 500 else full_response
print(f"\nRaw response (first 300 chars):")
print(full_response[:300])
# Parse JSON
parsed = _try_parse_extraction_json(full_response, log_repair=True)
if parsed:
total_items = sum(len(v) for v in parsed.values())
result['success'] = True
result['items_extracted'] = total_items
result['parsed_data'] = parsed
print(f"\n✅ SUCCESS - Extracted {total_items} items:")
for key, items in parsed.items():
print(f" {key}: {len(items)} items")
for item in items[:2]: # Show first 2 items
print(f" - {item[:80]}...")
else:
result['error'] = "Failed to parse JSON"
print(f"\n❌ FAILED - Could not parse JSON")
except Exception as e:
result['error'] = str(e)
result['time_seconds'] = time.time() - start_time if 'start_time' in locals() else 0
print(f"\n❌ ERROR: {e}")
return result
def main():
"""Run benchmark on all test models."""
print("=" * 60)
print("EXTRACTION MODEL BENCHMARK")
print("=" * 60)
print(f"\nTest window size: {len(TEST_WINDOW)} characters")
print(f"Models to test: {len(TEST_MODELS)}")
results = []
for model_config in TEST_MODELS:
result = test_model(model_config)
results.append(result)
# Small delay between models
time.sleep(2)
# Summary
print("\n" + "=" * 60)
print("BENCHMARK SUMMARY")
print("=" * 60)
successful = [r for r in results if r['success']]
failed = [r for r in results if not r['success']]
print(f"\nSuccessful: {len(successful)}/{len(results)}")
print(f"Failed: {len(failed)}/{len(results)}")
print("\nSuccessful Models:")
for r in successful:
print(f" ✅ {r['model']}: {r['items_extracted']} items ({r['time_seconds']:.1f}s)")
print("\nFailed Models:")
for r in failed:
print(f" ❌ {r['model']}: {r['error']}")
# Save results
with open('extraction_benchmark_results.json', 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print("\nResults saved to: extraction_benchmark_results.json")
if __name__ == "__main__":
main()