import math import os import random import io from pathlib import Path from typing import Dict, List import numpy as np import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import tqdm import matplotlib.pyplot as plt import pillow_jxl from sklearn.model_selection import train_test_split import json # ==================== CONFIGURATION ==================== class Config: # Data IMAGE_DIR = "/path/to/images" CROP_SIZE = 512 COMPRESSION_FORMATS = ['jpeg', 'webp', 'avif', 'jxl'] QUALITY_RANGES = { 'jpeg': (0, 100), 'webp': (0, 100), 'avif': (0, 100), 'jxl': (0, 100) } # Training BATCH_SIZE = 1 INT_BATCH_SIZE = 2 EPOCHS = 50 LEARNING_RATE = 1e-4 VAL_SPLIT = 0.07 RANDOM_SEED = 42 SAVE_INTERVAL = 5 # Save intermediate checkpoints every N epochs # Model NUM_WORKERS = 32 # Paths CHECKPOINT_DIR = "./checkpoints" RESULTS_DIR = "./results" LOG_FILE = "./results/training_log.json" # ==================== UTILITIES ==================== def ensure_dir(path: str): Path(path).mkdir(parents=True, exist_ok=True) def quality_to_normalized(quality: float, type: str) -> float: """Normalize JPEG quality [0,100] to [0,1]""" if type == "avif": return quality / 100 if type == "jxl": return quality / 100 return quality / 100.0 def normalized_to_quality(normalized: float) -> float: """Denormalize back to JPEG quality range""" return normalized * 100.0 # ==================== COMPRESSION ==================== def compress_image(image: Image.Image, format_name: str, quality: int) -> Image.Image: buffer = io.BytesIO() if format_name == 'jpeg': image.save(buffer, format="JPEG", quality=int(quality)) elif format_name == 'webp': image.save(buffer, format="WEBP", quality=int(quality)) elif format_name == 'avif': image.save(buffer, format="AVIF", quality=int(quality)) elif format_name == 'jxl': image.save(buffer, format="JXL", quality=int(quality)) else: raise ValueError(f"Unknown format: {format_name}") buffer.seek(0) return Image.open(buffer).copy() # ==================== DATASET ==================== class CompressionDataset(Dataset): def __init__(self, image_paths: List[str], is_train: bool = True): self.image_paths = image_paths self.is_train = is_train self.spatial_transform = transforms.Compose([ transforms.RandomCrop(Config.CROP_SIZE, pad_if_needed=True) if is_train else transforms.CenterCrop(Config.CROP_SIZE), transforms.RandomHorizontalFlip(p=0.5) if is_train else nn.Identity(), transforms.RandomVerticalFlip(p=0.5) if is_train else nn.Identity(), ]) def __len__(self) -> int: return len(self.image_paths) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: path = self.image_paths[idx] image = Image.open(path).convert('RGB') image = self.spatial_transform(image) # Generate multiple compressed variants of SAME image images = [] targets = [] formats = [] for _ in range(Config.INT_BATCH_SIZE): quality = random.randint(0, 100) compressed = compress_image(image.copy(), "jpeg", quality) tensor = transforms.ToTensor()(compressed) images.append(tensor) targets.append(quality_to_normalized(quality, "jpeg")) formats.append(Config.COMPRESSION_FORMATS.index("jpeg")) quality = random.randint(0, 100) compressed = compress_image(image.copy(), "webp", quality) tensor = transforms.ToTensor()(compressed) images.append(tensor) targets.append(quality_to_normalized(quality, "webp")) formats.append(Config.COMPRESSION_FORMATS.index("webp")) quality = random.randint(0, 100) compressed = compress_image(image.copy(), "avif", quality) tensor = transforms.ToTensor()(compressed) images.append(tensor) targets.append(quality_to_normalized(quality, "avif")) formats.append(Config.COMPRESSION_FORMATS.index("avif")) quality = random.randint(0, 100) compressed = compress_image(image.copy(), "jxl", quality) tensor = transforms.ToTensor()(compressed) images.append(tensor) targets.append(quality_to_normalized(quality, "jxl")) formats.append(Config.COMPRESSION_FORMATS.index("jxl")) return { 'images': torch.stack(images), # [INT_BATCH_SIZE, C, H, W] 'targets': torch.tensor(targets, dtype=torch.float32), 'formats': torch.tensor(formats, dtype=torch.long) } # ==================== COLLATE ==================== def collate_grouped(batch: List[Dict]) -> Dict[str, torch.Tensor]: return { 'images': torch.stack([item['images'] for item in batch]), 'targets': torch.stack([item['targets'] for item in batch]), 'formats': torch.stack([item['formats'] for item in batch]) # [B, INT_BATCH_SIZE] } # ==================== MODEL ==================== class LightweightCompressionNet(nn.Module): def __init__(self): super().__init__() self.conv_blocks = nn.Sequential( # STRIDE 1: Preserve fine details for artifact detection nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(), # 512->509 nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(), # 509->506 # THEN accelerate: Align with DCT blocks nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(), # 506->251 nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(), # 251->124 nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), # 124->30 nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), # 30->7 nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(), # 7->3 nn.AdaptiveAvgPool2d(1) # 3->1 (learns to pool block patterns) ) # Keep head simple and small self.head = nn.Sequential( nn.Linear(256, 32), nn.GELU(), #nn.Dropout(0.15), nn.Linear(32, 4), nn.Sigmoid() ) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.Linear)): # Xavier is variance-preserving for GELU nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): features = self.conv_blocks(x) # (B, 256, 1, 1) features = features.view(features.size(0), -1) return self.head(features) #.squeeze(1) # ==================== TRAINING ==================== def train_epoch(model, loader, criterion, optimizer, device, epoch): model.train() total_loss = 0.0 total_acc = 0.0 num_samples = 0 # NEW: Per-format counters num_formats = len(Config.COMPRESSION_FORMATS) per_format_correct = torch.zeros(num_formats, device=device) per_format_count = torch.zeros(num_formats, device=device) loader.generator.manual_seed(Config.RANDOM_SEED + epoch) pbar = tqdm.tqdm(loader, desc=f"Epoch {epoch + 1}/{Config.EPOCHS}") for batch in pbar: images = batch['images'].to(device, non_blocking=True) targets = batch['targets'].to(device, non_blocking=True) formats = batch['formats'].to(device, non_blocking=True) B, V, C, H, W = images.shape images = images.reshape(B * V, C, H, W) targets = targets.reshape(B * V) formats = formats.reshape(B * V) with torch.cuda.amp.autocast(dtype=torch.bfloat16): predictions = model(images) # [B*V, 4] pred_correct = torch.gather(predictions, 1, formats.unsqueeze(1)).squeeze(1) loss = criterion(pred_correct, targets) # Optimization (unchanged except clip_grad_norm value) optimizer.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() # Accuracy calculations correct = (torch.abs(pred_correct.detach() - targets) <= 0.05).float() acc = correct.mean() * 100 # *** PER-FORMAT TRACKING *** per_format_correct.scatter_add_(0, formats, correct) per_format_count.scatter_add_(0, formats, torch.ones_like(correct)) batch_size = B * V total_loss += loss.item() * batch_size total_acc += acc.item() * batch_size num_samples += batch_size # Format string for tqdm fmt_str = " | ".join([ f"{fmt}: {per_format_correct[i].item() / (per_format_count[i].item() + 1e-8) * 100:.1f}%" for i, fmt in enumerate(Config.COMPRESSION_FORMATS) if per_format_count[i] > 0 ]) pbar.set_postfix_str(f"Avg: {total_loss / num_samples:.4f}, Acc: {total_acc / num_samples:.1f}% | [{fmt_str}]") # Return per-format accuracy as list per_format_acc = (per_format_correct / (per_format_count + 1e-8) * 100).cpu().tolist() return { 'loss': total_loss / num_samples, 'accuracy': total_acc / num_samples, 'per_format_accuracy': per_format_acc # [jpeg_acc, webp_acc, ...] } def validate(model, loader, criterion, device): model.eval() total_loss = 0.0 total_acc = 0.0 num_samples = 0 num_formats = len(Config.COMPRESSION_FORMATS) per_format_correct = torch.zeros(num_formats, device=device) per_format_count = torch.zeros(num_formats, device=device) with torch.no_grad(): pbar = tqdm.tqdm(loader, desc="Validation", leave=False) for batch in pbar: images = batch['images'].to(device) targets = batch['targets'].to(device) formats = batch['formats'].to(device) B, V, C, H, W = images.shape images = images.reshape(B * V, C, H, W) targets = targets.reshape(B * V) formats = formats.reshape(B * V) predictions = model(images) pred_correct = torch.gather(predictions, 1, formats.unsqueeze(1)).squeeze(1) loss = criterion(pred_correct, targets) correct = (torch.abs(pred_correct - targets) <= 0.05).float() acc = correct.mean() * 100 # Per-format tracking per_format_correct.scatter_add_(0, formats, correct) per_format_count.scatter_add_(0, formats, torch.ones_like(correct)) batch_size = B * V total_loss += loss.item() * batch_size total_acc += acc.item() * batch_size num_samples += batch_size fmt_str = " | ".join([ f"{fmt}: {per_format_correct[i].item() / (per_format_count[i].item() + 1e-8) * 100:.1f}%" for i, fmt in enumerate(Config.COMPRESSION_FORMATS) if per_format_count[i] > 0 ]) pbar.set_postfix_str( f"Avg: {total_loss / num_samples:.4f}, Acc: {total_acc / num_samples:.1f}% | [{fmt_str}]") per_format_acc = (per_format_correct / (per_format_count + 1e-8) * 100).cpu().tolist() return { 'loss': total_loss / num_samples, 'accuracy': total_acc / num_samples, 'per_format_accuracy': per_format_acc } # ==================== MAIN ==================== def main(): ensure_dir(Config.CHECKPOINT_DIR) ensure_dir(Config.RESULTS_DIR) device = torch.device('cuda') torch.manual_seed(Config.RANDOM_SEED) image_paths = [str(p) for p in Path(Config.IMAGE_DIR).rglob("*.png")] # rglob for subfolders if not image_paths: raise ValueError(f"No PNGs found in {Config.IMAGE_DIR}") train_paths, val_paths = train_test_split( image_paths, test_size=Config.VAL_SPLIT, random_state=Config.RANDOM_SEED ) print(f"Train: {len(train_paths)} | Val: {len(val_paths)}") train_dataset = CompressionDataset(train_paths, is_train=True) val_dataset = CompressionDataset(val_paths, is_train=False) train_loader = DataLoader( train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=Config.NUM_WORKERS, pin_memory=True, prefetch_factor=50, collate_fn=collate_grouped, generator=torch.Generator() # Reduced prefetch ) val_loader = DataLoader( val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=Config.NUM_WORKERS, pin_memory=True, prefetch_factor=10, collate_fn=collate_grouped ) model = LightweightCompressionNet().to(device) criterion = nn.MSELoss() optimizer = torch.optim.AdamW( model.parameters(), lr=Config.LEARNING_RATE, weight_decay=1e-4, betas=(0.9, 0.999) ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=Config.EPOCHS, eta_min=1e-6 ) param_count = sum(p.numel() for p in model.parameters()) print(f"\nModel: {param_count:,} parameters ({param_count * 4 / 1024:.1f}KB)") best_val_loss = float('inf') training_log = [] print("\nStarting training...") for epoch in range(Config.EPOCHS): train_metrics = train_epoch(model, train_loader, criterion, optimizer, device, epoch) val_metrics = validate(model, val_loader, criterion, device) scheduler.step() # Print summary with per-format breakdown print(f"\nEpoch {epoch + 1} Summary:") print(f" Train: Loss {train_metrics['loss']:.4f} | Acc {train_metrics['accuracy']:.1f}%") print(f" Val: Loss {val_metrics['loss']:.4f} | Acc {val_metrics['accuracy']:.1f}%") print(f" LR: {optimizer.param_groups[0]['lr']:.2e}") print(" Per-Format Train Acc:", " | ".join([ f"{fmt}: {acc:.1f}%" for fmt, acc in zip(Config.COMPRESSION_FORMATS, train_metrics['per_format_accuracy']) ])) print(" Per-Format Val Acc: ", " | ".join([ f"{fmt}: {acc:.1f}%" for fmt, acc in zip(Config.COMPRESSION_FORMATS, val_metrics['per_format_accuracy']) ])) # Logging with per-format metrics training_log.append({ 'epoch': epoch + 1, 'train_loss': train_metrics['loss'], 'val_loss': val_metrics['loss'], 'train_accuracy': train_metrics['accuracy'], 'val_accuracy': val_metrics['accuracy'], 'train_per_format_accuracy': dict(zip(Config.COMPRESSION_FORMATS, train_metrics['per_format_accuracy'])), 'val_per_format_accuracy': dict(zip(Config.COMPRESSION_FORMATS, val_metrics['per_format_accuracy'])) }) if val_metrics['loss'] < best_val_loss: best_val_loss = val_metrics['loss'] torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'val_loss': best_val_loss, 'val_accuracy': val_metrics['accuracy'] }, os.path.join(Config.CHECKPOINT_DIR, "best_model.pt")) print("✓ Saved best model") if (epoch + 1) % Config.SAVE_INTERVAL == 0: torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'train_loss': train_metrics['loss'], 'val_loss': val_metrics['loss'], 'train_accuracy': train_metrics['accuracy'], 'val_accuracy': val_metrics['accuracy'] }, os.path.join(Config.CHECKPOINT_DIR, f"model_epoch_{epoch + 1:03d}.pt")) print(f"✓ Saved checkpoint epoch {epoch + 1}") with open(Config.LOG_FILE, 'w') as f: json.dump(training_log, f, indent=2) # Plotting code... print(f"\nDone! Best val loss: {best_val_loss:.4f}") print(f"Results saved to {Config.RESULTS_DIR}") if __name__ == "__main__": main()