File size: 6,946 Bytes
4768ab6
 
 
 
 
d79b7f7
 
 
4768ab6
d79b7f7
 
 
 
f0e14bb
d79b7f7
 
 
 
 
 
 
 
 
f0e14bb
d79b7f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0e14bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor, DataCollatorForTokenClassification
from src.sroie_loader import load_sroie
from PIL import Image
from tqdm import tqdm
from seqeval.metrics import f1_score, precision_score, recall_score
from pathlib import Path
import os

# --- 1. Global Configuration & Label Mapping ---
print("Setting up configuration...")
label_list = ['O', 'B-COMPANY', 'I-COMPANY', 'B-DATE', 'I-DATE', 
              'B-ADDRESS', 'I-ADDRESS', 'B-TOTAL', 'I-TOTAL']
label2id = {label: idx for idx, label in enumerate(label_list)}
id2label = {idx: label for idx, label in enumerate(label_list)}

MODEL_CHECKPOINT = "microsoft/layoutlmv3-base"
SROIE_DATA_PATH = os.getenv("SROIE_DATA_PATH", os.path.join("data", "sroie"))

# --- 2. PyTorch Dataset Class ---
class SROIEDataset(Dataset):
    """PyTorch Dataset for SROIE data."""
    def __init__(self, data, processor, label2id):
        self.data = data
        self.processor = processor
        self.label2id = label2id
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        example = self.data[idx]
        
        # Load image and get its dimensions
        image = Image.open(example['image_path']).convert("RGB")
        width, height = image.size
        
        # Normalize bounding boxes
        boxes = []
        for box in example['bboxes']:
            x, y, w, h = box
            x0, y0, x1, y1 = x, y, x + w, y + h
            
            x0_norm = int((x0 / width) * 1000)
            y0_norm = int((y0 / height) * 1000)
            x1_norm = int((x1 / width) * 1000)
            y1_norm = int((y1 / height) * 1000)
            
            # Clip to ensure all values are within the 0-1000 range
            x0_norm = max(0, min(x0_norm, 1000))
            y0_norm = max(0, min(y0_norm, 1000))
            x1_norm = max(0, min(x1_norm, 1000))
            y1_norm = max(0, min(y1_norm, 1000))
            
            boxes.append([x0_norm, y0_norm, x1_norm, y1_norm])
        
        # Convert NER tags to IDs
        word_labels = [self.label2id[label] for label in example['ner_tags']]
        
        # Use processor to encode everything, with truncation
        encoding = self.processor(
            image, 
            text=example['words'], 
            boxes=boxes, 
            word_labels=word_labels, 
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
        
        # Squeeze the batch dimension to get 1D tensors
        item = {key: val.squeeze(0) for key, val in encoding.items()}
        return item

# --- 3. Main Training Script ---
def train():
    """Main function to run the training process."""
    # --- Load Data ---
    print("Loading SROIE dataset...")
    raw_dataset = load_sroie(SROIE_DATA_PATH)
    
    # --- Load Processor ---
    print("Creating processor...")
    processor = LayoutLMv3Processor.from_pretrained(MODEL_CHECKPOINT, apply_ocr=False)

    # --- Create PyTorch Datasets and DataLoaders ---
    print("Creating PyTorch datasets and dataloaders...")
    train_dataset = SROIEDataset(raw_dataset['train'], processor, label2id)
    test_dataset = SROIEDataset(raw_dataset['test'], processor, label2id)

    data_collator = DataCollatorForTokenClassification(
        tokenizer=processor.tokenizer,
        padding=True,
        return_tensors="pt"
    )

    train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=data_collator)
    test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=data_collator)

    # --- Load Model ---
    print("Loading LayoutLMv3 model for fine-tuning...")
    model = LayoutLMv3ForTokenClassification.from_pretrained(
        MODEL_CHECKPOINT,
        num_labels=len(label_list),
        id2label=id2label,
        label2id=label2id
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(f"Training on: {device}")

    # --- Setup Optimizer ---
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

    # --- Training Loop ---
    best_f1 = 0
    NUM_EPOCHS = 10

    for epoch in range(NUM_EPOCHS):
        print(f"\n{'='*60}\nEpoch {epoch + 1}/{NUM_EPOCHS}\n{'='*60}")
        
        # --- Training Step ---
        model.train()
        total_train_loss = 0
        train_progress_bar = tqdm(train_dataloader, desc=f"Training Epoch {epoch+1}")
        for batch in train_progress_bar:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            outputs = model(**batch)
            loss = outputs.loss
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            total_train_loss += loss.item()
            train_progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        avg_train_loss = total_train_loss / len(train_dataloader)

        # --- Validation Step ---
        model.eval()
        all_predictions = []
        all_labels = []
        with torch.no_grad():
            for batch in tqdm(test_dataloader, desc="Validation"):
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(**batch)
                
                predictions = outputs.logits.argmax(dim=-1)
                labels = batch['labels']

                for i in range(labels.shape[0]):
                    true_labels_i = [id2label[l.item()] for l in labels[i] if l.item() != -100]
                    pred_labels_i = [id2label[p.item()] for p, l in zip(predictions[i], labels[i]) if l.item() != -100]
                    all_labels.append(true_labels_i)
                    all_predictions.append(pred_labels_i)
        
        # --- Calculate Metrics ---
        f1 = f1_score(all_labels, all_predictions)
        precision = precision_score(all_labels, all_predictions)
        recall = recall_score(all_labels, all_predictions)
        
        print(f"\n📊 Epoch {epoch + 1} Results:")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  F1 Score:   {f1:.4f}")
        print(f"  Precision:  {precision:.4f}")
        print(f"  Recall:     {recall:.4f}")
        
        # --- Save Best Model ---
        if f1 > best_f1:
            best_f1 = f1
            print(f"  🌟 New best F1! Saving model...")
            save_path = Path("./models/layoutlmv3-sroie-best")
            save_path.mkdir(parents=True, exist_ok=True)
            model.save_pretrained(save_path)
            processor.save_pretrained(save_path)

    print(f"\n🎉 TRAINING COMPLETE! Best F1 Score: {best_f1:.4f}")
    print(f"Model saved to: ./models/layoutlmv3-sroie-best")


if __name__ == '__main__':
    train()