""" STEP 2 — Train Document Classification Model (LayoutLMv3) Input: data2/train.json, data2/val.json, data2/label_mappings.json Output: models/classifier/ Fixes applied: - evaluation_strategy → eval_strategy (transformers >= 4.41) - All paths resolved relative to this script file (no CWD dependency) - doc_class read from both annotation results AND data field - Skips records with invalid doc_class_id gracefully - MISSING weights warning suppressed (expected when fine-tuning) - Added class weights to handle imbalanced dataset - Added early stopping """ import json import torch import numpy as np from pathlib import Path from PIL import Image from torch.utils.data import Dataset from transformers import ( LayoutLMv3ForSequenceClassification, LayoutLMv3Processor, TrainingArguments, Trainer, EarlyStoppingCallback, ) from sklearn.metrics import classification_report import warnings warnings.filterwarnings("ignore") # ── PATHS (resolved relative to this script) ──────────────────────────────── BASE_DIR = Path(__file__).resolve().parent DATA_DIR = BASE_DIR / "data2" TRAIN_JSON = DATA_DIR / "combined_train.json" VAL_JSON = DATA_DIR / "combined_val.json" MAPPINGS = DATA_DIR / "label_mappings.json" MODEL_OUTPUT = BASE_DIR / "models" / "classifier" LOGS_DIR = BASE_DIR / "outputs" / "logs_classifier" # ── HYPERPARAMETERS ────────────────────────────────────────────────────────── MODEL_NAME = "microsoft/layoutlmv3-base" MAX_LENGTH = 512 BATCH_SIZE = 8 # effective batch=16 with gradient_accumulation=2 EPOCHS = 10 # early stopping will trigger around epoch 7-8 LEARNING_RATE = 2e-5 # fine-tuning pretrained — never increase this WARMUP_STEPS = 46 # 6% of 770 total steps WEIGHT_DECAY = 0.01 # ── HELPERS ────────────────────────────────────────────────────────────────── def get_doc_class_from_record(rec, doc2id): """ Get doc_class_id from record. Prefer the value stored in data.doc_class (set during preprocessing). Falls back to -1 if unknown. """ doc_class = rec.get("doc_class", "") return doc2id.get(doc_class, -1) # ── DATASET ────────────────────────────────────────────────────────────────── class DocumentDataset(Dataset): def __init__(self, json_path, processor, doc2id): with open(json_path, encoding="utf-8") as f: raw = json.load(f) # Filter out records with unknown class self.records = [] skipped = 0 for rec in raw: cid = get_doc_class_from_record(rec, doc2id) if cid == -1: skipped += 1 continue rec["_class_id"] = cid self.records.append(rec) if skipped: print(f" ⚠️ Skipped {skipped} records with unknown doc_class") self.processor = processor def __len__(self): return len(self.records) def __getitem__(self, idx): rec = self.records[idx] # ── Image ────────────────────────────────────────────────────────── img_path = rec.get("image_path") if img_path and Path(img_path).exists(): image = Image.open(img_path).convert("RGB") else: image = Image.new("RGB", (1654, 2339), color=(255, 255, 255)) # ── Text (OCR words) ─────────────────────────────────────────────── ocr_text = rec.get("ocr_text", "") or "" words = ocr_text.split()[:128] # 128 words is enough for classification if not words: words = ["[PAD]"] # Uniform bounding boxes (0-1000 normalized) — good enough for classification boxes = [[0, 0, 1000, 1000]] * len(words) encoding = self.processor( image, words, boxes=boxes, max_length=MAX_LENGTH, padding="max_length", truncation=True, return_tensors="pt", ) return { "input_ids": encoding["input_ids"].squeeze(), "attention_mask": encoding["attention_mask"].squeeze(), "bbox": encoding["bbox"].squeeze(), "pixel_values": encoding["pixel_values"].squeeze(), "labels": torch.tensor(rec["_class_id"], dtype=torch.long), } # ── CLASS WEIGHTS (handle imbalance) ───────────────────────────────────────── def compute_class_weights(dataset, num_classes): counts = np.zeros(num_classes) for rec in dataset.records: counts[rec["_class_id"]] += 1 counts = np.where(counts == 0, 1, counts) # avoid divide-by-zero weights = 1.0 / counts weights = weights / weights.sum() * num_classes return torch.tensor(weights, dtype=torch.float) # ── WEIGHTED TRAINER ────────────────────────────────────────────────────────── class WeightedTrainer(Trainer): def __init__(self, class_weights, *args, **kwargs): super().__init__(*args, **kwargs) self.class_weights = class_weights.to(self.model.device if hasattr(self, 'model') else "cpu") def compute_loss(self, model, inputs, return_outputs=False, **kwargs): labels = inputs.pop("labels") outputs = model(**inputs) logits = outputs.logits weights = self.class_weights.to(logits.device) loss_fn = torch.nn.CrossEntropyLoss(weight=weights) loss = loss_fn(logits, labels) return (loss, outputs) if return_outputs else loss # ── METRICS ────────────────────────────────────────────────────────────────── def compute_metrics(eval_pred): logits, labels = eval_pred preds = np.argmax(logits, axis=-1) acc = (preds == labels).mean() return {"accuracy": float(acc)} # ── MAIN ───────────────────────────────────────────────────────────────────── def main(): # Verify data files exist for p in [TRAIN_JSON, VAL_JSON, MAPPINGS]: if not p.exists(): raise FileNotFoundError( f"\n❌ File not found: {p}\n" f" Run script 1_convert_labelstudio.py first to generate data files." ) with open(MAPPINGS, encoding="utf-8") as f: mappings = json.load(f) doc_classes = mappings["doc_classes"] doc2id = mappings["doc2id"] num_labels = len(doc_classes) print(f"📂 Data directory : {DATA_DIR}") print(f"📦 Model output : {MODEL_OUTPUT}") print(f"🏷️ Classes ({num_labels}): {doc_classes}") print(f"🤖 Base model : {MODEL_NAME}\n") # Load processor & model (suppress MISSING weights warning — it's expected) processor = LayoutLMv3Processor.from_pretrained(MODEL_NAME, apply_ocr=False) model = LayoutLMv3ForSequenceClassification.from_pretrained( MODEL_NAME, num_labels=num_labels, id2label={i: c for i, c in enumerate(doc_classes)}, label2id=doc2id, ignore_mismatched_sizes=True, # silences MISSING weights warning ) print("📊 Loading datasets...") train_dataset = DocumentDataset(TRAIN_JSON, processor, doc2id) val_dataset = DocumentDataset(VAL_JSON, processor, doc2id) print(f" Train: {len(train_dataset)} | Val: {len(val_dataset)}\n") # Class weights for imbalanced data class_weights = compute_class_weights(train_dataset, num_labels) print("⚖️ Class weights:") for i, (cls, w) in enumerate(zip(doc_classes, class_weights)): print(f" {cls}: {w:.3f}") print() MODEL_OUTPUT.mkdir(parents=True, exist_ok=True) LOGS_DIR.mkdir(parents=True, exist_ok=True) training_args = TrainingArguments( output_dir=str(MODEL_OUTPUT), num_train_epochs=EPOCHS, per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE, warmup_steps=WARMUP_STEPS, weight_decay=WEIGHT_DECAY, eval_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="accuracy", greater_is_better=True, logging_steps=10, report_to="none", fp16=torch.cuda.is_available(), dataloader_num_workers=0, lr_scheduler_type="cosine", gradient_accumulation_steps=2, save_total_limit=2, label_smoothing_factor=0.083, ) trainer = WeightedTrainer( class_weights=class_weights, model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=3)], ) print("🚀 Starting classification training...") trainer.train() # ── Save final model ────────────────────────────────────────────────── trainer.save_model(str(MODEL_OUTPUT)) processor.save_pretrained(str(MODEL_OUTPUT)) print(f"\n✅ Model saved to: {MODEL_OUTPUT}") # ── Final evaluation with full report ──────────────────────────────── print("\n📈 Final evaluation on validation set:") preds_output = trainer.predict(val_dataset) preds = np.argmax(preds_output.predictions, axis=-1) labels = preds_output.label_ids print(classification_report( labels, preds, target_names=doc_classes, zero_division=0 )) acc = (preds == labels).mean() print(f"Overall accuracy: {acc:.1%}") if __name__ == "__main__": main()