|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
from pathlib import Path |
|
|
import argparse |
|
|
import json |
|
|
from typing import Dict, Tuple |
|
|
|
|
|
|
|
|
|
|
|
class LightweightCompressionNet(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.conv_blocks = nn.Sequential( |
|
|
nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(), |
|
|
nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(), |
|
|
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(), |
|
|
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(), |
|
|
nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), |
|
|
nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), |
|
|
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(), |
|
|
nn.AdaptiveAvgPool2d(1) |
|
|
) |
|
|
self.head = nn.Sequential( |
|
|
nn.Linear(256, 32), nn.GELU(), |
|
|
nn.Linear(32, 4), nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
features = self.conv_blocks(x) |
|
|
features = features.view(features.size(0), -1) |
|
|
return self.head(features) |
|
|
|
|
|
|
|
|
|
|
|
class CompressionArtifactPredictor: |
|
|
def __init__(self, model_path: str, device: str = "cuda"): |
|
|
self.device = torch.device(device if torch.cuda.is_available() else "cpu") |
|
|
self.model = LightweightCompressionNet().to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=self.device, weights_only=True) |
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
|
|
|
self.preprocess = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
]) |
|
|
|
|
|
self.compression_formats = ['jpeg', 'webp', 'avif', 'jxl'] |
|
|
self.quality_ranges = { |
|
|
'jpeg': (0, 100), |
|
|
'webp': (0, 100), |
|
|
'avif': (0, 100), |
|
|
'jxl': (0, 100) |
|
|
} |
|
|
|
|
|
def predict(self, image: Image.Image) -> Dict[str, Dict[str, float]]: |
|
|
""" |
|
|
Predict compression quality/artifact levels for all formats. |
|
|
|
|
|
Args: |
|
|
image: PIL Image in RGB mode |
|
|
|
|
|
Returns: |
|
|
Dictionary with predictions for each format |
|
|
""" |
|
|
|
|
|
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
|
|
predictions = self.model(img_tensor).squeeze(0).cpu().float().numpy() |
|
|
|
|
|
|
|
|
results = {} |
|
|
for i, fmt in enumerate(self.compression_formats): |
|
|
normalized_score = float(predictions[i]) |
|
|
actual_quality = self._denormalize_quality(normalized_score, fmt) |
|
|
|
|
|
results[fmt] = { |
|
|
'normalized_score': normalized_score, |
|
|
'predicted_quality': actual_quality, |
|
|
'artifact_level': 1.0 - normalized_score |
|
|
} |
|
|
|
|
|
return results |
|
|
|
|
|
def _denormalize_quality(self, normalized: float, fmt: str) -> float: |
|
|
"""Convert normalized prediction back to original quality range""" |
|
|
min_q, max_q = self.quality_ranges[fmt] |
|
|
return normalized * (max_q - min_q) + min_q |
|
|
|
|
|
def predict_format(self, image: Image.Image, format_name: str) -> float: |
|
|
"""Predict quality for a specific format only""" |
|
|
if format_name not in self.compression_formats: |
|
|
raise ValueError(f"Unsupported format. Choose from: {self.compression_formats}") |
|
|
|
|
|
results = self.predict(image) |
|
|
return results[format_name]['predicted_quality'] |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
predictor = CompressionArtifactPredictor("checkpoints/model.pt") |
|
|
|
|
|
|
|
|
image_path = Path("/path/to/image") |
|
|
if not image_path.exists(): |
|
|
raise FileNotFoundError(f"Image not found: {image_path}") |
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
print(f"\nπ Analyzing image: {image_path}") |
|
|
print(f"π Image size: {image.size[0]}x{image.size[1]}\n") |
|
|
|
|
|
|
|
|
|
|
|
results = predictor.predict(image) |
|
|
|
|
|
print("=" * 50) |
|
|
print("π COMPRESSION ARTIFACT ANALYSIS") |
|
|
print("=" * 50) |
|
|
|
|
|
for fmt, data in results.items(): |
|
|
print(f"\n{fmt.upper():>4}:") |
|
|
print(f" Predicted Quality: {data['predicted_quality']:>6.1f} / {predictor.quality_ranges[fmt][1]}") |
|
|
print(f" Normalized Score: {data['normalized_score']:>6.3f}") |
|
|
print(f" Artifact Level: {data['artifact_level']:>6.3f} (0.0=clean, 1.0=heavily compressed)") |
|
|
|
|
|
|
|
|
avg_artifact_level = sum(r['artifact_level'] for r in results.values()) / len(results) |
|
|
print(f"\n{'=' * 50}") |
|
|
print(f"Overall artifact level: {avg_artifact_level:.3f}") |
|
|
if avg_artifact_level < 0.2: |
|
|
print("β
Image appears to have minimal compression artifacts") |
|
|
elif avg_artifact_level < 0.5: |
|
|
print("β οΈ Image shows moderate compression artifacts") |
|
|
else: |
|
|
print("β Image exhibits heavy compression artifacts") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |