# Credit to @Rimuru for the ideas and original implementation. # Trained on PNG illustrations and RAW photos converted to PNG that were then synthetically augmented at various quality levels. # Got 95.3% overall validation accuracy with the lowest performance being JXL. # Per-Format Val Acc: jpeg: 99.7% | webp: 96.2% | avif: 96.3% | jxl: 94.3% # Do not trust this for production, it will fail on edge cases and images with multiple compressions. import torch import torch.nn as nn from PIL import Image from torchvision import transforms from pathlib import Path from typing import Dict Image.MAX_IMAGE_PIXELS = 120000000 class LightweightCompressionNet(nn.Module): def __init__(self): super().__init__() self.conv_blocks = nn.Sequential( nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(), nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(), nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(), nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(), nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(), nn.AdaptiveAvgPool2d(1) ) self.head = nn.Sequential( nn.Linear(256, 32), nn.GELU(), nn.Linear(32, 4), nn.Sigmoid() ) def forward(self, x): features = self.conv_blocks(x) features = features.view(features.size(0), -1) return self.head(features) class CompressionArtifactPredictor: def __init__(self, model_path: str, device: str = "cuda"): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.model = LightweightCompressionNet().to(self.device) self.model.eval() checkpoint = torch.load(model_path, map_location=self.device, weights_only=True) self.model.load_state_dict(checkpoint['model_state_dict']) self.preprocess = transforms.Compose([transforms.ToTensor()]) self.compression_formats = ['jpeg', 'webp', 'avif', 'jxl'] self.quality_ranges = {'jpeg': (0, 100), 'webp': (0, 100), 'avif': (0, 100), 'jxl': (0, 100)} def predict(self, image: Image.Image) -> Dict[str, Dict[str, float]]: img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) with torch.no_grad(): with torch.amp.autocast('cuda', dtype=torch.bfloat16): predictions = self.model(img_tensor).squeeze(0).cpu().float().numpy() results = {} for i, fmt in enumerate(self.compression_formats): normalized_score = float(predictions[i]) min_q, max_q = self.quality_ranges[fmt] results[fmt] = { 'normalized_score': normalized_score, 'predicted_quality': normalized_score * (max_q - min_q) + min_q, 'artifact_level': 1.0 - normalized_score } return results def predict_format(self, image: Image.Image, format_name: str) -> float: if format_name not in self.compression_formats: raise ValueError(f"Unsupported format. Choose from: {self.compression_formats}") return self.predict(image)[format_name]['predicted_quality'] if __name__ == "__main__": predictor = CompressionArtifactPredictor("quality_factor_estimation.pt") # Set your image path here! image_path = Path("./demo_imgs/cat-q75.jpg") image = Image.open(image_path).convert('RGB') # This assumes that there isnt any format trickery or many different compressions, tried to keep it simple for first iteration ext_map = {'.jpg': 'jpeg', '.jpeg': 'jpeg', '.webp': 'webp', '.avif': 'avif', '.jxl': 'jxl'} fmt = ext_map.get(image_path.suffix.lower()) quality = predictor.predict_format(image, fmt) print(f"{image_path.name} - estimated to be q={quality:.2f}")