LoliRimuru commited on
Commit
4bd8e01
Β·
verified Β·
1 Parent(s): ca31f29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -48
app.py CHANGED
@@ -4,7 +4,6 @@ import torch
4
  import torch.nn as nn
5
  from PIL import Image
6
  import torchvision.transforms as transforms
7
- from pathlib import Path
8
  import numpy as np
9
  from huggingface_hub import hf_hub_download
10
 
@@ -49,7 +48,6 @@ class CompressionArtifactPredictor:
49
  checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
50
  self.model.load_state_dict(checkpoint['model_state_dict'])
51
 
52
- # Define preprocessing
53
  self.preprocess = transforms.Compose([
54
  transforms.ToTensor(),
55
  ])
@@ -63,21 +61,17 @@ class CompressionArtifactPredictor:
63
  }
64
 
65
  def predict(self, image: Image.Image) -> dict:
66
- """Predict compression quality levels for all formats with explanations."""
67
- # Preprocess
68
  img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
69
 
70
- # Inference
71
  with torch.no_grad():
72
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
73
  predictions = self.model(img_tensor).squeeze(0).cpu().float().numpy()
74
 
75
- # Format results
76
  results = {}
77
  for i, fmt in enumerate(self.compression_formats):
78
- quality_score = float(predictions[i] * 100) # Convert to 0-100 range
79
 
80
- # Determine quality category
81
  if quality_score >= 90:
82
  category = "Excellent (Minimal artifacts)"
83
  color = "🟒"
@@ -109,17 +103,12 @@ def create_ui():
109
  if image is None:
110
  return None, "Please upload an image."
111
 
112
- # Convert numpy array to PIL Image if needed
113
  if isinstance(image, np.ndarray):
114
  image = Image.fromarray(image)
115
 
116
- # Ensure RGB
117
  image = image.convert('RGB')
118
-
119
- # Get predictions
120
  results = predictor.predict(image)
121
 
122
- # Format output as a nice dictionary for Gradio
123
  formatted_results = {}
124
  for fmt, data in results.items():
125
  formatted_results[f"{data['indicator']} {fmt}"] = {
@@ -128,21 +117,18 @@ def create_ui():
128
  "Model Accuracy": f"{data['accuracy']}%"
129
  }
130
 
131
- # Calculate overall score
132
  avg_quality = np.mean([r['quality_score'] for r in results.values()])
133
  if avg_quality >= 85:
134
- overall_status = "βœ… **High Quality Image** - Minimal compression artifacts detected across all formats."
135
  elif avg_quality >= 65:
136
- overall_status = "⚠️ **Moderate Quality** - Some compression artifacts present, but image remains usable."
137
  else:
138
  overall_status = "❌ **Low Quality Image** - Significant compression artifacts detected."
139
 
140
- # Add overall summary
141
  summary = f"### Overall Assessment\n{overall_status}\n\n**Average Quality Score: {avg_quality:.1f}/100**"
142
 
143
  return formatted_results, summary
144
 
145
- # Create Gradio interface
146
  with gr.Blocks(
147
  title="AAL-Plus Image Quality Assessment",
148
  theme=gr.themes.Soft()
@@ -150,19 +136,13 @@ def create_ui():
150
  gr.Markdown(
151
  """
152
  # 🎯 AAL-Plus Image Quality Assessment
153
- ### Detect compression artifacts across multiple formats (JPEG, WebP, AVIF, JXL)
154
 
155
- This lightweight model (~2M parameters) predicts the quality level of your image and identifies
156
- compression artifacts with **97.1% overall accuracy**.
157
 
158
  **How to interpret results:**
159
- - **Quality Score**: 0-100 scale (higher = better quality, fewer artifacts)
160
- - **Assessment**: Text description of artifact level
161
- - **Color Indicators**:
162
- - 🟒 Green = Excellent (90-100)
163
- - 🟑 Yellow = Good (70-90)
164
- - 🟠 Orange = Fair (50-70)
165
- - πŸ”΄ Red = Poor (0-50)
166
  """
167
  )
168
 
@@ -173,7 +153,7 @@ def create_ui():
173
  type="pil",
174
  height=400
175
  )
176
- analyze_button = gr.Button("πŸ” Analyze Image Quality", variant="primary")
177
 
178
  with gr.Column():
179
  results_output = gr.Label(
@@ -184,21 +164,6 @@ def create_ui():
184
  label="Overall Assessment"
185
  )
186
 
