File size: 5,450 Bytes
db95d37 3911759 db95d37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
# inference.py
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from pathlib import Path
import argparse
import json
from typing import Dict, Tuple
# ==================== MODEL DEFINITION ====================
class LightweightCompressionNet(nn.Module):
def __init__(self):
super().__init__()
self.conv_blocks = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(),
nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(),
nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(),
nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(),
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(),
nn.AdaptiveAvgPool2d(1)
)
self.head = nn.Sequential(
nn.Linear(256, 32), nn.GELU(),
nn.Linear(32, 4), nn.Sigmoid()
)
def forward(self, x):
features = self.conv_blocks(x)
features = features.view(features.size(0), -1)
return self.head(features)
# ==================== INFERENCE PIPELINE ====================
class CompressionArtifactPredictor:
def __init__(self, model_path: str, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.model = LightweightCompressionNet().to(self.device)
self.model.eval()
# Load checkpoint
checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
self.model.load_state_dict(checkpoint['model_state_dict'])
# Define preprocessing
self.preprocess = transforms.Compose([
transforms.ToTensor(),
])
self.compression_formats = ['jpeg', 'webp', 'avif', 'jxl']
self.quality_ranges = {
'jpeg': (0, 100),
'webp': (0, 100),
'avif': (0, 100),
'jxl': (0, 100)
}
def predict(self, image: Image.Image) -> Dict[str, Dict[str, float]]:
"""
Predict compression quality/artifact levels for all formats.
Args:
image: PIL Image in RGB mode
Returns:
Dictionary with predictions for each format
"""
# Preprocess
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# Inference
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
predictions = self.model(img_tensor).squeeze(0).cpu().float().numpy()
# Format results
results = {}
for i, fmt in enumerate(self.compression_formats):
normalized_score = float(predictions[i])
actual_quality = self._denormalize_quality(normalized_score, fmt)
results[fmt] = {
'normalized_score': normalized_score, # 0.0 to 1.0
'predicted_quality': actual_quality, # Actual quality range
'artifact_level': 1.0 - normalized_score # Higher = more artifacts
}
return results
def _denormalize_quality(self, normalized: float, fmt: str) -> float:
"""Convert normalized prediction back to original quality range"""
min_q, max_q = self.quality_ranges[fmt]
return normalized * (max_q - min_q) + min_q
def predict_format(self, image: Image.Image, format_name: str) -> float:
"""Predict quality for a specific format only"""
if format_name not in self.compression_formats:
raise ValueError(f"Unsupported format. Choose from: {self.compression_formats}")
results = self.predict(image)
return results[format_name]['predicted_quality']
# ==================== MAIN ====================
def main():
# Initialize predictor
predictor = CompressionArtifactPredictor("checkpoints/model.pt")
# Load image
image_path = Path("/path/to/image")
if not image_path.exists():
raise FileNotFoundError(f"Image not found: {image_path}")
image = Image.open(image_path).convert('RGB')
print(f"\n๐ Analyzing image: {image_path}")
print(f"๐ Image size: {image.size[0]}x{image.size[1]}\n")
# Run prediction
results = predictor.predict(image)
print("=" * 50)
print("๐ COMPRESSION ARTIFACT ANALYSIS")
print("=" * 50)
for fmt, data in results.items():
print(f"\n{fmt.upper():>4}:")
print(f" Predicted Quality: {data['predicted_quality']:>6.1f} / {predictor.quality_ranges[fmt][1]}")
print(f" Normalized Score: {data['normalized_score']:>6.3f}")
print(f" Artifact Level: {data['artifact_level']:>6.3f} (0.0=clean, 1.0=heavily compressed)")
# Overall compression quality score
avg_artifact_level = sum(r['artifact_level'] for r in results.values()) / len(results)
print(f"\n{'=' * 50}")
print(f"Overall artifact level: {avg_artifact_level:.3f}")
if avg_artifact_level < 0.2:
print("โ
Image appears to have minimal compression artifacts")
elif avg_artifact_level < 0.5:
print("โ ๏ธ Image shows moderate compression artifacts")
else:
print("โ Image exhibits heavy compression artifacts")
if __name__ == "__main__":
main() |