|
|
import math |
|
|
import os |
|
|
import random |
|
|
import io |
|
|
from pathlib import Path |
|
|
from typing import Dict, List |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import tqdm |
|
|
import matplotlib.pyplot as plt |
|
|
import pillow_jxl |
|
|
from sklearn.model_selection import train_test_split |
|
|
import json |
|
|
|
|
|
|
|
|
|
|
|
class Config: |
|
|
|
|
|
IMAGE_DIR = "/path/to/images" |
|
|
CROP_SIZE = 512 |
|
|
|
|
|
COMPRESSION_FORMATS = ['jpeg', 'webp', 'avif', 'jxl'] |
|
|
QUALITY_RANGES = { |
|
|
'jpeg': (0, 100), |
|
|
'webp': (0, 100), |
|
|
'avif': (0, 100), |
|
|
'jxl': (0, 100) |
|
|
} |
|
|
|
|
|
|
|
|
BATCH_SIZE = 1 |
|
|
INT_BATCH_SIZE = 2 |
|
|
EPOCHS = 50 |
|
|
LEARNING_RATE = 1e-4 |
|
|
VAL_SPLIT = 0.07 |
|
|
RANDOM_SEED = 42 |
|
|
SAVE_INTERVAL = 5 |
|
|
|
|
|
|
|
|
NUM_WORKERS = 32 |
|
|
|
|
|
|
|
|
CHECKPOINT_DIR = "./checkpoints" |
|
|
RESULTS_DIR = "./results" |
|
|
LOG_FILE = "./results/training_log.json" |
|
|
|
|
|
|
|
|
|
|
|
def ensure_dir(path: str): |
|
|
Path(path).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
def quality_to_normalized(quality: float, type: str) -> float: |
|
|
"""Normalize JPEG quality [0,100] to [0,1]""" |
|
|
if type == "avif": |
|
|
return quality / 100 |
|
|
if type == "jxl": |
|
|
return quality / 100 |
|
|
return quality / 100.0 |
|
|
|
|
|
|
|
|
def normalized_to_quality(normalized: float) -> float: |
|
|
"""Denormalize back to JPEG quality range""" |
|
|
return normalized * 100.0 |
|
|
|
|
|
|
|
|
|
|
|
def compress_image(image: Image.Image, format_name: str, quality: int) -> Image.Image: |
|
|
buffer = io.BytesIO() |
|
|
if format_name == 'jpeg': |
|
|
image.save(buffer, format="JPEG", quality=int(quality)) |
|
|
elif format_name == 'webp': |
|
|
image.save(buffer, format="WEBP", quality=int(quality)) |
|
|
elif format_name == 'avif': |
|
|
image.save(buffer, format="AVIF", quality=int(quality)) |
|
|
elif format_name == 'jxl': |
|
|
image.save(buffer, format="JXL", quality=int(quality)) |
|
|
else: |
|
|
raise ValueError(f"Unknown format: {format_name}") |
|
|
buffer.seek(0) |
|
|
return Image.open(buffer).copy() |
|
|
|
|
|
|
|
|
class CompressionDataset(Dataset): |
|
|
def __init__(self, image_paths: List[str], is_train: bool = True): |
|
|
self.image_paths = image_paths |
|
|
self.is_train = is_train |
|
|
|
|
|
self.spatial_transform = transforms.Compose([ |
|
|
transforms.RandomCrop(Config.CROP_SIZE, pad_if_needed=True) if is_train |
|
|
else transforms.CenterCrop(Config.CROP_SIZE), |
|
|
transforms.RandomHorizontalFlip(p=0.5) if is_train else nn.Identity(), |
|
|
transforms.RandomVerticalFlip(p=0.5) if is_train else nn.Identity(), |
|
|
]) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.image_paths) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
|
path = self.image_paths[idx] |
|
|
image = Image.open(path).convert('RGB') |
|
|
image = self.spatial_transform(image) |
|
|
|
|
|
|
|
|
images = [] |
|
|
targets = [] |
|
|
formats = [] |
|
|
|
|
|
for _ in range(Config.INT_BATCH_SIZE): |
|
|
quality = random.randint(0, 100) |
|
|
compressed = compress_image(image.copy(), "jpeg", quality) |
|
|
tensor = transforms.ToTensor()(compressed) |
|
|
images.append(tensor) |
|
|
targets.append(quality_to_normalized(quality, "jpeg")) |
|
|
formats.append(Config.COMPRESSION_FORMATS.index("jpeg")) |
|
|
|
|
|
quality = random.randint(0, 100) |
|
|
compressed = compress_image(image.copy(), "webp", quality) |
|
|
tensor = transforms.ToTensor()(compressed) |
|
|
images.append(tensor) |
|
|
targets.append(quality_to_normalized(quality, "webp")) |
|
|
formats.append(Config.COMPRESSION_FORMATS.index("webp")) |
|
|
|
|
|
quality = random.randint(0, 100) |
|
|
compressed = compress_image(image.copy(), "avif", quality) |
|
|
tensor = transforms.ToTensor()(compressed) |
|
|
images.append(tensor) |
|
|
targets.append(quality_to_normalized(quality, "avif")) |
|
|
formats.append(Config.COMPRESSION_FORMATS.index("avif")) |
|
|
|
|
|
quality = random.randint(0, 100) |
|
|
compressed = compress_image(image.copy(), "jxl", quality) |
|
|
tensor = transforms.ToTensor()(compressed) |
|
|
images.append(tensor) |
|
|
targets.append(quality_to_normalized(quality, "jxl")) |
|
|
formats.append(Config.COMPRESSION_FORMATS.index("jxl")) |
|
|
|
|
|
return { |
|
|
'images': torch.stack(images), |
|
|
'targets': torch.tensor(targets, dtype=torch.float32), |
|
|
'formats': torch.tensor(formats, dtype=torch.long) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def collate_grouped(batch: List[Dict]) -> Dict[str, torch.Tensor]: |
|
|
return { |
|
|
'images': torch.stack([item['images'] for item in batch]), |
|
|
'targets': torch.stack([item['targets'] for item in batch]), |
|
|
'formats': torch.stack([item['formats'] for item in batch]) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class LightweightCompressionNet(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.conv_blocks = nn.Sequential( |
|
|
|
|
|
nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(), |
|
|
nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(), |
|
|
|
|
|
|
|
|
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(), |
|
|
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(), |
|
|
nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), |
|
|
nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), |
|
|
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(), |
|
|
nn.AdaptiveAvgPool2d(1) |
|
|
) |
|
|
|
|
|
|
|
|
self.head = nn.Sequential( |
|
|
nn.Linear(256, 32), |
|
|
nn.GELU(), |
|
|
|
|
|
nn.Linear(32, 4), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
for m in self.modules(): |
|
|
if isinstance(m, (nn.Conv2d, nn.Linear)): |
|
|
|
|
|
nn.init.xavier_uniform_(m.weight) |
|
|
if m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
def forward(self, x): |
|
|
features = self.conv_blocks(x) |
|
|
features = features.view(features.size(0), -1) |
|
|
return self.head(features) |
|
|
|
|
|
|
|
|
|
|
|
def train_epoch(model, loader, criterion, optimizer, device, epoch): |
|
|
model.train() |
|
|
total_loss = 0.0 |
|
|
total_acc = 0.0 |
|
|
num_samples = 0 |
|
|
|
|
|
|
|
|
num_formats = len(Config.COMPRESSION_FORMATS) |
|
|
per_format_correct = torch.zeros(num_formats, device=device) |
|
|
per_format_count = torch.zeros(num_formats, device=device) |
|
|
|
|
|
loader.generator.manual_seed(Config.RANDOM_SEED + epoch) |
|
|
pbar = tqdm.tqdm(loader, desc=f"Epoch {epoch + 1}/{Config.EPOCHS}") |
|
|
|
|
|
for batch in pbar: |
|
|
images = batch['images'].to(device, non_blocking=True) |
|
|
targets = batch['targets'].to(device, non_blocking=True) |
|
|
formats = batch['formats'].to(device, non_blocking=True) |
|
|
|
|
|
B, V, C, H, W = images.shape |
|
|
images = images.reshape(B * V, C, H, W) |
|
|
targets = targets.reshape(B * V) |
|
|
formats = formats.reshape(B * V) |
|
|
|
|
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
|
|
predictions = model(images) |
|
|
pred_correct = torch.gather(predictions, 1, formats.unsqueeze(1)).squeeze(1) |
|
|
loss = criterion(pred_correct, targets) |
|
|
|
|
|
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
correct = (torch.abs(pred_correct.detach() - targets) <= 0.05).float() |
|
|
acc = correct.mean() * 100 |
|
|
|
|
|
|
|
|
per_format_correct.scatter_add_(0, formats, correct) |
|
|
per_format_count.scatter_add_(0, formats, torch.ones_like(correct)) |
|
|
|
|
|
batch_size = B * V |
|
|
total_loss += loss.item() * batch_size |
|
|
total_acc += acc.item() * batch_size |
|
|
num_samples += batch_size |
|
|
|
|
|
|
|
|
fmt_str = " | ".join([ |
|
|
f"{fmt}: {per_format_correct[i].item() / (per_format_count[i].item() + 1e-8) * 100:.1f}%" |
|
|
for i, fmt in enumerate(Config.COMPRESSION_FORMATS) if per_format_count[i] > 0 |
|
|
]) |
|
|
|
|
|
pbar.set_postfix_str(f"Avg: {total_loss / num_samples:.4f}, Acc: {total_acc / num_samples:.1f}% | [{fmt_str}]") |
|
|
|
|
|
|
|
|
per_format_acc = (per_format_correct / (per_format_count + 1e-8) * 100).cpu().tolist() |
|
|
return { |
|
|
'loss': total_loss / num_samples, |
|
|
'accuracy': total_acc / num_samples, |
|
|
'per_format_accuracy': per_format_acc |
|
|
} |
|
|
|
|
|
|
|
|
def validate(model, loader, criterion, device): |
|
|
model.eval() |
|
|
total_loss = 0.0 |
|
|
total_acc = 0.0 |
|
|
num_samples = 0 |
|
|
|
|
|
num_formats = len(Config.COMPRESSION_FORMATS) |
|
|
per_format_correct = torch.zeros(num_formats, device=device) |
|
|
per_format_count = torch.zeros(num_formats, device=device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
pbar = tqdm.tqdm(loader, desc="Validation", leave=False) |
|
|
for batch in pbar: |
|
|
images = batch['images'].to(device) |
|
|
targets = batch['targets'].to(device) |
|
|
formats = batch['formats'].to(device) |
|
|
|
|
|
B, V, C, H, W = images.shape |
|
|
images = images.reshape(B * V, C, H, W) |
|
|
targets = targets.reshape(B * V) |
|
|
formats = formats.reshape(B * V) |
|
|
|
|
|
predictions = model(images) |
|
|
pred_correct = torch.gather(predictions, 1, formats.unsqueeze(1)).squeeze(1) |
|
|
|
|
|
loss = criterion(pred_correct, targets) |
|
|
correct = (torch.abs(pred_correct - targets) <= 0.05).float() |
|
|
acc = correct.mean() * 100 |
|
|
|
|
|
|
|
|
per_format_correct.scatter_add_(0, formats, correct) |
|
|
per_format_count.scatter_add_(0, formats, torch.ones_like(correct)) |
|
|
|
|
|
batch_size = B * V |
|
|
total_loss += loss.item() * batch_size |
|
|
total_acc += acc.item() * batch_size |
|
|
num_samples += batch_size |
|
|
|
|
|
fmt_str = " | ".join([ |
|
|
f"{fmt}: {per_format_correct[i].item() / (per_format_count[i].item() + 1e-8) * 100:.1f}%" |
|
|
for i, fmt in enumerate(Config.COMPRESSION_FORMATS) if per_format_count[i] > 0 |
|
|
]) |
|
|
pbar.set_postfix_str( |
|
|
f"Avg: {total_loss / num_samples:.4f}, Acc: {total_acc / num_samples:.1f}% | [{fmt_str}]") |
|
|
|
|
|
per_format_acc = (per_format_correct / (per_format_count + 1e-8) * 100).cpu().tolist() |
|
|
return { |
|
|
'loss': total_loss / num_samples, |
|
|
'accuracy': total_acc / num_samples, |
|
|
'per_format_accuracy': per_format_acc |
|
|
} |
|
|
|
|
|
|
|
|
def main(): |
|
|
ensure_dir(Config.CHECKPOINT_DIR) |
|
|
ensure_dir(Config.RESULTS_DIR) |
|
|
|
|
|
device = torch.device('cuda') |
|
|
torch.manual_seed(Config.RANDOM_SEED) |
|
|
|
|
|
image_paths = [str(p) for p in Path(Config.IMAGE_DIR).rglob("*.png")] |
|
|
if not image_paths: |
|
|
raise ValueError(f"No PNGs found in {Config.IMAGE_DIR}") |
|
|
|
|
|
train_paths, val_paths = train_test_split( |
|
|
image_paths, test_size=Config.VAL_SPLIT, random_state=Config.RANDOM_SEED |
|
|
) |
|
|
print(f"Train: {len(train_paths)} | Val: {len(val_paths)}") |
|
|
|
|
|
train_dataset = CompressionDataset(train_paths, is_train=True) |
|
|
val_dataset = CompressionDataset(val_paths, is_train=False) |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, |
|
|
num_workers=Config.NUM_WORKERS, pin_memory=True, |
|
|
prefetch_factor=50, collate_fn=collate_grouped, generator=torch.Generator() |
|
|
) |
|
|
val_loader = DataLoader( |
|
|
val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, |
|
|
num_workers=Config.NUM_WORKERS, pin_memory=True, |
|
|
prefetch_factor=10, collate_fn=collate_grouped |
|
|
) |
|
|
|
|
|
model = LightweightCompressionNet().to(device) |
|
|
criterion = nn.MSELoss() |
|
|
optimizer = torch.optim.AdamW( |
|
|
model.parameters(), lr=Config.LEARNING_RATE, |
|
|
weight_decay=1e-4, betas=(0.9, 0.999) |
|
|
) |
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
|
|
optimizer, T_max=Config.EPOCHS, eta_min=1e-6 |
|
|
) |
|
|
|
|
|
param_count = sum(p.numel() for p in model.parameters()) |
|
|
print(f"\nModel: {param_count:,} parameters ({param_count * 4 / 1024:.1f}KB)") |
|
|
|
|
|
best_val_loss = float('inf') |
|
|
training_log = [] |
|
|
|
|
|
print("\nStarting training...") |
|
|
for epoch in range(Config.EPOCHS): |
|
|
train_metrics = train_epoch(model, train_loader, criterion, optimizer, device, epoch) |
|
|
val_metrics = validate(model, val_loader, criterion, device) |
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
print(f"\nEpoch {epoch + 1} Summary:") |
|
|
print(f" Train: Loss {train_metrics['loss']:.4f} | Acc {train_metrics['accuracy']:.1f}%") |
|
|
print(f" Val: Loss {val_metrics['loss']:.4f} | Acc {val_metrics['accuracy']:.1f}%") |
|
|
print(f" LR: {optimizer.param_groups[0]['lr']:.2e}") |
|
|
|
|
|
print(" Per-Format Train Acc:", " | ".join([ |
|
|
f"{fmt}: {acc:.1f}%" for fmt, acc in zip(Config.COMPRESSION_FORMATS, train_metrics['per_format_accuracy']) |
|
|
])) |
|
|
print(" Per-Format Val Acc: ", " | ".join([ |
|
|
f"{fmt}: {acc:.1f}%" for fmt, acc in zip(Config.COMPRESSION_FORMATS, val_metrics['per_format_accuracy']) |
|
|
])) |
|
|
|
|
|
|
|
|
training_log.append({ |
|
|
'epoch': epoch + 1, |
|
|
'train_loss': train_metrics['loss'], |
|
|
'val_loss': val_metrics['loss'], |
|
|
'train_accuracy': train_metrics['accuracy'], |
|
|
'val_accuracy': val_metrics['accuracy'], |
|
|
'train_per_format_accuracy': dict(zip(Config.COMPRESSION_FORMATS, train_metrics['per_format_accuracy'])), |
|
|
'val_per_format_accuracy': dict(zip(Config.COMPRESSION_FORMATS, val_metrics['per_format_accuracy'])) |
|
|
}) |
|
|
|
|
|
if val_metrics['loss'] < best_val_loss: |
|
|
best_val_loss = val_metrics['loss'] |
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'val_loss': best_val_loss, |
|
|
'val_accuracy': val_metrics['accuracy'] |
|
|
}, os.path.join(Config.CHECKPOINT_DIR, "best_model.pt")) |
|
|
print("✓ Saved best model") |
|
|
|
|
|
if (epoch + 1) % Config.SAVE_INTERVAL == 0: |
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'train_loss': train_metrics['loss'], |
|
|
'val_loss': val_metrics['loss'], |
|
|
'train_accuracy': train_metrics['accuracy'], |
|
|
'val_accuracy': val_metrics['accuracy'] |
|
|
}, os.path.join(Config.CHECKPOINT_DIR, f"model_epoch_{epoch + 1:03d}.pt")) |
|
|
print(f"✓ Saved checkpoint epoch {epoch + 1}") |
|
|
|
|
|
with open(Config.LOG_FILE, 'w') as f: |
|
|
json.dump(training_log, f, indent=2) |
|
|
|
|
|
|
|
|
print(f"\nDone! Best val loss: {best_val_loss:.4f}") |
|
|
print(f"Results saved to {Config.RESULTS_DIR}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |