SwiftContext / train.py
tripathyShaswata's picture
Upload folder using huggingface_hub
31befd6 verified
Raw
History Blame Contribute Delete
6.56 kB
"""
Train SwiftContext search-strategy router.
The router is a lightweight DistilBERT (66M params) classifier that runs in
~5 ms and tells the 4B explorer LLM which search strategy to apply before
it starts exploring. This is the core token-saving improvement over
FastContext, which always starts blind and wastes the first turn discovering
the search strategy itself.
Classes:
0 - broad_scan : wide exploration, file/module locations unknown
1 - targeted_search : specific named symbol to locate
2 - pinpoint_cite : exact line-level citation of already-scoped code
Base model : distilbert-base-uncased (66 M params)
Training : ~2 min on GPU, ~10 min on CPU
"""
import json
import os
import numpy as np
from datasets import Dataset
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer,
EarlyStoppingCallback,
)
from sklearn.metrics import accuracy_score, f1_score, classification_report
import torch
# ── Config ───────────────────────────────────────────────────────────────────
MODEL_NAME = "distilbert-base-uncased"
OUTPUT_DIR = "./model"
DATA_DIR = "./data"
MAX_LENGTH = 128
BATCH_SIZE = 32
NUM_EPOCHS = 5
LEARNING_RATE = 2e-5
EARLY_STOPPING_PATIENCE = 2
LABEL2ID = {"broad_scan": 0, "targeted_search": 1, "pinpoint_cite": 2}
ID2LABEL = {0: "broad_scan", 1: "targeted_search", 2: "pinpoint_cite"}
# ── Data loading ─────────────────────────────────────────────────────────────
def load_jsonl(path: str) -> list[dict]:
with open(path, "r", encoding="utf-8") as f:
return [json.loads(line.strip()) for line in f]
def load_split(split_name: str) -> Dataset:
path = os.path.join(DATA_DIR, f"{split_name}.jsonl")
raw = load_jsonl(path)
return Dataset.from_dict({
"text": [ex["text"] for ex in raw],
"label": [ex["label"] for ex in raw],
})
# ── Tokenization ──────────────────────────────────────────────────────────────
def get_tokenizer():
return AutoTokenizer.from_pretrained(MODEL_NAME)
def tokenize(batch, tokenizer):
return tokenizer(
batch["text"],
truncation=True,
padding="max_length",
max_length=MAX_LENGTH,
)
# ── Metrics ──────────────────────────────────────────────────────────────────
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
return {
"accuracy": accuracy_score(labels, preds),
"f1": f1_score(labels, preds, average="weighted"),
}
# ── Main ─────────────────────────────────────────────────────────────────────
def main():
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
print(f"Loading tokenizer from {MODEL_NAME} ...")
tokenizer = get_tokenizer()
print("Loading datasets ...")
train_ds = load_split("train")
val_ds = load_split("val")
test_ds = load_split("test")
print(f" train: {len(train_ds)}, val: {len(val_ds)}, test: {len(test_ds)}")
tok_fn = lambda batch: tokenize(batch, tokenizer)
train_ds = train_ds.map(tok_fn, batched=True)
val_ds = val_ds.map(tok_fn, batched=True)
test_ds = test_ds.map(tok_fn, batched=True)
for ds in (train_ds, val_ds, test_ds):
ds.set_format("torch", columns=["input_ids", "attention_mask", "label"])
print("Loading model ...")
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
num_labels=3,
id2label=ID2LABEL,
label2id=LABEL2ID,
)
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
learning_rate=LEARNING_RATE,
weight_decay=0.01,
warmup_steps=10, # 10% of ~100 total steps
eval_strategy="epoch",
save_strategy="epoch",
save_total_limit=3, # keep only the 3 best checkpoints on disk
load_best_model_at_end=True,
metric_for_best_model="f1",
dataloader_num_workers=2, # parallel data loading
report_to="none",
fp16=torch.cuda.is_available(),
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE)],
)
print("\nTraining ...")
trainer.train()
# ── Test evaluation ───────────────────────────────────────────────────────
print("\nEvaluating on test set ...")
preds_output = trainer.predict(test_ds)
preds = np.argmax(preds_output.predictions, axis=-1)
labels = preds_output.label_ids
print("\n=== Test Set Results ===")
print(f"Accuracy : {accuracy_score(labels, preds):.4f}")
print(f"F1 (weighted): {f1_score(labels, preds, average='weighted'):.4f}")
print("\nPer-class report:")
print(classification_report(
labels, preds,
target_names=[ID2LABEL[i] for i in sorted(ID2LABEL)],
))
# ── Save ─────────────────────────────────────────────────────────────────
final_dir = os.path.join(OUTPUT_DIR, "final")
print(f"\nSaving model β†’ {final_dir}")
trainer.save_model(final_dir)
tokenizer.save_pretrained(final_dir)
print("Done.")
if __name__ == "__main__":
main()