File size: 10,706 Bytes
f4c0357
 
33ddb61
f4c0357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33ddb61
 
 
f4c0357
 
 
 
 
33ddb61
f4c0357
33ddb61
 
 
 
f4c0357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33ddb61
f4c0357
33ddb61
f4c0357
 
 
 
 
 
 
 
33ddb61
 
 
 
f4c0357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
"""
STEP 2 β€” Train Document Classification Model (LayoutLMv3)
Input:  data2/train.json, data2/val.json, data2/label_mappings.json
Output: models/classifier/

Fixes applied:
  - evaluation_strategy β†’ eval_strategy (transformers >= 4.41)
  - All paths resolved relative to this script file (no CWD dependency)
  - doc_class read from both annotation results AND data field
  - Skips records with invalid doc_class_id gracefully
  - MISSING weights warning suppressed (expected when fine-tuning)
  - Added class weights to handle imbalanced dataset
  - Added early stopping
"""

import json
import torch
import numpy as np
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
from transformers import (
    LayoutLMv3ForSequenceClassification,
    LayoutLMv3Processor,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
)
from sklearn.metrics import classification_report
import warnings
warnings.filterwarnings("ignore")

# ── PATHS (resolved relative to this script) ────────────────────────────────
BASE_DIR     = Path(__file__).resolve().parent
DATA_DIR     = BASE_DIR / "data2"
TRAIN_JSON = DATA_DIR / "combined_train.json"
VAL_JSON   = DATA_DIR / "combined_val.json"
MAPPINGS     = DATA_DIR / "label_mappings.json"
MODEL_OUTPUT = BASE_DIR / "models" / "classifier"
LOGS_DIR     = BASE_DIR / "outputs" / "logs_classifier"

# ── HYPERPARAMETERS ──────────────────────────────────────────────────────────
MODEL_NAME = "microsoft/layoutlmv3-base"
MAX_LENGTH    = 512
BATCH_SIZE    = 8        # effective batch=16 with gradient_accumulation=2
EPOCHS        = 10       # early stopping will trigger around epoch 7-8
LEARNING_RATE = 2e-5     # fine-tuning pretrained β€” never increase this
WARMUP_STEPS  = 46       # 6% of 770 total steps
WEIGHT_DECAY  = 0.01

# ── HELPERS ──────────────────────────────────────────────────────────────────
def get_doc_class_from_record(rec, doc2id):
    """
    Get doc_class_id from record.
    Prefer the value stored in data.doc_class (set during preprocessing).
    Falls back to -1 if unknown.
    """
    doc_class = rec.get("doc_class", "")
    return doc2id.get(doc_class, -1)


