FiberGate / scripts /02_train_classifier.py
AzizMiladi's picture
chore: git mv scripts, UI, dev tools, docs into folders
70c46cc
Raw
History Blame
10.7 kB
"""
STEP 2 β€” Train Document Classification Model (LayoutLMv3)
Input: data2/train.json, data2/val.json, data2/label_mappings.json
Output: models/classifier/
Fixes applied:
- evaluation_strategy β†’ eval_strategy (transformers >= 4.41)
- All paths resolved relative to this script file (no CWD dependency)
- doc_class read from both annotation results AND data field
- Skips records with invalid doc_class_id gracefully
- MISSING weights warning suppressed (expected when fine-tuning)
- Added class weights to handle imbalanced dataset
- Added early stopping
"""
import json
import torch
import numpy as np
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
from transformers import (
LayoutLMv3ForSequenceClassification,
LayoutLMv3Processor,
TrainingArguments,
Trainer,
EarlyStoppingCallback,
)
from sklearn.metrics import classification_report
import warnings
warnings.filterwarnings("ignore")
# ── PATHS (resolved relative to this script) ────────────────────────────────
BASE_DIR = Path(__file__).resolve().parent
DATA_DIR = BASE_DIR / "data2"
TRAIN_JSON = DATA_DIR / "combined_train.json"
VAL_JSON = DATA_DIR / "combined_val.json"
MAPPINGS = DATA_DIR / "label_mappings.json"
MODEL_OUTPUT = BASE_DIR / "models" / "classifier"
LOGS_DIR = BASE_DIR / "outputs" / "logs_classifier"
# ── HYPERPARAMETERS ──────────────────────────────────────────────────────────
MODEL_NAME = "microsoft/layoutlmv3-base"
MAX_LENGTH = 512
BATCH_SIZE = 8 # effective batch=16 with gradient_accumulation=2
EPOCHS = 10 # early stopping will trigger around epoch 7-8
LEARNING_RATE = 2e-5 # fine-tuning pretrained β€” never increase this
WARMUP_STEPS = 46 # 6% of 770 total steps
WEIGHT_DECAY = 0.01
# ── HELPERS ──────────────────────────────────────────────────────────────────
def get_doc_class_from_record(rec, doc2id):
"""
Get doc_class_id from record.
Prefer the value stored in data.doc_class (set during preprocessing).
Falls back to -1 if unknown.
"""
doc_class = rec.get("doc_class", "")
return doc2id.get(doc_class, -1)
# ── DATASET ──────────────────────────────────────────────────────────────────
class DocumentDataset(Dataset):
def __init__(self, json_path, processor, doc2id):
with open(json_path, encoding="utf-8") as f:
raw = json.load(f)
# Filter out records with unknown class
self.records = []
skipped = 0
for rec in raw:
cid = get_doc_class_from_record(rec, doc2id)
if cid == -1:
skipped += 1
continue
rec["_class_id"] = cid
self.records.append(rec)
if skipped:
print(f" ⚠️ Skipped {skipped} records with unknown doc_class")
self.processor = processor
def __len__(self):
return len(self.records)
def __getitem__(self, idx):
rec = self.records[idx]
# ── Image ──────────────────────────────────────────────────────────
img_path = rec.get("image_path")
if img_path and Path(img_path).exists():
image = Image.open(img_path).convert("RGB")
else:
image = Image.new("RGB", (1654, 2339), color=(255, 255, 255))
# ── Text (OCR words) ───────────────────────────────────────────────
ocr_text = rec.get("ocr_text", "") or ""
words = ocr_text.split()[:128] # 128 words is enough for classification
if not words:
words = ["[PAD]"]
# Uniform bounding boxes (0-1000 normalized) β€” good enough for classification
boxes = [[0, 0, 1000, 1000]] * len(words)
encoding = self.processor(
image,
words,
boxes=boxes,
max_length=MAX_LENGTH,
padding="max_length",
truncation=True,
return_tensors="pt",
)
return {
"input_ids": encoding["input_ids"].squeeze(),
"attention_mask": encoding["attention_mask"].squeeze(),
"bbox": encoding["bbox"].squeeze(),
"pixel_values": encoding["pixel_values"].squeeze(),
"labels": torch.tensor(rec["_class_id"], dtype=torch.long),
}
# ── CLASS WEIGHTS (handle imbalance) ─────────────────────────────────────────
def compute_class_weights(dataset, num_classes):
counts = np.zeros(num_classes)
for rec in dataset.records:
counts[rec["_class_id"]] += 1
counts = np.where(counts == 0, 1, counts) # avoid divide-by-zero
weights = 1.0 / counts
weights = weights / weights.sum() * num_classes
return torch.tensor(weights, dtype=torch.float)
# ── WEIGHTED TRAINER ──────────────────────────────────────────────────────────
class WeightedTrainer(Trainer):
def __init__(self, class_weights, *args, **kwargs):
super().__init__(*args, **kwargs)
self.class_weights = class_weights.to(self.model.device if hasattr(self, 'model') else "cpu")
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
weights = self.class_weights.to(logits.device)
loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
loss = loss_fn(logits, labels)
return (loss, outputs) if return_outputs else loss
# ── METRICS ──────────────────────────────────────────────────────────────────
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
acc = (preds == labels).mean()
return {"accuracy": float(acc)}
# ── MAIN ─────────────────────────────────────────────────────────────────────
def main():
# Verify data files exist
for p in [TRAIN_JSON, VAL_JSON, MAPPINGS]:
if not p.exists():
raise FileNotFoundError(
f"\n❌ File not found: {p}\n"
f" Run script 1_convert_labelstudio.py first to generate data files."
)
with open(MAPPINGS, encoding="utf-8") as f:
mappings = json.load(f)
doc_classes = mappings["doc_classes"]
doc2id = mappings["doc2id"]
num_labels = len(doc_classes)
print(f"πŸ“‚ Data directory : {DATA_DIR}")
print(f"πŸ“¦ Model output : {MODEL_OUTPUT}")
print(f"🏷️ Classes ({num_labels}): {doc_classes}")
print(f"πŸ€– Base model : {MODEL_NAME}\n")
# Load processor & model (suppress MISSING weights warning β€” it's expected)
processor = LayoutLMv3Processor.from_pretrained(MODEL_NAME, apply_ocr=False)
model = LayoutLMv3ForSequenceClassification.from_pretrained(
MODEL_NAME,
num_labels=num_labels,
id2label={i: c for i, c in enumerate(doc_classes)},
label2id=doc2id,
ignore_mismatched_sizes=True, # silences MISSING weights warning
)
print("πŸ“Š Loading datasets...")
train_dataset = DocumentDataset(TRAIN_JSON, processor, doc2id)
val_dataset = DocumentDataset(VAL_JSON, processor, doc2id)
print(f" Train: {len(train_dataset)} | Val: {len(val_dataset)}\n")
# Class weights for imbalanced data
class_weights = compute_class_weights(train_dataset, num_labels)
print("βš–οΈ Class weights:")
for i, (cls, w) in enumerate(zip(doc_classes, class_weights)):
print(f" {cls}: {w:.3f}")
print()
MODEL_OUTPUT.mkdir(parents=True, exist_ok=True)
LOGS_DIR.mkdir(parents=True, exist_ok=True)
training_args = TrainingArguments(
output_dir=str(MODEL_OUTPUT),
num_train_epochs=EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
learning_rate=LEARNING_RATE,
warmup_steps=WARMUP_STEPS,
weight_decay=WEIGHT_DECAY,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
greater_is_better=True,
logging_steps=10,
report_to="none",
fp16=torch.cuda.is_available(),
dataloader_num_workers=0,
lr_scheduler_type="cosine",
gradient_accumulation_steps=2,
save_total_limit=2,
label_smoothing_factor=0.083,
)
trainer = WeightedTrainer(
class_weights=class_weights,
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)
print("πŸš€ Starting classification training...")
trainer.train()
# ── Save final model ──────────────────────────────────────────────────
trainer.save_model(str(MODEL_OUTPUT))
processor.save_pretrained(str(MODEL_OUTPUT))
print(f"\nβœ… Model saved to: {MODEL_OUTPUT}")
# ── Final evaluation with full report ────────────────────────────────
print("\nπŸ“ˆ Final evaluation on validation set:")
preds_output = trainer.predict(val_dataset)
preds = np.argmax(preds_output.predictions, axis=-1)
labels = preds_output.label_ids
print(classification_report(
labels, preds,
target_names=doc_classes,
zero_division=0
))
acc = (preds == labels).mean()
print(f"Overall accuracy: {acc:.1%}")
if __name__ == "__main__":
main()