| """
|
| Train SwiftContext search-strategy router.
|
|
|
| The router is a lightweight DistilBERT (66M params) classifier that runs in
|
| ~5 ms and tells the 4B explorer LLM which search strategy to apply before
|
| it starts exploring. This is the core token-saving improvement over
|
| FastContext, which always starts blind and wastes the first turn discovering
|
| the search strategy itself.
|
|
|
| Classes:
|
| 0 - broad_scan : wide exploration, file/module locations unknown
|
| 1 - targeted_search : specific named symbol to locate
|
| 2 - pinpoint_cite : exact line-level citation of already-scoped code
|
|
|
| Base model : distilbert-base-uncased (66 M params)
|
| Training : ~2 min on GPU, ~10 min on CPU
|
| """
|
|
|
| import json
|
| import os
|
| import numpy as np
|
| from datasets import Dataset
|
| from transformers import (
|
| AutoTokenizer,
|
| AutoModelForSequenceClassification,
|
| TrainingArguments,
|
| Trainer,
|
| EarlyStoppingCallback,
|
| )
|
| from sklearn.metrics import accuracy_score, f1_score, classification_report
|
| import torch
|
|
|
|
|
|
|
| MODEL_NAME = "distilbert-base-uncased"
|
| OUTPUT_DIR = "./model"
|
| DATA_DIR = "./data"
|
| MAX_LENGTH = 128
|
| BATCH_SIZE = 32
|
| NUM_EPOCHS = 5
|
| LEARNING_RATE = 2e-5
|
| EARLY_STOPPING_PATIENCE = 2
|
|
|
| LABEL2ID = {"broad_scan": 0, "targeted_search": 1, "pinpoint_cite": 2}
|
| ID2LABEL = {0: "broad_scan", 1: "targeted_search", 2: "pinpoint_cite"}
|
|
|
|
|
|
|
| def load_jsonl(path: str) -> list[dict]:
|
| with open(path, "r", encoding="utf-8") as f:
|
| return [json.loads(line.strip()) for line in f]
|
|
|
|
|
| def load_split(split_name: str) -> Dataset:
|
| path = os.path.join(DATA_DIR, f"{split_name}.jsonl")
|
| raw = load_jsonl(path)
|
| return Dataset.from_dict({
|
| "text": [ex["text"] for ex in raw],
|
| "label": [ex["label"] for ex in raw],
|
| })
|
|
|
|
|
|
|
| def get_tokenizer():
|
| return AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
|
|
| def tokenize(batch, tokenizer):
|
| return tokenizer(
|
| batch["text"],
|
| truncation=True,
|
| padding="max_length",
|
| max_length=MAX_LENGTH,
|
| )
|
|
|
|
|
|
|
| def compute_metrics(eval_pred):
|
| logits, labels = eval_pred
|
| preds = np.argmax(logits, axis=-1)
|
| return {
|
| "accuracy": accuracy_score(labels, preds),
|
| "f1": f1_score(labels, preds, average="weighted"),
|
| }
|
|
|
|
|
|
|
| def main():
|
| print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
|
| print(f"Loading tokenizer from {MODEL_NAME} ...")
|
|
|
| tokenizer = get_tokenizer()
|
|
|
| print("Loading datasets ...")
|
| train_ds = load_split("train")
|
| val_ds = load_split("val")
|
| test_ds = load_split("test")
|
| print(f" train: {len(train_ds)}, val: {len(val_ds)}, test: {len(test_ds)}")
|
|
|
| tok_fn = lambda batch: tokenize(batch, tokenizer)
|
| train_ds = train_ds.map(tok_fn, batched=True)
|
| val_ds = val_ds.map(tok_fn, batched=True)
|
| test_ds = test_ds.map(tok_fn, batched=True)
|
|
|
| for ds in (train_ds, val_ds, test_ds):
|
| ds.set_format("torch", columns=["input_ids", "attention_mask", "label"])
|
|
|
| print("Loading model ...")
|
| model = AutoModelForSequenceClassification.from_pretrained(
|
| MODEL_NAME,
|
| num_labels=3,
|
| id2label=ID2LABEL,
|
| label2id=LABEL2ID,
|
| )
|
|
|
| training_args = TrainingArguments(
|
| output_dir=OUTPUT_DIR,
|
| num_train_epochs=NUM_EPOCHS,
|
| per_device_train_batch_size=BATCH_SIZE,
|
| per_device_eval_batch_size=BATCH_SIZE,
|
| learning_rate=LEARNING_RATE,
|
| weight_decay=0.01,
|
| warmup_steps=10,
|
| eval_strategy="epoch",
|
| save_strategy="epoch",
|
| save_total_limit=3,
|
| load_best_model_at_end=True,
|
| metric_for_best_model="f1",
|
| dataloader_num_workers=2,
|
| report_to="none",
|
| fp16=torch.cuda.is_available(),
|
| )
|
|
|
| trainer = Trainer(
|
| model=model,
|
| args=training_args,
|
| train_dataset=train_ds,
|
| eval_dataset=val_ds,
|
| compute_metrics=compute_metrics,
|
| callbacks=[EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE)],
|
| )
|
|
|
| print("\nTraining ...")
|
| trainer.train()
|
|
|
|
|
| print("\nEvaluating on test set ...")
|
| preds_output = trainer.predict(test_ds)
|
| preds = np.argmax(preds_output.predictions, axis=-1)
|
| labels = preds_output.label_ids
|
|
|
| print("\n=== Test Set Results ===")
|
| print(f"Accuracy : {accuracy_score(labels, preds):.4f}")
|
| print(f"F1 (weighted): {f1_score(labels, preds, average='weighted'):.4f}")
|
| print("\nPer-class report:")
|
| print(classification_report(
|
| labels, preds,
|
| target_names=[ID2LABEL[i] for i in sorted(ID2LABEL)],
|
| ))
|
|
|
|
|
| final_dir = os.path.join(OUTPUT_DIR, "final")
|
| print(f"\nSaving model β {final_dir}")
|
| trainer.save_model(final_dir)
|
| tokenizer.save_pretrained(final_dir)
|
| print("Done.")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|