| |
| """Evaluate fine-tuned LoRA models on expanded bakeoff test set. |
| |
| 30 tests organized by tier: |
| Tier 1 (1-10): Core functionality — basic symbol subs, casing, paths, URLs |
| Tier 2 (11-20): Compound patterns — multi-symbol, git commands, env vars, pipes |
| Tier 3 (21-30): Known failure modes — dot dot, compound &&, numbers, fidelity |
| """ |
|
|
| import json |
| import time |
| from collections import defaultdict |
| from mlx_lm import load, generate |
|
|
| SYS = "Reconstruct the intended syntax from the dictated text. Output only the result." |
|
|
| tests = [ |
| |
| {"id": 1, "cat": "symbols-basic", "dictated": "hello dash world", "expected": "hello-world"}, |
| {"id": 2, "cat": "symbols-basic", "dictated": "hello underscore world", "expected": "hello_world"}, |
| {"id": 3, "cat": "symbols-compound", "dictated": "dash dash verbose", "expected": "--verbose"}, |
| {"id": 4, "cat": "symbols-compound", "dictated": "equals equals equals", "expected": "==="}, |
| {"id": 5, "cat": "casing", "dictated": "camel case get user name", "expected": "getUserName"}, |
| {"id": 6, "cat": "casing", "dictated": "snake case total tokens generated", "expected": "total_tokens_generated"}, |
| {"id": 7, "cat": "casing", "dictated": "kebab case dark mode toggle", "expected": "dark-mode-toggle"}, |
| {"id": 8, "cat": "quotes", "dictated": "quote hello world quote", "expected": "\"hello world\""}, |
| {"id": 9, "cat": "paths", "dictated": "tilde slash dev slash talkie", "expected": "~/dev/talkie"}, |
| {"id": 10, "cat": "urls", "dictated": "HTTPS colon slash slash GitHub dot com slash arach slash talkie", "expected": "https://github.com/arach/talkie"}, |
|
|
| |
| {"id": 11, "cat": "mixed", "dictated": "git commit dash M quote fix latency quote", "expected": "git commit -m \"fix latency\""}, |
| {"id": 12, "cat": "mixed", "dictated": "export all caps API underscore KEY equals quote my dash key dash one two three quote", "expected": "export API_KEY=\"my-key-123\""}, |
| {"id": 13, "cat": "mixed", "dictated": "shebang slash bin slash bash", "expected": "#!/bin/bash"}, |
| {"id": 14, "cat": "mixed", "dictated": "docker run dash D dash P eighty eighty colon eighty eighty nginx", "expected": "docker run -d -p 8080:8080 nginx"}, |
| {"id": 15, "cat": "mixed", "dictated": "func camel case view did load open paren close paren", "expected": "func viewDidLoad()"}, |
| {"id": 16, "cat": "mixed", "dictated": "import open brace camel case use state close brace from single quote react single quote", "expected": "import { useState } from 'react'"}, |
| {"id": 17, "cat": "mixed", "dictated": "LS dash L A pipe grep dot swift", "expected": "ls -la | grep .swift"}, |
| {"id": 18, "cat": "mixed", "dictated": "GH PR create dash dash title quote fix inference latency quote dash dash body quote added TTFT tracking and latency instrumentation quote", "expected": "gh pr create --title \"fix inference latency\" --body \"Added TTFT tracking and latency instrumentation\""}, |
| {"id": 19, "cat": "identifiers", "dictated": "dot E N V dot local", "expected": ".env.local"}, |
| {"id": 20, "cat": "operators", "dictated": "open paren X close paren fat arrow open brace close brace", "expected": "(x) => {}"}, |
|
|
| |
| {"id": 21, "cat": "symbols-compound", "dictated": "dot dot slash dev", "expected": "../dev"}, |
| {"id": 22, "cat": "paths", "dictated": "dot dot slash dot dot slash dot dot slash", "expected": "../../../"}, |
| {"id": 23, "cat": "paths", "dictated": "dot dot slash configs", "expected": "../configs"}, |
| {"id": 24, "cat": "operators", "dictated": "A and and B and and C", "expected": "a && b && c"}, |
| {"id": 25, "cat": "numbers", "dictated": "zero point seven", "expected": "0.7"}, |
| {"id": 26, "cat": "numbers", "dictated": "one two seven dot zero dot zero dot one", "expected": "127.0.0.1"}, |
| {"id": 27, "cat": "mixed", "dictated": "git add dash A and and git commit dash M quote fix typo quote and and git push", "expected": "git add -A && git commit -m \"fix typo\" && git push"}, |
| {"id": 28, "cat": "spacing", "dictated": "no space git hub", "expected": "github"}, |
| {"id": 29, "cat": "brackets", "dictated": "open bracket colon colon dash one close bracket", "expected": "[::-1]"}, |
| {"id": 30, "cat": "mixed", "dictated": "dash dash temp zero point seven dash dash tokens five twelve", "expected": "--temp 0.7 --tokens 512"}, |
| ] |
|
|
| configs = [ |
| { |
| "label": "QWEN 0.5B + LoRA v3", |
| "model": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", |
| "adapter": "/Users/arach/dev/talkie/datasets/finetune/adapters/qwen-0.5b-lora-v3", |
| }, |
| ] |
|
|
| all_results = {} |
|
|
| for cfg in configs: |
| print(f"\n{'='*60}") |
| print(f" {cfg['label']}") |
| print(f"{'='*60}\n") |
|
|
| model, tokenizer = load(cfg["model"], adapter_path=cfg["adapter"]) |
|
|
| results = [] |
| total_time = 0 |
| for t in tests: |
| messages = [ |
| {"role": "system", "content": SYS}, |
| {"role": "user", "content": t["dictated"]}, |
| ] |
|
|
| prompt = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
|
|
| start = time.perf_counter() |
| got = generate( |
| model, tokenizer, prompt=prompt, |
| max_tokens=80, verbose=False |
| ) |
| elapsed = time.perf_counter() - start |
| total_time += elapsed |
|
|
| got = got.strip() |
| match = got == t["expected"] |
| results.append({"id": t["id"], "cat": t["cat"], "match": match, "got": got}) |
|
|
| icon = "\033[32m✓\033[0m" if match else "\033[31m✗\033[0m" |
| print(f"{icon} {str(t['id']).rjust(2)}. \"{t['dictated']}\"") |
| print(f" expected: {t['expected']}") |
| if not match: |
| print(f" got: {got}") |
|
|
| correct = sum(1 for r in results if r["match"]) |
| pct = round(correct / len(results) * 100) |
| avg_ms = round(total_time / len(results) * 1000) |
| print(f"\nScore: {correct}/{len(results)} ({pct}%)") |
| print(f"Avg latency: {avg_ms}ms per inference") |
| all_results[cfg["label"]] = results |
|
|
| |
| print(f"\nPer-category accuracy:") |
| cat_results = defaultdict(lambda: {"correct": 0, "total": 0}) |
| for r in results: |
| cat_results[r["cat"]]["total"] += 1 |
| if r["match"]: |
| cat_results[r["cat"]]["correct"] += 1 |
|
|
| for cat in sorted(cat_results.keys()): |
| cr = cat_results[cat] |
| cat_pct = round(cr["correct"] / cr["total"] * 100) |
| bar = "â–ˆ" * cr["correct"] + "â–‘" * (cr["total"] - cr["correct"]) |
| print(f" {cat:<20} {cr['correct']}/{cr['total']} ({cat_pct:>3}%) {bar}") |
|
|
| |
| print(f"\nPer-tier accuracy:") |
| tiers = [ |
| ("Tier 1: Core", results[0:10]), |
| ("Tier 2: Compound", results[10:20]), |
| ("Tier 3: Failure modes", results[20:30]), |
| ] |
| for tier_name, tier_results in tiers: |
| tier_correct = sum(1 for r in tier_results if r["match"]) |
| tier_pct = round(tier_correct / len(tier_results) * 100) |
| print(f" {tier_name:<25} {tier_correct}/{len(tier_results)} ({tier_pct}%)") |
|
|
| del model, tokenizer |
|
|
| |
| print(f"\n{'='*60}") |
| print(" RESULTS SUMMARY") |
| print(f"{'='*60}\n") |
|
|
| for label, results in all_results.items(): |
| correct = sum(1 for r in results if r["match"]) |
| pct = round(correct / len(results) * 100) |
| print(f" {label}: {correct}/{len(results)} ({pct}%)") |
|
|
| print("\nBaselines:") |
| print(" LoRA v1 (240 train): 13/15 (87%)") |
| print(" LoRA v2 (474 train): 27/30 (90%)") |
| print(" Claude: 14/15 (93%)") |
|
|