| |
| """ |
| Train a per-word protocol segment classifier. |
| |
| Purpose: Given a mixed dictation like "I want to check the directory ls dash la", |
| classify each word as protocol (1) or natural (0), so we can extract segments |
| to send to the fused Qwen3 model. |
| |
| Architecture: |
| - Per-word logistic regression with contextual features |
| - Context window of Β±2 words |
| - Output: weights + bias β port to Swift as ProtocolSegmentClassifier |
| |
| Training data: Synthetic mixed dictations combining natural speech fragments |
| with protocol command dictations from eval-fuzzy.json. |
| """ |
|
|
| import json |
| import numpy as np |
| from pathlib import Path |
| import re |
|
|
| |
| |
| |
|
|
| STRONG_PROTOCOL = { |
| |
| "dash", "dot", "slash", "pipe", "tilde", "hash", "dollar", |
| "caret", "ampersand", "equals", "underscore", "backslash", |
| "backtick", "semicolon", "colon", |
| |
| "minus", "hyphen", "asterisk", "hashtag", |
| |
| "paren", "brace", "bracket", "parenthesis", "curly", |
| |
| "capital", "caps", "camel", "snake", "pascal", "kebab", "screaming", |
| |
| "space", |
| |
| "redirect", "append", |
| } |
|
|
| WEAK_PROTOCOL = { |
| |
| "at", "star", "bang", "exclamation", "question", "comma", "quote", |
| "period", "plus", "percent", |
| |
| "single", "open", "close", "angle", "forward", "back", "sign", |
| "double", "mark", "than", "less", "new", "line", "all", "case", |
| |
| |
| } |
|
|
| PROTOCOL_VOCAB = STRONG_PROTOCOL | WEAK_PROTOCOL |
|
|
| |
| EXPANDED_SYMBOLS = { |
| "-", ".", "/", "\\", "_", "|", "~", "@", "#", "*", |
| "+", "=", ":", ";", "&", "%", "^", "!", "?", "`", |
| "$", "<", ">", "--", "&&", "||", |
| } |
|
|
| SYNTAX_CHARS = set("-./\\_|~@#:=") |
|
|
| |
|
|
| NATURAL_PHRASES = [ |
| "I want to check the directory", |
| "can you run this command for me", |
| "let's see what happens when we", |
| "okay so basically I need to", |
| "the next thing I want to do is", |
| "alright let me try", |
| "I think we should also look at", |
| "and then after that", |
| "go ahead and type out", |
| "so the idea is to", |
| "I was thinking maybe we could", |
| "hold on let me think about this", |
| "what if we try something different", |
| "actually no wait", |
| "right so the problem is", |
| "I need to figure out why", |
| "let's debug this real quick", |
| "the output should show us", |
| "and that would give us", |
| "basically what I'm trying to do is", |
| "so first we need to", |
| "and then we can see if", |
| "alright this is important", |
| "I want to emphasize", |
| "the reason I'm doing this is", |
| "so we did have a classifier", |
| "wasn't that part of the mission", |
| "like if we run all of a long dictation", |
| "I mean we'd be lucky to get", |
| "there's some stuff happening", |
| "that was not at all what I said", |
| "my last dictation was", |
| "we're not trying to do any list extraction", |
| "can you give me a few different test commands", |
| "I don't really understand", |
| "we should be able to fix that", |
| "alright let's do a proper dictation", |
| "the whole point is that", |
| "I think the idea would be", |
| "we want to build a new one", |
| "so yeah I think like we want to", |
| "and that would be a cleanup job", |
| "with maybe a bigger model", |
| "right so I think the plan is", |
| "let me explain what's going on", |
| "this is getting really interesting", |
| "okay perfect that works", |
| "no that's not right", |
| "wait what happened there", |
| "I'm gonna try again", |
| ] |
|
|
| |
|
|
| PROTOCOL_COMMANDS = [ |
| |
| "ls dash la", |
| "git dash dash help", |
| "cd dot dot", |
| "rm dash rf", |
| "mkdir dash p", |
| "cat slash etc slash hosts", |
| "chmod seven five five", |
| "grep dash r", |
| "find dot dash name", |
| "ps aux", |
| "kill dash nine", |
| |
| "git space push space dash u space origin space main", |
| "docker space run space dash dash rm space dash p space eight zero eight zero colon eight zero", |
| "npm space install space dash capital D space typescript", |
| "ssh space dash i space tilde slash dot ssh slash id underscore rsa", |
| "curl space dash capital X space all caps POST", |
| "kubectl space get space pods space dash n space kube dash system", |
| "brew space install space dash dash cask", |
| "pip space install space dash r space requirements dot txt", |
| "python space dash m space pytest space dash v", |
| "cargo space build space dash dash release", |
| |
| "rsync space dash avz space dash e space ssh space dot slash dist slash", |
| "export space all caps DATABASE underscore URL equals quote postgres colon slash slash admin", |
| "redis dash cli space dash h space one two seven dot zero dot zero dot one", |
| "terraform space plan space dash var dash file equals production dot tfvars", |
| |
| "dash dash help", |
| "dash la", |
| "dot slash", |
| "slash usr slash local slash bin", |
| "tilde slash dot config", |
| "star dot js", |
| "dollar sign open paren", |
| |
| "talkie dash dev", |
| "visual dash studio dash code", |
| "docker dash compose", |
| "kube dash system", |
| "type dash script", |
| ] |
|
|
| |
| |
| PURE_NATURAL = [ |
| "I was just thinking about how to approach this problem differently", |
| "the meeting is at three o'clock tomorrow afternoon", |
| "can you send me the link to that document", |
| "I'll get back to you on that later today", |
| "the project deadline is next Friday", |
| "we should probably discuss this with the rest of the team", |
| "that's a great idea I hadn't thought of that", |
| "let me know if you need anything else from me", |
| "I think the best approach would be to start from scratch", |
| "have you seen the latest updates to the design", |
| "the performance numbers look really good", |
| "we might need to reconsider our strategy here", |
| "I'll set up a call with the engineering team", |
| "the documentation needs to be updated before release", |
| "this reminds me of a similar issue we had last month", |
| |
| "I arrived at the office at nine this morning", |
| "there are three new cases to review today", |
| "we need to go back and sign the contract", |
| "one of the most important things is to stay open", |
| "all of the line items need to be double checked", |
| "the new sign at the front of the building looks great", |
| "I need to close out all the open tickets by five", |
| "she gave us a star review which was really nice", |
| "the question mark at the end was confusing", |
| "he made a plus sized version of the original", |
| "we should open a new line of inquiry", |
| "there were less than ten people at the event", |
| "I'm going to mark this case as resolved", |
| "go back to the previous page and forward it to me", |
| "we had one single issue in all of last quarter", |
| "the angle of the photo makes it look close to the sign", |
| "I'll be back at four thirty with the new draft", |
| "that's a great point about the new direction", |
| "we need all hands on deck for this one", |
| "the bottom line is we need more time", |
| "let me quote you on that", |
| "he was at a loss for words", |
| "the new employee starts on day one next week", |
| "all in all it was a pretty good quarter", |
| "the case study shows a ten percent improvement", |
| "I need to sign off on this before close of business", |
| "the angle was all wrong for that particular shot", |
| ] |
|
|
|
|
| def is_protocol_word(word: str) -> bool: |
| """Check if a word is a protocol word.""" |
| lower = word.lower().strip(".,!?;:'\"") |
| if lower in PROTOCOL_VOCAB: |
| return True |
| if word.strip() in EXPANDED_SYMBOLS: |
| return True |
| |
| if any(c in SYNTAX_CHARS for c in lower): |
| if "'" not in word and "\u2019" not in word: |
| |
| if not (word.endswith(".") and "." not in word[:-1]): |
| return True |
| return False |
|
|
|
|
| def label_words(words: list[str], labels: list[int]) -> list[dict]: |
| """Create labeled word entries with context features.""" |
| entries = [] |
| n = len(words) |
| for i, (word, label) in enumerate(zip(words, labels)): |
| entry = { |
| "word": word, |
| "label": label, |
| "position": i, |
| "total_words": n, |
| } |
| entries.append(entry) |
| return entries |
|
|
|
|
| def generate_mixed_examples(n_examples: int = 500, seed: int = 42) -> list[dict]: |
| """Generate synthetic mixed dictations with word-level labels.""" |
| rng = np.random.RandomState(seed) |
| all_labeled = [] |
|
|
| |
| for _ in range(n_examples // 3): |
| nat1 = rng.choice(NATURAL_PHRASES) |
| cmd = rng.choice(PROTOCOL_COMMANDS) |
| nat2 = rng.choice(NATURAL_PHRASES) |
|
|
| nat1_words = nat1.split() |
| cmd_words = cmd.split() |
| nat2_words = nat2.split() |
|
|
| words = nat1_words + cmd_words + nat2_words |
| labels = [0] * len(nat1_words) + [1] * len(cmd_words) + [0] * len(nat2_words) |
| all_labeled.extend(label_words(words, labels)) |
|
|
| |
| for _ in range(n_examples // 4): |
| cmd = rng.choice(PROTOCOL_COMMANDS) |
| words = cmd.split() |
| labels = [1] * len(words) |
| all_labeled.extend(label_words(words, labels)) |
|
|
| |
| for _ in range(n_examples // 4): |
| nat = rng.choice(PURE_NATURAL + NATURAL_PHRASES) |
| words = nat.split() |
| labels = [0] * len(words) |
| all_labeled.extend(label_words(words, labels)) |
|
|
| |
| for _ in range(n_examples // 6): |
| nat = rng.choice(NATURAL_PHRASES) |
| cmd = rng.choice(PROTOCOL_COMMANDS) |
| nat_words = nat.split() |
| cmd_words = cmd.split() |
| words = nat_words + cmd_words |
| labels = [0] * len(nat_words) + [1] * len(cmd_words) |
| all_labeled.extend(label_words(words, labels)) |
|
|
| |
| for _ in range(n_examples // 6): |
| cmd = rng.choice(PROTOCOL_COMMANDS) |
| nat = rng.choice(NATURAL_PHRASES) |
| cmd_words = cmd.split() |
| nat_words = nat.split() |
| words = cmd_words + nat_words |
| labels = [1] * len(cmd_words) + [0] * len(nat_words) |
| all_labeled.extend(label_words(words, labels)) |
|
|
| return all_labeled |
|
|
|
|
| def is_strong_protocol(word: str) -> bool: |
| """Check if a word is an unambiguous protocol word.""" |
| lower = word.lower().strip(".,!?;:'\"") |
| if lower in STRONG_PROTOCOL: |
| return True |
| if word.strip() in EXPANDED_SYMBOLS: |
| return True |
| return False |
|
|
|
|
| def extract_features(word: str, context: list[str], position: int, total: int) -> list[float]: |
| """ |
| Extract features for a single word with its context. |
| |
| Features (14 total): |
| 0. is_strong_protocol β word is an unambiguous protocol word (dash, dot, slash...) |
| 1. is_weak_protocol β word is an ambiguous protocol word (at, three, one...) |
| 2. is_expanded_symbol β word is an expanded symbol (-, ., /) |
| 3. has_syntax_chars β word contains syntax characters |
| 4. word_length_norm β word length / 10 (normalized) |
| 5. is_short_word β len <= 3 (commands: ls, cd, rm) |
| 6. context_strong_density β fraction of Β±2 context words that are STRONG protocol |
| 7. context_any_density β fraction of Β±2 context words that are any protocol |
| 8. left_is_strong β immediate left neighbor is strong protocol |
| 9. right_is_strong β immediate right neighbor is strong protocol |
| 10. is_number_like β word looks like a number or number word |
| 11. strong_neighbor_count β count of strong protocol words in Β±2 window |
| 12. is_all_lower β all lowercase |
| 13. position_ratio β position / total |
| """ |
| lower = word.lower().strip(".,!?;:'\"") |
| stripped = word.strip() |
|
|
| |
| f_strong = 1.0 if lower in STRONG_PROTOCOL else 0.0 |
|
|
| |
| f_weak = 1.0 if lower in WEAK_PROTOCOL else 0.0 |
|
|
| |
| f_symbol = 1.0 if stripped in EXPANDED_SYMBOLS else 0.0 |
|
|
| |
| f_syntax = 0.0 |
| if any(c in SYNTAX_CHARS for c in lower): |
| if "'" not in word and "\u2019" not in word: |
| if not (word.endswith(".") and "." not in word[:-1]): |
| f_syntax = 1.0 |
|
|
| |
| f_len = len(word) / 10.0 |
|
|
| |
| f_short = 1.0 if len(lower) <= 3 else 0.0 |
|
|
| |
| ctx_strong = sum(1 for w in context if is_strong_protocol(w)) |
| f_ctx_strong = ctx_strong / max(len(context), 1) |
|
|
| |
| ctx_any = sum(1 for w in context if is_protocol_word(w)) |
| f_ctx_any = ctx_any / max(len(context), 1) |
|
|
| |
| f_left = 0.0 |
| if position > 0 and len(context) > 0: |
| ctx_center = min(position, 2) |
| if ctx_center > 0 and ctx_center - 1 < len(context): |
| f_left = 1.0 if is_strong_protocol(context[ctx_center - 1]) else 0.0 |
|
|
| |
| f_right = 0.0 |
| if position < total - 1 and len(context) > 0: |
| ctx_center = min(position, 2) |
| if ctx_center + 1 < len(context): |
| f_right = 1.0 if is_strong_protocol(context[ctx_center + 1]) else 0.0 |
|
|
| |
| number_words = {"zero", "one", "two", "three", "four", "five", "six", |
| "seven", "eight", "nine", "ten"} |
| f_number = 1.0 if (lower in number_words or lower.isdigit()) else 0.0 |
|
|
| |
| f_strong_neighbors = float(ctx_strong) |
|
|
| |
| f_lower = 1.0 if word.isalpha() and word == word.lower() else 0.0 |
|
|
| |
| f_pos = position / max(total - 1, 1) |
|
|
| return [ |
| f_strong, |
| f_weak, |
| f_symbol, |
| f_syntax, |
| f_len, |
| f_short, |
| f_ctx_strong, |
| f_ctx_any, |
| f_left, |
| f_right, |
| f_number, |
| f_strong_neighbors, |
| f_lower, |
| f_pos, |
| ] |
|
|
|
|
| def build_dataset(labeled_words: list[dict], all_words_by_sentence=None): |
| """Build feature matrix and label vector from labeled words.""" |
| |
| sentences = [] |
| current = [] |
| for entry in labeled_words: |
| if current and (entry["position"] == 0 or entry["total_words"] != current[0]["total_words"]): |
| sentences.append(current) |
| current = [] |
| current.append(entry) |
| if current: |
| sentences.append(current) |
|
|
| X = [] |
| y = [] |
|
|
| for sentence in sentences: |
| words = [e["word"] for e in sentence] |
| for entry in sentence: |
| pos = entry["position"] |
| |
| ctx_start = max(0, pos - 2) |
| ctx_end = min(len(words), pos + 3) |
| context = words[ctx_start:ctx_end] |
|
|
| features = extract_features( |
| entry["word"], context, pos, entry["total_words"] |
| ) |
| X.append(features) |
| y.append(entry["label"]) |
|
|
| return np.array(X), np.array(y) |
|
|
|
|
| def train_logistic_regression(X, y, lr=0.05, lambda_reg=0.1, max_epochs=500, tol=1e-6): |
| """Train logistic regression with L2 regularization via batch gradient descent.""" |
| n_samples, n_features = X.shape |
| weights = np.zeros(n_features) |
| bias = 0.0 |
|
|
| prev_loss = float("inf") |
|
|
| for epoch in range(max_epochs): |
| |
| logits = X @ weights + bias |
| probs = 1.0 / (1.0 + np.exp(-np.clip(logits, -500, 500))) |
|
|
| |
| eps = 1e-15 |
| loss = -np.mean( |
| y * np.log(probs + eps) + (1 - y) * np.log(1 - probs + eps) |
| ) + 0.5 * lambda_reg * np.sum(weights ** 2) |
|
|
| |
| if abs(prev_loss - loss) < tol: |
| print(f" Converged at epoch {epoch}, loss={loss:.6f}") |
| break |
| prev_loss = loss |
|
|
| |
| error = probs - y |
| grad_w = (X.T @ error) / n_samples + lambda_reg * weights |
| grad_b = np.mean(error) |
|
|
| |
| weights -= lr * grad_w |
| bias -= lr * grad_b |
|
|
| if epoch % 50 == 0: |
| acc = np.mean((probs >= 0.5) == y) |
| print(f" Epoch {epoch}: loss={loss:.4f}, acc={acc:.3f}") |
|
|
| |
| probs = 1.0 / (1.0 + np.exp(-np.clip(X @ weights + bias, -500, 500))) |
| acc = np.mean((probs >= 0.5) == y) |
| print(f" Final: loss={prev_loss:.4f}, acc={acc:.3f}") |
|
|
| return weights, bias |
|
|
|
|
| def evaluate(X, y, weights, bias, threshold=0.5): |
| """Evaluate classifier and print metrics.""" |
| logits = X @ weights + bias |
| probs = 1.0 / (1.0 + np.exp(-np.clip(logits, -500, 500))) |
| preds = (probs >= threshold).astype(int) |
|
|
| acc = np.mean(preds == y) |
| tp = np.sum((preds == 1) & (y == 1)) |
| fp = np.sum((preds == 1) & (y == 0)) |
| fn = np.sum((preds == 0) & (y == 1)) |
| tn = np.sum((preds == 0) & (y == 0)) |
|
|
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0 |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0 |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 |
|
|
| print(f"\n Accuracy: {acc:.3f} ({int(acc * len(y))}/{len(y)})") |
| print(f" Precision: {precision:.3f}") |
| print(f" Recall: {recall:.3f}") |
| print(f" F1: {f1:.3f}") |
| print(f" Confusion: TP={tp} FP={fp} FN={fn} TN={tn}") |
|
|
| return acc, precision, recall, f1 |
|
|
|
|
| def test_mixed_examples(weights, bias): |
| """Test on hand-crafted mixed dictation examples.""" |
| test_cases = [ |
| |
| ("I want to check the directory ls dash la", {"ls", "dash", "la"}), |
| ("can you run git dash dash help for me", {"git", "dash", "help"}), |
| ("alright let me try talkie dash dev dash dash help", {"talkie", "dash", "dev", "help"}), |
| ("the meeting is at three o'clock tomorrow afternoon", set()), |
| ("I think we should also look at cd dot dot slash src", {"cd", "dot", "slash", "src"}), |
| ("so basically npm space install space dash capital D", {"npm", "space", "install", "dash", "capital", "D"}), |
| ("let me know if you need anything else from me", set()), |
| ("okay go ahead and type out chmod seven five five", {"chmod", "seven", "five"}), |
| ("the performance numbers look really good", set()), |
| ("ssh dash i tilde slash dot ssh slash id underscore rsa", {"ssh", "dash", "i", "tilde", "slash", "dot", "id", "underscore", "rsa"}), |
| ] |
|
|
| print("\nβββ Test Cases βββ") |
| for text, expected_protocol in test_cases: |
| words = text.split() |
| preds = [] |
| for i, word in enumerate(words): |
| ctx_start = max(0, i - 2) |
| ctx_end = min(len(words), i + 3) |
| context = words[ctx_start:ctx_end] |
| features = extract_features(word, context, i, len(words)) |
| logit = np.dot(features, weights) + bias |
| prob = 1.0 / (1.0 + np.exp(-logit)) |
| preds.append((word, prob, prob >= 0.5)) |
|
|
| detected = {w for w, p, is_proto in preds if is_proto} |
| natural = {w for w, p, is_proto in preds if not is_proto} |
|
|
| |
| colored = [] |
| for w, p, is_proto in preds: |
| if is_proto: |
| colored.append(f"\033[91m{w}\033[0m") |
| else: |
| colored.append(f"\033[92m{w}\033[0m") |
| print(f" {' '.join(colored)}") |
|
|
| |
| if expected_protocol: |
| correct = detected & expected_protocol |
| missed = expected_protocol - detected |
| false_pos = detected - expected_protocol |
| if missed: |
| print(f" MISSED: {missed}") |
| if false_pos: |
| print(f" FALSE+: {false_pos}") |
| elif detected: |
| print(f" FALSE+: {detected}") |
|
|
|
|
| def export_model(weights, bias, feature_names, output_path): |
| """Export model as JSON for porting to Swift.""" |
| model = { |
| "classifier": "ProtocolSegmentClassifier", |
| "description": "Per-word logistic regression for protocol segment detection", |
| "features": feature_names, |
| "weights": weights.tolist(), |
| "bias": float(bias), |
| "threshold": 0.5, |
| } |
| with open(output_path, "w") as f: |
| json.dump(model, f, indent=2) |
| print(f"\nModel exported to {output_path}") |
|
|
|
|
| FEATURE_NAMES = [ |
| "is_strong_protocol", |
| "is_weak_protocol", |
| "is_expanded_symbol", |
| "has_syntax_chars", |
| "word_length_norm", |
| "is_short_word", |
| "context_strong_density", |
| "context_any_density", |
| "left_is_strong", |
| "right_is_strong", |
| "is_number_like", |
| "strong_neighbor_count", |
| "is_all_lower", |
| "position_ratio", |
| ] |
|
|
|
|
| def main(): |
| print("=" * 60) |
| print("Protocol Segment Classifier Training") |
| print("=" * 60) |
|
|
| |
| print("\n1. Generating training data...") |
| labeled = generate_mixed_examples(n_examples=600, seed=42) |
| print(f" {len(labeled)} labeled words generated") |
|
|
| |
| n_protocol = sum(1 for e in labeled if e["label"] == 1) |
| n_natural = sum(1 for e in labeled if e["label"] == 0) |
| print(f" Protocol: {n_protocol} ({100*n_protocol/len(labeled):.1f}%)") |
| print(f" Natural: {n_natural} ({100*n_natural/len(labeled):.1f}%)") |
|
|
| |
| print("\n2. Extracting features...") |
| X, y = build_dataset(labeled) |
| print(f" Feature matrix: {X.shape}") |
|
|
| |
| n = len(y) |
| indices = np.random.RandomState(42).permutation(n) |
| split = int(0.8 * n) |
| train_idx, test_idx = indices[:split], indices[split:] |
| X_train, y_train = X[train_idx], y[train_idx] |
| X_test, y_test = X[test_idx], y[test_idx] |
| print(f" Train: {len(y_train)}, Test: {len(y_test)}") |
|
|
| |
| print("\n3. Training logistic regression...") |
| weights, bias = train_logistic_regression(X_train, y_train) |
|
|
| |
| print("\n Weights:") |
| for name, w in zip(FEATURE_NAMES, weights): |
| print(f" {name:30s} {w:+.6f}") |
| print(f" {'bias':30s} {bias:+.6f}") |
|
|
| |
| print("\n4. Test set evaluation:") |
| evaluate(X_test, y_test, weights, bias) |
|
|
| |
| print("\n5. Train set evaluation (sanity):") |
| evaluate(X_train, y_train, weights, bias) |
|
|
| |
| print("\n6. Mixed dictation examples:") |
| test_mixed_examples(weights, bias) |
|
|
| |
| output_path = Path(__file__).parent / "segment-classifier-model.json" |
| export_model(weights, bias, FEATURE_NAMES, output_path) |
|
|
| |
| print("\nβββ Swift Constants βββ") |
| print(f"private static let weights: [Double] = [") |
| for name, w in zip(FEATURE_NAMES, weights): |
| print(f" {w:+.20f}, // {name}") |
| print(f"]") |
| print(f"private static let bias: Double = {bias:+.20f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|