LoliRimuru's picture
Update app.py
9361d66 verified
# app.py
import warnings
import gradio as gr
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
from huggingface_hub import hf_hub_download
# Suppress HF spaces warning (internal library, not our code)
warnings.filterwarnings("ignore", category=FutureWarning, module="spaces")
# ==================== MODEL DEFINITION ====================
class LightweightCompressionNet(nn.Module):
def __init__(self):
super().__init__()
self.conv_blocks = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(),
nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(),
nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(),
nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(),
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(),
nn.AdaptiveAvgPool2d(1)
)
self.head = nn.Sequential(
nn.Linear(256, 32), nn.GELU(),
nn.Linear(32, 4), nn.Sigmoid()
)
def forward(self, x):
features = self.conv_blocks(x)
features = features.view(features.size(0), -1)
return self.head(features)
# ==================== INFERENCE CLASS ====================
class CompressionArtifactPredictor:
def __init__(self, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.model = LightweightCompressionNet().to(self.device)
self.model.eval()
# Load model from Hugging Face Hub
model_path = hf_hub_download(
repo_id="LoliRimuru/AAL-Plus_Image_Quality_Assessment",
filename="model.pt"
)
checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
self.model.load_state_dict(checkpoint['model_state_dict'])
# FIXED: Add padding->center crop to handle arbitrary sizes
self.preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Pad(512, padding_mode='edge'), # Pad smaller images
transforms.CenterCrop(512), # Then crop to 512x512
])
self.compression_formats = ['JPEG', 'WebP', 'AVIF', 'JXL']
self.accuracy_scores = {
'JPEG': 99.4,
'WebP': 97.0,
'AVIF': 97.1,
'JXL': 94.8
}
def predict(self, image: Image.Image) -> dict:
"""Predict compression quality levels for all formats."""
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# FIXED: Full precision, no autocast
with torch.no_grad():
predictions = self.model(img_tensor).squeeze(0).cpu().numpy()
results = {}
for i, fmt in enumerate(self.compression_formats):
quality_score = float(predictions[i] * 100)
if quality_score >= 90:
category = "Excellent"
color = "🟒"
desc = "Minimal artifacts"
elif quality_score >= 70:
category = "Good"
color = "🟑"
desc = "Light artifacts"
elif quality_score >= 50:
category = "Fair"
color = "🟠"
desc = "Moderate artifacts"
else:
category = "Poor"
color = "πŸ”΄"
desc = "Heavy artifacts"
results[fmt] = {
'quality_score': round(quality_score, 1),
'category': category,
'desc': desc,
'accuracy': self.accuracy_scores[fmt],
'indicator': color
}
return results
# ==================== GRADIO UI ====================
def create_ui():
predictor = CompressionArtifactPredictor()
def analyze_image(image):
if image is None:
return "", "Please upload an image."
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = image.convert('RGB')
print(f"Processing image of size: {image.size}") # Debug log
results = predictor.predict(image)
# FIXED: Dark mode compatible using CSS variables
html_results = """
<table style='width:100%; border-collapse: collapse; font-family: inherit;'>
<tr style='background: var(--block-label-background-fill, #f5f5f5);'>
<th style='padding:12px; text-align:left; border-bottom: 2px solid var(--border-color-primary, #ddd);'>Format</th>
<th style='padding:12px; text-align:center; border-bottom: 2px solid var(--border-color-primary, #ddd);'>Quality</th>
<th style='padding:12px; text-align:center; border-bottom: 2px solid var(--border-color-primary, #ddd);'>Assessment</th>
<th style='padding:12px; text-align:center; border-bottom: 2px solid var(--border-color-primary, #ddd);'>Accuracy</th>
</tr>
"""
for fmt, data in results.items():
html_results += f"""
<tr style='border-bottom: 1px solid var(--border-color-primary, #eee);'>
<td style='padding:12px; font-weight:500;'>{data['indicator']} {fmt}</td>
<td style='padding:12px; text-align:center;'><strong>{data['quality_score']}/100</strong></td>
<td style='padding:12px; text-align:center;'>{data['category']}<br><small style='color: var(--body-text-color-subdued, #666);'>{data['desc']}</small></td>
<td style='padding:12px; text-align:center;'>{data['accuracy']}%</td>
</tr>
"""
html_results += "</table>"
# Overall summary
avg_quality = np.mean([r['quality_score'] for r in results.values()])
if avg_quality >= 85:
overall_status = "βœ… **High Quality Image** - Minimal compression artifacts detected across all formats."
elif avg_quality >= 65:
overall_status = "⚠️ **Moderate Quality** - Some compression artifacts present, but image remains usable."
else:
overall_status = "❌ **Low Quality Image** - Significant compression artifacts detected."
summary = f"""
### Overall Assessment
{overall_status}
**Average Quality Score: {avg_quality:.1f}/100**
"""
return html_results, summary
with gr.Blocks(
title="AAL-Plus Image Quality Assessment",
theme=gr.themes.Soft()
) as demo:
gr.Markdown(
"""
# 🎯 AAL-Plus Image Quality Assessment
### Detect compression artifacts across multiple image formats (JPEG, WebP, AVIF, JXL)
This lightweight model (~2M parameters, 8MB) predicts quality levels with **97.1% overall accuracy**.
**How to interpret results:**
- **Quality Score**: 0-100 scale (higher = better quality)
- **Score Categories**: 🟒 90-100 | 🟑 70-90 | 🟠 50-70 | πŸ”΄ 0-50
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(
label="Upload Image",
type="pil",
height=400
)
analyze_button = gr.Button("πŸ” Analyze Image Quality", variant="primary", size="lg")
with gr.Column():
results_output = gr.HTML(
label="Format-Specific Quality Scores"
)
summary_output = gr.Markdown(
label="Overall Assessment"
)
gr.Markdown(
"""
---
### πŸ“Š Model Performance
| Format | Validation Accuracy | Quality Range |
|--------|---------------------|---------------|
| JPEG | 99.4% | 0-100 |
| WebP | 97.0% | 0-100 |
| AVIF | 97.1% | 0-100 |
| JXL | 94.8% | 0-100 |
*Accuracy measured as predictions within Β±5% of actual quality values*
"""
)
analyze_button.click(
fn=analyze_image,
inputs=image_input,
outputs=[results_output, summary_output]
)
image_input.change(
fn=analyze_image,
inputs=image_input,
outputs=[results_output, summary_output]
)
return demo
if __name__ == "__main__":
demo = create_ui()
demo.launch()