import math import os import random import io from pathlib import Path from typing import Dict, List import numpy as np import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import tqdm import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split import json # ==================== CONFIGURATION ==================== class Config: # Data IMAGE_DIR = "/path/to/images" CROP_SIZE = 512 # Training BATCH_SIZE = 1 INT_BATCH_SIZE = 8 EPOCHS = 50 LEARNING_RATE = 1e-4 VAL_SPLIT = 0.07 RANDOM_SEED = 42 SAVE_INTERVAL = 5 # Save intermediate checkpoints every N epochs # Model NUM_WORKERS = 32 # Paths CHECKPOINT_DIR = "./checkpoints" RESULTS_DIR = "./results" LOG_FILE = "./results/training_log.json" # ==================== UTILITIES ==================== def ensure_dir(path: str): Path(path).mkdir(parents=True, exist_ok=True) def quality_to_normalized(quality: float) -> float: """Normalize JPEG quality [0,100] to [0,1]""" return quality / 100.0 def normalized_to_quality(normalized: float) -> float: """Denormalize back to JPEG quality range""" return normalized * 100.0 # ==================== COMPRESSION ==================== def compress_jpeg(image: Image.Image, quality: int) -> Image.Image: buffer = io.BytesIO() image.save(buffer, format="JPEG", quality=int(quality)) buffer.seek(0) return Image.open(buffer).copy() # ==================== DATASET ==================== class CompressionDataset(Dataset): def __init__(self, image_paths: List[str], is_train: bool = True): self.image_paths = image_paths self.is_train = is_train self.spatial_transform = transforms.Compose([ transforms.RandomCrop(Config.CROP_SIZE, pad_if_needed=True) if is_train else transforms.CenterCrop(Config.CROP_SIZE), transforms.RandomHorizontalFlip(p=0.5) if is_train else nn.Identity(), transforms.RandomVerticalFlip(p=0.5) if is_train else nn.Identity(), ]) def __len__(self) -> int: return len(self.image_paths) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: path = self.image_paths[idx] image = Image.open(path).convert('RGB') image = self.spatial_transform(image) # Generate multiple compressed variants of SAME image images = [] targets = [] for _ in range(Config.INT_BATCH_SIZE): quality = random.randint(0, 100) compressed = compress_jpeg(image.copy(), quality) tensor = transforms.ToTensor()(compressed) images.append(tensor) targets.append(quality_to_normalized(quality)) return { 'images': torch.stack(images), # [INT_BATCH_SIZE, C, H, W] 'targets': torch.tensor(targets, dtype=torch.float32) } # ==================== COLLATE ==================== def collate_grouped(batch: List[Dict]) -> Dict[str, torch.Tensor]: """Stack images and targets from multiple groups""" all_images = torch.stack([item['images'] for item in batch]) # [B, INT_BATCH_SIZE, C, H, W] all_targets = torch.stack([item['targets'] for item in batch]) # [B, INT_BATCH_SIZE] return {'images': all_images, 'targets': all_targets} # ==================== MODEL ==================== class LightweightCompressionNet(nn.Module): def __init__(self): super().__init__() # Gradual stride: 512->509->506->251->124->30->7->3->1 self.conv_blocks = nn.Sequential( # STRIDE 1: Preserve fine details for artifact detection nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(), # 512->509 nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(), # 509->506 # THEN accelerate: Align with DCT blocks nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(), # 506->251 nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(), # 251->124 nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), # 124->30 nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), # 30->7 nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(), # 7->3 nn.AdaptiveAvgPool2d(1) # 3->1 (learns to pool block patterns) ) # Keep head simple and small self.head = nn.Sequential( nn.Linear(256, 32), nn.GELU(), nn.Linear(32, 1), nn.Sigmoid() ) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.Linear)): # Xavier is variance-preserving for GELU nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): features = self.conv_blocks(x) # (B, 256, 1, 1) features = features.view(features.size(0), -1) return self.head(features).squeeze(1) # ==================== TRAINING ==================== def train_epoch(model, loader, criterion, optimizer, device, epoch): model.train() total_loss = 0.0 total_acc = 0.0 num_samples = 0 loader.generator.manual_seed(Config.RANDOM_SEED + epoch) pbar = tqdm.tqdm(loader, desc=f"Epoch {epoch + 1}/{Config.EPOCHS}") for batch in pbar: images = batch['images'].to(device, non_blocking=True) # [B, INT_BATCH_SIZE, C, H, W] targets = batch['targets'].to(device, non_blocking=True) # [B, INT_BATCH_SIZE] # Flatten: process each variant independently B, V, C, H, W = images.shape images = images.reshape(B * V, C, H, W) targets = targets.reshape(B * V) with torch.cuda.amp.autocast(dtype=torch.bfloat16): predictions = model(images) loss = criterion(predictions.float(), targets) optimizer.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() acc = (torch.abs(predictions.detach() - targets) <= 0.05).float().mean() * 100 batch_size = B * V total_loss += loss.item() * batch_size total_acc += acc.item() * batch_size num_samples += batch_size pbar.set_postfix_str( f"Loss: {loss.item():.4f}, Avg: {total_loss / num_samples:.4f}, Acc: {total_acc / num_samples:.1f}%" ) return {'loss': total_loss / num_samples, 'accuracy': total_acc / num_samples} def validate(model, loader, criterion, device): model.eval() total_loss = 0.0 total_acc = 0.0 num_samples = 0 with torch.no_grad(): pbar = tqdm.tqdm(loader, desc="Validation", leave=False) for batch in pbar: images = batch['images'].to(device, non_blocking=True) targets = batch['targets'].to(device, non_blocking=True) B, V, C, H, W = images.shape images = images.reshape(B * V, C, H, W) targets = targets.reshape(B * V) predictions = model(images) loss = criterion(predictions, targets) acc = (torch.abs(predictions - targets) <= 0.05).float().mean() * 100 batch_size = B * V total_loss += loss.item() * batch_size total_acc += acc.item() * batch_size num_samples += batch_size pbar.set_postfix_str( f"Avg Loss: {total_loss / num_samples:.4f}, Avg Acc: {total_acc / num_samples:.1f}%" ) return {'loss': total_loss / num_samples, 'accuracy': total_acc / num_samples} # ==================== MAIN ==================== def main(): ensure_dir(Config.CHECKPOINT_DIR) ensure_dir(Config.RESULTS_DIR) device = torch.device('cuda') torch.manual_seed(Config.RANDOM_SEED) image_paths = [str(p) for p in Path(Config.IMAGE_DIR).rglob("*.png")] # rglob for subfolders if not image_paths: raise ValueError(f"No PNGs found in {Config.IMAGE_DIR}") train_paths, val_paths = train_test_split( image_paths, test_size=Config.VAL_SPLIT, random_state=Config.RANDOM_SEED ) print(f"Train: {len(train_paths)} | Val: {len(val_paths)}") train_dataset = CompressionDataset(train_paths, is_train=True) val_dataset = CompressionDataset(val_paths, is_train=False) train_loader = DataLoader( train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=Config.NUM_WORKERS, pin_memory=True, prefetch_factor=50, collate_fn=collate_grouped, generator=torch.Generator() # Reduced prefetch ) val_loader = DataLoader( val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=Config.NUM_WORKERS, pin_memory=True, prefetch_factor=10, collate_fn=collate_grouped ) model = LightweightCompressionNet().to(device) criterion = nn.MSELoss() optimizer = torch.optim.AdamW( model.parameters(), lr=Config.LEARNING_RATE, weight_decay=1e-4, betas=(0.9, 0.999) ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=Config.EPOCHS, eta_min=1e-6 ) param_count = sum(p.numel() for p in model.parameters()) print(f"\nModel: {param_count:,} parameters ({param_count * 4 / 1024:.1f}KB)") best_val_loss = float('inf') training_log = [] print("\nStarting training...") for epoch in range(Config.EPOCHS): train_metrics = train_epoch(model, train_loader, criterion, optimizer, device, epoch) val_metrics = validate(model, val_loader, criterion, device) scheduler.step() print( f"\nEpoch {epoch + 1} | " f"Train Loss: {train_metrics['loss']:.4f} | Train Acc: {train_metrics['accuracy']:.1f}% | " f"Val Loss: {val_metrics['loss']:.4f} | Val Acc: {val_metrics['accuracy']:.1f}% | " f"LR: {optimizer.param_groups[0]['lr']:.2e}" ) if val_metrics['loss'] < best_val_loss: best_val_loss = val_metrics['loss'] torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'val_loss': best_val_loss, 'val_accuracy': val_metrics['accuracy'] }, os.path.join(Config.CHECKPOINT_DIR, "best_model.pt")) print("✓ Saved best model") if (epoch + 1) % Config.SAVE_INTERVAL == 0: torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'train_loss': train_metrics['loss'], 'val_loss': val_metrics['loss'], 'train_accuracy': train_metrics['accuracy'], 'val_accuracy': val_metrics['accuracy'] }, os.path.join(Config.CHECKPOINT_DIR, f"model_epoch_{epoch + 1:03d}.pt")) print(f"✓ Saved checkpoint epoch {epoch + 1}") training_log.append({ 'epoch': epoch + 1, 'train_loss': train_metrics['loss'], 'val_loss': val_metrics['loss'], 'train_accuracy': train_metrics['accuracy'], 'val_accuracy': val_metrics['accuracy'] }) with open(Config.LOG_FILE, 'w') as f: json.dump(training_log, f, indent=2) # Plotting code... print(f"\nDone! Best val loss: {best_val_loss:.4f}") print(f"Results saved to {Config.RESULTS_DIR}") if __name__ == "__main__": main()