Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
# app.py
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
|
@@ -7,6 +8,9 @@ import torchvision.transforms as transforms
|
|
| 7 |
import numpy as np
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
# ==================== MODEL DEFINITION ====================
|
| 11 |
class LightweightCompressionNet(nn.Module):
|
| 12 |
def __init__(self):
|
|
@@ -31,7 +35,6 @@ class LightweightCompressionNet(nn.Module):
|
|
| 31 |
features = features.view(features.size(0), -1)
|
| 32 |
return self.head(features)
|
| 33 |
|
| 34 |
-
|
| 35 |
# ==================== INFERENCE CLASS ====================
|
| 36 |
class CompressionArtifactPredictor:
|
| 37 |
def __init__(self, device: str = "cuda"):
|
|
@@ -48,8 +51,11 @@ class CompressionArtifactPredictor:
|
|
| 48 |
checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
|
| 49 |
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 50 |
|
|
|
|
| 51 |
self.preprocess = transforms.Compose([
|
| 52 |
transforms.ToTensor(),
|
|
|
|
|
|
|
| 53 |
])
|
| 54 |
|
| 55 |
self.compression_formats = ['JPEG', 'WebP', 'AVIF', 'JXL']
|
|
@@ -64,7 +70,7 @@ class CompressionArtifactPredictor:
|
|
| 64 |
"""Predict compression quality levels for all formats."""
|
| 65 |
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
|
| 66 |
|
| 67 |
-
#
|
| 68 |
with torch.no_grad():
|
| 69 |
predictions = self.model(img_tensor).squeeze(0).cpu().numpy()
|
| 70 |
|
|
@@ -99,7 +105,6 @@ class CompressionArtifactPredictor:
|
|
| 99 |
|
| 100 |
return results
|
| 101 |
|
| 102 |
-
|
| 103 |
# ==================== GRADIO UI ====================
|
| 104 |
def create_ui():
|
| 105 |
predictor = CompressionArtifactPredictor()
|
|
@@ -112,25 +117,27 @@ def create_ui():
|
|
| 112 |
image = Image.fromarray(image)
|
| 113 |
|
| 114 |
image = image.convert('RGB')
|
|
|
|
|
|
|
| 115 |
results = predictor.predict(image)
|
| 116 |
|
| 117 |
-
#
|
| 118 |
html_results = """
|
| 119 |
<table style='width:100%; border-collapse: collapse; font-family: inherit;'>
|
| 120 |
-
<tr style='background: #f5f5f5;'>
|
| 121 |
-
<th style='padding:12px; text-align:left; border-bottom: 2px solid #ddd;'>Format</th>
|
| 122 |
-
<th style='padding:12px; text-align:center; border-bottom: 2px solid #ddd;'>Quality</th>
|
| 123 |
-
<th style='padding:12px; text-align:center; border-bottom: 2px solid #ddd;'>Assessment</th>
|
| 124 |
-
<th style='padding:12px; text-align:center; border-bottom: 2px solid #ddd;'>Accuracy</th>
|
| 125 |
</tr>
|
| 126 |
"""
|
| 127 |
|
| 128 |
for fmt, data in results.items():
|
| 129 |
html_results += f"""
|
| 130 |
-
<tr style='border-bottom: 1px solid #eee;'>
|
| 131 |
<td style='padding:12px; font-weight:500;'>{data['indicator']} {fmt}</td>
|
| 132 |
<td style='padding:12px; text-align:center;'><strong>{data['quality_score']}/100</strong></td>
|
| 133 |
-
<td style='padding:12px; text-align:center;'>{data['category']}<br><small style='color
|
| 134 |
<td style='padding:12px; text-align:center;'>{data['accuracy']}%</td>
|
| 135 |
</tr>
|
| 136 |
"""
|
|
@@ -181,7 +188,6 @@ def create_ui():
|
|
| 181 |
analyze_button = gr.Button("π Analyze Image Quality", variant="primary", size="lg")
|
| 182 |
|
| 183 |
with gr.Column():
|
| 184 |
-
# USE HTML COMPONENT INSTEAD OF LABEL
|
| 185 |
results_output = gr.HTML(
|
| 186 |
label="Format-Specific Quality Scores"
|
| 187 |
)
|
|
|
|
| 1 |
# app.py
|
| 2 |
+
import warnings
|
| 3 |
import gradio as gr
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
from huggingface_hub import hf_hub_download
|
| 10 |
|
| 11 |
+
# Suppress HF spaces warning (internal library, not our code)
|
| 12 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module="spaces")
|
| 13 |
+
|
| 14 |
# ==================== MODEL DEFINITION ====================
|
| 15 |
class LightweightCompressionNet(nn.Module):
|
| 16 |
def __init__(self):
|
|
|
|
| 35 |
features = features.view(features.size(0), -1)
|
| 36 |
return self.head(features)
|
| 37 |
|
|
|
|
| 38 |
# ==================== INFERENCE CLASS ====================
|
| 39 |
class CompressionArtifactPredictor:
|
| 40 |
def __init__(self, device: str = "cuda"):
|
|
|
|
| 51 |
checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
|
| 52 |
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 53 |
|
| 54 |
+
# FIXED: Add padding->center crop to handle arbitrary sizes
|
| 55 |
self.preprocess = transforms.Compose([
|
| 56 |
transforms.ToTensor(),
|
| 57 |
+
transforms.Pad(512, padding_mode='edge'), # Pad smaller images
|
| 58 |
+
transforms.CenterCrop(512), # Then crop to 512x512
|
| 59 |
])
|
| 60 |
|
| 61 |
self.compression_formats = ['JPEG', 'WebP', 'AVIF', 'JXL']
|
|
|
|
| 70 |
"""Predict compression quality levels for all formats."""
|
| 71 |
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
|
| 72 |
|
| 73 |
+
# FIXED: Full precision, no autocast
|
| 74 |
with torch.no_grad():
|
| 75 |
predictions = self.model(img_tensor).squeeze(0).cpu().numpy()
|
| 76 |
|
|
|
|
| 105 |
|
| 106 |
return results
|
| 107 |
|
|
|
|
| 108 |
# ==================== GRADIO UI ====================
|
| 109 |
def create_ui():
|
| 110 |
predictor = CompressionArtifactPredictor()
|
|
|
|
| 117 |
image = Image.fromarray(image)
|
| 118 |
|
| 119 |
image = image.convert('RGB')
|
| 120 |
+
print(f"Processing image of size: {image.size}") # Debug log
|
| 121 |
+
|
| 122 |
results = predictor.predict(image)
|
| 123 |
|
| 124 |
+
# FIXED: Dark mode compatible using CSS variables
|
| 125 |
html_results = """
|
| 126 |
<table style='width:100%; border-collapse: collapse; font-family: inherit;'>
|
| 127 |
+
<tr style='background: var(--block-label-background-fill, #f5f5f5);'>
|
| 128 |
+
<th style='padding:12px; text-align:left; border-bottom: 2px solid var(--border-color-primary, #ddd);'>Format</th>
|
| 129 |
+
<th style='padding:12px; text-align:center; border-bottom: 2px solid var(--border-color-primary, #ddd);'>Quality</th>
|
| 130 |
+
<th style='padding:12px; text-align:center; border-bottom: 2px solid var(--border-color-primary, #ddd);'>Assessment</th>
|
| 131 |
+
<th style='padding:12px; text-align:center; border-bottom: 2px solid var(--border-color-primary, #ddd);'>Accuracy</th>
|
| 132 |
</tr>
|
| 133 |
"""
|
| 134 |
|
| 135 |
for fmt, data in results.items():
|
| 136 |
html_results += f"""
|
| 137 |
+
<tr style='border-bottom: 1px solid var(--border-color-primary, #eee);'>
|
| 138 |
<td style='padding:12px; font-weight:500;'>{data['indicator']} {fmt}</td>
|
| 139 |
<td style='padding:12px; text-align:center;'><strong>{data['quality_score']}/100</strong></td>
|
| 140 |
+
<td style='padding:12px; text-align:center;'>{data['category']}<br><small style='color: var(--body-text-color-subdued, #666);'>{data['desc']}</small></td>
|
| 141 |
<td style='padding:12px; text-align:center;'>{data['accuracy']}%</td>
|
| 142 |
</tr>
|
| 143 |
"""
|
|
|
|
| 188 |
analyze_button = gr.Button("π Analyze Image Quality", variant="primary", size="lg")
|
| 189 |
|
| 190 |
with gr.Column():
|
|
|
|
| 191 |
results_output = gr.HTML(
|
| 192 |
label="Format-Specific Quality Scores"
|
| 193 |
)
|