Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import random | |
| from pathlib import Path | |
| import torch | |
| from huggingface_hub import HfApi | |
| from torch.utils.data import DataLoader, Dataset | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| ROOT = Path(__file__).resolve().parents[1] | |
| LABELS = ["ingredients", "nutrition", "license", "dates", "refuse_absolute"] | |
| LABEL_TO_ID = {label: index for index, label in enumerate(LABELS)} | |
| class RouterDataset(Dataset): | |
| def __init__(self, records, tokenizer): | |
| self.records = records | |
| self.tokenizer = tokenizer | |
| def __len__(self): | |
| return len(self.records) | |
| def __getitem__(self, index): | |
| record = self.records[index] | |
| encoded = self.tokenizer( | |
| record["text"], | |
| padding="max_length", | |
| truncation=True, | |
| max_length=32, | |
| return_tensors="pt", | |
| ) | |
| return { | |
| "input_ids": encoded["input_ids"].squeeze(0), | |
| "attention_mask": encoded["attention_mask"].squeeze(0), | |
| "labels": torch.tensor(LABEL_TO_ID[record["label"]]), | |
| } | |
| def evaluate(model, loader, device): | |
| model.eval() | |
| correct = total = 0 | |
| with torch.no_grad(): | |
| for batch in loader: | |
| labels = batch.pop("labels").to(device) | |
| logits = model(**{key: value.to(device) for key, value in batch.items()}).logits | |
| correct += (logits.argmax(dim=-1) == labels).sum().item() | |
| total += labels.numel() | |
| return correct / total | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--repo-id", default="build-small-hackathon/packetcourt-evidence-router") | |
| parser.add_argument("--base-model", default="google/bert_uncased_L-2_H-128_A-2") | |
| parser.add_argument("--epochs", type=int, default=30) | |
| args = parser.parse_args() | |
| random.seed(42) | |
| torch.manual_seed(42) | |
| records = [json.loads(line) for line in (ROOT / "data/router_training.jsonl").read_text().splitlines()] | |
| grouped = {label: [] for label in LABELS} | |
| for record in records: | |
| grouped[record["label"]].append(record) | |
| for group in grouped.values(): | |
| random.shuffle(group) | |
| validation = [group.pop() for group in grouped.values()] | |
| training = [record for group in grouped.values() for record in group] | |
| random.shuffle(training) | |
| tokenizer = AutoTokenizer.from_pretrained(args.base_model) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| args.base_model, | |
| num_labels=len(LABELS), | |
| id2label={index: label for index, label in enumerate(LABELS)}, | |
| label2id=LABEL_TO_ID, | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| train_loader = DataLoader(RouterDataset(training, tokenizer), batch_size=8, shuffle=True) | |
| validation_loader = DataLoader(RouterDataset(validation, tokenizer), batch_size=5) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) | |
| for epoch in range(args.epochs): | |
| model.train() | |
| for batch in train_loader: | |
| optimizer.zero_grad() | |
| labels = batch.pop("labels").to(device) | |
| loss = model(**{key: value.to(device) for key, value in batch.items()}, labels=labels).loss | |
| loss.backward() | |
| optimizer.step() | |
| print(f"epoch={epoch + 1} validation_accuracy={evaluate(model, validation_loader, device):.3f}") | |
| output = ROOT / "router_model" | |
| model.save_pretrained(output) | |
| tokenizer.save_pretrained(output) | |
| score = evaluate(model, validation_loader, device) | |
| card = f"""--- | |
| license: apache-2.0 | |
| base_model: {args.base_model} | |
| tags: | |
| - text-classification | |
| - build-small-hackathon | |
| - packetcourt | |
| - fine-tuned | |
| --- | |
| # PacketCourt Evidence Router | |
| A {sum(parameter.numel() for parameter in model.parameters()):,}-parameter fine-tuned classifier used by | |
| PacketCourt's investigation agent to choose the next evidence tool for a packet claim. | |
| Labels: `{", ".join(LABELS)}`. | |
| Held-out validation accuracy: `{score:.3f}` on a small PacketCourt-specific routing set. | |
| The router proposes an investigation tool; deterministic code remains responsible for final verdicts. | |
| """ | |
| (output / "README.md").write_text(card) | |
| api = HfApi() | |
| api.create_repo(args.repo_id, repo_type="model", private=True, exist_ok=True) | |
| api.upload_folder( | |
| repo_id=args.repo_id, | |
| repo_type="model", | |
| folder_path=output, | |
| commit_message="feat: publish PacketCourt fine-tuned evidence router", | |
| ) | |
| print(f"published={args.repo_id} validation_accuracy={score:.3f}") | |
| if __name__ == "__main__": | |
| main() | |