# app.py import warnings import gradio as gr import torch import torch.nn as nn from PIL import Image import torchvision.transforms as transforms import numpy as np from huggingface_hub import hf_hub_download # Suppress HF spaces warning (internal library, not our code) warnings.filterwarnings("ignore", category=FutureWarning, module="spaces") # ==================== MODEL DEFINITION ==================== class LightweightCompressionNet(nn.Module): def __init__(self): super().__init__() self.conv_blocks = nn.Sequential( nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(), nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(), nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(), nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(), nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(), nn.AdaptiveAvgPool2d(1) ) self.head = nn.Sequential( nn.Linear(256, 32), nn.GELU(), nn.Linear(32, 4), nn.Sigmoid() ) def forward(self, x): features = self.conv_blocks(x) features = features.view(features.size(0), -1) return self.head(features) # ==================== INFERENCE CLASS ==================== class CompressionArtifactPredictor: def __init__(self, device: str = "cuda"): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.model = LightweightCompressionNet().to(self.device) self.model.eval() # Load model from Hugging Face Hub model_path = hf_hub_download( repo_id="LoliRimuru/AAL-Plus_Image_Quality_Assessment", filename="model.pt" ) checkpoint = torch.load(model_path, map_location=self.device, weights_only=True) self.model.load_state_dict(checkpoint['model_state_dict']) # FIXED: Add padding->center crop to handle arbitrary sizes self.preprocess = transforms.Compose([ transforms.ToTensor(), transforms.Pad(512, padding_mode='edge'), # Pad smaller images transforms.CenterCrop(512), # Then crop to 512x512 ]) self.compression_formats = ['JPEG', 'WebP', 'AVIF', 'JXL'] self.accuracy_scores = { 'JPEG': 99.4, 'WebP': 97.0, 'AVIF': 97.1, 'JXL': 94.8 } def predict(self, image: Image.Image) -> dict: """Predict compression quality levels for all formats.""" img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) # FIXED: Full precision, no autocast with torch.no_grad(): predictions = self.model(img_tensor).squeeze(0).cpu().numpy() results = {} for i, fmt in enumerate(self.compression_formats): quality_score = float(predictions[i] * 100) if quality_score >= 90: category = "Excellent" color = "🟢" desc = "Minimal artifacts" elif quality_score >= 70: category = "Good" color = "🟡" desc = "Light artifacts" elif quality_score >= 50: category = "Fair" color = "🟠" desc = "Moderate artifacts" else: category = "Poor" color = "🔴" desc = "Heavy artifacts" results[fmt] = { 'quality_score': round(quality_score, 1), 'category': category, 'desc': desc, 'accuracy': self.accuracy_scores[fmt], 'indicator': color } return results # ==================== GRADIO UI ==================== def create_ui(): predictor = CompressionArtifactPredictor() def analyze_image(image): if image is None: return "", "Please upload an image." if isinstance(image, np.ndarray): image = Image.fromarray(image) image = image.convert('RGB') print(f"Processing image of size: {image.size}") # Debug log results = predictor.predict(image) # FIXED: Dark mode compatible using CSS variables html_results = """ """ for fmt, data in results.items(): html_results += f""" """ html_results += "
Format Quality Assessment Accuracy
{data['indicator']} {fmt} {data['quality_score']}/100 {data['category']}
{data['desc']}
{data['accuracy']}%
" # Overall summary avg_quality = np.mean([r['quality_score'] for r in results.values()]) if avg_quality >= 85: overall_status = "✅ **High Quality Image** - Minimal compression artifacts detected across all formats." elif avg_quality >= 65: overall_status = "⚠️ **Moderate Quality** - Some compression artifacts present, but image remains usable." else: overall_status = "❌ **Low Quality Image** - Significant compression artifacts detected." summary = f""" ### Overall Assessment {overall_status} **Average Quality Score: {avg_quality:.1f}/100** """ return html_results, summary with gr.Blocks( title="AAL-Plus Image Quality Assessment", theme=gr.themes.Soft() ) as demo: gr.Markdown( """ # 🎯 AAL-Plus Image Quality Assessment ### Detect compression artifacts across multiple image formats (JPEG, WebP, AVIF, JXL) This lightweight model (~2M parameters, 8MB) predicts quality levels with **97.1% overall accuracy**. **How to interpret results:** - **Quality Score**: 0-100 scale (higher = better quality) - **Score Categories**: 🟢 90-100 | 🟡 70-90 | 🟠 50-70 | 🔴 0-50 """ ) with gr.Row(): with gr.Column(): image_input = gr.Image( label="Upload Image", type="pil", height=400 ) analyze_button = gr.Button("🔍 Analyze Image Quality", variant="primary", size="lg") with gr.Column(): results_output = gr.HTML( label="Format-Specific Quality Scores" ) summary_output = gr.Markdown( label="Overall Assessment" ) gr.Markdown( """ --- ### 📊 Model Performance | Format | Validation Accuracy | Quality Range | |--------|---------------------|---------------| | JPEG | 99.4% | 0-100 | | WebP | 97.0% | 0-100 | | AVIF | 97.1% | 0-100 | | JXL | 94.8% | 0-100 | *Accuracy measured as predictions within ±5% of actual quality values* """ ) analyze_button.click( fn=analyze_image, inputs=image_input, outputs=[results_output, summary_output] ) image_input.change( fn=analyze_image, inputs=image_input, outputs=[results_output, summary_output] ) return demo if __name__ == "__main__": demo = create_ui() demo.launch()