|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import numpy as np |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, Any, Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from datasets import load_dataset |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForSequenceClassification, |
|
|
DataCollatorWithPadding, |
|
|
TrainingArguments, |
|
|
Trainer, |
|
|
) |
|
|
import evaluate |
|
|
from sklearn.utils.class_weight import compute_class_weight |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME = os.getenv("MODEL_NAME", "google/bert_uncased_L-2_H-128_A-2") |
|
|
HUB_REPO = os.getenv("HUB_REPO", "tlogandesigns/fairhousing-bert-tiny") |
|
|
MAX_LEN = int(os.getenv("MAX_LEN", "256")) |
|
|
|
|
|
TRAIN_PATH = os.getenv("TRAIN_PATH", "train.csv") |
|
|
VAL_PATH = os.getenv("VAL_PATH", "val.csv") |
|
|
|
|
|
|
|
|
id2label = {0: "Compliant", 1: "Potential Violation"} |
|
|
label2id = {v: k for k, v in id2label.items()} |
|
|
|
|
|
|
|
|
accuracy = evaluate.load("accuracy") |
|
|
precision = evaluate.load("precision") |
|
|
recall = evaluate.load("recall") |
|
|
f1 = evaluate.load("f1") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_files = {"train": TRAIN_PATH, "validation": VAL_PATH} |
|
|
raw = load_dataset("csv", data_files=data_files) |
|
|
|
|
|
|
|
|
|
|
|
def cast_label(example): |
|
|
example["label"] = int(example["label"]) |
|
|
return example |
|
|
|
|
|
raw = raw.map(cast_label) |
|
|
|
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) |
|
|
except Exception: |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
|
|
|
|
|
|
def tokenize(batch): |
|
|
return tokenizer( |
|
|
batch["text"], |
|
|
truncation=True, |
|
|
padding=False, |
|
|
max_length=MAX_LEN, |
|
|
) |
|
|
|
|
|
|
|
|
tok = raw.map(tokenize, batched=True) |
|
|
|
|
|
|
|
|
if "label" in tok["train"].column_names: |
|
|
tok = tok.rename_column("label", "labels") |
|
|
|
|
|
|
|
|
keep = {"input_ids", "attention_mask", "token_type_ids", "labels"} |
|
|
cols = tok["train"].column_names |
|
|
remove_cols = [c for c in cols if c not in keep] |
|
|
if remove_cols: |
|
|
tok = tok.remove_columns(remove_cols) |
|
|
|
|
|
|
|
|
tok.set_format(type="torch") |
|
|
|
|
|
train_ds = tok["train"] |
|
|
val_ds = tok["validation"] |
|
|
|
|
|
|
|
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y_train = train_ds["labels"].detach().cpu().numpy() if hasattr(train_ds["labels"], "detach") else np.array(train_ds["labels"]) |
|
|
unique = np.unique(y_train) |
|
|
|
|
|
|
|
|
if set(unique) - {0, 1}: |
|
|
|
|
|
def to_bin(v: int) -> int: |
|
|
return 0 if int(v) == 0 else 1 |
|
|
raw = raw.map(lambda ex: {"label": to_bin(int(ex["label"]))}) |
|
|
tok = raw.map(tokenize, batched=True) |
|
|
if "label" in tok["train"].column_names: |
|
|
tok = tok.rename_column("label", "labels") |
|
|
cols = tok["train"].column_names |
|
|
remove_cols = [c for c in cols if c not in keep] |
|
|
if remove_cols: |
|
|
tok = tok.remove_columns(remove_cols) |
|
|
tok.set_format(type="torch") |
|
|
train_ds = tok["train"] |
|
|
val_ds = tok["validation"] |
|
|
y_train = train_ds["labels"].detach().cpu().numpy() if hasattr(train_ds["labels"], "detach") else np.array(train_ds["labels"]) |
|
|
unique = np.unique(y_train) |
|
|
|
|
|
assert set(unique) <= {0, 1}, f"Found unexpected labels: {unique} — ensure CSVs are binary 0/1." |
|
|
|
|
|
|
|
|
CW_ENV = os.getenv("CLASS_WEIGHTS", "") |
|
|
if CW_ENV: |
|
|
cw = np.array([float(x) for x in CW_ENV.split(",")], dtype=np.float32) |
|
|
assert cw.shape[0] == 2, "CLASS_WEIGHTS must have 2 values for binary classification." |
|
|
else: |
|
|
cw = compute_class_weight(class_weight="balanced", classes=np.array([0, 1]), y=y_train).astype(np.float32) |
|
|
class_weights_tensor = torch.tensor(cw, dtype=torch.float32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.set_num_threads(max(1, (os.cpu_count() or 2) // 2)) |
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
MODEL_NAME, |
|
|
num_labels=2, |
|
|
id2label=id2label, |
|
|
label2id={k: v for v, k in id2label.items()}, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WeightedTrainer(Trainer): |
|
|
def __init__(self, *args, class_weights: Optional[torch.Tensor] = None, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.class_weights = class_weights |
|
|
self._ce_loss: Optional[nn.Module] = None |
|
|
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
|
|
labels = inputs.pop("labels", None) |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.get("logits") |
|
|
if labels is None: |
|
|
loss = outputs["loss"] if "loss" in outputs else None |
|
|
else: |
|
|
if self._ce_loss is None: |
|
|
if self.class_weights is not None: |
|
|
self._ce_loss = nn.CrossEntropyLoss(weight=self.class_weights.to(model.device)) |
|
|
else: |
|
|
self._ce_loss = nn.CrossEntropyLoss() |
|
|
if labels.dtype != torch.long: |
|
|
labels = labels.to(torch.long) |
|
|
loss = self._ce_loss(logits, labels) |
|
|
return (loss, outputs) if return_outputs else loss |
|
|
|
|
|
|
|
|
def compute_metrics(eval_pred): |
|
|
logits, labels = eval_pred |
|
|
preds = np.argmax(logits, axis=1) |
|
|
return { |
|
|
"accuracy": accuracy.compute(predictions=preds, references=labels)["accuracy"], |
|
|
"precision": precision.compute(predictions=preds, references=labels, average="binary")["precision"], |
|
|
"recall": recall.compute(predictions=preds, references=labels, average="binary")["recall"], |
|
|
"f1": f1.compute(predictions=preds, references=labels, average="binary")["f1"], |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args = TrainingArguments( |
|
|
output_dir="runs", |
|
|
eval_strategy="epoch", |
|
|
save_strategy="epoch", |
|
|
logging_strategy="steps", |
|
|
logging_steps=50, |
|
|
|
|
|
per_device_train_batch_size=16, |
|
|
per_device_eval_batch_size=32, |
|
|
gradient_accumulation_steps=2, |
|
|
num_train_epochs=5, |
|
|
learning_rate=3e-5, |
|
|
warmup_ratio=0.1, |
|
|
weight_decay=0.01, |
|
|
|
|
|
load_best_model_at_end=True, |
|
|
metric_for_best_model="f1", |
|
|
greater_is_better=True, |
|
|
|
|
|
report_to=[], |
|
|
seed=42, |
|
|
dataloader_pin_memory=False, |
|
|
|
|
|
push_to_hub=bool(HUB_REPO), |
|
|
hub_model_id=HUB_REPO, |
|
|
hub_private_repo=False, |
|
|
) |
|
|
|
|
|
trainer = WeightedTrainer( |
|
|
model=model, |
|
|
args=args, |
|
|
train_dataset=train_ds, |
|
|
eval_dataset=val_ds, |
|
|
tokenizer=tokenizer, |
|
|
data_collator=data_collator, |
|
|
class_weights=class_weights_tensor, |
|
|
compute_metrics=compute_metrics, |
|
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
metrics = trainer.evaluate() |
|
|
print("Eval:", metrics) |
|
|
|
|
|
trainer.save_model("model") |
|
|
try: |
|
|
tokenizer.save_pretrained("model") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if HUB_REPO: |
|
|
try: |
|
|
trainer.push_to_hub() |
|
|
tokenizer.push_to_hub(HUB_REPO) |
|
|
print(f"Pushed model to {HUB_REPO}") |
|
|
except Exception as e: |
|
|
print(f"Skipping hub push: {e}") |
|
|
else: |
|
|
print("No hub repo specified, model not pushed.") |
|
|
|