LoliRimuru commited on
Commit
e7c578a
Β·
verified Β·
1 Parent(s): 4bd8e01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -21
app.py CHANGED
@@ -64,30 +64,35 @@ class CompressionArtifactPredictor:
64
  """Predict compression quality levels for all formats."""
65
  img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
66
 
 
67
  with torch.no_grad():
68
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
69
- predictions = self.model(img_tensor).squeeze(0).cpu().float().numpy()
70
 
71
  results = {}
72
  for i, fmt in enumerate(self.compression_formats):
73
  quality_score = float(predictions[i] * 100)
74
 
75
  if quality_score >= 90:
76
- category = "Excellent (Minimal artifacts)"
77
  color = "🟒"
 
78
  elif quality_score >= 70:
79
- category = "Good (Light artifacts)"
80
  color = "🟑"
 
81
  elif quality_score >= 50:
82
- category = "Fair (Moderate artifacts)"
83
  color = "🟠"
 
84
  else:
85
- category = "Poor (Heavy artifacts)"
86
  color = "πŸ”΄"
 
87
 
88
  results[fmt] = {
89
  'quality_score': round(quality_score, 1),
90
  'category': category,
 
91
  'accuracy': self.accuracy_scores[fmt],
92
  'indicator': color
93
  }
@@ -101,7 +106,7 @@ def create_ui():
101
 
102
  def analyze_image(image):
103
  if image is None:
104
- return None, "Please upload an image."
105
 
106
  if isinstance(image, np.ndarray):
107
  image = Image.fromarray(image)
@@ -109,25 +114,45 @@ def create_ui():
109
  image = image.convert('RGB')
110
  results = predictor.predict(image)
111
 
112
- formatted_results = {}
 
 
 
 
 
 
 
 
 
 
113
  for fmt, data in results.items():
114
- formatted_results[f"{data['indicator']} {fmt}"] = {
115
- "Predicted Quality": f"{data['quality_score']}/100",
116
- "Assessment": data['category'],
117
- "Model Accuracy": f"{data['accuracy']}%"
118
- }
 
 
 
 
119
 
 
120
  avg_quality = np.mean([r['quality_score'] for r in results.values()])
121
  if avg_quality >= 85:
122
- overall_status = "βœ… **High Quality Image** - Minimal compression artifacts detected."
123
  elif avg_quality >= 65:
124
- overall_status = "⚠️ **Moderate Quality** - Some compression artifacts present, but usable."
125
  else:
126
  overall_status = "❌ **Low Quality Image** - Significant compression artifacts detected."
127
 
128
- summary = f"### Overall Assessment\n{overall_status}\n\n**Average Quality Score: {avg_quality:.1f}/100**"
 
 
 
 
 
129
 
130
- return formatted_results, summary
131
 
