""" Train SwiftContext search-strategy router. The router is a lightweight DistilBERT (66M params) classifier that runs in ~5 ms and tells the 4B explorer LLM which search strategy to apply before it starts exploring. This is the core token-saving improvement over FastContext, which always starts blind and wastes the first turn discovering the search strategy itself. Classes: 0 - broad_scan : wide exploration, file/module locations unknown 1 - targeted_search : specific named symbol to locate 2 - pinpoint_cite : exact line-level citation of already-scoped code Base model : distilbert-base-uncased (66 M params) Training : ~2 min on GPU, ~10 min on CPU """ import json import os import numpy as np from datasets import Dataset from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, EarlyStoppingCallback, ) from sklearn.metrics import accuracy_score, f1_score, classification_report import torch # ── Config ─────────────────────────────────────────────────────────────────── MODEL_NAME = "distilbert-base-uncased" OUTPUT_DIR = "./model" DATA_DIR = "./data" MAX_LENGTH = 128 BATCH_SIZE = 32 NUM_EPOCHS = 5 LEARNING_RATE = 2e-5 EARLY_STOPPING_PATIENCE = 2 LABEL2ID = {"broad_scan": 0, "targeted_search": 1, "pinpoint_cite": 2} ID2LABEL = {0: "broad_scan", 1: "targeted_search", 2: "pinpoint_cite"} # ── Data loading ───────────────────────────────────────────────────────────── def load_jsonl(path: str) -> list[dict]: with open(path, "r", encoding="utf-8") as f: return [json.loads(line.strip()) for line in f] def load_split(split_name: str) -> Dataset: path = os.path.join(DATA_DIR, f"{split_name}.jsonl") raw = load_jsonl(path) return Dataset.from_dict({ "text": [ex["text"] for ex in raw], "label": [ex["label"] for ex in raw], }) # ── Tokenization ────────────────────────────────────────────────────────────── def get_tokenizer(): return AutoTokenizer.from_pretrained(MODEL_NAME) def tokenize(batch, tokenizer): return tokenizer( batch["text"], truncation=True, padding="max_length", max_length=MAX_LENGTH, ) # ── Metrics ────────────────────────────────────────────────────────────────── def compute_metrics(eval_pred): logits, labels = eval_pred preds = np.argmax(logits, axis=-1) return { "accuracy": accuracy_score(labels, preds), "f1": f1_score(labels, preds, average="weighted"), } # ── Main ───────────────────────────────────────────────────────────────────── def main(): print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}") print(f"Loading tokenizer from {MODEL_NAME} ...") tokenizer = get_tokenizer() print("Loading datasets ...") train_ds = load_split("train") val_ds = load_split("val") test_ds = load_split("test") print(f" train: {len(train_ds)}, val: {len(val_ds)}, test: {len(test_ds)}") tok_fn = lambda batch: tokenize(batch, tokenizer) train_ds = train_ds.map(tok_fn, batched=True) val_ds = val_ds.map(tok_fn, batched=True) test_ds = test_ds.map(tok_fn, batched=True) for ds in (train_ds, val_ds, test_ds): ds.set_format("torch", columns=["input_ids", "attention_mask", "label"]) print("Loading model ...") model = AutoModelForSequenceClassification.from_pretrained( MODEL_NAME, num_labels=3, id2label=ID2LABEL, label2id=LABEL2ID, ) training_args = TrainingArguments( output_dir=OUTPUT_DIR, num_train_epochs=NUM_EPOCHS, per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE, weight_decay=0.01, warmup_steps=10, # 10% of ~100 total steps eval_strategy="epoch", save_strategy="epoch", save_total_limit=3, # keep only the 3 best checkpoints on disk load_best_model_at_end=True, metric_for_best_model="f1", dataloader_num_workers=2, # parallel data loading report_to="none", fp16=torch.cuda.is_available(), ) trainer = Trainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=val_ds, compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE)], ) print("\nTraining ...") trainer.train() # ── Test evaluation ─────────────────────────────────────────────────────── print("\nEvaluating on test set ...") preds_output = trainer.predict(test_ds) preds = np.argmax(preds_output.predictions, axis=-1) labels = preds_output.label_ids print("\n=== Test Set Results ===") print(f"Accuracy : {accuracy_score(labels, preds):.4f}") print(f"F1 (weighted): {f1_score(labels, preds, average='weighted'):.4f}") print("\nPer-class report:") print(classification_report( labels, preds, target_names=[ID2LABEL[i] for i in sorted(ID2LABEL)], )) # ── Save ───────────────────────────────────────────────────────────────── final_dir = os.path.join(OUTPUT_DIR, "final") print(f"\nSaving model → {final_dir}") trainer.save_model(final_dir) tokenizer.save_pretrained(final_dir) print("Done.") if __name__ == "__main__": main()