picky / infer.py
OJ-1's picture
Upload folder using huggingface_hub
05f4c27 verified
# Credit to @Rimuru for the ideas and original implementation.
# Trained on PNG illustrations and RAW photos converted to PNG that were then synthetically augmented at various quality levels.
# Got 95.3% overall validation accuracy with the lowest performance being JXL.
# Per-Format Val Acc: jpeg: 99.7% | webp: 96.2% | avif: 96.3% | jxl: 94.3%
# Do not trust this for production, it will fail on edge cases and images with multiple compressions.
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from pathlib import Path
from typing import Dict
Image.MAX_IMAGE_PIXELS = 120000000
class LightweightCompressionNet(nn.Module):
def __init__(self):
super().__init__()
self.conv_blocks = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(),
nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(),
nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(),
nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(),
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(),
nn.AdaptiveAvgPool2d(1)
)
self.head = nn.Sequential(
nn.Linear(256, 32), nn.GELU(),
nn.Linear(32, 4), nn.Sigmoid()
)
def forward(self, x):
features = self.conv_blocks(x)
features = features.view(features.size(0), -1)
return self.head(features)
class CompressionArtifactPredictor:
def __init__(self, model_path: str, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.model = LightweightCompressionNet().to(self.device)
self.model.eval()
checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.preprocess = transforms.Compose([transforms.ToTensor()])
self.compression_formats = ['jpeg', 'webp', 'avif', 'jxl']
self.quality_ranges = {'jpeg': (0, 100), 'webp': (0, 100), 'avif': (0, 100), 'jxl': (0, 100)}
def predict(self, image: Image.Image) -> Dict[str, Dict[str, float]]:
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
predictions = self.model(img_tensor).squeeze(0).cpu().float().numpy()
results = {}
for i, fmt in enumerate(self.compression_formats):
normalized_score = float(predictions[i])
min_q, max_q = self.quality_ranges[fmt]
results[fmt] = {
'normalized_score': normalized_score,
'predicted_quality': normalized_score * (max_q - min_q) + min_q,
'artifact_level': 1.0 - normalized_score
}
return results
def predict_format(self, image: Image.Image, format_name: str) -> float:
if format_name not in self.compression_formats:
raise ValueError(f"Unsupported format. Choose from: {self.compression_formats}")
return self.predict(image)[format_name]['predicted_quality']
if __name__ == "__main__":
predictor = CompressionArtifactPredictor("quality_factor_estimation.pt")
# Set your image path here!
image_path = Path("./demo_imgs/cat-q75.jpg")
image = Image.open(image_path).convert('RGB')
# This assumes that there isnt any format trickery or many different compressions, tried to keep it simple for first iteration
ext_map = {'.jpg': 'jpeg', '.jpeg': 'jpeg', '.webp': 'webp', '.avif': 'avif', '.jxl': 'jxl'}
fmt = ext_map.get(image_path.suffix.lower())
quality = predictor.predict_format(image, fmt)
print(f"{image_path.name} - estimated to be q={quality:.2f}")