LoliRimuru's picture
v1 fix
3911759 verified
# inference.py
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from pathlib import Path
import argparse
import json
from typing import Dict, Tuple
# ==================== MODEL DEFINITION ====================
class LightweightCompressionNet(nn.Module):
def __init__(self):
super().__init__()
self.conv_blocks = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(),
nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(),
nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(),
nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(),
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(),
nn.AdaptiveAvgPool2d(1)
)
self.head = nn.Sequential(
nn.Linear(256, 32), nn.GELU(),
nn.Linear(32, 4), nn.Sigmoid()
)
def forward(self, x):
features = self.conv_blocks(x)
features = features.view(features.size(0), -1)
return self.head(features)
# ==================== INFERENCE PIPELINE ====================
class CompressionArtifactPredictor:
def __init__(self, model_path: str, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.model = LightweightCompressionNet().to(self.device)
self.model.eval()
# Load checkpoint
checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
self.model.load_state_dict(checkpoint['model_state_dict'])
# Define preprocessing
self.preprocess = transforms.Compose([
transforms.ToTensor(),
])
self.compression_formats = ['jpeg', 'webp', 'avif', 'jxl']
self.quality_ranges = {
'jpeg': (0, 100),
'webp': (0, 100),
'avif': (0, 100),
'jxl': (0, 100)
}
def predict(self, image: Image.Image) -> Dict[str, Dict[str, float]]:
"""
Predict compression quality/artifact levels for all formats.
Args:
image: PIL Image in RGB mode
Returns:
Dictionary with predictions for each format
"""
# Preprocess
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# Inference
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
predictions = self.model(img_tensor).squeeze(0).cpu().float().numpy()
# Format results
results = {}
for i, fmt in enumerate(self.compression_formats):
normalized_score = float(predictions[i])
actual_quality = self._denormalize_quality(normalized_score, fmt)
results[fmt] = {
'normalized_score': normalized_score, # 0.0 to 1.0
'predicted_quality': actual_quality, # Actual quality range
'artifact_level': 1.0 - normalized_score # Higher = more artifacts
}
return results
def _denormalize_quality(self, normalized: float, fmt: str) -> float:
"""Convert normalized prediction back to original quality range"""
min_q, max_q = self.quality_ranges[fmt]
return normalized * (max_q - min_q) + min_q
def predict_format(self, image: Image.Image, format_name: str) -> float:
"""Predict quality for a specific format only"""
if format_name not in self.compression_formats:
raise ValueError(f"Unsupported format. Choose from: {self.compression_formats}")
results = self.predict(image)
return results[format_name]['predicted_quality']
# ==================== MAIN ====================
def main():
# Initialize predictor
predictor = CompressionArtifactPredictor("checkpoints/model.pt")
# Load image
image_path = Path("/path/to/image")
if not image_path.exists():
raise FileNotFoundError(f"Image not found: {image_path}")
image = Image.open(image_path).convert('RGB')
print(f"\nπŸ” Analyzing image: {image_path}")
print(f"πŸ“ Image size: {image.size[0]}x{image.size[1]}\n")
# Run prediction
results = predictor.predict(image)
print("=" * 50)
print("πŸ“Š COMPRESSION ARTIFACT ANALYSIS")
print("=" * 50)
for fmt, data in results.items():
print(f"\n{fmt.upper():>4}:")
print(f" Predicted Quality: {data['predicted_quality']:>6.1f} / {predictor.quality_ranges[fmt][1]}")
print(f" Normalized Score: {data['normalized_score']:>6.3f}")
print(f" Artifact Level: {data['artifact_level']:>6.3f} (0.0=clean, 1.0=heavily compressed)")
# Overall compression quality score
avg_artifact_level = sum(r['artifact_level'] for r in results.values()) / len(results)
print(f"\n{'=' * 50}")
print(f"Overall artifact level: {avg_artifact_level:.3f}")
if avg_artifact_level < 0.2:
print("βœ… Image appears to have minimal compression artifacts")
elif avg_artifact_level < 0.5:
print("⚠️ Image shows moderate compression artifacts")
else:
print("❌ Image exhibits heavy compression artifacts")
if __name__ == "__main__":
main()