File size: 6,946 Bytes
4768ab6 d79b7f7 4768ab6 d79b7f7 f0e14bb d79b7f7 f0e14bb d79b7f7 f0e14bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor, DataCollatorForTokenClassification
from src.sroie_loader import load_sroie
from PIL import Image
from tqdm import tqdm
from seqeval.metrics import f1_score, precision_score, recall_score
from pathlib import Path
import os
# --- 1. Global Configuration & Label Mapping ---
print("Setting up configuration...")
label_list = ['O', 'B-COMPANY', 'I-COMPANY', 'B-DATE', 'I-DATE',
'B-ADDRESS', 'I-ADDRESS', 'B-TOTAL', 'I-TOTAL']
label2id = {label: idx for idx, label in enumerate(label_list)}
id2label = {idx: label for idx, label in enumerate(label_list)}
MODEL_CHECKPOINT = "microsoft/layoutlmv3-base"
SROIE_DATA_PATH = os.getenv("SROIE_DATA_PATH", os.path.join("data", "sroie"))
# --- 2. PyTorch Dataset Class ---
class SROIEDataset(Dataset):
"""PyTorch Dataset for SROIE data."""
def __init__(self, data, processor, label2id):
self.data = data
self.processor = processor
self.label2id = label2id
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
example = self.data[idx]
# Load image and get its dimensions
image = Image.open(example['image_path']).convert("RGB")
width, height = image.size
# Normalize bounding boxes
boxes = []
for box in example['bboxes']:
x, y, w, h = box
x0, y0, x1, y1 = x, y, x + w, y + h
x0_norm = int((x0 / width) * 1000)
y0_norm = int((y0 / height) * 1000)
x1_norm = int((x1 / width) * 1000)
y1_norm = int((y1 / height) * 1000)
# Clip to ensure all values are within the 0-1000 range
x0_norm = max(0, min(x0_norm, 1000))
y0_norm = max(0, min(y0_norm, 1000))
x1_norm = max(0, min(x1_norm, 1000))
y1_norm = max(0, min(y1_norm, 1000))
boxes.append([x0_norm, y0_norm, x1_norm, y1_norm])
# Convert NER tags to IDs
word_labels = [self.label2id[label] for label in example['ner_tags']]
# Use processor to encode everything, with truncation
encoding = self.processor(
image,
text=example['words'],
boxes=boxes,
word_labels=word_labels,
truncation=True,
max_length=512,
return_tensors="pt"
)
# Squeeze the batch dimension to get 1D tensors
item = {key: val.squeeze(0) for key, val in encoding.items()}
return item
# --- 3. Main Training Script ---
def train():
"""Main function to run the training process."""
# --- Load Data ---
print("Loading SROIE dataset...")
raw_dataset = load_sroie(SROIE_DATA_PATH)
# --- Load Processor ---
print("Creating processor...")
processor = LayoutLMv3Processor.from_pretrained(MODEL_CHECKPOINT, apply_ocr=False)
# --- Create PyTorch Datasets and DataLoaders ---
print("Creating PyTorch datasets and dataloaders...")
train_dataset = SROIEDataset(raw_dataset['train'], processor, label2id)
test_dataset = SROIEDataset(raw_dataset['test'], processor, label2id)
data_collator = DataCollatorForTokenClassification(
tokenizer=processor.tokenizer,
padding=True,
return_tensors="pt"
)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=data_collator)
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=data_collator)
# --- Load Model ---
print("Loading LayoutLMv3 model for fine-tuning...")
model = LayoutLMv3ForTokenClassification.from_pretrained(
MODEL_CHECKPOINT,
num_labels=len(label_list),
id2label=id2label,
label2id=label2id
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Training on: {device}")
# --- Setup Optimizer ---
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
# --- Training Loop ---
best_f1 = 0
NUM_EPOCHS = 10
for epoch in range(NUM_EPOCHS):
print(f"\n{'='*60}\nEpoch {epoch + 1}/{NUM_EPOCHS}\n{'='*60}")
# --- Training Step ---
model.train()
total_train_loss = 0
train_progress_bar = tqdm(train_dataloader, desc=f"Training Epoch {epoch+1}")
for batch in train_progress_bar:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_train_loss += loss.item()
train_progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
avg_train_loss = total_train_loss / len(train_dataloader)
# --- Validation Step ---
model.eval()
all_predictions = []
all_labels = []
with torch.no_grad():
for batch in tqdm(test_dataloader, desc="Validation"):
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
labels = batch['labels']
for i in range(labels.shape[0]):
true_labels_i = [id2label[l.item()] for l in labels[i] if l.item() != -100]
pred_labels_i = [id2label[p.item()] for p, l in zip(predictions[i], labels[i]) if l.item() != -100]
all_labels.append(true_labels_i)
all_predictions.append(pred_labels_i)
# --- Calculate Metrics ---
f1 = f1_score(all_labels, all_predictions)
precision = precision_score(all_labels, all_predictions)
recall = recall_score(all_labels, all_predictions)
print(f"\n📊 Epoch {epoch + 1} Results:")
print(f" Train Loss: {avg_train_loss:.4f}")
print(f" F1 Score: {f1:.4f}")
print(f" Precision: {precision:.4f}")
print(f" Recall: {recall:.4f}")
# --- Save Best Model ---
if f1 > best_f1:
best_f1 = f1
print(f" 🌟 New best F1! Saving model...")
save_path = Path("./models/layoutlmv3-sroie-best")
save_path.mkdir(parents=True, exist_ok=True)
model.save_pretrained(save_path)
processor.save_pretrained(save_path)
print(f"\n🎉 TRAINING COMPLETE! Best F1 Score: {best_f1:.4f}")
print(f"Model saved to: ./models/layoutlmv3-sroie-best")
if __name__ == '__main__':
train()
|