Bharath
add model
3e213ce
#!/usr/bin/env python3
"""Train a multi-label classifier for Terraform plan risk labels.
Example:
python train.py \
--dataset_path ../tfplan-risk-labels \
--model_name microsoft/deberta-v3-small \
--output_dir ./out \
--max_length 512 \
--epochs 3
This script expects JSONL files: train.jsonl, validation.jsonl, test.jsonl
and label_map.json containing {"labels":[...]}.
"""
import argparse, json, os
from dataclasses import dataclass
from typing import Dict, List
import numpy as np
from datasets import load_dataset
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification,
TrainingArguments, Trainer, EvalPrediction
)
import torch
from torch.nn import BCEWithLogitsLoss
def read_labels(label_map_path: str) -> List[str]:
lm = json.load(open(label_map_path, "r", encoding="utf-8"))
return lm["labels"]
def encode_labels(example, label_to_id: Dict[str, int]):
y = np.zeros(len(label_to_id), dtype=np.float32)
for lab in example.get("labels", []):
if lab in label_to_id:
y[label_to_id[lab]] = 1.0
example["label_vec"] = y
return example
def compute_metrics(p: EvalPrediction):
logits = p.predictions
y_true = p.label_ids
probs = 1.0 / (1.0 + np.exp(-logits))
y_pred = (probs >= 0.5).astype(int)
# micro F1
tp = (y_pred * y_true).sum()
fp = (y_pred * (1 - y_true)).sum()
fn = ((1 - y_pred) * y_true).sum()
micro_f1 = (2 * tp) / (2 * tp + fp + fn + 1e-9)
# macro F1
f1s = []
for i in range(y_true.shape[1]):
tpi = (y_pred[:, i] * y_true[:, i]).sum()
fpi = (y_pred[:, i] * (1 - y_true[:, i])).sum()
fni = ((1 - y_pred[:, i]) * y_true[:, i]).sum()
f1 = (2 * tpi) / (2 * tpi + fpi + fni + 1e-9)
f1s.append(f1)
macro_f1 = float(np.mean(f1s))
return {"micro_f1": float(micro_f1), "macro_f1": macro_f1}
@dataclass
class DataCollatorMultiLabel:
tokenizer: AutoTokenizer
max_length: int
def __call__(self, features):
texts = [f["text"] for f in features]
enc = self.tokenizer(texts, truncation=True, padding=True, max_length=self.max_length, return_tensors="pt")
labels = torch.tensor([f["label_vec"] for f in features], dtype=torch.float32)
enc["labels"] = labels
return enc
class MultiLabelTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
return (loss, outputs) if return_outputs else loss
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--dataset_path", required=True, help="Path to dataset repo folder (contains train.jsonl etc.)")
ap.add_argument("--model_name", default="microsoft/deberta-v3-small")
ap.add_argument("--output_dir", default="./out")
ap.add_argument("--max_length", type=int, default=512)
ap.add_argument("--epochs", type=int, default=3)
ap.add_argument("--batch_size", type=int, default=8)
ap.add_argument("--lr", type=float, default=2e-5)
args = ap.parse_args()
label_map_path = os.path.join(args.dataset_path, "label_map.json")
labels = read_labels(label_map_path)
label_to_id = {l: i for i, l in enumerate(labels)}
data_files = {
"train": os.path.join(args.dataset_path, "train.jsonl"),
"validation": os.path.join(args.dataset_path, "validation.jsonl"),
"test": os.path.join(args.dataset_path, "test.jsonl"),
}
ds = load_dataset("json", data_files=data_files)
ds = ds.map(lambda ex: encode_labels(ex, label_to_id))
tok = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name,
num_labels=len(labels),
problem_type="multi_label_classification"
)
collator = DataCollatorMultiLabel(tokenizer=tok, max_length=args.max_length)
training_args = TrainingArguments(
output_dir=args.output_dir,
eval_strategy="steps",
save_strategy="epoch",
learning_rate=args.lr,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=args.epochs,
weight_decay=0.01,
logging_steps=50,
report_to="none"
)
trainer = MultiLabelTrainer(
model=model,
args=training_args,
train_dataset=ds["train"],
eval_dataset=ds["validation"],
tokenizer=tok,
data_collator=collator,
compute_metrics=compute_metrics
)
trainer.train()
metrics = trainer.evaluate(ds["test"])
print("Test metrics:", metrics)
# Save label map for inference
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "label_map.json"), "w", encoding="utf-8") as f:
json.dump({"labels": labels}, f, indent=2)
trainer.save_model(args.output_dir)
tok.save_pretrained(args.output_dir)
if __name__ == "__main__":
main()