Spaces:
Running
Running
File size: 7,171 Bytes
126dfa5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | #!/usr/bin/env python3
"""
Benchmark script for testing extraction models individually.
Tests each model on a single small window to verify extraction works.
"""
import json
import time
from typing import Dict, List, Tuple, Optional
import sys
sys.path.insert(0, '/home/luigi/tiny-scribe')
from meeting_summarizer.extraction import (
_build_schema_extraction_prompt,
_build_reasoning_extraction_prompt,
_try_parse_extraction_json,
)
from llama_cpp import Llama
# Test window - small excerpt from transcripts/full.txt
TEST_WINDOW = """SPEAKER_02: 三星在去年Q3的時候已經告訴,今年,它所有的產出50會在AI跟Service上面。25在Mobile20在PCM那模組廠就是PCMOthers這一塊。所以26年的供給已經會比25年的供給在PCMOthers這塊少了15那再加上現在的狀況。所以我們覺得看起來應該缺到了8年,再加上現在昨天我不知道昨天你們看到SanDisk有一個這不是只有DDRName也是這樣Name你知道。
SPEAKER_03: 我想請教一下,以現在來講第四三一,對於就是說三星他們減產,或是甚至於後面可能會停產的。這樣的狀況跟凱力士也差不多的情況。
SPEAKER_02: 對於這塊,你們怎麼應?該是這樣說他們就算減產或停產,vivo是不會停的,顆粒會停,它的成品會停,但vivo是不會停的。"""
# Small models to test (< 2B parameters)
TEST_MODELS = [
{
"name": "Falcon-H1 100M",
"repo_id": "tiiuae/Falcon-H1-100M-Base-GGUF",
"filename": "*Q8_0.gguf",
"temperature": 0.1,
"supports_reasoning": False,
},
{
"name": "Gemma-3 270M",
"repo_id": "google/gemma-3-270m-it-GGUF",
"filename": "*Q4_K_M.gguf",
"temperature": 0.1,
"supports_reasoning": False,
},
{
"name": "Granite-4.0 350M",
"repo_id": "unsloth/granite-4.0-h-350m-GGUF",
"filename": "*Q8_0.gguf",
"temperature": 0.1,
"supports_reasoning": False,
},
{
"name": "BitCPM4 0.5B",
"repo_id": "openbmb/BitCPM4-0.5B-GGUF",
"filename": "*q4_0.gguf",
"temperature": 0.1,
"supports_reasoning": False,
},
{
"name": "Qwen3 0.6B",
"repo_id": "unsloth/Qwen3-0.6B-GGUF",
"filename": "*Q4_0.gguf",
"temperature": 0.1,
"supports_reasoning": True,
},
{
"name": "Granite 3.1 1B",
"repo_id": "bartowski/granite-3.1-1b-a400m-instruct-GGUF",
"filename": "*Q8_0.gguf",
"temperature": 0.1,
"supports_reasoning": False,
},
{
"name": "Falcon-H1 1.5B",
"repo_id": "unsloth/Falcon-H1-1.5B-Deep-Instruct-GGUF",
"filename": "*Q4_K_M.gguf",
"temperature": 0.1,
"supports_reasoning": False,
},
{
"name": "Qwen3 1.7B",
"repo_id": "unsloth/Qwen3-1.7B-GGUF",
"filename": "*Q4_0.gguf",
"temperature": 0.1,
"supports_reasoning": True,
},
]
def test_model(model_config: Dict) -> Dict:
"""Test a single model on the test window."""
print(f"\n{'='*60}")
print(f"Testing: {model_config['name']}")
print(f"{'='*60}")
result = {
"model": model_config['name'],
"repo_id": model_config['repo_id'],
"success": False,
"items_extracted": 0,
"response": "",
"error": "",
"time_seconds": 0,
}
try:
# Load model
print(f"Loading {model_config['name']}...")
start_time = time.time()
llm = Llama.from_pretrained(
repo_id=model_config['repo_id'],
filename=model_config['filename'],
n_ctx=4096,
verbose=False,
)
# Build prompt
supports_reasoning = model_config.get('supports_reasoning', False)
if supports_reasoning:
system_prompt = _build_reasoning_extraction_prompt('zh-TW')
else:
system_prompt = _build_schema_extraction_prompt('zh-TW')
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Transcript:\n\n{TEST_WINDOW}"}
]
# Run extraction
print("Running extraction...")
response = llm.create_chat_completion(
messages=messages,
max_tokens=1024,
temperature=model_config['temperature'],
top_p=0.9,
top_k=30,
)
result['time_seconds'] = time.time() - start_time
# Get response text
full_response = response["choices"][0]["message"]["content"]
result['response'] = full_response[:500] + "..." if len(full_response) > 500 else full_response
print(f"\nRaw response (first 300 chars):")
print(full_response[:300])
# Parse JSON
parsed = _try_parse_extraction_json(full_response, log_repair=True)
if parsed:
total_items = sum(len(v) for v in parsed.values())
result['success'] = True
result['items_extracted'] = total_items
result['parsed_data'] = parsed
print(f"\n✅ SUCCESS - Extracted {total_items} items:")
for key, items in parsed.items():
print(f" {key}: {len(items)} items")
for item in items[:2]: # Show first 2 items
print(f" - {item[:80]}...")
else:
result['error'] = "Failed to parse JSON"
print(f"\n❌ FAILED - Could not parse JSON")
except Exception as e:
result['error'] = str(e)
result['time_seconds'] = time.time() - start_time if 'start_time' in locals() else 0
print(f"\n❌ ERROR: {e}")
return result
def main():
"""Run benchmark on all test models."""
print("=" * 60)
print("EXTRACTION MODEL BENCHMARK")
print("=" * 60)
print(f"\nTest window size: {len(TEST_WINDOW)} characters")
print(f"Models to test: {len(TEST_MODELS)}")
results = []
for model_config in TEST_MODELS:
result = test_model(model_config)
results.append(result)
# Small delay between models
time.sleep(2)
# Summary
print("\n" + "=" * 60)
print("BENCHMARK SUMMARY")
print("=" * 60)
successful = [r for r in results if r['success']]
failed = [r for r in results if not r['success']]
print(f"\nSuccessful: {len(successful)}/{len(results)}")
print(f"Failed: {len(failed)}/{len(results)}")
print("\nSuccessful Models:")
for r in successful:
print(f" ✅ {r['model']}: {r['items_extracted']} items ({r['time_seconds']:.1f}s)")
print("\nFailed Models:")
for r in failed:
print(f" ❌ {r['model']}: {r['error']}")
# Save results
with open('extraction_benchmark_results.json', 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print("\nResults saved to: extraction_benchmark_results.json")
if __name__ == "__main__":
main()
|