# ── DATASET ──────────────────────────────────────────────────────────────────
class DocumentDataset(Dataset):
    def __init__(self, json_path, processor, doc2id):
        with open(json_path, encoding="utf-8") as f:
            raw = json.load(f)

        # Filter out records with unknown class
        self.records = []
        skipped = 0
        for rec in raw:
            cid = get_doc_class_from_record(rec, doc2id)
            if cid == -1:
                skipped += 1
                continue
            rec["_class_id"] = cid
            self.records.append(rec)

        if skipped:
            print(f"  ⚠️  Skipped {skipped} records with unknown doc_class")

        self.processor = processor

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        rec = self.records[idx]

        # ── Image ──────────────────────────────────────────────────────────
        img_path = rec.get("image_path")
        if img_path and Path(img_path).exists():
            image = Image.open(img_path).convert("RGB")
        else:
            image = Image.new("RGB", (1654, 2339), color=(255, 255, 255))

        # ── Text (OCR words) ───────────────────────────────────────────────
        ocr_text = rec.get("ocr_text", "") or ""
        words    = ocr_text.split()[:128]   # 128 words is enough for classification
        if not words:
            words = ["[PAD]"]

        # Uniform bounding boxes (0-1000 normalized) β€” good enough for classification
        boxes = [[0, 0, 1000, 1000]] * len(words)

        encoding = self.processor(
            image,
            words,
            boxes=boxes,
            max_length=MAX_LENGTH,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        return {
            "input_ids":      encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "bbox":           encoding["bbox"].squeeze(),
            "pixel_values":   encoding["pixel_values"].squeeze(),
            "labels":         torch.tensor(rec["_class_id"], dtype=torch.long),
        }


# ── CLASS WEIGHTS (handle imbalance) ─────────────────────────────────────────
def compute_class_weights(dataset, num_classes):
    counts = np.zeros(num_classes)
    for rec in dataset.records:
        counts[rec["_class_id"]] += 1
    counts = np.where(counts == 0, 1, counts)   # avoid divide-by-zero
    weights = 1.0 / counts
    weights = weights / weights.sum() * num_classes
    return torch.tensor(weights, dtype=torch.float)


# ── WEIGHTED TRAINER ──────────────────────────────────────────────────────────
class WeightedTrainer(Trainer):
    def __init__(self, class_weights, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights.to(self.model.device if hasattr(self, 'model') else "cpu")

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels  = inputs.pop("labels")
        outputs = model(**inputs)
        logits  = outputs.logits
        weights = self.class_weights.to(logits.device)
        loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
        loss    = loss_fn(logits, labels)
        return (loss, outputs) if return_outputs else loss


# ── METRICS ──────────────────────────────────────────────────────────────────
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds  = np.argmax(logits, axis=-1)
    acc    = (preds == labels).mean()
    return {"accuracy": float(acc)}


# ── MAIN ─────────────────────────────────────────────────────────────────────
def main():
    # Verify data files exist
    for p in [TRAIN_JSON, VAL_JSON, MAPPINGS]:
        if not p.exists():
            raise FileNotFoundError(
                f"\n❌ File not found: {p}\n"
                f"   Run script 1_convert_labelstudio.py first to generate data files."
            )

    with open(MAPPINGS, encoding="utf-8") as f:
        mappings = json.load(f)

    doc_classes = mappings["doc_classes"]
    doc2id      = mappings["doc2id"]
    num_labels  = len(doc_classes)

    print(f"πŸ“‚ Data directory : {DATA_DIR}")
    print(f"πŸ“¦ Model output   : {MODEL_OUTPUT}")
    print(f"🏷️  Classes ({num_labels}): {doc_classes}")
    print(f"πŸ€– Base model     : {MODEL_NAME}\n")

    # Load processor & model (suppress MISSING weights warning β€” it's expected)
    processor = LayoutLMv3Processor.from_pretrained(MODEL_NAME, apply_ocr=False)
    model     = LayoutLMv3ForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=num_labels,
        id2label={i: c for i, c in enumerate(doc_classes)},
        label2id=doc2id,
        ignore_mismatched_sizes=True,   # silences MISSING weights warning
    )

    print("πŸ“Š Loading datasets...")
    train_dataset = DocumentDataset(TRAIN_JSON, processor, doc2id)
    val_dataset   = DocumentDataset(VAL_JSON,   processor, doc2id)
    print(f"   Train: {len(train_dataset)} | Val: {len(val_dataset)}\n")

    # Class weights for imbalanced data
    class_weights = compute_class_weights(train_dataset, num_labels)
    print("βš–οΈ  Class weights:")
    for i, (cls, w) in enumerate(zip(doc_classes, class_weights)):
        print(f"   {cls}: {w:.3f}")
    print()

    MODEL_OUTPUT.mkdir(parents=True, exist_ok=True)
    LOGS_DIR.mkdir(parents=True, exist_ok=True)

    training_args = TrainingArguments(
    output_dir=str(MODEL_OUTPUT),
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    weight_decay=WEIGHT_DECAY,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    logging_steps=10,
    report_to="none",
    fp16=torch.cuda.is_available(),
    dataloader_num_workers=0,
    lr_scheduler_type="cosine",
    gradient_accumulation_steps=2,
    save_total_limit=2,
    label_smoothing_factor=0.083,
)
    trainer = WeightedTrainer(
        class_weights=class_weights,
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    )

    print("πŸš€ Starting classification training...")
    trainer.train()

    # ── Save final model ──────────────────────────────────────────────────
    trainer.save_model(str(MODEL_OUTPUT))
    processor.save_pretrained(str(MODEL_OUTPUT))
    print(f"\nβœ… Model saved to: {MODEL_OUTPUT}")

    # ── Final evaluation with full report ────────────────────────────────
    print("\nπŸ“ˆ Final evaluation on validation set:")
    preds_output = trainer.predict(val_dataset)
    preds  = np.argmax(preds_output.predictions, axis=-1)
    labels = preds_output.label_ids

    print(classification_report(
        labels, preds,
        target_names=doc_classes,
        zero_division=0
    ))

    acc = (preds == labels).mean()
    print(f"Overall accuracy: {acc:.1%}")


if __name__ == "__main__":
    main()