187
- # Examples
188
- gr.Examples(
189
- examples=[
190
- ["examples/example1.jpg"],
191
- ["examples/example2.webp"],
192
- ["examples/example3.avif"],
193
- ["examples/example4.jxl"],
194
- ],
195
- inputs=image_input,
196
- outputs=[results_output, summary_output],
197
- fn=analyze_image,
198
- cache_examples=False,
199
- label="Try Example Images"
200
- )
201
-
202
  gr.Markdown(
203
  """
204
  ---
@@ -210,13 +175,10 @@ def create_ui():
210
  | AVIF | 97.1% | 0-100 |
211
  | JXL | 94.8% | 0-100 |
212
 
213
- *Accuracy measured as predictions within Β±5% range of actual quality values*
214
-
215
- **🌍 Environmental Impact**: Model training required ~12 GPU hours on RTX 5090. Model size: 8MB.
216
  """
217
  )
218
 
219
- # Wire up the interface
220
  analyze_button.click(
221
  fn=analyze_image,
222
  inputs=image_input,
 
4
  import torch.nn as nn
5
  from PIL import Image
6
  import torchvision.transforms as transforms
 
7
  import numpy as np
8
  from huggingface_hub import hf_hub_download
9
 
 
48
  checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
49
  self.model.load_state_dict(checkpoint['model_state_dict'])
50
 
 
51
  self.preprocess = transforms.Compose([
52
  transforms.ToTensor(),
53
  ])
 
61
  }
62
 
63
  def predict(self, image: Image.Image) -> dict:
64
+ """Predict compression quality levels for all formats."""
 
65
  img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
66
 
 
67
  with torch.no_grad():
68
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
69
  predictions = self.model(img_tensor).squeeze(0).cpu().float().numpy()
70
 
 
71
  results = {}
72
  for i, fmt in enumerate(self.compression_formats):
73
+ quality_score = float(predictions[i] * 100)
74
 
 
75
  if quality_score >= 90:
76
  category = "Excellent (Minimal artifacts)"
77
  color = "🟒"
 
103
  if image is None:
104
  return None, "Please upload an image."
105
 
 
106
  if isinstance(image, np.ndarray):
107
  image = Image.fromarray(image)
108
 
 
109
  image = image.convert('RGB')
 
 
110
  results = predictor.predict(image)
111
 
 
112
  formatted_results = {}
113
  for fmt, data in results.items():
114
  formatted_results[f"{data['indicator']} {fmt}"] = {
 
117
  "Model Accuracy": f"{data['accuracy']}%"
118
  }
119
 
 
120
  avg_quality = np.mean([r['quality_score'] for r in results.values()])
121
  if avg_quality >= 85:
122
+ overall_status = "βœ… **High Quality Image** - Minimal compression artifacts detected."
123
  elif avg_quality >= 65:
124
+ overall_status = "⚠️ **Moderate Quality** - Some compression artifacts present, but usable."
125
  else:
126
  overall_status = "❌ **Low Quality Image** - Significant compression artifacts detected."
127
 
 
128
  summary = f"### Overall Assessment\n{overall_status}\n\n**Average Quality Score: {avg_quality:.1f}/100**"
129
 
130
  return formatted_results, summary
131
 
 
132
  with gr.Blocks(
133
  title="AAL-Plus Image Quality Assessment",
134
  theme=gr.themes.Soft()
 
136
  gr.Markdown(
137
  """
138
  # 🎯 AAL-Plus Image Quality Assessment
139
+ ### Detect compression artifacts across multiple image formats (JPEG, WebP, AVIF, JXL)
140
 
141
+ This lightweight model (~2M parameters, 8MB) predicts quality levels with **97.1% overall accuracy**.
 
142
 
143
  **How to interpret results:**
144
+ - **Quality Score**: 0-100 scale (higher = better quality)
145
+ - **Score Categories**: 🟒 90-100 | 🟑 70-90 | 🟠 50-70 | πŸ”΄ 0-50
 
 
 
 
 
146
  """
147
  )
148
 
 
153
  type="pil",
154
  height=400
155
  )
156
+ analyze_button = gr.Button("πŸ” Analyze Image Quality", variant="primary", size="lg")
157
 
158
  with gr.Column():
159
  results_output = gr.Label(
 
164
  label="Overall Assessment"
165
  )
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  gr.Markdown(
168
  """
169
  ---
 
175
  | AVIF | 97.1% | 0-100 |
176
  | JXL | 94.8% | 0-100 |
177
 
178
+ *Accuracy measured as predictions within Β±5% of actual quality values*
 
 
179
  """
180
  )
181
 
 
182
  analyze_button.click(
183
  fn=analyze_image,
184
  inputs=image_input,