triumphh77's picture
Upload 13 files
f9a156f verified
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import os
import sys
import glob
# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from src.data.dataset import IAMDataset, collate_fn
from src.models.crnn import CRNN
from src.models.gan import Generator
# Define transforms matching training exactly
transform = transforms.Compose([
transforms.Resize((32, 1024)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
def decode_pseudo_labels(preds):
# preds: (seq_len, batch, classes)
_, max_preds = torch.max(preds, 2)
max_preds = max_preds.permute(1, 0) # (batch, seq_len)
targets_list = []
target_lengths = []
for batch_idx in range(max_preds.size(0)):
pred_seq = max_preds[batch_idx]
decoded_seq = []
for i in range(len(pred_seq)):
if pred_seq[i] != 0 and (i == 0 or pred_seq[i] != pred_seq[i-1]):
decoded_seq.append(pred_seq[i].item())
target_tensor = torch.tensor(decoded_seq, dtype=torch.long)
targets_list.append(target_tensor)
target_lengths.append(len(decoded_seq))
return targets_list, target_lengths
def train_ssl(model, generator, dataloader, optimizer, criterion, device, epochs=5, threshold=0.8, latent_dim=100):
"""
Pseudo-labeling approach for Semi-Supervised Learning.
Combines real labeled data with synthetic unlabeled data generated dynamically by the GAN.
"""
model.train()
generator.eval() # Generator is fixed during this phase
for epoch in range(epochs):
total_loss_real = 0
total_loss_fake = 0
for step, (labeled_imgs, labeled_texts, labeled_lengths) in enumerate(dataloader):
labeled_imgs = labeled_imgs.to(device)
labeled_texts = labeled_texts.to(device)
batch_size = labeled_imgs.size(0)
optimizer.zero_grad()
# ==============================================================
# 1. Train on Real Labeled Data
# ==============================================================
preds_l = model(labeled_imgs)
preds_l = preds_l.permute(1, 0, 2) # (seq_len, batch, classes)
input_lengths_l = torch.full(size=(preds_l.size(1),), fill_value=preds_l.size(0), dtype=torch.long)
targets_list_l = []
for i in range(labeled_texts.size(0)):
targets_list_l.append(labeled_texts[i][:labeled_lengths[i]])
targets_concat_l = torch.cat(targets_list_l)
loss_real = criterion(preds_l, targets_concat_l, input_lengths_l, labeled_lengths)
# ==============================================================
# 2. Train on Synthetic GAN Data (Pseudo-Labeling)
# ==============================================================
# Generate fake images
with torch.no_grad():
z = torch.randn(batch_size, latent_dim).to(device)
fake_imgs = generator(z) # Shape: (batch, 1, 32, 1024), range [-1, 1]
# Get pseudo-labels
model.eval()
preds_fake_eval = model(fake_imgs)
probs = torch.exp(preds_fake_eval) # Softmax probs
max_probs, _ = torch.max(probs, dim=2)
avg_confidence = max_probs.mean(dim=1)
# Mask confident predictions
mask = avg_confidence > threshold
model.train()
loss_fake = torch.tensor(0.0).to(device)
if mask.sum() > 0:
confident_imgs = fake_imgs[mask]
preds_fake = model(confident_imgs)
preds_fake_perm = preds_fake.permute(1, 0, 2)
# Decode the pseudo-labels into CTC targets
targets_list_u, target_lengths_u = decode_pseudo_labels(preds_fake_perm.detach())
# Filter out empty pseudo-labels
valid_idx = [i for i, length in enumerate(target_lengths_u) if length > 0]
if valid_idx:
valid_preds_fake_perm = preds_fake_perm[:, valid_idx, :]
valid_targets_list = [targets_list_u[i].to(device) for i in valid_idx]
valid_target_lengths = torch.tensor([target_lengths_u[i] for i in valid_idx], dtype=torch.long).to(device)
valid_targets_concat = torch.cat(valid_targets_list)
input_lengths_u = torch.full(size=(valid_preds_fake_perm.size(1),), fill_value=valid_preds_fake_perm.size(0), dtype=torch.long).to(device)
loss_fake = criterion(valid_preds_fake_perm, valid_targets_concat, input_lengths_u, valid_target_lengths)
# Scale down the fake loss slightly so it doesn't overwhelm real data
loss_fake = loss_fake * 0.5
# Total loss
total_loss = loss_real + loss_fake
total_loss.backward()
optimizer.step()
total_loss_real += loss_real.item()
total_loss_fake += loss_fake.item() if loss_fake > 0 else 0
if step % 20 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Step [{step}/{len(dataloader)}], Real Loss: {loss_real.item():.4f}, Fake Loss: {loss_fake.item() if loss_fake > 0 else 0:.4f}, Confident Fakes: {mask.sum().item()}/{batch_size}")
print(f"Epoch {epoch+1} Average Real Loss: {total_loss_real/len(dataloader):.4f}, Average Fake Loss: {total_loss_fake/len(dataloader):.4f}")
# Save checkpoints
os.makedirs('weights', exist_ok=True)
torch.save(model.state_dict(), f'weights/crnn_ssl_epoch_{epoch+1}.pth')
if __name__ == "__main__":
print("Starting Semi-Supervised Learning (SSL) Training Phase...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 1. Load Dataset
data_dir = 'data/iam_words'
csv_file = 'data/labels.csv'
dataset = IAMDataset(data_dir=data_dir, csv_file=csv_file, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
# 2. Load the Baseline CRNN Model
num_classes = dataset.num_classes
crnn_model = CRNN(img_channel=1, img_height=32, img_width=1024, num_class=num_classes).to(device)
checkpoints_crnn = glob.glob('weights/crnn_baseline_epoch_*.pth')
if not checkpoints_crnn:
print("Error: Could not find baseline CRNN weights.")
sys.exit(1)
checkpoints_crnn.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0]))
latest_crnn = checkpoints_crnn[-1]
print(f"Loading Baseline CRNN from {latest_crnn}")
crnn_model.load_state_dict(torch.load(latest_crnn, map_location=device))
# 3. Load the Trained GAN Generator
generator = Generator(latent_dim=100).to(device)
checkpoints_gan = glob.glob('weights/gan_generator_epoch_*.pth')
if not checkpoints_gan:
print("Error: Could not find GAN Generator weights.")
sys.exit(1)
checkpoints_gan.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0]))
latest_gan = checkpoints_gan[-1]
print(f"Loading GAN Generator from {latest_gan}")
generator.load_state_dict(torch.load(latest_gan, map_location=device))
# 4. Setup Optimizer & Loss
# Use a smaller learning rate for fine-tuning
optimizer = optim.Adam(crnn_model.parameters(), lr=0.0001)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
# 5. Start SSL Training Loop
train_ssl(crnn_model, generator, dataloader, optimizer, criterion, device, epochs=5, threshold=0.8)
print("SSL Training complete!")