132
  with gr.Blocks(
133
  title="AAL-Plus Image Quality Assessment",
@@ -156,9 +181,9 @@ def create_ui():
156
  analyze_button = gr.Button("πŸ” Analyze Image Quality", variant="primary", size="lg")
157
 
158
  with gr.Column():
159
- results_output = gr.Label(
160
- label="Format-Specific Quality Scores",
161
- num_top_classes=4
162
  )
163
  summary_output = gr.Markdown(
164
  label="Overall Assessment"
@@ -193,7 +218,6 @@ def create_ui():
193
 
194
  return demo
195
 
196
- # ==================== MAIN ====================
197
  if __name__ == "__main__":
198
  demo = create_ui()
199
  demo.launch()
 
64
  """Predict compression quality levels for all formats."""
65
  img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
66
 
67
+ # SIMPLE FULL PRECISION INFERENCE - NO AUTOCAST
68
  with torch.no_grad():
69
+ predictions = self.model(img_tensor).squeeze(0).cpu().numpy()
 
70
 
71
  results = {}
72
  for i, fmt in enumerate(self.compression_formats):
73
  quality_score = float(predictions[i] * 100)
74
 
75
  if quality_score >= 90:
76
+ category = "Excellent"
77
  color = "🟒"
78
+ desc = "Minimal artifacts"
79
  elif quality_score >= 70:
80
+ category = "Good"
81
  color = "🟑"
82
+ desc = "Light artifacts"
83
  elif quality_score >= 50:
84
+ category = "Fair"
85
  color = "🟠"
86
+ desc = "Moderate artifacts"
87
  else:
88
+ category = "Poor"
89
  color = "πŸ”΄"
90
+ desc = "Heavy artifacts"
91
 
92
  results[fmt] = {
93
  'quality_score': round(quality_score, 1),
94
  'category': category,
95
+ 'desc': desc,
96
  'accuracy': self.accuracy_scores[fmt],
97
  'indicator': color
98
  }
 
106
 
107
  def analyze_image(image):
108
  if image is None:
109
+ return "", "Please upload an image."
110
 
111
  if isinstance(image, np.ndarray):
112
  image = Image.fromarray(image)
 
114
  image = image.convert('RGB')
115
  results = predictor.predict(image)
116
 
117
+ # Generate HTML table for results
118
+ html_results = """
119
+ <table style='width:100%; border-collapse: collapse; font-family: inherit;'>
120
+ <tr style='background: #f5f5f5;'>
121
+ <th style='padding:12px; text-align:left; border-bottom: 2px solid #ddd;'>Format</th>
122
+ <th style='padding:12px; text-align:center; border-bottom: 2px solid #ddd;'>Quality</th>
123
+ <th style='padding:12px; text-align:center; border-bottom: 2px solid #ddd;'>Assessment</th>
124
+ <th style='padding:12px; text-align:center; border-bottom: 2px solid #ddd;'>Accuracy</th>
125
+ </tr>
126
+ """
127
+
128
  for fmt, data in results.items():
129
+ html_results += f"""
130
+ <tr style='border-bottom: 1px solid #eee;'>
131
+ <td style='padding:12px; font-weight:500;'>{data['indicator']} {fmt}</td>
132
+ <td style='padding:12px; text-align:center;'><strong>{data['quality_score']}/100</strong></td>
133
+ <td style='padding:12px; text-align:center;'>{data['category']}<br><small style='color:#666;'>{data['desc']}</small></td>
134
+ <td style='padding:12px; text-align:center;'>{data['accuracy']}%</td>
135
+ </tr>
136
+ """
137
+ html_results += "</table>"
138
 
139
+ # Overall summary
140
  avg_quality = np.mean([r['quality_score'] for r in results.values()])
141
  if avg_quality >= 85:
142
+ overall_status = "βœ… **High Quality Image** - Minimal compression artifacts detected across all formats."
143
  elif avg_quality >= 65:
144
+ overall_status = "⚠️ **Moderate Quality** - Some compression artifacts present, but image remains usable."
145
  else:
146
  overall_status = "❌ **Low Quality Image** - Significant compression artifacts detected."
147
 
148
+ summary = f"""
149
+ ### Overall Assessment
150
+ {overall_status}
151
+
152
+ **Average Quality Score: {avg_quality:.1f}/100**
153
+ """
154
 
155
+ return html_results, summary
156
 
157
  with gr.Blocks(
158
  title="AAL-Plus Image Quality Assessment",
 
181
  analyze_button = gr.Button("πŸ” Analyze Image Quality", variant="primary", size="lg")
182
 
183
  with gr.Column():
184
+ # USE HTML COMPONENT INSTEAD OF LABEL
185
+ results_output = gr.HTML(
186
+ label="Format-Specific Quality Scores"
187
  )
188
  summary_output = gr.Markdown(
189
  label="Overall Assessment"
 
218
 
219
  return demo
220
 
 
221
  if __name__ == "__main__":
222
  demo = create_ui()
223
  demo.launch()