| |
| """Train a multi-label classifier for Terraform plan risk labels. |
| |
| Example: |
| python train.py \ |
| --dataset_path ../tfplan-risk-labels \ |
| --model_name microsoft/deberta-v3-small \ |
| --output_dir ./out \ |
| --max_length 512 \ |
| --epochs 3 |
| |
| This script expects JSONL files: train.jsonl, validation.jsonl, test.jsonl |
| and label_map.json containing {"labels":[...]}. |
| """ |
|
|
| import argparse, json, os |
| from dataclasses import dataclass |
| from typing import Dict, List |
|
|
| import numpy as np |
| from datasets import load_dataset |
| from transformers import ( |
| AutoTokenizer, AutoModelForSequenceClassification, |
| TrainingArguments, Trainer, EvalPrediction |
| ) |
| import torch |
| from torch.nn import BCEWithLogitsLoss |
|
|
| def read_labels(label_map_path: str) -> List[str]: |
| lm = json.load(open(label_map_path, "r", encoding="utf-8")) |
| return lm["labels"] |
|
|
| def encode_labels(example, label_to_id: Dict[str, int]): |
| y = np.zeros(len(label_to_id), dtype=np.float32) |
| for lab in example.get("labels", []): |
| if lab in label_to_id: |
| y[label_to_id[lab]] = 1.0 |
| example["label_vec"] = y |
| return example |
|
|
| def compute_metrics(p: EvalPrediction): |
| logits = p.predictions |
| y_true = p.label_ids |
| probs = 1.0 / (1.0 + np.exp(-logits)) |
| y_pred = (probs >= 0.5).astype(int) |
|
|
| |
| tp = (y_pred * y_true).sum() |
| fp = (y_pred * (1 - y_true)).sum() |
| fn = ((1 - y_pred) * y_true).sum() |
| micro_f1 = (2 * tp) / (2 * tp + fp + fn + 1e-9) |
|
|
| |
| f1s = [] |
| for i in range(y_true.shape[1]): |
| tpi = (y_pred[:, i] * y_true[:, i]).sum() |
| fpi = (y_pred[:, i] * (1 - y_true[:, i])).sum() |
| fni = ((1 - y_pred[:, i]) * y_true[:, i]).sum() |
| f1 = (2 * tpi) / (2 * tpi + fpi + fni + 1e-9) |
| f1s.append(f1) |
| macro_f1 = float(np.mean(f1s)) |
|
|
| return {"micro_f1": float(micro_f1), "macro_f1": macro_f1} |
|
|
| @dataclass |
| class DataCollatorMultiLabel: |
| tokenizer: AutoTokenizer |
| max_length: int |
|
|
| def __call__(self, features): |
| texts = [f["text"] for f in features] |
| enc = self.tokenizer(texts, truncation=True, padding=True, max_length=self.max_length, return_tensors="pt") |
| labels = torch.tensor([f["label_vec"] for f in features], dtype=torch.float32) |
| enc["labels"] = labels |
| return enc |
|
|
| class MultiLabelTrainer(Trainer): |
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
| labels = inputs.pop("labels") |
| outputs = model(**inputs) |
| logits = outputs.logits |
| loss_fct = BCEWithLogitsLoss() |
| loss = loss_fct(logits, labels) |
| return (loss, outputs) if return_outputs else loss |
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--dataset_path", required=True, help="Path to dataset repo folder (contains train.jsonl etc.)") |
| ap.add_argument("--model_name", default="microsoft/deberta-v3-small") |
| ap.add_argument("--output_dir", default="./out") |
| ap.add_argument("--max_length", type=int, default=512) |
| ap.add_argument("--epochs", type=int, default=3) |
| ap.add_argument("--batch_size", type=int, default=8) |
| ap.add_argument("--lr", type=float, default=2e-5) |
| args = ap.parse_args() |
|
|
| label_map_path = os.path.join(args.dataset_path, "label_map.json") |
| labels = read_labels(label_map_path) |
| label_to_id = {l: i for i, l in enumerate(labels)} |
|
|
| data_files = { |
| "train": os.path.join(args.dataset_path, "train.jsonl"), |
| "validation": os.path.join(args.dataset_path, "validation.jsonl"), |
| "test": os.path.join(args.dataset_path, "test.jsonl"), |
| } |
| ds = load_dataset("json", data_files=data_files) |
|
|
| ds = ds.map(lambda ex: encode_labels(ex, label_to_id)) |
|
|
| tok = AutoTokenizer.from_pretrained(args.model_name) |
| model = AutoModelForSequenceClassification.from_pretrained( |
| args.model_name, |
| num_labels=len(labels), |
| problem_type="multi_label_classification" |
| ) |
|
|
| collator = DataCollatorMultiLabel(tokenizer=tok, max_length=args.max_length) |
|
|
| training_args = TrainingArguments( |
| output_dir=args.output_dir, |
| eval_strategy="steps", |
| save_strategy="epoch", |
| learning_rate=args.lr, |
| per_device_train_batch_size=args.batch_size, |
| per_device_eval_batch_size=args.batch_size, |
| num_train_epochs=args.epochs, |
| weight_decay=0.01, |
| logging_steps=50, |
| report_to="none" |
| ) |
|
|
| trainer = MultiLabelTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=ds["train"], |
| eval_dataset=ds["validation"], |
| tokenizer=tok, |
| data_collator=collator, |
| compute_metrics=compute_metrics |
| ) |
|
|
| trainer.train() |
| metrics = trainer.evaluate(ds["test"]) |
| print("Test metrics:", metrics) |
|
|
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| with open(os.path.join(args.output_dir, "label_map.json"), "w", encoding="utf-8") as f: |
| json.dump({"labels": labels}, f, indent=2) |
|
|
| trainer.save_model(args.output_dir) |
| tok.save_pretrained(args.output_dir) |
|
|
| if __name__ == "__main__": |
| main() |
|
|