|
|
import sys |
|
|
import os |
|
|
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
|
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor, DataCollatorForTokenClassification |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
from seqeval.metrics import f1_score |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
import random |
|
|
import os |
|
|
import pickle |
|
|
|
|
|
|
|
|
from src.sroie_loader import load_sroie |
|
|
from src.data_loader import load_unified_dataset |
|
|
|
|
|
|
|
|
|
|
|
SROIE_DATA_PATH = "data/sroie" |
|
|
DOCTR_CACHE_PATH = "data/doctr_trained_cache.pkl" |
|
|
MODEL_CHECKPOINT = "microsoft/layoutlmv3-base" |
|
|
OUTPUT_DIR = "models/layoutlmv3-doctr-trained" |
|
|
|
|
|
|
|
|
LABEL_LIST = ['O', 'B-COMPANY', 'I-COMPANY', 'B-DATE', 'I-DATE', |
|
|
'B-ADDRESS', 'I-ADDRESS', 'B-TOTAL', 'I-TOTAL', |
|
|
'B-INVOICE_NO', 'I-INVOICE_NO','B-BILL_TO', 'I-BILL_TO'] |
|
|
label2id = {label: idx for idx, label in enumerate(LABEL_LIST)} |
|
|
id2label = {idx: label for idx, label in enumerate(LABEL_LIST)} |
|
|
|
|
|
class UnifiedDataset(Dataset): |
|
|
def __init__(self, data, processor, label2id): |
|
|
self.data = data |
|
|
self.processor = processor |
|
|
self.label2id = label2id |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
example = self.data[idx] |
|
|
|
|
|
|
|
|
try: |
|
|
if 'image' in example and isinstance(example['image'], Image.Image): |
|
|
image = example['image'] |
|
|
elif 'image_path' in example: |
|
|
image = Image.open(example['image_path']).convert("RGB") |
|
|
else: |
|
|
image = Image.new('RGB', (224, 224), color='white') |
|
|
except Exception: |
|
|
image = Image.new('RGB', (224, 224), color='white') |
|
|
|
|
|
|
|
|
|
|
|
boxes = [] |
|
|
for box in example['bboxes']: |
|
|
|
|
|
safe_box = [ |
|
|
max(0, min(int(box[0]), 1000)), |
|
|
max(0, min(int(box[1]), 1000)), |
|
|
max(0, min(int(box[2]), 1000)), |
|
|
max(0, min(int(box[3]), 1000)) |
|
|
] |
|
|
boxes.append(safe_box) |
|
|
|
|
|
|
|
|
word_labels = [] |
|
|
for label in example['ner_tags']: |
|
|
word_labels.append(self.label2id.get(label, 0)) |
|
|
|
|
|
|
|
|
encoding = self.processor( |
|
|
image, |
|
|
text=example['words'], |
|
|
boxes=boxes, |
|
|
word_labels=word_labels, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=512, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
return {k: v.squeeze(0) for k, v in encoding.items()} |
|
|
|
|
|
|
|
|
def load_doctr_cache(cache_path: str) -> dict: |
|
|
"""Load pre-processed DocTR training data from cache.""" |
|
|
print(f"📦 Loading DocTR cache from {cache_path}...") |
|
|
with open(cache_path, "rb") as f: |
|
|
data = pickle.load(f) |
|
|
print(f" ✅ Loaded {len(data.get('train', []))} train, {len(data.get('test', []))} test examples") |
|
|
return data |
|
|
|
|
|
|
|
|
def train(): |
|
|
print(f"{'='*40}\n🚀 STARTING HYBRID TRAINING\n{'='*40}") |
|
|
|
|
|
|
|
|
if os.path.exists(DOCTR_CACHE_PATH): |
|
|
print("🔄 Using DocTR-aligned training data (recommended)") |
|
|
sroie_data = load_doctr_cache(DOCTR_CACHE_PATH) |
|
|
else: |
|
|
print("⚠️ DocTR cache not found. Using original SROIE loader.") |
|
|
print(" Run 'python scripts/prepare_doctr_data.py' to generate the cache.") |
|
|
|
|
|
if not os.path.exists(SROIE_DATA_PATH): |
|
|
print(f"❌ Error: SROIE path not found at {SROIE_DATA_PATH}") |
|
|
print("Please make sure you copied the 'sroie' folder into 'data/'.") |
|
|
return |
|
|
|
|
|
sroie_data = load_sroie(SROIE_DATA_PATH) |
|
|
|
|
|
print(f" - SROIE Train: {len(sroie_data['train'])}") |
|
|
print(f" - SROIE Test: {len(sroie_data['test'])}") |
|
|
|
|
|
|
|
|
print("📦 Loading General Invoice dataset...") |
|
|
|
|
|
new_data = load_unified_dataset(split='train', sample_size=600) |
|
|
|
|
|
random.shuffle(new_data) |
|
|
split_idx = int(len(new_data) * 0.9) |
|
|
new_train = new_data[:split_idx] |
|
|
new_test = new_data[split_idx:] |
|
|
|
|
|
print(f" - General Train: {len(new_train)}") |
|
|
print(f" - General Test: {len(new_test)}") |
|
|
|
|
|
|
|
|
full_train_data = sroie_data['train'] + new_train |
|
|
full_test_data = sroie_data['test'] + new_test |
|
|
print(f"\n🔗 COMBINED DATASET SIZE: {len(full_train_data)} Training Images") |
|
|
|
|
|
|
|
|
processor = LayoutLMv3Processor.from_pretrained(MODEL_CHECKPOINT, apply_ocr=False) |
|
|
model = LayoutLMv3ForTokenClassification.from_pretrained( |
|
|
MODEL_CHECKPOINT, num_labels=len(LABEL_LIST), |
|
|
id2label=id2label, label2id=label2id |
|
|
) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
print(f" - Device: {device}") |
|
|
|
|
|
|
|
|
train_ds = UnifiedDataset(full_train_data, processor, label2id) |
|
|
test_ds = UnifiedDataset(full_test_data, processor, label2id) |
|
|
|
|
|
collator = DataCollatorForTokenClassification(processor.tokenizer, padding=True, return_tensors="pt") |
|
|
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, collate_fn=collator) |
|
|
test_loader = DataLoader(test_ds, batch_size=2, collate_fn=collator) |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5) |
|
|
best_f1 = 0.0 |
|
|
NUM_EPOCHS = 10 |
|
|
|
|
|
print("\n🔥 Beginning Fine-Tuning...") |
|
|
for epoch in range(NUM_EPOCHS): |
|
|
model.train() |
|
|
total_loss = 0 |
|
|
|
|
|
progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}") |
|
|
for batch in progress: |
|
|
batch = {k: v.to(device) for k, v in batch.items()} |
|
|
|
|
|
optimizer.zero_grad() |
|
|
outputs = model(**batch) |
|
|
loss = outputs.loss |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
progress.set_postfix({"loss": f"{loss.item():.4f}"}) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
all_preds, all_labels = [], [] |
|
|
print(" Running Validation...") |
|
|
with torch.no_grad(): |
|
|
for batch in test_loader: |
|
|
batch = {k: v.to(device) for k, v in batch.items()} |
|
|
outputs = model(**batch) |
|
|
predictions = outputs.logits.argmax(dim=-1) |
|
|
labels = batch['labels'] |
|
|
|
|
|
for i in range(len(labels)): |
|
|
true_labels = [id2label[l.item()] for l in labels[i] if l.item() != -100] |
|
|
pred_labels = [id2label[p.item()] for p, l in zip(predictions[i], labels[i]) if l.item() != -100] |
|
|
all_labels.append(true_labels) |
|
|
all_preds.append(pred_labels) |
|
|
|
|
|
f1 = f1_score(all_labels, all_preds) |
|
|
print(f" 📊 Epoch {epoch+1} F1 Score: {f1:.4f}") |
|
|
|
|
|
if f1 > best_f1: |
|
|
best_f1 = f1 |
|
|
print(f" 💾 Saving Improved Model to {OUTPUT_DIR}") |
|
|
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) |
|
|
model.save_pretrained(OUTPUT_DIR) |
|
|
processor.save_pretrained(OUTPUT_DIR) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |