# inference.py import torch import torch.nn as nn from PIL import Image from torchvision import transforms from pathlib import Path import argparse import json from typing import Dict, Tuple # ==================== MODEL DEFINITION ==================== class LightweightCompressionNet(nn.Module): def __init__(self): super().__init__() self.conv_blocks = nn.Sequential( nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(), nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(), nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(), nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(), nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(), nn.AdaptiveAvgPool2d(1) ) self.head = nn.Sequential( nn.Linear(256, 32), nn.GELU(), nn.Linear(32, 4), nn.Sigmoid() ) def forward(self, x): features = self.conv_blocks(x) features = features.view(features.size(0), -1) return self.head(features) # ==================== INFERENCE PIPELINE ==================== class CompressionArtifactPredictor: def __init__(self, model_path: str, device: str = "cuda"): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.model = LightweightCompressionNet().to(self.device) self.model.eval() # Load checkpoint checkpoint = torch.load(model_path, map_location=self.device, weights_only=True) self.model.load_state_dict(checkpoint['model_state_dict']) # Define preprocessing self.preprocess = transforms.Compose([ transforms.ToTensor(), ]) self.compression_formats = ['jpeg', 'webp', 'avif', 'jxl'] self.quality_ranges = { 'jpeg': (0, 100), 'webp': (0, 100), 'avif': (0, 100), 'jxl': (0, 100) } def predict(self, image: Image.Image) -> Dict[str, Dict[str, float]]: """ Predict compression quality/artifact levels for all formats. Args: image: PIL Image in RGB mode Returns: Dictionary with predictions for each format """ # Preprocess img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) # Inference with torch.no_grad(): with torch.cuda.amp.autocast(dtype=torch.bfloat16): predictions = self.model(img_tensor).squeeze(0).cpu().float().numpy() # Format results results = {} for i, fmt in enumerate(self.compression_formats): normalized_score = float(predictions[i]) actual_quality = self._denormalize_quality(normalized_score, fmt) results[fmt] = { 'normalized_score': normalized_score, # 0.0 to 1.0 'predicted_quality': actual_quality, # Actual quality range 'artifact_level': 1.0 - normalized_score # Higher = more artifacts } return results def _denormalize_quality(self, normalized: float, fmt: str) -> float: """Convert normalized prediction back to original quality range""" min_q, max_q = self.quality_ranges[fmt] return normalized * (max_q - min_q) + min_q def predict_format(self, image: Image.Image, format_name: str) -> float: """Predict quality for a specific format only""" if format_name not in self.compression_formats: raise ValueError(f"Unsupported format. Choose from: {self.compression_formats}") results = self.predict(image) return results[format_name]['predicted_quality'] # ==================== MAIN ==================== def main(): # Initialize predictor predictor = CompressionArtifactPredictor("checkpoints/model.pt") # Load image image_path = Path("/path/to/image") if not image_path.exists(): raise FileNotFoundError(f"Image not found: {image_path}") image = Image.open(image_path).convert('RGB') print(f"\nšŸ” Analyzing image: {image_path}") print(f"šŸ“ Image size: {image.size[0]}x{image.size[1]}\n") # Run prediction results = predictor.predict(image) print("=" * 50) print("šŸ“Š COMPRESSION ARTIFACT ANALYSIS") print("=" * 50) for fmt, data in results.items(): print(f"\n{fmt.upper():>4}:") print(f" Predicted Quality: {data['predicted_quality']:>6.1f} / {predictor.quality_ranges[fmt][1]}") print(f" Normalized Score: {data['normalized_score']:>6.3f}") print(f" Artifact Level: {data['artifact_level']:>6.3f} (0.0=clean, 1.0=heavily compressed)") # Overall compression quality score avg_artifact_level = sum(r['artifact_level'] for r in results.values()) / len(results) print(f"\n{'=' * 50}") print(f"Overall artifact level: {avg_artifact_level:.3f}") if avg_artifact_level < 0.2: print("āœ… Image appears to have minimal compression artifacts") elif avg_artifact_level < 0.5: print("āš ļø Image shows moderate compression artifacts") else: print("āŒ Image exhibits heavy compression artifacts") if __name__ == "__main__": main()