from __future__ import annotations import argparse import json import random from pathlib import Path import torch from huggingface_hub import HfApi from torch.utils.data import DataLoader, Dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer ROOT = Path(__file__).resolve().parents[1] LABELS = ["ingredients", "nutrition", "license", "dates", "refuse_absolute"] LABEL_TO_ID = {label: index for index, label in enumerate(LABELS)} class RouterDataset(Dataset): def __init__(self, records, tokenizer): self.records = records self.tokenizer = tokenizer def __len__(self): return len(self.records) def __getitem__(self, index): record = self.records[index] encoded = self.tokenizer( record["text"], padding="max_length", truncation=True, max_length=32, return_tensors="pt", ) return { "input_ids": encoded["input_ids"].squeeze(0), "attention_mask": encoded["attention_mask"].squeeze(0), "labels": torch.tensor(LABEL_TO_ID[record["label"]]), } def evaluate(model, loader, device): model.eval() correct = total = 0 with torch.no_grad(): for batch in loader: labels = batch.pop("labels").to(device) logits = model(**{key: value.to(device) for key, value in batch.items()}).logits correct += (logits.argmax(dim=-1) == labels).sum().item() total += labels.numel() return correct / total def main(): parser = argparse.ArgumentParser() parser.add_argument("--repo-id", default="build-small-hackathon/packetcourt-evidence-router") parser.add_argument("--base-model", default="google/bert_uncased_L-2_H-128_A-2") parser.add_argument("--epochs", type=int, default=30) args = parser.parse_args() random.seed(42) torch.manual_seed(42) records = [json.loads(line) for line in (ROOT / "data/router_training.jsonl").read_text().splitlines()] grouped = {label: [] for label in LABELS} for record in records: grouped[record["label"]].append(record) for group in grouped.values(): random.shuffle(group) validation = [group.pop() for group in grouped.values()] training = [record for group in grouped.values() for record in group] random.shuffle(training) tokenizer = AutoTokenizer.from_pretrained(args.base_model) model = AutoModelForSequenceClassification.from_pretrained( args.base_model, num_labels=len(LABELS), id2label={index: label for index, label in enumerate(LABELS)}, label2id=LABEL_TO_ID, ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) train_loader = DataLoader(RouterDataset(training, tokenizer), batch_size=8, shuffle=True) validation_loader = DataLoader(RouterDataset(validation, tokenizer), batch_size=5) optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) for epoch in range(args.epochs): model.train() for batch in train_loader: optimizer.zero_grad() labels = batch.pop("labels").to(device) loss = model(**{key: value.to(device) for key, value in batch.items()}, labels=labels).loss loss.backward() optimizer.step() print(f"epoch={epoch + 1} validation_accuracy={evaluate(model, validation_loader, device):.3f}") output = ROOT / "router_model" model.save_pretrained(output) tokenizer.save_pretrained(output) score = evaluate(model, validation_loader, device) card = f"""--- license: apache-2.0 base_model: {args.base_model} tags: - text-classification - build-small-hackathon - packetcourt - fine-tuned --- # PacketCourt Evidence Router A {sum(parameter.numel() for parameter in model.parameters()):,}-parameter fine-tuned classifier used by PacketCourt's investigation agent to choose the next evidence tool for a packet claim. Labels: `{", ".join(LABELS)}`. Held-out validation accuracy: `{score:.3f}` on a small PacketCourt-specific routing set. The router proposes an investigation tool; deterministic code remains responsible for final verdicts. """ (output / "README.md").write_text(card) api = HfApi() api.create_repo(args.repo_id, repo_type="model", private=True, exist_ok=True) api.upload_folder( repo_id=args.repo_id, repo_type="model", folder_path=output, commit_message="feat: publish PacketCourt fine-tuned evidence router", ) print(f"published={args.repo_id} validation_accuracy={score:.3f}") if __name__ == "__main__": main()