File size: 4,124 Bytes
05f4c27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Credit to @Rimuru for the ideas and original implementation.
# Trained on PNG illustrations and RAW photos converted to PNG that were then synthetically augmented at various quality levels.

# Got 95.3% overall validation accuracy with the lowest performance being JXL.
# Per-Format Val Acc:   jpeg: 99.7% | webp: 96.2% | avif: 96.3% | jxl: 94.3%

# Do not trust this for production, it will fail on edge cases and images with multiple compressions.

import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from pathlib import Path
from typing import Dict

Image.MAX_IMAGE_PIXELS = 120000000

class LightweightCompressionNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_blocks = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(),
            nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(),
            nn.AdaptiveAvgPool2d(1)
        )
        self.head = nn.Sequential(
            nn.Linear(256, 32), nn.GELU(),
            nn.Linear(32, 4), nn.Sigmoid()
        )

    def forward(self, x):
        features = self.conv_blocks(x)
        features = features.view(features.size(0), -1)
        return self.head(features)


class CompressionArtifactPredictor:
    def __init__(self, model_path: str, device: str = "cuda"):
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.model = LightweightCompressionNet().to(self.device)
        self.model.eval()
        checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.preprocess = transforms.Compose([transforms.ToTensor()])
        self.compression_formats = ['jpeg', 'webp', 'avif', 'jxl']
        self.quality_ranges = {'jpeg': (0, 100), 'webp': (0, 100), 'avif': (0, 100), 'jxl': (0, 100)}

    def predict(self, image: Image.Image) -> Dict[str, Dict[str, float]]:
        img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
        with torch.no_grad():
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                predictions = self.model(img_tensor).squeeze(0).cpu().float().numpy()
        results = {}
        for i, fmt in enumerate(self.compression_formats):
            normalized_score = float(predictions[i])
            min_q, max_q = self.quality_ranges[fmt]
            results[fmt] = {
                'normalized_score': normalized_score,
                'predicted_quality': normalized_score * (max_q - min_q) + min_q,
                'artifact_level': 1.0 - normalized_score
            }
        return results

    def predict_format(self, image: Image.Image, format_name: str) -> float:
        if format_name not in self.compression_formats:
            raise ValueError(f"Unsupported format. Choose from: {self.compression_formats}")
        return self.predict(image)[format_name]['predicted_quality']


if __name__ == "__main__":
    predictor = CompressionArtifactPredictor("quality_factor_estimation.pt")
    
    # Set your image path here!
    image_path = Path("./demo_imgs/cat-q75.jpg")
    
    image = Image.open(image_path).convert('RGB')
    
    # This assumes that there isnt any format trickery or many different compressions, tried to keep it simple for first iteration
    ext_map = {'.jpg': 'jpeg', '.jpeg': 'jpeg', '.webp': 'webp', '.avif': 'avif', '.jxl': 'jxl'}
    fmt = ext_map.get(image_path.suffix.lower())
    quality = predictor.predict_format(image, fmt)
    print(f"{image_path.name} - estimated to be q={quality:.2f}")