packetcourt / scripts /train_router.py
DIV-45's picture
feat: deploy evidence investigation agent and fine-tuned router
644a42b verified
Raw
History Blame Contribute Delete
4.69 kB
from __future__ import annotations
import argparse
import json
import random
from pathlib import Path
import torch
from huggingface_hub import HfApi
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
ROOT = Path(__file__).resolve().parents[1]
LABELS = ["ingredients", "nutrition", "license", "dates", "refuse_absolute"]
LABEL_TO_ID = {label: index for index, label in enumerate(LABELS)}
class RouterDataset(Dataset):
def __init__(self, records, tokenizer):
self.records = records
self.tokenizer = tokenizer
def __len__(self):
return len(self.records)
def __getitem__(self, index):
record = self.records[index]
encoded = self.tokenizer(
record["text"],
padding="max_length",
truncation=True,
max_length=32,
return_tensors="pt",
)
return {
"input_ids": encoded["input_ids"].squeeze(0),
"attention_mask": encoded["attention_mask"].squeeze(0),
"labels": torch.tensor(LABEL_TO_ID[record["label"]]),
}
def evaluate(model, loader, device):
model.eval()
correct = total = 0
with torch.no_grad():
for batch in loader:
labels = batch.pop("labels").to(device)
logits = model(**{key: value.to(device) for key, value in batch.items()}).logits
correct += (logits.argmax(dim=-1) == labels).sum().item()
total += labels.numel()
return correct / total
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--repo-id", default="build-small-hackathon/packetcourt-evidence-router")
parser.add_argument("--base-model", default="google/bert_uncased_L-2_H-128_A-2")
parser.add_argument("--epochs", type=int, default=30)
args = parser.parse_args()
random.seed(42)
torch.manual_seed(42)
records = [json.loads(line) for line in (ROOT / "data/router_training.jsonl").read_text().splitlines()]
grouped = {label: [] for label in LABELS}
for record in records:
grouped[record["label"]].append(record)
for group in grouped.values():
random.shuffle(group)
validation = [group.pop() for group in grouped.values()]
training = [record for group in grouped.values() for record in group]
random.shuffle(training)
tokenizer = AutoTokenizer.from_pretrained(args.base_model)
model = AutoModelForSequenceClassification.from_pretrained(
args.base_model,
num_labels=len(LABELS),
id2label={index: label for index, label in enumerate(LABELS)},
label2id=LABEL_TO_ID,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
train_loader = DataLoader(RouterDataset(training, tokenizer), batch_size=8, shuffle=True)
validation_loader = DataLoader(RouterDataset(validation, tokenizer), batch_size=5)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for epoch in range(args.epochs):
model.train()
for batch in train_loader:
optimizer.zero_grad()
labels = batch.pop("labels").to(device)
loss = model(**{key: value.to(device) for key, value in batch.items()}, labels=labels).loss
loss.backward()
optimizer.step()
print(f"epoch={epoch + 1} validation_accuracy={evaluate(model, validation_loader, device):.3f}")
output = ROOT / "router_model"
model.save_pretrained(output)
tokenizer.save_pretrained(output)
score = evaluate(model, validation_loader, device)
card = f"""---
license: apache-2.0
base_model: {args.base_model}
tags:
- text-classification
- build-small-hackathon
- packetcourt
- fine-tuned
---
# PacketCourt Evidence Router
A {sum(parameter.numel() for parameter in model.parameters()):,}-parameter fine-tuned classifier used by
PacketCourt's investigation agent to choose the next evidence tool for a packet claim.
Labels: `{", ".join(LABELS)}`.
Held-out validation accuracy: `{score:.3f}` on a small PacketCourt-specific routing set.
The router proposes an investigation tool; deterministic code remains responsible for final verdicts.
"""
(output / "README.md").write_text(card)
api = HfApi()
api.create_repo(args.repo_id, repo_type="model", private=True, exist_ok=True)
api.upload_folder(
repo_id=args.repo_id,
repo_type="model",
folder_path=output,
commit_message="feat: publish PacketCourt fine-tuned evidence router",
)
print(f"published={args.repo_id} validation_accuracy={score:.3f}")
if __name__ == "__main__":
main()