| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader |
| from torchvision import transforms |
| import os |
| import sys |
| import glob |
|
|
| |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) |
|
|
| from src.data.dataset import IAMDataset, collate_fn |
| from src.models.crnn import CRNN |
| from src.models.gan import Generator |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((32, 1024)), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5,), (0.5,)) |
| ]) |
|
|
| def decode_pseudo_labels(preds): |
| |
| _, max_preds = torch.max(preds, 2) |
| max_preds = max_preds.permute(1, 0) |
| |
| targets_list = [] |
| target_lengths = [] |
| |
| for batch_idx in range(max_preds.size(0)): |
| pred_seq = max_preds[batch_idx] |
| decoded_seq = [] |
| for i in range(len(pred_seq)): |
| if pred_seq[i] != 0 and (i == 0 or pred_seq[i] != pred_seq[i-1]): |
| decoded_seq.append(pred_seq[i].item()) |
| |
| target_tensor = torch.tensor(decoded_seq, dtype=torch.long) |
| targets_list.append(target_tensor) |
| target_lengths.append(len(decoded_seq)) |
| |
| return targets_list, target_lengths |
|
|
| def train_ssl(model, generator, dataloader, optimizer, criterion, device, epochs=5, threshold=0.8, latent_dim=100): |
| """ |
| Pseudo-labeling approach for Semi-Supervised Learning. |
| Combines real labeled data with synthetic unlabeled data generated dynamically by the GAN. |
| """ |
| model.train() |
| generator.eval() |
| |
| for epoch in range(epochs): |
| total_loss_real = 0 |
| total_loss_fake = 0 |
| |
| for step, (labeled_imgs, labeled_texts, labeled_lengths) in enumerate(dataloader): |
| labeled_imgs = labeled_imgs.to(device) |
| labeled_texts = labeled_texts.to(device) |
| batch_size = labeled_imgs.size(0) |
| |
| optimizer.zero_grad() |
| |
| |
| |
| |
| preds_l = model(labeled_imgs) |
| preds_l = preds_l.permute(1, 0, 2) |
| |
| input_lengths_l = torch.full(size=(preds_l.size(1),), fill_value=preds_l.size(0), dtype=torch.long) |
| |
| targets_list_l = [] |
| for i in range(labeled_texts.size(0)): |
| targets_list_l.append(labeled_texts[i][:labeled_lengths[i]]) |
| targets_concat_l = torch.cat(targets_list_l) |
| |
| loss_real = criterion(preds_l, targets_concat_l, input_lengths_l, labeled_lengths) |
| |
| |
| |
| |
| |
| with torch.no_grad(): |
| z = torch.randn(batch_size, latent_dim).to(device) |
| fake_imgs = generator(z) |
| |
| |
| model.eval() |
| preds_fake_eval = model(fake_imgs) |
| probs = torch.exp(preds_fake_eval) |
| max_probs, _ = torch.max(probs, dim=2) |
| avg_confidence = max_probs.mean(dim=1) |
| |
| |
| mask = avg_confidence > threshold |
| |
| model.train() |
| loss_fake = torch.tensor(0.0).to(device) |
| |
| if mask.sum() > 0: |
| confident_imgs = fake_imgs[mask] |
| |
| preds_fake = model(confident_imgs) |
| preds_fake_perm = preds_fake.permute(1, 0, 2) |
| |
| |
| targets_list_u, target_lengths_u = decode_pseudo_labels(preds_fake_perm.detach()) |
| |
| |
| valid_idx = [i for i, length in enumerate(target_lengths_u) if length > 0] |
| |
| if valid_idx: |
| valid_preds_fake_perm = preds_fake_perm[:, valid_idx, :] |
| valid_targets_list = [targets_list_u[i].to(device) for i in valid_idx] |
| valid_target_lengths = torch.tensor([target_lengths_u[i] for i in valid_idx], dtype=torch.long).to(device) |
| |
| valid_targets_concat = torch.cat(valid_targets_list) |
| input_lengths_u = torch.full(size=(valid_preds_fake_perm.size(1),), fill_value=valid_preds_fake_perm.size(0), dtype=torch.long).to(device) |
| |
| loss_fake = criterion(valid_preds_fake_perm, valid_targets_concat, input_lengths_u, valid_target_lengths) |
| |
| loss_fake = loss_fake * 0.5 |
|
|
| |
| total_loss = loss_real + loss_fake |
| total_loss.backward() |
| optimizer.step() |
| |
| total_loss_real += loss_real.item() |
| total_loss_fake += loss_fake.item() if loss_fake > 0 else 0 |
| |
| if step % 20 == 0: |
| print(f"Epoch [{epoch+1}/{epochs}], Step [{step}/{len(dataloader)}], Real Loss: {loss_real.item():.4f}, Fake Loss: {loss_fake.item() if loss_fake > 0 else 0:.4f}, Confident Fakes: {mask.sum().item()}/{batch_size}") |
| |
| print(f"Epoch {epoch+1} Average Real Loss: {total_loss_real/len(dataloader):.4f}, Average Fake Loss: {total_loss_fake/len(dataloader):.4f}") |
| |
| |
| os.makedirs('weights', exist_ok=True) |
| torch.save(model.state_dict(), f'weights/crnn_ssl_epoch_{epoch+1}.pth') |
|
|
| if __name__ == "__main__": |
| print("Starting Semi-Supervised Learning (SSL) Training Phase...") |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
| |
| |
| data_dir = 'data/iam_words' |
| csv_file = 'data/labels.csv' |
| dataset = IAMDataset(data_dir=data_dir, csv_file=csv_file, transform=transform) |
| dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn) |
| |
| |
| num_classes = dataset.num_classes |
| crnn_model = CRNN(img_channel=1, img_height=32, img_width=1024, num_class=num_classes).to(device) |
| |
| checkpoints_crnn = glob.glob('weights/crnn_baseline_epoch_*.pth') |
| if not checkpoints_crnn: |
| print("Error: Could not find baseline CRNN weights.") |
| sys.exit(1) |
| checkpoints_crnn.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0])) |
| latest_crnn = checkpoints_crnn[-1] |
| print(f"Loading Baseline CRNN from {latest_crnn}") |
| crnn_model.load_state_dict(torch.load(latest_crnn, map_location=device)) |
| |
| |
| generator = Generator(latent_dim=100).to(device) |
| checkpoints_gan = glob.glob('weights/gan_generator_epoch_*.pth') |
| if not checkpoints_gan: |
| print("Error: Could not find GAN Generator weights.") |
| sys.exit(1) |
| checkpoints_gan.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0])) |
| latest_gan = checkpoints_gan[-1] |
| print(f"Loading GAN Generator from {latest_gan}") |
| generator.load_state_dict(torch.load(latest_gan, map_location=device)) |
| |
| |
| |
| optimizer = optim.Adam(crnn_model.parameters(), lr=0.0001) |
| criterion = nn.CTCLoss(blank=0, zero_infinity=True) |
| |
| |
| train_ssl(crnn_model, generator, dataloader, optimizer, criterion, device, epochs=5, threshold=0.8) |
| print("SSL Training complete!") |