import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import transforms import os import sys import glob # Add project root to path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) from src.data.dataset import IAMDataset, collate_fn from src.models.crnn import CRNN from src.models.gan import Generator # Define transforms matching training exactly transform = transforms.Compose([ transforms.Resize((32, 1024)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) def decode_pseudo_labels(preds): # preds: (seq_len, batch, classes) _, max_preds = torch.max(preds, 2) max_preds = max_preds.permute(1, 0) # (batch, seq_len) targets_list = [] target_lengths = [] for batch_idx in range(max_preds.size(0)): pred_seq = max_preds[batch_idx] decoded_seq = [] for i in range(len(pred_seq)): if pred_seq[i] != 0 and (i == 0 or pred_seq[i] != pred_seq[i-1]): decoded_seq.append(pred_seq[i].item()) target_tensor = torch.tensor(decoded_seq, dtype=torch.long) targets_list.append(target_tensor) target_lengths.append(len(decoded_seq)) return targets_list, target_lengths def train_ssl(model, generator, dataloader, optimizer, criterion, device, epochs=5, threshold=0.8, latent_dim=100): """ Pseudo-labeling approach for Semi-Supervised Learning. Combines real labeled data with synthetic unlabeled data generated dynamically by the GAN. """ model.train() generator.eval() # Generator is fixed during this phase for epoch in range(epochs): total_loss_real = 0 total_loss_fake = 0 for step, (labeled_imgs, labeled_texts, labeled_lengths) in enumerate(dataloader): labeled_imgs = labeled_imgs.to(device) labeled_texts = labeled_texts.to(device) batch_size = labeled_imgs.size(0) optimizer.zero_grad() # ============================================================== # 1. Train on Real Labeled Data # ============================================================== preds_l = model(labeled_imgs) preds_l = preds_l.permute(1, 0, 2) # (seq_len, batch, classes) input_lengths_l = torch.full(size=(preds_l.size(1),), fill_value=preds_l.size(0), dtype=torch.long) targets_list_l = [] for i in range(labeled_texts.size(0)): targets_list_l.append(labeled_texts[i][:labeled_lengths[i]]) targets_concat_l = torch.cat(targets_list_l) loss_real = criterion(preds_l, targets_concat_l, input_lengths_l, labeled_lengths) # ============================================================== # 2. Train on Synthetic GAN Data (Pseudo-Labeling) # ============================================================== # Generate fake images with torch.no_grad(): z = torch.randn(batch_size, latent_dim).to(device) fake_imgs = generator(z) # Shape: (batch, 1, 32, 1024), range [-1, 1] # Get pseudo-labels model.eval() preds_fake_eval = model(fake_imgs) probs = torch.exp(preds_fake_eval) # Softmax probs max_probs, _ = torch.max(probs, dim=2) avg_confidence = max_probs.mean(dim=1) # Mask confident predictions mask = avg_confidence > threshold model.train() loss_fake = torch.tensor(0.0).to(device) if mask.sum() > 0: confident_imgs = fake_imgs[mask] preds_fake = model(confident_imgs) preds_fake_perm = preds_fake.permute(1, 0, 2) # Decode the pseudo-labels into CTC targets targets_list_u, target_lengths_u = decode_pseudo_labels(preds_fake_perm.detach()) # Filter out empty pseudo-labels valid_idx = [i for i, length in enumerate(target_lengths_u) if length > 0] if valid_idx: valid_preds_fake_perm = preds_fake_perm[:, valid_idx, :] valid_targets_list = [targets_list_u[i].to(device) for i in valid_idx] valid_target_lengths = torch.tensor([target_lengths_u[i] for i in valid_idx], dtype=torch.long).to(device) valid_targets_concat = torch.cat(valid_targets_list) input_lengths_u = torch.full(size=(valid_preds_fake_perm.size(1),), fill_value=valid_preds_fake_perm.size(0), dtype=torch.long).to(device) loss_fake = criterion(valid_preds_fake_perm, valid_targets_concat, input_lengths_u, valid_target_lengths) # Scale down the fake loss slightly so it doesn't overwhelm real data loss_fake = loss_fake * 0.5 # Total loss total_loss = loss_real + loss_fake total_loss.backward() optimizer.step() total_loss_real += loss_real.item() total_loss_fake += loss_fake.item() if loss_fake > 0 else 0 if step % 20 == 0: print(f"Epoch [{epoch+1}/{epochs}], Step [{step}/{len(dataloader)}], Real Loss: {loss_real.item():.4f}, Fake Loss: {loss_fake.item() if loss_fake > 0 else 0:.4f}, Confident Fakes: {mask.sum().item()}/{batch_size}") print(f"Epoch {epoch+1} Average Real Loss: {total_loss_real/len(dataloader):.4f}, Average Fake Loss: {total_loss_fake/len(dataloader):.4f}") # Save checkpoints os.makedirs('weights', exist_ok=True) torch.save(model.state_dict(), f'weights/crnn_ssl_epoch_{epoch+1}.pth') if __name__ == "__main__": print("Starting Semi-Supervised Learning (SSL) Training Phase...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # 1. Load Dataset data_dir = 'data/iam_words' csv_file = 'data/labels.csv' dataset = IAMDataset(data_dir=data_dir, csv_file=csv_file, transform=transform) dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn) # 2. Load the Baseline CRNN Model num_classes = dataset.num_classes crnn_model = CRNN(img_channel=1, img_height=32, img_width=1024, num_class=num_classes).to(device) checkpoints_crnn = glob.glob('weights/crnn_baseline_epoch_*.pth') if not checkpoints_crnn: print("Error: Could not find baseline CRNN weights.") sys.exit(1) checkpoints_crnn.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0])) latest_crnn = checkpoints_crnn[-1] print(f"Loading Baseline CRNN from {latest_crnn}") crnn_model.load_state_dict(torch.load(latest_crnn, map_location=device)) # 3. Load the Trained GAN Generator generator = Generator(latent_dim=100).to(device) checkpoints_gan = glob.glob('weights/gan_generator_epoch_*.pth') if not checkpoints_gan: print("Error: Could not find GAN Generator weights.") sys.exit(1) checkpoints_gan.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0])) latest_gan = checkpoints_gan[-1] print(f"Loading GAN Generator from {latest_gan}") generator.load_state_dict(torch.load(latest_gan, map_location=device)) # 4. Setup Optimizer & Loss # Use a smaller learning rate for fine-tuning optimizer = optim.Adam(crnn_model.parameters(), lr=0.0001) criterion = nn.CTCLoss(blank=0, zero_infinity=True) # 5. Start SSL Training Loop train_ssl(crnn_model, generator, dataloader, optimizer, criterion, device, epochs=5, threshold=0.8) print("SSL Training complete!")