File size: 8,176 Bytes
f9a156f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import os
import sys
import glob
# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from src.data.dataset import IAMDataset, collate_fn
from src.models.crnn import CRNN
from src.models.gan import Generator
# Define transforms matching training exactly
transform = transforms.Compose([
transforms.Resize((32, 1024)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
def decode_pseudo_labels(preds):
# preds: (seq_len, batch, classes)
_, max_preds = torch.max(preds, 2)
max_preds = max_preds.permute(1, 0) # (batch, seq_len)
targets_list = []
target_lengths = []
for batch_idx in range(max_preds.size(0)):
pred_seq = max_preds[batch_idx]
decoded_seq = []
for i in range(len(pred_seq)):
if pred_seq[i] != 0 and (i == 0 or pred_seq[i] != pred_seq[i-1]):
decoded_seq.append(pred_seq[i].item())
target_tensor = torch.tensor(decoded_seq, dtype=torch.long)
targets_list.append(target_tensor)
target_lengths.append(len(decoded_seq))
return targets_list, target_lengths
def train_ssl(model, generator, dataloader, optimizer, criterion, device, epochs=5, threshold=0.8, latent_dim=100):
"""
Pseudo-labeling approach for Semi-Supervised Learning.
Combines real labeled data with synthetic unlabeled data generated dynamically by the GAN.
"""
model.train()
generator.eval() # Generator is fixed during this phase
for epoch in range(epochs):
total_loss_real = 0
total_loss_fake = 0
for step, (labeled_imgs, labeled_texts, labeled_lengths) in enumerate(dataloader):
labeled_imgs = labeled_imgs.to(device)
labeled_texts = labeled_texts.to(device)
batch_size = labeled_imgs.size(0)
optimizer.zero_grad()
# ==============================================================
# 1. Train on Real Labeled Data
# ==============================================================
preds_l = model(labeled_imgs)
preds_l = preds_l.permute(1, 0, 2) # (seq_len, batch, classes)
input_lengths_l = torch.full(size=(preds_l.size(1),), fill_value=preds_l.size(0), dtype=torch.long)
targets_list_l = []
for i in range(labeled_texts.size(0)):
targets_list_l.append(labeled_texts[i][:labeled_lengths[i]])
targets_concat_l = torch.cat(targets_list_l)
loss_real = criterion(preds_l, targets_concat_l, input_lengths_l, labeled_lengths)
# ==============================================================
# 2. Train on Synthetic GAN Data (Pseudo-Labeling)
# ==============================================================
# Generate fake images
with torch.no_grad():
z = torch.randn(batch_size, latent_dim).to(device)
fake_imgs = generator(z) # Shape: (batch, 1, 32, 1024), range [-1, 1]
# Get pseudo-labels
model.eval()
preds_fake_eval = model(fake_imgs)
probs = torch.exp(preds_fake_eval) # Softmax probs
max_probs, _ = torch.max(probs, dim=2)
avg_confidence = max_probs.mean(dim=1)
# Mask confident predictions
mask = avg_confidence > threshold
model.train()
loss_fake = torch.tensor(0.0).to(device)
if mask.sum() > 0:
confident_imgs = fake_imgs[mask]
preds_fake = model(confident_imgs)
preds_fake_perm = preds_fake.permute(1, 0, 2)
# Decode the pseudo-labels into CTC targets
targets_list_u, target_lengths_u = decode_pseudo_labels(preds_fake_perm.detach())
# Filter out empty pseudo-labels
valid_idx = [i for i, length in enumerate(target_lengths_u) if length > 0]
if valid_idx:
valid_preds_fake_perm = preds_fake_perm[:, valid_idx, :]
valid_targets_list = [targets_list_u[i].to(device) for i in valid_idx]
valid_target_lengths = torch.tensor([target_lengths_u[i] for i in valid_idx], dtype=torch.long).to(device)
valid_targets_concat = torch.cat(valid_targets_list)
input_lengths_u = torch.full(size=(valid_preds_fake_perm.size(1),), fill_value=valid_preds_fake_perm.size(0), dtype=torch.long).to(device)
loss_fake = criterion(valid_preds_fake_perm, valid_targets_concat, input_lengths_u, valid_target_lengths)
# Scale down the fake loss slightly so it doesn't overwhelm real data
loss_fake = loss_fake * 0.5
# Total loss
total_loss = loss_real + loss_fake
total_loss.backward()
optimizer.step()
total_loss_real += loss_real.item()
total_loss_fake += loss_fake.item() if loss_fake > 0 else 0
if step % 20 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Step [{step}/{len(dataloader)}], Real Loss: {loss_real.item():.4f}, Fake Loss: {loss_fake.item() if loss_fake > 0 else 0:.4f}, Confident Fakes: {mask.sum().item()}/{batch_size}")
print(f"Epoch {epoch+1} Average Real Loss: {total_loss_real/len(dataloader):.4f}, Average Fake Loss: {total_loss_fake/len(dataloader):.4f}")
# Save checkpoints
os.makedirs('weights', exist_ok=True)
torch.save(model.state_dict(), f'weights/crnn_ssl_epoch_{epoch+1}.pth')
if __name__ == "__main__":
print("Starting Semi-Supervised Learning (SSL) Training Phase...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 1. Load Dataset
data_dir = 'data/iam_words'
csv_file = 'data/labels.csv'
dataset = IAMDataset(data_dir=data_dir, csv_file=csv_file, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
# 2. Load the Baseline CRNN Model
num_classes = dataset.num_classes
crnn_model = CRNN(img_channel=1, img_height=32, img_width=1024, num_class=num_classes).to(device)
checkpoints_crnn = glob.glob('weights/crnn_baseline_epoch_*.pth')
if not checkpoints_crnn:
print("Error: Could not find baseline CRNN weights.")
sys.exit(1)
checkpoints_crnn.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0]))
latest_crnn = checkpoints_crnn[-1]
print(f"Loading Baseline CRNN from {latest_crnn}")
crnn_model.load_state_dict(torch.load(latest_crnn, map_location=device))
# 3. Load the Trained GAN Generator
generator = Generator(latent_dim=100).to(device)
checkpoints_gan = glob.glob('weights/gan_generator_epoch_*.pth')
if not checkpoints_gan:
print("Error: Could not find GAN Generator weights.")
sys.exit(1)
checkpoints_gan.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0]))
latest_gan = checkpoints_gan[-1]
print(f"Loading GAN Generator from {latest_gan}")
generator.load_state_dict(torch.load(latest_gan, map_location=device))
# 4. Setup Optimizer & Loss
# Use a smaller learning rate for fine-tuning
optimizer = optim.Adam(crnn_model.parameters(), lr=0.0001)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
# 5. Start SSL Training Loop
train_ssl(crnn_model, generator, dataloader, optimizer, criterion, device, epochs=5, threshold=0.8)
print("SSL Training complete!") |