Spaces:
Sleeping
Sleeping
| """ | |
| STEP 2 β Train Document Classification Model (LayoutLMv3) | |
| Input: data2/train.json, data2/val.json, data2/label_mappings.json | |
| Output: models/classifier/ | |
| Fixes applied: | |
| - evaluation_strategy β eval_strategy (transformers >= 4.41) | |
| - All paths resolved relative to this script file (no CWD dependency) | |
| - doc_class read from both annotation results AND data field | |
| - Skips records with invalid doc_class_id gracefully | |
| - MISSING weights warning suppressed (expected when fine-tuning) | |
| - Added class weights to handle imbalanced dataset | |
| - Added early stopping | |
| """ | |
| import json | |
| import torch | |
| import numpy as np | |
| from pathlib import Path | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from transformers import ( | |
| LayoutLMv3ForSequenceClassification, | |
| LayoutLMv3Processor, | |
| TrainingArguments, | |
| Trainer, | |
| EarlyStoppingCallback, | |
| ) | |
| from sklearn.metrics import classification_report | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # ββ PATHS (resolved relative to this script) ββββββββββββββββββββββββββββββββ | |
| BASE_DIR = Path(__file__).resolve().parent | |
| DATA_DIR = BASE_DIR / "data2" | |
| TRAIN_JSON = DATA_DIR / "combined_train.json" | |
| VAL_JSON = DATA_DIR / "combined_val.json" | |
| MAPPINGS = DATA_DIR / "label_mappings.json" | |
| MODEL_OUTPUT = BASE_DIR / "models" / "classifier" | |
| LOGS_DIR = BASE_DIR / "outputs" / "logs_classifier" | |
| # ββ HYPERPARAMETERS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_NAME = "microsoft/layoutlmv3-base" | |
| MAX_LENGTH = 512 | |
| BATCH_SIZE = 8 # effective batch=16 with gradient_accumulation=2 | |
| EPOCHS = 10 # early stopping will trigger around epoch 7-8 | |
| LEARNING_RATE = 2e-5 # fine-tuning pretrained β never increase this | |
| WARMUP_STEPS = 46 # 6% of 770 total steps | |
| WEIGHT_DECAY = 0.01 | |
| # ββ HELPERS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_doc_class_from_record(rec, doc2id): | |
| """ | |
| Get doc_class_id from record. | |
| Prefer the value stored in data.doc_class (set during preprocessing). | |
| Falls back to -1 if unknown. | |
| """ | |
| doc_class = rec.get("doc_class", "") | |
| return doc2id.get(doc_class, -1) | |
| # ββ DATASET ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class DocumentDataset(Dataset): | |
| def __init__(self, json_path, processor, doc2id): | |
| with open(json_path, encoding="utf-8") as f: | |
| raw = json.load(f) | |
| # Filter out records with unknown class | |
| self.records = [] | |
| skipped = 0 | |
| for rec in raw: | |
| cid = get_doc_class_from_record(rec, doc2id) | |
| if cid == -1: | |
| skipped += 1 | |
| continue | |
| rec["_class_id"] = cid | |
| self.records.append(rec) | |
| if skipped: | |
| print(f" β οΈ Skipped {skipped} records with unknown doc_class") | |
| self.processor = processor | |
| def __len__(self): | |
| return len(self.records) | |
| def __getitem__(self, idx): | |
| rec = self.records[idx] | |
| # ββ Image ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| img_path = rec.get("image_path") | |
| if img_path and Path(img_path).exists(): | |
| image = Image.open(img_path).convert("RGB") | |
| else: | |
| image = Image.new("RGB", (1654, 2339), color=(255, 255, 255)) | |
| # ββ Text (OCR words) βββββββββββββββββββββββββββββββββββββββββββββββ | |
| ocr_text = rec.get("ocr_text", "") or "" | |
| words = ocr_text.split()[:128] # 128 words is enough for classification | |
| if not words: | |
| words = ["[PAD]"] | |
| # Uniform bounding boxes (0-1000 normalized) β good enough for classification | |
| boxes = [[0, 0, 1000, 1000]] * len(words) | |
| encoding = self.processor( | |
| image, | |
| words, | |
| boxes=boxes, | |
| max_length=MAX_LENGTH, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| return { | |
| "input_ids": encoding["input_ids"].squeeze(), | |
| "attention_mask": encoding["attention_mask"].squeeze(), | |
| "bbox": encoding["bbox"].squeeze(), | |
| "pixel_values": encoding["pixel_values"].squeeze(), | |
| "labels": torch.tensor(rec["_class_id"], dtype=torch.long), | |
| } | |
| # ββ CLASS WEIGHTS (handle imbalance) βββββββββββββββββββββββββββββββββββββββββ | |
| def compute_class_weights(dataset, num_classes): | |
| counts = np.zeros(num_classes) | |
| for rec in dataset.records: | |
| counts[rec["_class_id"]] += 1 | |
| counts = np.where(counts == 0, 1, counts) # avoid divide-by-zero | |
| weights = 1.0 / counts | |
| weights = weights / weights.sum() * num_classes | |
| return torch.tensor(weights, dtype=torch.float) | |
| # ββ WEIGHTED TRAINER ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class WeightedTrainer(Trainer): | |
| def __init__(self, class_weights, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.class_weights = class_weights.to(self.model.device if hasattr(self, 'model') else "cpu") | |
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): | |
| labels = inputs.pop("labels") | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| weights = self.class_weights.to(logits.device) | |
| loss_fn = torch.nn.CrossEntropyLoss(weight=weights) | |
| loss = loss_fn(logits, labels) | |
| return (loss, outputs) if return_outputs else loss | |
| # ββ METRICS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| preds = np.argmax(logits, axis=-1) | |
| acc = (preds == labels).mean() | |
| return {"accuracy": float(acc)} | |
| # ββ MAIN βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| # Verify data files exist | |
| for p in [TRAIN_JSON, VAL_JSON, MAPPINGS]: | |
| if not p.exists(): | |
| raise FileNotFoundError( | |
| f"\nβ File not found: {p}\n" | |
| f" Run script 1_convert_labelstudio.py first to generate data files." | |
| ) | |
| with open(MAPPINGS, encoding="utf-8") as f: | |
| mappings = json.load(f) | |
| doc_classes = mappings["doc_classes"] | |
| doc2id = mappings["doc2id"] | |
| num_labels = len(doc_classes) | |
| print(f"π Data directory : {DATA_DIR}") | |
| print(f"π¦ Model output : {MODEL_OUTPUT}") | |
| print(f"π·οΈ Classes ({num_labels}): {doc_classes}") | |
| print(f"π€ Base model : {MODEL_NAME}\n") | |
| # Load processor & model (suppress MISSING weights warning β it's expected) | |
| processor = LayoutLMv3Processor.from_pretrained(MODEL_NAME, apply_ocr=False) | |
| model = LayoutLMv3ForSequenceClassification.from_pretrained( | |
| MODEL_NAME, | |
| num_labels=num_labels, | |
| id2label={i: c for i, c in enumerate(doc_classes)}, | |
| label2id=doc2id, | |
| ignore_mismatched_sizes=True, # silences MISSING weights warning | |
| ) | |
| print("π Loading datasets...") | |
| train_dataset = DocumentDataset(TRAIN_JSON, processor, doc2id) | |
| val_dataset = DocumentDataset(VAL_JSON, processor, doc2id) | |
| print(f" Train: {len(train_dataset)} | Val: {len(val_dataset)}\n") | |
| # Class weights for imbalanced data | |
| class_weights = compute_class_weights(train_dataset, num_labels) | |
| print("βοΈ Class weights:") | |
| for i, (cls, w) in enumerate(zip(doc_classes, class_weights)): | |
| print(f" {cls}: {w:.3f}") | |
| print() | |
| MODEL_OUTPUT.mkdir(parents=True, exist_ok=True) | |
| LOGS_DIR.mkdir(parents=True, exist_ok=True) | |
| training_args = TrainingArguments( | |
| output_dir=str(MODEL_OUTPUT), | |
| num_train_epochs=EPOCHS, | |
| per_device_train_batch_size=BATCH_SIZE, | |
| per_device_eval_batch_size=BATCH_SIZE, | |
| learning_rate=LEARNING_RATE, | |
| warmup_steps=WARMUP_STEPS, | |
| weight_decay=WEIGHT_DECAY, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model="accuracy", | |
| greater_is_better=True, | |
| logging_steps=10, | |
| report_to="none", | |
| fp16=torch.cuda.is_available(), | |
| dataloader_num_workers=0, | |
| lr_scheduler_type="cosine", | |
| gradient_accumulation_steps=2, | |
| save_total_limit=2, | |
| label_smoothing_factor=0.083, | |
| ) | |
| trainer = WeightedTrainer( | |
| class_weights=class_weights, | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| compute_metrics=compute_metrics, | |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=3)], | |
| ) | |
| print("π Starting classification training...") | |
| trainer.train() | |
| # ββ Save final model ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| trainer.save_model(str(MODEL_OUTPUT)) | |
| processor.save_pretrained(str(MODEL_OUTPUT)) | |
| print(f"\nβ Model saved to: {MODEL_OUTPUT}") | |
| # ββ Final evaluation with full report ββββββββββββββββββββββββββββββββ | |
| print("\nπ Final evaluation on validation set:") | |
| preds_output = trainer.predict(val_dataset) | |
| preds = np.argmax(preds_output.predictions, axis=-1) | |
| labels = preds_output.label_ids | |
| print(classification_report( | |
| labels, preds, | |
| target_names=doc_classes, | |
| zero_division=0 | |
| )) | |
| acc = (preds == labels).mean() | |
| print(f"Overall accuracy: {acc:.1%}") | |
| if __name__ == "__main__": | |
| main() |