File size: 24,429 Bytes
35717ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 | #!/usr/bin/env python3
"""
Train a per-word protocol segment classifier.
Purpose: Given a mixed dictation like "I want to check the directory ls dash la",
classify each word as protocol (1) or natural (0), so we can extract segments
to send to the fused Qwen3 model.
Architecture:
- Per-word logistic regression with contextual features
- Context window of Β±2 words
- Output: weights + bias β port to Swift as ProtocolSegmentClassifier
Training data: Synthetic mixed dictations combining natural speech fragments
with protocol command dictations from eval-fuzzy.json.
"""
import json
import numpy as np
from pathlib import Path
import re
# βββ Protocol vocabulary split into strong/weak βββ
# Strong: almost never appear in natural speech
# Weak: frequently appear in natural speech, only protocol in context
STRONG_PROTOCOL = {
# Symbol words (unambiguous)
"dash", "dot", "slash", "pipe", "tilde", "hash", "dollar",
"caret", "ampersand", "equals", "underscore", "backslash",
"backtick", "semicolon", "colon",
# Synonyms
"minus", "hyphen", "asterisk", "hashtag",
# Brackets (unambiguous components)
"paren", "brace", "bracket", "parenthesis", "curly",
# Casing directives
"capital", "caps", "camel", "snake", "pascal", "kebab", "screaming",
# Space as protocol
"space",
# Redirect
"redirect", "append",
}
WEAK_PROTOCOL = {
# These appear frequently in natural speech
"at", "star", "bang", "exclamation", "question", "comma", "quote",
"period", "plus", "percent",
# Multi-word components that are also common English
"single", "open", "close", "angle", "forward", "back", "sign",
"double", "mark", "than", "less", "new", "line", "all", "case",
# Number words REMOVED β too ambiguous in natural speech.
# Numbers near protocol words get captured by Β±2 expansion instead.
}
PROTOCOL_VOCAB = STRONG_PROTOCOL | WEAK_PROTOCOL
# Expanded symbols (after symbolic mapping runs)
EXPANDED_SYMBOLS = {
"-", ".", "/", "\\", "_", "|", "~", "@", "#", "*",
"+", "=", ":", ";", "&", "%", "^", "!", "?", "`",
"$", "<", ">", "--", "&&", "||",
}
SYNTAX_CHARS = set("-./\\_|~@#:=")
# βββ Natural speech fragments βββ
NATURAL_PHRASES = [
"I want to check the directory",
"can you run this command for me",
"let's see what happens when we",
"okay so basically I need to",
"the next thing I want to do is",
"alright let me try",
"I think we should also look at",
"and then after that",
"go ahead and type out",
"so the idea is to",
"I was thinking maybe we could",
"hold on let me think about this",
"what if we try something different",
"actually no wait",
"right so the problem is",
"I need to figure out why",
"let's debug this real quick",
"the output should show us",
"and that would give us",
"basically what I'm trying to do is",
"so first we need to",
"and then we can see if",
"alright this is important",
"I want to emphasize",
"the reason I'm doing this is",
"so we did have a classifier",
"wasn't that part of the mission",
"like if we run all of a long dictation",
"I mean we'd be lucky to get",
"there's some stuff happening",
"that was not at all what I said",
"my last dictation was",
"we're not trying to do any list extraction",
"can you give me a few different test commands",
"I don't really understand",
"we should be able to fix that",
"alright let's do a proper dictation",
"the whole point is that",
"I think the idea would be",
"we want to build a new one",
"so yeah I think like we want to",
"and that would be a cleanup job",
"with maybe a bigger model",
"right so I think the plan is",
"let me explain what's going on",
"this is getting really interesting",
"okay perfect that works",
"no that's not right",
"wait what happened there",
"I'm gonna try again",
]
# βββ Protocol command dictations βββ
PROTOCOL_COMMANDS = [
# Simple commands
"ls dash la",
"git dash dash help",
"cd dot dot",
"rm dash rf",
"mkdir dash p",
"cat slash etc slash hosts",
"chmod seven five five",
"grep dash r",
"find dot dash name",
"ps aux",
"kill dash nine",
# Medium commands
"git space push space dash u space origin space main",
"docker space run space dash dash rm space dash p space eight zero eight zero colon eight zero",
"npm space install space dash capital D space typescript",
"ssh space dash i space tilde slash dot ssh slash id underscore rsa",
"curl space dash capital X space all caps POST",
"kubectl space get space pods space dash n space kube dash system",
"brew space install space dash dash cask",
"pip space install space dash r space requirements dot txt",
"python space dash m space pytest space dash v",
"cargo space build space dash dash release",
# Complex commands with paths
"rsync space dash avz space dash e space ssh space dot slash dist slash",
"export space all caps DATABASE underscore URL equals quote postgres colon slash slash admin",
"redis dash cli space dash h space one two seven dot zero dot zero dot one",
"terraform space plan space dash var dash file equals production dot tfvars",
# Short fragments
"dash dash help",
"dash la",
"dot slash",
"slash usr slash local slash bin",
"tilde slash dot config",
"star dot js",
"dollar sign open paren",
# Tool names with hyphens
"talkie dash dev",
"visual dash studio dash code",
"docker dash compose",
"kube dash system",
"type dash script",
]
# Additional pure natural sentences β includes words that overlap with protocol vocab
# (at, three, one, new, line, back, sign, case, all, open, close, etc.)
PURE_NATURAL = [
"I was just thinking about how to approach this problem differently",
"the meeting is at three o'clock tomorrow afternoon",
"can you send me the link to that document",
"I'll get back to you on that later today",
"the project deadline is next Friday",
"we should probably discuss this with the rest of the team",
"that's a great idea I hadn't thought of that",
"let me know if you need anything else from me",
"I think the best approach would be to start from scratch",
"have you seen the latest updates to the design",
"the performance numbers look really good",
"we might need to reconsider our strategy here",
"I'll set up a call with the engineering team",
"the documentation needs to be updated before release",
"this reminds me of a similar issue we had last month",
# Sentences with ambiguous protocol words used naturally
"I arrived at the office at nine this morning",
"there are three new cases to review today",
"we need to go back and sign the contract",
"one of the most important things is to stay open",
"all of the line items need to be double checked",
"the new sign at the front of the building looks great",
"I need to close out all the open tickets by five",
"she gave us a star review which was really nice",
"the question mark at the end was confusing",
"he made a plus sized version of the original",
"we should open a new line of inquiry",
"there were less than ten people at the event",
"I'm going to mark this case as resolved",
"go back to the previous page and forward it to me",
"we had one single issue in all of last quarter",
"the angle of the photo makes it look close to the sign",
"I'll be back at four thirty with the new draft",
"that's a great point about the new direction",
"we need all hands on deck for this one",
"the bottom line is we need more time",
"let me quote you on that",
"he was at a loss for words",
"the new employee starts on day one next week",
"all in all it was a pretty good quarter",
"the case study shows a ten percent improvement",
"I need to sign off on this before close of business",
"the angle was all wrong for that particular shot",
]
def is_protocol_word(word: str) -> bool:
"""Check if a word is a protocol word."""
lower = word.lower().strip(".,!?;:'\"")
if lower in PROTOCOL_VOCAB:
return True
if word.strip() in EXPANDED_SYMBOLS:
return True
# Contains syntax characters (but not contractions)
if any(c in SYNTAX_CHARS for c in lower):
if "'" not in word and "\u2019" not in word:
# Not just a trailing period on a normal word
if not (word.endswith(".") and "." not in word[:-1]):
return True
return False
def label_words(words: list[str], labels: list[int]) -> list[dict]:
"""Create labeled word entries with context features."""
entries = []
n = len(words)
for i, (word, label) in enumerate(zip(words, labels)):
entry = {
"word": word,
"label": label,
"position": i,
"total_words": n,
}
entries.append(entry)
return entries
def generate_mixed_examples(n_examples: int = 500, seed: int = 42) -> list[dict]:
"""Generate synthetic mixed dictations with word-level labels."""
rng = np.random.RandomState(seed)
all_labeled = []
# Pattern 1: Natural + Protocol + Natural (sandwich)
for _ in range(n_examples // 3):
nat1 = rng.choice(NATURAL_PHRASES)
cmd = rng.choice(PROTOCOL_COMMANDS)
nat2 = rng.choice(NATURAL_PHRASES)
nat1_words = nat1.split()
cmd_words = cmd.split()
nat2_words = nat2.split()
words = nat1_words + cmd_words + nat2_words
labels = [0] * len(nat1_words) + [1] * len(cmd_words) + [0] * len(nat2_words)
all_labeled.extend(label_words(words, labels))
# Pattern 2: Protocol only (full command dictation)
for _ in range(n_examples // 4):
cmd = rng.choice(PROTOCOL_COMMANDS)
words = cmd.split()
labels = [1] * len(words)
all_labeled.extend(label_words(words, labels))
# Pattern 3: Natural only (pure speech, no protocol)
for _ in range(n_examples // 4):
nat = rng.choice(PURE_NATURAL + NATURAL_PHRASES)
words = nat.split()
labels = [0] * len(words)
all_labeled.extend(label_words(words, labels))
# Pattern 4: Natural + Protocol (command at end)
for _ in range(n_examples // 6):
nat = rng.choice(NATURAL_PHRASES)
cmd = rng.choice(PROTOCOL_COMMANDS)
nat_words = nat.split()
cmd_words = cmd.split()
words = nat_words + cmd_words
labels = [0] * len(nat_words) + [1] * len(cmd_words)
all_labeled.extend(label_words(words, labels))
# Pattern 5: Protocol + Natural (command at start, explanation after)
for _ in range(n_examples // 6):
cmd = rng.choice(PROTOCOL_COMMANDS)
nat = rng.choice(NATURAL_PHRASES)
cmd_words = cmd.split()
nat_words = nat.split()
words = cmd_words + nat_words
labels = [1] * len(cmd_words) + [0] * len(nat_words)
all_labeled.extend(label_words(words, labels))
return all_labeled
def is_strong_protocol(word: str) -> bool:
"""Check if a word is an unambiguous protocol word."""
lower = word.lower().strip(".,!?;:'\"")
if lower in STRONG_PROTOCOL:
return True
if word.strip() in EXPANDED_SYMBOLS:
return True
return False
def extract_features(word: str, context: list[str], position: int, total: int) -> list[float]:
"""
Extract features for a single word with its context.
Features (14 total):
0. is_strong_protocol β word is an unambiguous protocol word (dash, dot, slash...)
1. is_weak_protocol β word is an ambiguous protocol word (at, three, one...)
2. is_expanded_symbol β word is an expanded symbol (-, ., /)
3. has_syntax_chars β word contains syntax characters
4. word_length_norm β word length / 10 (normalized)
5. is_short_word β len <= 3 (commands: ls, cd, rm)
6. context_strong_density β fraction of Β±2 context words that are STRONG protocol
7. context_any_density β fraction of Β±2 context words that are any protocol
8. left_is_strong β immediate left neighbor is strong protocol
9. right_is_strong β immediate right neighbor is strong protocol
10. is_number_like β word looks like a number or number word
11. strong_neighbor_count β count of strong protocol words in Β±2 window
12. is_all_lower β all lowercase
13. position_ratio β position / total
"""
lower = word.lower().strip(".,!?;:'\"")
stripped = word.strip()
# Feature 0: is_strong_protocol
f_strong = 1.0 if lower in STRONG_PROTOCOL else 0.0
# Feature 1: is_weak_protocol
f_weak = 1.0 if lower in WEAK_PROTOCOL else 0.0
# Feature 2: is_expanded_symbol
f_symbol = 1.0 if stripped in EXPANDED_SYMBOLS else 0.0
# Feature 3: has_syntax_chars
f_syntax = 0.0
if any(c in SYNTAX_CHARS for c in lower):
if "'" not in word and "\u2019" not in word:
if not (word.endswith(".") and "." not in word[:-1]):
f_syntax = 1.0
# Feature 4: word_length_norm
f_len = len(word) / 10.0
# Feature 5: is_short_word
f_short = 1.0 if len(lower) <= 3 else 0.0
# Feature 6: context_strong_density
ctx_strong = sum(1 for w in context if is_strong_protocol(w))
f_ctx_strong = ctx_strong / max(len(context), 1)
# Feature 7: context_any_density
ctx_any = sum(1 for w in context if is_protocol_word(w))
f_ctx_any = ctx_any / max(len(context), 1)
# Feature 8: left_is_strong
f_left = 0.0
if position > 0 and len(context) > 0:
ctx_center = min(position, 2)
if ctx_center > 0 and ctx_center - 1 < len(context):
f_left = 1.0 if is_strong_protocol(context[ctx_center - 1]) else 0.0
# Feature 9: right_is_strong
f_right = 0.0
if position < total - 1 and len(context) > 0:
ctx_center = min(position, 2)
if ctx_center + 1 < len(context):
f_right = 1.0 if is_strong_protocol(context[ctx_center + 1]) else 0.0
# Feature 10: is_number_like
number_words = {"zero", "one", "two", "three", "four", "five", "six",
"seven", "eight", "nine", "ten"}
f_number = 1.0 if (lower in number_words or lower.isdigit()) else 0.0
# Feature 11: strong_neighbor_count β raw count, not ratio
f_strong_neighbors = float(ctx_strong)
# Feature 12: is_all_lower
f_lower = 1.0 if word.isalpha() and word == word.lower() else 0.0
# Feature 13: position_ratio
f_pos = position / max(total - 1, 1)
return [
f_strong, # 0
f_weak, # 1
f_symbol, # 2
f_syntax, # 3
f_len, # 4
f_short, # 5
f_ctx_strong, # 6
f_ctx_any, # 7
f_left, # 8
f_right, # 9
f_number, # 10
f_strong_neighbors, # 11
f_lower, # 12
f_pos, # 13
]
def build_dataset(labeled_words: list[dict], all_words_by_sentence=None):
"""Build feature matrix and label vector from labeled words."""
# Group by sentence (consecutive entries with same total_words)
sentences = []
current = []
for entry in labeled_words:
if current and (entry["position"] == 0 or entry["total_words"] != current[0]["total_words"]):
sentences.append(current)
current = []
current.append(entry)
if current:
sentences.append(current)
X = []
y = []
for sentence in sentences:
words = [e["word"] for e in sentence]
for entry in sentence:
pos = entry["position"]
# Build context window Β±2
ctx_start = max(0, pos - 2)
ctx_end = min(len(words), pos + 3)
context = words[ctx_start:ctx_end]
features = extract_features(
entry["word"], context, pos, entry["total_words"]
)
X.append(features)
y.append(entry["label"])
return np.array(X), np.array(y)
def train_logistic_regression(X, y, lr=0.05, lambda_reg=0.1, max_epochs=500, tol=1e-6):
"""Train logistic regression with L2 regularization via batch gradient descent."""
n_samples, n_features = X.shape
weights = np.zeros(n_features)
bias = 0.0
prev_loss = float("inf")
for epoch in range(max_epochs):
# Forward pass
logits = X @ weights + bias
probs = 1.0 / (1.0 + np.exp(-np.clip(logits, -500, 500)))
# Loss (binary cross-entropy + L2)
eps = 1e-15
loss = -np.mean(
y * np.log(probs + eps) + (1 - y) * np.log(1 - probs + eps)
) + 0.5 * lambda_reg * np.sum(weights ** 2)
# Convergence check
if abs(prev_loss - loss) < tol:
print(f" Converged at epoch {epoch}, loss={loss:.6f}")
break
prev_loss = loss
# Gradients
error = probs - y
grad_w = (X.T @ error) / n_samples + lambda_reg * weights
grad_b = np.mean(error)
# Update
weights -= lr * grad_w
bias -= lr * grad_b
if epoch % 50 == 0:
acc = np.mean((probs >= 0.5) == y)
print(f" Epoch {epoch}: loss={loss:.4f}, acc={acc:.3f}")
# Final accuracy
probs = 1.0 / (1.0 + np.exp(-np.clip(X @ weights + bias, -500, 500)))
acc = np.mean((probs >= 0.5) == y)
print(f" Final: loss={prev_loss:.4f}, acc={acc:.3f}")
return weights, bias
def evaluate(X, y, weights, bias, threshold=0.5):
"""Evaluate classifier and print metrics."""
logits = X @ weights + bias
probs = 1.0 / (1.0 + np.exp(-np.clip(logits, -500, 500)))
preds = (probs >= threshold).astype(int)
acc = np.mean(preds == y)
tp = np.sum((preds == 1) & (y == 1))
fp = np.sum((preds == 1) & (y == 0))
fn = np.sum((preds == 0) & (y == 1))
tn = np.sum((preds == 0) & (y == 0))
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
print(f"\n Accuracy: {acc:.3f} ({int(acc * len(y))}/{len(y)})")
print(f" Precision: {precision:.3f}")
print(f" Recall: {recall:.3f}")
print(f" F1: {f1:.3f}")
print(f" Confusion: TP={tp} FP={fp} FN={fn} TN={tn}")
return acc, precision, recall, f1
def test_mixed_examples(weights, bias):
"""Test on hand-crafted mixed dictation examples."""
test_cases = [
# (text, expected_protocol_words)
("I want to check the directory ls dash la", {"ls", "dash", "la"}),
("can you run git dash dash help for me", {"git", "dash", "help"}),
("alright let me try talkie dash dev dash dash help", {"talkie", "dash", "dev", "help"}),
("the meeting is at three o'clock tomorrow afternoon", set()),
("I think we should also look at cd dot dot slash src", {"cd", "dot", "slash", "src"}),
("so basically npm space install space dash capital D", {"npm", "space", "install", "dash", "capital", "D"}),
("let me know if you need anything else from me", set()),
("okay go ahead and type out chmod seven five five", {"chmod", "seven", "five"}),
("the performance numbers look really good", set()),
("ssh dash i tilde slash dot ssh slash id underscore rsa", {"ssh", "dash", "i", "tilde", "slash", "dot", "id", "underscore", "rsa"}),
]
print("\nβββ Test Cases βββ")
for text, expected_protocol in test_cases:
words = text.split()
preds = []
for i, word in enumerate(words):
ctx_start = max(0, i - 2)
ctx_end = min(len(words), i + 3)
context = words[ctx_start:ctx_end]
features = extract_features(word, context, i, len(words))
logit = np.dot(features, weights) + bias
prob = 1.0 / (1.0 + np.exp(-logit))
preds.append((word, prob, prob >= 0.5))
detected = {w for w, p, is_proto in preds if is_proto}
natural = {w for w, p, is_proto in preds if not is_proto}
# Color output
colored = []
for w, p, is_proto in preds:
if is_proto:
colored.append(f"\033[91m{w}\033[0m") # red = protocol
else:
colored.append(f"\033[92m{w}\033[0m") # green = natural
print(f" {' '.join(colored)}")
# Check accuracy
if expected_protocol:
correct = detected & expected_protocol
missed = expected_protocol - detected
false_pos = detected - expected_protocol
if missed:
print(f" MISSED: {missed}")
if false_pos:
print(f" FALSE+: {false_pos}")
elif detected:
print(f" FALSE+: {detected}")
def export_model(weights, bias, feature_names, output_path):
"""Export model as JSON for porting to Swift."""
model = {
"classifier": "ProtocolSegmentClassifier",
"description": "Per-word logistic regression for protocol segment detection",
"features": feature_names,
"weights": weights.tolist(),
"bias": float(bias),
"threshold": 0.5,
}
with open(output_path, "w") as f:
json.dump(model, f, indent=2)
print(f"\nModel exported to {output_path}")
FEATURE_NAMES = [
"is_strong_protocol",
"is_weak_protocol",
"is_expanded_symbol",
"has_syntax_chars",
"word_length_norm",
"is_short_word",
"context_strong_density",
"context_any_density",
"left_is_strong",
"right_is_strong",
"is_number_like",
"strong_neighbor_count",
"is_all_lower",
"position_ratio",
]
def main():
print("=" * 60)
print("Protocol Segment Classifier Training")
print("=" * 60)
# Generate training data
print("\n1. Generating training data...")
labeled = generate_mixed_examples(n_examples=600, seed=42)
print(f" {len(labeled)} labeled words generated")
# Count class balance
n_protocol = sum(1 for e in labeled if e["label"] == 1)
n_natural = sum(1 for e in labeled if e["label"] == 0)
print(f" Protocol: {n_protocol} ({100*n_protocol/len(labeled):.1f}%)")
print(f" Natural: {n_natural} ({100*n_natural/len(labeled):.1f}%)")
# Build feature matrix
print("\n2. Extracting features...")
X, y = build_dataset(labeled)
print(f" Feature matrix: {X.shape}")
# Split train/test (80/20)
n = len(y)
indices = np.random.RandomState(42).permutation(n)
split = int(0.8 * n)
train_idx, test_idx = indices[:split], indices[split:]
X_train, y_train = X[train_idx], y[train_idx]
X_test, y_test = X[test_idx], y[test_idx]
print(f" Train: {len(y_train)}, Test: {len(y_test)}")
# Train
print("\n3. Training logistic regression...")
weights, bias = train_logistic_regression(X_train, y_train)
# Print weights
print("\n Weights:")
for name, w in zip(FEATURE_NAMES, weights):
print(f" {name:30s} {w:+.6f}")
print(f" {'bias':30s} {bias:+.6f}")
# Evaluate on test set
print("\n4. Test set evaluation:")
evaluate(X_test, y_test, weights, bias)
# Evaluate on train set (sanity check)
print("\n5. Train set evaluation (sanity):")
evaluate(X_train, y_train, weights, bias)
# Test on hand-crafted mixed examples
print("\n6. Mixed dictation examples:")
test_mixed_examples(weights, bias)
# Export
output_path = Path(__file__).parent / "segment-classifier-model.json"
export_model(weights, bias, FEATURE_NAMES, output_path)
# Print Swift-ready constants
print("\nβββ Swift Constants βββ")
print(f"private static let weights: [Double] = [")
for name, w in zip(FEATURE_NAMES, weights):
print(f" {w:+.20f}, // {name}")
print(f"]")
print(f"private static let bias: Double = {bias:+.20f}")
if __name__ == "__main__":
main()
|