# app.py import warnings import gradio as gr import torch import torch.nn as nn from PIL import Image import torchvision.transforms as transforms import numpy as np from huggingface_hub import hf_hub_download # Suppress HF spaces warning (internal library, not our code) warnings.filterwarnings("ignore", category=FutureWarning, module="spaces") # ==================== MODEL DEFINITION ==================== class LightweightCompressionNet(nn.Module): def __init__(self): super().__init__() self.conv_blocks = nn.Sequential( nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(), nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(), nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(), nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(), nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(), nn.AdaptiveAvgPool2d(1) ) self.head = nn.Sequential( nn.Linear(256, 32), nn.GELU(), nn.Linear(32, 4), nn.Sigmoid() ) def forward(self, x): features = self.conv_blocks(x) features = features.view(features.size(0), -1) return self.head(features) # ==================== INFERENCE CLASS ==================== class CompressionArtifactPredictor: def __init__(self, device: str = "cuda"): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.model = LightweightCompressionNet().to(self.device) self.model.eval() # Load model from Hugging Face Hub model_path = hf_hub_download( repo_id="LoliRimuru/AAL-Plus_Image_Quality_Assessment", filename="model.pt" ) checkpoint = torch.load(model_path, map_location=self.device, weights_only=True) self.model.load_state_dict(checkpoint['model_state_dict']) # FIXED: Add padding->center crop to handle arbitrary sizes self.preprocess = transforms.Compose([ transforms.ToTensor(), transforms.Pad(512, padding_mode='edge'), # Pad smaller images transforms.CenterCrop(512), # Then crop to 512x512 ]) self.compression_formats = ['JPEG', 'WebP', 'AVIF', 'JXL'] self.accuracy_scores = { 'JPEG': 99.4, 'WebP': 97.0, 'AVIF': 97.1, 'JXL': 94.8 } def predict(self, image: Image.Image) -> dict: """Predict compression quality levels for all formats.""" img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) # FIXED: Full precision, no autocast with torch.no_grad(): predictions = self.model(img_tensor).squeeze(0).cpu().numpy() results = {} for i, fmt in enumerate(self.compression_formats): quality_score = float(predictions[i] * 100) if quality_score >= 90: category = "Excellent" color = "🟢" desc = "Minimal artifacts" elif quality_score >= 70: category = "Good" color = "🟡" desc = "Light artifacts" elif quality_score >= 50: category = "Fair" color = "🟠" desc = "Moderate artifacts" else: category = "Poor" color = "🔴" desc = "Heavy artifacts" results[fmt] = { 'quality_score': round(quality_score, 1), 'category': category, 'desc': desc, 'accuracy': self.accuracy_scores[fmt], 'indicator': color } return results # ==================== GRADIO UI ==================== def create_ui(): predictor = CompressionArtifactPredictor() def analyze_image(image): if image is None: return "", "Please upload an image." if isinstance(image, np.ndarray): image = Image.fromarray(image) image = image.convert('RGB') print(f"Processing image of size: {image.size}") # Debug log results = predictor.predict(image) # FIXED: Dark mode compatible using CSS variables html_results = """
| Format | Quality | Assessment | Accuracy |
|---|---|---|---|
| {data['indicator']} {fmt} | {data['quality_score']}/100 | {data['category']} {data['desc']} |
{data['accuracy']}% |