import sys import os sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import torch from torch.utils.data import Dataset, DataLoader from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor, DataCollatorForTokenClassification from PIL import Image from tqdm import tqdm from seqeval.metrics import f1_score from pathlib import Path import numpy as np import random import os import pickle # --- IMPORTS --- from src.sroie_loader import load_sroie from src.data_loader import load_unified_dataset # --- CONFIGURATION --- # Points to your local SROIE copy SROIE_DATA_PATH = "data/sroie" DOCTR_CACHE_PATH = "data/doctr_trained_cache.pkl" # DocTR pre-processed cache MODEL_CHECKPOINT = "microsoft/layoutlmv3-base" OUTPUT_DIR = "models/layoutlmv3-doctr-trained" # Standard Label Set LABEL_LIST = ['O', 'B-COMPANY', 'I-COMPANY', 'B-DATE', 'I-DATE', 'B-ADDRESS', 'I-ADDRESS', 'B-TOTAL', 'I-TOTAL', 'B-INVOICE_NO', 'I-INVOICE_NO','B-BILL_TO', 'I-BILL_TO'] label2id = {label: idx for idx, label in enumerate(LABEL_LIST)} id2label = {idx: label for idx, label in enumerate(LABEL_LIST)} class UnifiedDataset(Dataset): def __init__(self, data, processor, label2id): self.data = data self.processor = processor self.label2id = label2id def __len__(self): return len(self.data) def __getitem__(self, idx): example = self.data[idx] # 1. Image Loading try: if 'image' in example and isinstance(example['image'], Image.Image): image = example['image'] elif 'image_path' in example: image = Image.open(example['image_path']).convert("RGB") else: image = Image.new('RGB', (224, 224), color='white') except Exception: image = Image.new('RGB', (224, 224), color='white') # 2. Boxes are ALREADY normalized! # Just need to ensure they are integers and valid boxes = [] for box in example['bboxes']: # Extra safety clamp, just in case safe_box = [ max(0, min(int(box[0]), 1000)), max(0, min(int(box[1]), 1000)), max(0, min(int(box[2]), 1000)), max(0, min(int(box[3]), 1000)) ] boxes.append(safe_box) # 3. Label Encoding word_labels = [] for label in example['ner_tags']: word_labels.append(self.label2id.get(label, 0)) # 4. Processor Encoding encoding = self.processor( image, text=example['words'], boxes=boxes, word_labels=word_labels, truncation=True, padding="max_length", max_length=512, return_tensors="pt" ) return {k: v.squeeze(0) for k, v in encoding.items()} def load_doctr_cache(cache_path: str) -> dict: """Load pre-processed DocTR training data from cache.""" print(f"šŸ“¦ Loading DocTR cache from {cache_path}...") with open(cache_path, "rb") as f: data = pickle.load(f) print(f" āœ… Loaded {len(data.get('train', []))} train, {len(data.get('test', []))} test examples") return data def train(): print(f"{'='*40}\nšŸš€ STARTING HYBRID TRAINING\n{'='*40}") # 1. Load SROIE data (prefer DocTR cache if available) if os.path.exists(DOCTR_CACHE_PATH): print("šŸ”„ Using DocTR-aligned training data (recommended)") sroie_data = load_doctr_cache(DOCTR_CACHE_PATH) else: print("āš ļø DocTR cache not found. Using original SROIE loader.") print(" Run 'python scripts/prepare_doctr_data.py' to generate the cache.") if not os.path.exists(SROIE_DATA_PATH): print(f"āŒ Error: SROIE path not found at {SROIE_DATA_PATH}") print("Please make sure you copied the 'sroie' folder into 'data/'.") return sroie_data = load_sroie(SROIE_DATA_PATH) print(f" - SROIE Train: {len(sroie_data['train'])}") print(f" - SROIE Test: {len(sroie_data['test'])}") # 2. Load New Dataset print("šŸ“¦ Loading General Invoice dataset...") # Reduced sample size slightly to stay safe on RAM new_data = load_unified_dataset(split='train', sample_size=600) random.shuffle(new_data) split_idx = int(len(new_data) * 0.9) new_train = new_data[:split_idx] new_test = new_data[split_idx:] print(f" - General Train: {len(new_train)}") print(f" - General Test: {len(new_test)}") # 3. Merge full_train_data = sroie_data['train'] + new_train full_test_data = sroie_data['test'] + new_test print(f"\nšŸ”— COMBINED DATASET SIZE: {len(full_train_data)} Training Images") # 4. Setup Model processor = LayoutLMv3Processor.from_pretrained(MODEL_CHECKPOINT, apply_ocr=False) model = LayoutLMv3ForTokenClassification.from_pretrained( MODEL_CHECKPOINT, num_labels=len(LABEL_LIST), id2label=id2label, label2id=label2id ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) print(f" - Device: {device}") # 5. Dataloaders train_ds = UnifiedDataset(full_train_data, processor, label2id) test_ds = UnifiedDataset(full_test_data, processor, label2id) collator = DataCollatorForTokenClassification(processor.tokenizer, padding=True, return_tensors="pt") train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, collate_fn=collator) test_loader = DataLoader(test_ds, batch_size=2, collate_fn=collator) # 6. Optimize & Train optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5) best_f1 = 0.0 NUM_EPOCHS = 10 print("\nšŸ”„ Beginning Fine-Tuning...") for epoch in range(NUM_EPOCHS): model.train() total_loss = 0 progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}") for batch in progress: batch = {k: v.to(device) for k, v in batch.items()} optimizer.zero_grad() outputs = model(**batch) loss = outputs.loss loss.backward() optimizer.step() total_loss += loss.item() progress.set_postfix({"loss": f"{loss.item():.4f}"}) # --- Evaluation --- model.eval() all_preds, all_labels = [], [] print(" Running Validation...") with torch.no_grad(): for batch in test_loader: batch = {k: v.to(device) for k, v in batch.items()} outputs = model(**batch) predictions = outputs.logits.argmax(dim=-1) labels = batch['labels'] for i in range(len(labels)): true_labels = [id2label[l.item()] for l in labels[i] if l.item() != -100] pred_labels = [id2label[p.item()] for p, l in zip(predictions[i], labels[i]) if l.item() != -100] all_labels.append(true_labels) all_preds.append(pred_labels) f1 = f1_score(all_labels, all_preds) print(f" šŸ“Š Epoch {epoch+1} F1 Score: {f1:.4f}") if f1 > best_f1: best_f1 = f1 print(f" šŸ’¾ Saving Improved Model to {OUTPUT_DIR}") Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) model.save_pretrained(OUTPUT_DIR) processor.save_pretrained(OUTPUT_DIR) if __name__ == "__main__": train()