| |
| """ |
| End-to-end test of the segment classifier + fused model chain. |
| |
| Simulates the full dictation pipeline without a microphone: |
| 1. Takes mixed dictation text (natural + protocol) |
| 2. Runs ProtocolSegmentClassifier to extract segments |
| 3. Sends protocol segments to TalkieInference via talkie-dev CLI |
| 4. Reassembles output and compares against expected |
| |
| Usage: |
| python3 test-full-chain.py # run all test cases |
| python3 test-full-chain.py --case 3 # run specific case |
| python3 test-full-chain.py --classifier-only # skip model, just test segmentation |
| """ |
|
|
| import json |
| import subprocess |
| import sys |
| import time |
| import importlib |
| from pathlib import Path |
| import numpy as np |
|
|
| |
| _mod = importlib.import_module("train-segment-classifier") |
| extract_features = _mod.extract_features |
| is_protocol_word = _mod.is_protocol_word |
| is_strong_protocol = _mod.is_strong_protocol |
|
|
| |
| MODEL_PATH = Path(__file__).parent / "segment-classifier-model.json" |
| with open(MODEL_PATH) as f: |
| _model = json.load(f) |
| WEIGHTS = np.array(_model["weights"]) |
| BIAS = _model["bias"] |
| THRESHOLD = _model.get("threshold", 0.5) |
|
|
| |
| MODEL_ID = "arach/qwen3-0.6b-bash-v1" |
| SYSTEM_PROMPT = "You are a dictation-to-bash converter. The input is a segment of dictated speech containing a command. Words like 'space', 'dash', 'dot', 'slash', 'colon', 'underscore', 'tilde', 'quote' are literal symbols. Number words like 'eight zero eight zero' mean '8080'. 'capital' means uppercase the next letter. 'all caps' means uppercase the next word. Ignore filler words like 'so', 'try', 'run', 'do', 'with'. Output ONLY the reconstructed command, nothing else." |
|
|
|
|
| |
|
|
| TEST_CASES = [ |
| { |
| "input": "I want to check the directory ls dash la", |
| "expected_natural": "I want to check the directory", |
| "expected_command": "ls -la", |
| "note": "Simple mixed: natural + short command" |
| }, |
| { |
| "input": "let me see what's running ps dash aux", |
| "expected_natural": "let me see what's running", |
| "expected_command": "ps -aux", |
| "note": "Casual intro + ps" |
| }, |
| { |
| "input": "git space push space dash u space origin space main", |
| "expected_natural": "", |
| "expected_command": "git push -u origin main", |
| "note": "Pure protocol, no natural speech" |
| }, |
| { |
| "input": "the meeting is at three o'clock tomorrow afternoon", |
| "expected_natural": "the meeting is at three o'clock tomorrow afternoon", |
| "expected_command": "", |
| "note": "Pure natural, should NOT trigger model" |
| }, |
| { |
| "input": "can you run git space clone space https colon slash slash github dot com slash user slash repo for me", |
| "expected_natural": "can you run ... for me", |
| "expected_command": "git clone https://github.com/user/repo", |
| "note": "Command sandwiched in natural speech" |
| }, |
| { |
| "input": "alright first do cd space tilde slash projects and then run npm space install", |
| "expected_natural": "alright first do ... and then run", |
| "expected_command": "cd ~/projects | npm install", |
| "note": "Two commands with natural bridge" |
| }, |
| { |
| "input": "the quarterly report shows revenue increased by twelve percent year over year", |
| "expected_natural": "the quarterly report shows revenue increased by twelve percent year over year", |
| "expected_command": "", |
| "note": "Business English with 'percent' β should NOT trigger" |
| }, |
| { |
| "input": "ssh space dash i space tilde slash dot ssh slash id underscore rsa space root at one nine two dot one six eight dot one dot one", |
| "expected_natural": "", |
| "expected_command": "ssh -i ~/.ssh/id_rsa root@192.168.1.1", |
| "note": "Complex SSH command with IP" |
| }, |
| { |
| "input": "I think the problem is the firewall so try iptables space dash capital A space INPUT space dash p space tcp space dash dash dport space eight zero eight zero space dash j space ACCEPT", |
| "expected_natural": "I think the problem is the firewall so try", |
| "expected_command": "iptables -A INPUT -p tcp --dport 8080 -j ACCEPT", |
| "note": "Diagnostic reasoning + iptables" |
| }, |
| { |
| "input": "according to the handbook all new employees must complete training within two weeks", |
| "expected_natural": "according to the handbook all new employees must complete training within two weeks", |
| "expected_command": "", |
| "note": "HR policy with 'all', 'new', 'two' β pure natural" |
| }, |
| { |
| "input": "to find large files run find space dot space dash type space f space dash size space plus one hundred capital M", |
| "expected_natural": "to find large files run", |
| "expected_command": "find . -type f -size +100M", |
| "note": "Instruction + find command" |
| }, |
| { |
| "input": "curl space dash capital X space all caps POST space dash capital H space quote Content dash Type colon space application slash json quote space https colon slash slash api dot example dot com slash users", |
| "expected_natural": "", |
| "expected_command": "curl -X POST -H \"Content-Type: application/json\" https://api.example.com/users", |
| "note": "Complex curl β pure protocol" |
| }, |
| ] |
|
|
|
|
| def classify_words(words, weights, bias, threshold=0.5): |
| """Classify each word. Returns list of (word, prob, is_protocol).""" |
| results = [] |
| n = len(words) |
| for i, word in enumerate(words): |
| ctx_start = max(0, i - 2) |
| ctx_end = min(n, i + 3) |
| context = words[ctx_start:ctx_end] |
| features = extract_features(word, context, i, n) |
| logit = float(np.dot(features, weights) + bias) |
| prob = 1.0 / (1.0 + np.exp(-logit)) |
| results.append((word, prob, prob >= threshold)) |
| return results |
|
|
|
|
| def extract_segments(text, weights, bias, threshold=0.5, expansion=2): |
| """Run classifier + expansion, return segments.""" |
| words = text.split() |
| if not words: |
| return [] |
|
|
| results = classify_words(words, weights, bias, threshold) |
| anchors = [r[2] for r in results] |
|
|
| |
| n = len(words) |
| expanded = [False] * n |
| for i in range(n): |
| if anchors[i]: |
| for j in range(max(0, i - expansion), min(n, i + expansion + 1)): |
| expanded[j] = True |
|
|
| |
| segments = [] |
| current_kind = "protocol" if expanded[0] else "natural" |
| current_words = [words[0]] |
|
|
| for i in range(1, n): |
| kind = "protocol" if expanded[i] else "natural" |
| if kind == current_kind: |
| current_words.append(words[i]) |
| else: |
| segments.append({"kind": current_kind, "text": " ".join(current_words)}) |
| current_kind = kind |
| current_words = [words[i]] |
| segments.append({"kind": current_kind, "text": " ".join(current_words)}) |
|
|
| return segments |
|
|
|
|
| def call_inference(prompt, model_id=MODEL_ID, system_prompt=SYSTEM_PROMPT, timeout=30): |
| """Call talkie-dev inference generate and return the output.""" |
| cmd = [ |
| "talkie-dev", "inference", "generate", prompt, |
| "--model", model_id, |
| "--system", system_prompt, |
| "--temp", "0", |
| "--tokens", "120", |
| "--json", |
| ] |
| try: |
| result = subprocess.run( |
| cmd, |
| capture_output=True, |
| text=True, |
| timeout=timeout, |
| cwd=str(Path(__file__).parent.parent) |
| ) |
| if result.returncode != 0: |
| return None, f"exit {result.returncode}: {result.stderr.strip()}" |
| |
| try: |
| data = json.loads(result.stdout.strip()) |
| text = data.get("text", "").strip() |
| |
| import re |
| text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip() |
| |
| text = re.sub(r'<think>.*', '', text, flags=re.DOTALL).strip() |
| return text, None |
| except json.JSONDecodeError: |
| |
| lines = result.stdout.strip().split("\n") |
| return lines[0].strip() if lines else "", None |
| except subprocess.TimeoutExpired: |
| return None, "timeout" |
| except Exception as e: |
| return None, str(e) |
|
|
|
|
| def run_test(case, idx, classifier_only=False): |
| """Run a single test case through the full chain.""" |
| text = case["input"] |
| note = case["note"] |
|
|
| print(f"\n{'='*70}") |
| print(f" TEST {idx}: {note}") |
| print(f"{'='*70}") |
| print(f" INPUT: {text}") |
|
|
| |
| t0 = time.time() |
| segments = extract_segments(text, WEIGHTS, BIAS) |
| seg_ms = (time.time() - t0) * 1000 |
|
|
| protocol_segments = [s for s in segments if s["kind"] == "protocol"] |
| natural_segments = [s for s in segments if s["kind"] == "natural"] |
|
|
| print(f"\n SEGMENTS ({seg_ms:.1f}ms):") |
| for s in segments: |
| marker = "\033[91mPROTO\033[0m" if s["kind"] == "protocol" else "\033[92mNATRL\033[0m" |
| print(f" [{marker}] {s['text']}") |
|
|
| has_protocol = len(protocol_segments) > 0 |
| expected_has_protocol = bool(case["expected_command"]) |
|
|
| if has_protocol != expected_has_protocol: |
| if has_protocol: |
| print(f"\n \033[91mFALSE POSITIVE: detected protocol in pure natural text!\033[0m") |
| else: |
| print(f"\n \033[93mFALSE NEGATIVE: missed protocol segment!\033[0m") |
| return {"pass": False, "reason": "segmentation"} |
|
|
| if not has_protocol: |
| print(f"\n \033[92mPASS\033[0m β correctly identified as pure natural speech (no model call)") |
| return {"pass": True, "reason": "natural_passthrough"} |
|
|
| if classifier_only: |
| print(f"\n [classifier-only mode, skipping model inference]") |
| return {"pass": True, "reason": "classifier_only"} |
|
|
| |
| print(f"\n MODEL INFERENCE:") |
| result_parts = [] |
| total_ms = 0 |
|
|
| for segment in segments: |
| if segment["kind"] == "natural": |
| result_parts.append(segment["text"]) |
| continue |
|
|
| t0 = time.time() |
| model_output, error = call_inference(segment["text"]) |
| infer_ms = (time.time() - t0) * 1000 |
| total_ms += infer_ms |
|
|
| if error: |
| print(f" \033[91mERROR\033[0m: {error}") |
| result_parts.append(segment["text"]) |
| else: |
| print(f" IN: {segment['text']}") |
| print(f" OUT: {model_output} ({infer_ms:.0f}ms)") |
| result_parts.append(model_output or segment["text"]) |
|
|
| |
| assembled = " ".join(result_parts).strip() |
| expected = case["expected_command"] |
|
|
| print(f"\n ASSEMBLED: {assembled}") |
| print(f" EXPECTED: {expected}") |
| print(f" LATENCY: {total_ms:.0f}ms (model only)") |
|
|
| |
| |
| if expected and expected in assembled: |
| print(f"\n \033[92mPASS\033[0m β command correctly reconstructed") |
| return {"pass": True, "reason": "match"} |
| elif expected: |
| print(f"\n \033[93mCHECK\033[0m β command may differ, review above") |
| return {"pass": None, "reason": "fuzzy"} |
| else: |
| print(f"\n \033[92mPASS\033[0m") |
| return {"pass": True, "reason": "no_command"} |
|
|
|
|
| def main(): |
| import argparse |
| parser = argparse.ArgumentParser(description="Test segment classifier + model chain") |
| parser.add_argument("--case", type=int, help="Run specific test case (1-indexed)") |
| parser.add_argument("--classifier-only", action="store_true", help="Skip model inference, just test segmentation") |
| args = parser.parse_args() |
|
|
| print("=" * 70) |
| print(" FULL CHAIN TEST: Classifier β Segmentation β Model β Assembly") |
| if args.classifier_only: |
| print(" MODE: classifier-only (no model inference)") |
| print("=" * 70) |
|
|
| cases = TEST_CASES |
| if args.case: |
| if 1 <= args.case <= len(TEST_CASES): |
| cases = [TEST_CASES[args.case - 1]] |
| else: |
| print(f"Invalid case number. Must be 1-{len(TEST_CASES)}") |
| sys.exit(1) |
|
|
| results = [] |
| for i, case in enumerate(cases): |
| idx = args.case if args.case else i + 1 |
| result = run_test(case, idx, classifier_only=args.classifier_only) |
| results.append(result) |
|
|
| |
| print(f"\n{'='*70}") |
| print(f" SUMMARY") |
| print(f"{'='*70}") |
| passed = sum(1 for r in results if r["pass"] is True) |
| failed = sum(1 for r in results if r["pass"] is False) |
| check = sum(1 for r in results if r["pass"] is None) |
| print(f" Passed: {passed}/{len(results)}") |
| if failed: print(f" Failed: {failed}") |
| if check: print(f" Check: {check} (review manually)") |
|
|
| |
| natural_ok = sum(1 for r in results if r["reason"] == "natural_passthrough") |
| match_ok = sum(1 for r in results if r["reason"] == "match") |
| if natural_ok: print(f" Natural passthrough: {natural_ok}") |
| if match_ok: print(f" Command reconstructed: {match_ok}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|