OJ-1 commited on
Commit
05f4c27
·
verified ·
1 Parent(s): a8dfb2b

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo_imgs/cat-q20.jpg filter=lfs diff=lfs merge=lfs -text
37
+ demo_imgs/cat-q75.jpg filter=lfs diff=lfs merge=lfs -text
38
+ demo_imgs/Doughnut-q30.jpg filter=lfs diff=lfs merge=lfs -text
39
+ demo_imgs/fail-case.jpg filter=lfs diff=lfs merge=lfs -text
40
+ demo_imgs/random-screenshot-q90.jpg filter=lfs diff=lfs merge=lfs -text
demo_imgs/Doughnut-q30.jpg ADDED

Git LFS Details

  • SHA256: 5d72f82913e3b615cd5ce1dd95986707a4ba71b9d1deeef47740fd3e3986d038
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
demo_imgs/cat-q20.jpg ADDED

Git LFS Details

  • SHA256: 7c7573e01c83305b429fcb241aab106746deb400b0d525d9814c583b1bd6c54e
  • Pointer size: 131 Bytes
  • Size of remote file: 354 kB
demo_imgs/cat-q75.jpg ADDED

Git LFS Details

  • SHA256: 216f70daf6a86b37811cb6590f2fceae6e204a8f7b16042cae31b7b09470677e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
demo_imgs/fail-case.jpg ADDED

Git LFS Details

  • SHA256: 13038cd468095314706aeed70f3ddf12f8e66306f3c823d83459387c15a4ce98
  • Pointer size: 131 Bytes
  • Size of remote file: 109 kB
demo_imgs/random-screenshot-q48.jpg ADDED
demo_imgs/random-screenshot-q90.jpg ADDED

Git LFS Details

  • SHA256: 6aebb5f6b3b7bf5aeeba4d878029a8b578cbe83c411a0e5e34a51073491c1664
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB
infer.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Credit to @Rimuru for the ideas and original implementation.
2
+ # Trained on PNG illustrations and RAW photos converted to PNG that were then synthetically augmented at various quality levels.
3
+
4
+ # Got 95.3% overall validation accuracy with the lowest performance being JXL.
5
+ # Per-Format Val Acc: jpeg: 99.7% | webp: 96.2% | avif: 96.3% | jxl: 94.3%
6
+
7
+ # Do not trust this for production, it will fail on edge cases and images with multiple compressions.
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+ from pathlib import Path
14
+ from typing import Dict
15
+
16
+ Image.MAX_IMAGE_PIXELS = 120000000
17
+
18
+ class LightweightCompressionNet(nn.Module):
19
+ def __init__(self):
20
+ super().__init__()
21
+ self.conv_blocks = nn.Sequential(
22
+ nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(),
23
+ nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(),
24
+ nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(),
25
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(),
26
+ nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(),
27
+ nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(),
28
+ nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(),
29
+ nn.AdaptiveAvgPool2d(1)
30
+ )
31
+ self.head = nn.Sequential(
32
+ nn.Linear(256, 32), nn.GELU(),
33
+ nn.Linear(32, 4), nn.Sigmoid()
34
+ )
35
+
36
+ def forward(self, x):
37
+ features = self.conv_blocks(x)
38
+ features = features.view(features.size(0), -1)
39
+ return self.head(features)
40
+
41
+
42
+ class CompressionArtifactPredictor:
43
+ def __init__(self, model_path: str, device: str = "cuda"):
44
+ self.device = torch.device(device if torch.cuda.is_available() else "cpu")
45
+ self.model = LightweightCompressionNet().to(self.device)
46
+ self.model.eval()
47
+ checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
48
+ self.model.load_state_dict(checkpoint['model_state_dict'])
49
+ self.preprocess = transforms.Compose([transforms.ToTensor()])
50
+ self.compression_formats = ['jpeg', 'webp', 'avif', 'jxl']
51
+ self.quality_ranges = {'jpeg': (0, 100), 'webp': (0, 100), 'avif': (0, 100), 'jxl': (0, 100)}
52
+
53
+ def predict(self, image: Image.Image) -> Dict[str, Dict[str, float]]:
54
+ img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
55
+ with torch.no_grad():
56
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
57
+ predictions = self.model(img_tensor).squeeze(0).cpu().float().numpy()
58
+ results = {}
59
+ for i, fmt in enumerate(self.compression_formats):
60
+ normalized_score = float(predictions[i])
61
+ min_q, max_q = self.quality_ranges[fmt]
62
+ results[fmt] = {
63
+ 'normalized_score': normalized_score,
64
+ 'predicted_quality': normalized_score * (max_q - min_q) + min_q,
65
+ 'artifact_level': 1.0 - normalized_score
66
+ }
67
+ return results
68
+
69
+ def predict_format(self, image: Image.Image, format_name: str) -> float:
70
+ if format_name not in self.compression_formats:
71
+ raise ValueError(f"Unsupported format. Choose from: {self.compression_formats}")
72
+ return self.predict(image)[format_name]['predicted_quality']
73
+
74
+
75
+ if __name__ == "__main__":
76
+ predictor = CompressionArtifactPredictor("quality_factor_estimation.pt")
77
+
78
+ # Set your image path here!
79
+ image_path = Path("./demo_imgs/cat-q75.jpg")
80
+
81
+ image = Image.open(image_path).convert('RGB')
82
+
83
+ # This assumes that there isnt any format trickery or many different compressions, tried to keep it simple for first iteration
84
+ ext_map = {'.jpg': 'jpeg', '.jpeg': 'jpeg', '.webp': 'webp', '.avif': 'avif', '.jxl': 'jxl'}
85
+ fmt = ext_map.get(image_path.suffix.lower())
86
+ quality = predictor.predict_format(image, fmt)
87
+ print(f"{image_path.name} - estimated to be q={quality:.2f}")
quality_factor_estimation.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7450c4d027dbeb7e686eaf531f45719e245a5c1f89adddc7e1ae0c2d1b7b48f2
3
+ size 9386549