File size: 8,176 Bytes
f9a156f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import os
import sys
import glob

# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

from src.data.dataset import IAMDataset, collate_fn
from src.models.crnn import CRNN
from src.models.gan import Generator

# Define transforms matching training exactly
transform = transforms.Compose([
    transforms.Resize((32, 1024)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

def decode_pseudo_labels(preds):
    # preds: (seq_len, batch, classes)
    _, max_preds = torch.max(preds, 2)
    max_preds = max_preds.permute(1, 0) # (batch, seq_len)
    
    targets_list = []
    target_lengths = []
    
    for batch_idx in range(max_preds.size(0)):
        pred_seq = max_preds[batch_idx]
        decoded_seq = []
        for i in range(len(pred_seq)):
            if pred_seq[i] != 0 and (i == 0 or pred_seq[i] != pred_seq[i-1]):
                decoded_seq.append(pred_seq[i].item())
        
        target_tensor = torch.tensor(decoded_seq, dtype=torch.long)
        targets_list.append(target_tensor)
        target_lengths.append(len(decoded_seq))
        
    return targets_list, target_lengths

def train_ssl(model, generator, dataloader, optimizer, criterion, device, epochs=5, threshold=0.8, latent_dim=100):
    """
    Pseudo-labeling approach for Semi-Supervised Learning.
    Combines real labeled data with synthetic unlabeled data generated dynamically by the GAN.
    """
    model.train()
    generator.eval() # Generator is fixed during this phase
    
    for epoch in range(epochs):
        total_loss_real = 0
        total_loss_fake = 0
        
        for step, (labeled_imgs, labeled_texts, labeled_lengths) in enumerate(dataloader):
            labeled_imgs = labeled_imgs.to(device)
            labeled_texts = labeled_texts.to(device)
            batch_size = labeled_imgs.size(0)
            
            optimizer.zero_grad()
            
            # ==============================================================
            # 1. Train on Real Labeled Data
            # ==============================================================
            preds_l = model(labeled_imgs)
            preds_l = preds_l.permute(1, 0, 2) # (seq_len, batch, classes)
            
            input_lengths_l = torch.full(size=(preds_l.size(1),), fill_value=preds_l.size(0), dtype=torch.long)
            
            targets_list_l = []
            for i in range(labeled_texts.size(0)):
                targets_list_l.append(labeled_texts[i][:labeled_lengths[i]])
            targets_concat_l = torch.cat(targets_list_l)
            
            loss_real = criterion(preds_l, targets_concat_l, input_lengths_l, labeled_lengths)
            
            # ==============================================================
            # 2. Train on Synthetic GAN Data (Pseudo-Labeling)
            # ==============================================================
            # Generate fake images
            with torch.no_grad():
                z = torch.randn(batch_size, latent_dim).to(device)
                fake_imgs = generator(z) # Shape: (batch, 1, 32, 1024), range [-1, 1]
                
                # Get pseudo-labels
                model.eval()
                preds_fake_eval = model(fake_imgs)
                probs = torch.exp(preds_fake_eval) # Softmax probs
                max_probs, _ = torch.max(probs, dim=2)
                avg_confidence = max_probs.mean(dim=1)
                
                # Mask confident predictions
                mask = avg_confidence > threshold
                
            model.train()
            loss_fake = torch.tensor(0.0).to(device)
            
            if mask.sum() > 0:
                confident_imgs = fake_imgs[mask]
                
                preds_fake = model(confident_imgs)
                preds_fake_perm = preds_fake.permute(1, 0, 2)
                
                # Decode the pseudo-labels into CTC targets
                targets_list_u, target_lengths_u = decode_pseudo_labels(preds_fake_perm.detach())
                
                # Filter out empty pseudo-labels
                valid_idx = [i for i, length in enumerate(target_lengths_u) if length > 0]
                
                if valid_idx:
                    valid_preds_fake_perm = preds_fake_perm[:, valid_idx, :]
                    valid_targets_list = [targets_list_u[i].to(device) for i in valid_idx]
                    valid_target_lengths = torch.tensor([target_lengths_u[i] for i in valid_idx], dtype=torch.long).to(device)
                    
                    valid_targets_concat = torch.cat(valid_targets_list)
                    input_lengths_u = torch.full(size=(valid_preds_fake_perm.size(1),), fill_value=valid_preds_fake_perm.size(0), dtype=torch.long).to(device)
                    
                    loss_fake = criterion(valid_preds_fake_perm, valid_targets_concat, input_lengths_u, valid_target_lengths)
                    # Scale down the fake loss slightly so it doesn't overwhelm real data
                    loss_fake = loss_fake * 0.5 

            # Total loss
            total_loss = loss_real + loss_fake
            total_loss.backward()
            optimizer.step()
            
            total_loss_real += loss_real.item()
            total_loss_fake += loss_fake.item() if loss_fake > 0 else 0
            
            if step % 20 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Step [{step}/{len(dataloader)}], Real Loss: {loss_real.item():.4f}, Fake Loss: {loss_fake.item() if loss_fake > 0 else 0:.4f}, Confident Fakes: {mask.sum().item()}/{batch_size}")
                
        print(f"Epoch {epoch+1} Average Real Loss: {total_loss_real/len(dataloader):.4f}, Average Fake Loss: {total_loss_fake/len(dataloader):.4f}")
        
        # Save checkpoints
        os.makedirs('weights', exist_ok=True)
        torch.save(model.state_dict(), f'weights/crnn_ssl_epoch_{epoch+1}.pth')

if __name__ == "__main__":
    print("Starting Semi-Supervised Learning (SSL) Training Phase...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 1. Load Dataset
    data_dir = 'data/iam_words'
    csv_file = 'data/labels.csv'
    dataset = IAMDataset(data_dir=data_dir, csv_file=csv_file, transform=transform)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    
    # 2. Load the Baseline CRNN Model
    num_classes = dataset.num_classes
    crnn_model = CRNN(img_channel=1, img_height=32, img_width=1024, num_class=num_classes).to(device)
    
    checkpoints_crnn = glob.glob('weights/crnn_baseline_epoch_*.pth')
    if not checkpoints_crnn:
        print("Error: Could not find baseline CRNN weights.")
        sys.exit(1)
    checkpoints_crnn.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0]))
    latest_crnn = checkpoints_crnn[-1]
    print(f"Loading Baseline CRNN from {latest_crnn}")
    crnn_model.load_state_dict(torch.load(latest_crnn, map_location=device))
    
    # 3. Load the Trained GAN Generator
    generator = Generator(latent_dim=100).to(device)
    checkpoints_gan = glob.glob('weights/gan_generator_epoch_*.pth')
    if not checkpoints_gan:
        print("Error: Could not find GAN Generator weights.")
        sys.exit(1)
    checkpoints_gan.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0]))
    latest_gan = checkpoints_gan[-1]
    print(f"Loading GAN Generator from {latest_gan}")
    generator.load_state_dict(torch.load(latest_gan, map_location=device))
    
    # 4. Setup Optimizer & Loss
    # Use a smaller learning rate for fine-tuning
    optimizer = optim.Adam(crnn_model.parameters(), lr=0.0001)
    criterion = nn.CTCLoss(blank=0, zero_infinity=True)
    
    # 5. Start SSL Training Loop
    train_ssl(crnn_model, generator, dataloader, optimizer, criterion, device, epochs=5, threshold=0.8)
    print("SSL Training complete!")