|
|
|
|
|
import warnings |
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from PIL import Image |
|
|
import torchvision.transforms as transforms |
|
|
import numpy as np |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning, module="spaces") |
|
|
|
|
|
|
|
|
class LightweightCompressionNet(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.conv_blocks = nn.Sequential( |
|
|
nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(), |
|
|
nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(), |
|
|
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(), |
|
|
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(), |
|
|
nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), |
|
|
nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), |
|
|
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(), |
|
|
nn.AdaptiveAvgPool2d(1) |
|
|
) |
|
|
self.head = nn.Sequential( |
|
|
nn.Linear(256, 32), nn.GELU(), |
|
|
nn.Linear(32, 4), nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
features = self.conv_blocks(x) |
|
|
features = features.view(features.size(0), -1) |
|
|
return self.head(features) |
|
|
|
|
|
|
|
|
class CompressionArtifactPredictor: |
|
|
def __init__(self, device: str = "cuda"): |
|
|
self.device = torch.device(device if torch.cuda.is_available() else "cpu") |
|
|
self.model = LightweightCompressionNet().to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
model_path = hf_hub_download( |
|
|
repo_id="LoliRimuru/AAL-Plus_Image_Quality_Assessment", |
|
|
filename="model.pt" |
|
|
) |
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=self.device, weights_only=True) |
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
|
|
|
self.preprocess = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Pad(512, padding_mode='edge'), |
|
|
transforms.CenterCrop(512), |
|
|
]) |
|
|
|
|
|
self.compression_formats = ['JPEG', 'WebP', 'AVIF', 'JXL'] |
|
|
self.accuracy_scores = { |
|
|
'JPEG': 99.4, |
|
|
'WebP': 97.0, |
|
|
'AVIF': 97.1, |
|
|
'JXL': 94.8 |
|
|
} |
|
|
|
|
|
def predict(self, image: Image.Image) -> dict: |
|
|
"""Predict compression quality levels for all formats.""" |
|
|
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
predictions = self.model(img_tensor).squeeze(0).cpu().numpy() |
|
|
|
|
|
results = {} |
|
|
for i, fmt in enumerate(self.compression_formats): |
|
|
quality_score = float(predictions[i] * 100) |
|
|
|
|
|
if quality_score >= 90: |
|
|
category = "Excellent" |
|
|
color = "π’" |
|
|
desc = "Minimal artifacts" |
|
|
elif quality_score >= 70: |
|
|
category = "Good" |
|
|
color = "π‘" |
|
|
desc = "Light artifacts" |
|
|
elif quality_score >= 50: |
|
|
category = "Fair" |
|
|
color = "π " |
|
|
desc = "Moderate artifacts" |
|
|
else: |
|
|
category = "Poor" |
|
|
color = "π΄" |
|
|
desc = "Heavy artifacts" |
|
|
|
|
|
results[fmt] = { |
|
|
'quality_score': round(quality_score, 1), |
|
|
'category': category, |
|
|
'desc': desc, |
|
|
'accuracy': self.accuracy_scores[fmt], |
|
|
'indicator': color |
|
|
} |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def create_ui(): |
|
|
predictor = CompressionArtifactPredictor() |
|
|
|
|
|
def analyze_image(image): |
|
|
if image is None: |
|
|
return "", "Please upload an image." |
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
image = Image.fromarray(image) |
|
|
|
|
|
image = image.convert('RGB') |
|
|
print(f"Processing image of size: {image.size}") |
|
|
|
|
|
results = predictor.predict(image) |
|
|
|
|
|
|
|
|
html_results = """ |
|
|
<table style='width:100%; border-collapse: collapse; font-family: inherit;'> |
|
|
<tr style='background: var(--block-label-background-fill, #f5f5f5);'> |
|
|
<th style='padding:12px; text-align:left; border-bottom: 2px solid var(--border-color-primary, #ddd);'>Format</th> |
|
|
<th style='padding:12px; text-align:center; border-bottom: 2px solid var(--border-color-primary, #ddd);'>Quality</th> |
|
|
<th style='padding:12px; text-align:center; border-bottom: 2px solid var(--border-color-primary, #ddd);'>Assessment</th> |
|
|
<th style='padding:12px; text-align:center; border-bottom: 2px solid var(--border-color-primary, #ddd);'>Accuracy</th> |
|
|
</tr> |
|
|
""" |
|
|
|
|
|
for fmt, data in results.items(): |
|
|
html_results += f""" |
|
|
<tr style='border-bottom: 1px solid var(--border-color-primary, #eee);'> |
|
|
<td style='padding:12px; font-weight:500;'>{data['indicator']} {fmt}</td> |
|
|
<td style='padding:12px; text-align:center;'><strong>{data['quality_score']}/100</strong></td> |
|
|
<td style='padding:12px; text-align:center;'>{data['category']}<br><small style='color: var(--body-text-color-subdued, #666);'>{data['desc']}</small></td> |
|
|
<td style='padding:12px; text-align:center;'>{data['accuracy']}%</td> |
|
|
</tr> |
|
|
""" |
|
|
html_results += "</table>" |
|
|
|
|
|
|
|
|
avg_quality = np.mean([r['quality_score'] for r in results.values()]) |
|
|
if avg_quality >= 85: |
|
|
overall_status = "β
**High Quality Image** - Minimal compression artifacts detected across all formats." |
|
|
elif avg_quality >= 65: |
|
|
overall_status = "β οΈ **Moderate Quality** - Some compression artifacts present, but image remains usable." |
|
|
else: |
|
|
overall_status = "β **Low Quality Image** - Significant compression artifacts detected." |
|
|
|
|
|
summary = f""" |
|
|
### Overall Assessment |
|
|
{overall_status} |
|
|
|
|
|
**Average Quality Score: {avg_quality:.1f}/100** |
|
|
""" |
|
|
|
|
|
return html_results, summary |
|
|
|
|
|
with gr.Blocks( |
|
|
title="AAL-Plus Image Quality Assessment", |
|
|
theme=gr.themes.Soft() |
|
|
) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# π― AAL-Plus Image Quality Assessment |
|
|
### Detect compression artifacts across multiple image formats (JPEG, WebP, AVIF, JXL) |
|
|
|
|
|
This lightweight model (~2M parameters, 8MB) predicts quality levels with **97.1% overall accuracy**. |
|
|
|
|
|
**How to interpret results:** |
|
|
- **Quality Score**: 0-100 scale (higher = better quality) |
|
|
- **Score Categories**: π’ 90-100 | π‘ 70-90 | π 50-70 | π΄ 0-50 |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
image_input = gr.Image( |
|
|
label="Upload Image", |
|
|
type="pil", |
|
|
height=400 |
|
|
) |
|
|
analyze_button = gr.Button("π Analyze Image Quality", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
results_output = gr.HTML( |
|
|
label="Format-Specific Quality Scores" |
|
|
) |
|
|
summary_output = gr.Markdown( |
|
|
label="Overall Assessment" |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
### π Model Performance |
|
|
| Format | Validation Accuracy | Quality Range | |
|
|
|--------|---------------------|---------------| |
|
|
| JPEG | 99.4% | 0-100 | |
|
|
| WebP | 97.0% | 0-100 | |
|
|
| AVIF | 97.1% | 0-100 | |
|
|
| JXL | 94.8% | 0-100 | |
|
|
|
|
|
*Accuracy measured as predictions within Β±5% of actual quality values* |
|
|
""" |
|
|
) |
|
|
|
|
|
analyze_button.click( |
|
|
fn=analyze_image, |
|
|
inputs=image_input, |
|
|
outputs=[results_output, summary_output] |
|
|
) |
|
|
|
|
|
image_input.change( |
|
|
fn=analyze_image, |
|
|
inputs=image_input, |
|
|
outputs=[results_output, summary_output] |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_ui() |
|
|
demo.launch() |