LoliRimuru commited on
Commit
9361d66
Β·
verified Β·
1 Parent(s): e7c578a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -12
app.py CHANGED
@@ -1,4 +1,5 @@
1
  # app.py
 
2
  import gradio as gr
3
  import torch
4
  import torch.nn as nn
@@ -7,6 +8,9 @@ import torchvision.transforms as transforms
7
  import numpy as np
8
  from huggingface_hub import hf_hub_download
9
 
 
 
 
10
  # ==================== MODEL DEFINITION ====================
11
  class LightweightCompressionNet(nn.Module):
12
  def __init__(self):
@@ -31,7 +35,6 @@ class LightweightCompressionNet(nn.Module):
31
  features = features.view(features.size(0), -1)
32
  return self.head(features)
33
 
34
-
35
  # ==================== INFERENCE CLASS ====================
36
  class CompressionArtifactPredictor:
37
  def __init__(self, device: str = "cuda"):
@@ -48,8 +51,11 @@ class CompressionArtifactPredictor:
48
  checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
49
  self.model.load_state_dict(checkpoint['model_state_dict'])
50
 
 
51
  self.preprocess = transforms.Compose([
52
  transforms.ToTensor(),
 
 
53
  ])
54
 
55
  self.compression_formats = ['JPEG', 'WebP', 'AVIF', 'JXL']
@@ -64,7 +70,7 @@ class CompressionArtifactPredictor:
64
  """Predict compression quality levels for all formats."""
65
  img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
66
 
67
- # SIMPLE FULL PRECISION INFERENCE - NO AUTOCAST
68
  with torch.no_grad():
69
  predictions = self.model(img_tensor).squeeze(0).cpu().numpy()
70
 
@@ -99,7 +105,6 @@ class CompressionArtifactPredictor:
99
 
100
  return results
101
 
102
-
103
  # ==================== GRADIO UI ====================
104
  def create_ui():
105
  predictor = CompressionArtifactPredictor()
@@ -112,25 +117,27 @@ def create_ui():
112
  image = Image.fromarray(image)
113
 
114
  image = image.convert('RGB')
 
 
115
  results = predictor.predict(image)
116
 
117
- # Generate HTML table for results
118
  html_results = """
119
  <table style='width:100%; border-collapse: collapse; font-family: inherit;'>
120
- <tr style='background: #f5f5f5;'>
121
- <th style='padding:12px; text-align:left; border-bottom: 2px solid #ddd;'>Format</th>
122
- <th style='padding:12px; text-align:center; border-bottom: 2px solid #ddd;'>Quality</th>
123
- <th style='padding:12px; text-align:center; border-bottom: 2px solid #ddd;'>Assessment</th>
124
- <th style='padding:12px; text-align:center; border-bottom: 2px solid #ddd;'>Accuracy</th>
125
  </tr>
126
  """
127
 
128
  for fmt, data in results.items():
129
  html_results += f"""
130
- <tr style='border-bottom: 1px solid #eee;'>
131
  <td style='padding:12px; font-weight:500;'>{data['indicator']} {fmt}</td>
132
  <td style='padding:12px; text-align:center;'><strong>{data['quality_score']}/100</strong></td>
133
- <td style='padding:12px; text-align:center;'>{data['category']}<br><small style='color:#666;'>{data['desc']}</small></td>
134
  <td style='padding:12px; text-align:center;'>{data['accuracy']}%</td>
135
  </tr>
136
  """
@@ -181,7 +188,6 @@ def create_ui():
181
  analyze_button = gr.Button("πŸ” Analyze Image Quality", variant="primary", size="lg")
182
 
183
  with gr.Column():
184
- # USE HTML COMPONENT INSTEAD OF LABEL
185
  results_output = gr.HTML(
186
  label="Format-Specific Quality Scores"
187
  )
 
1
  # app.py
2
+ import warnings
3
  import gradio as gr
4
  import torch
5
  import torch.nn as nn
 
8
  import numpy as np
9
  from huggingface_hub import hf_hub_download
10
 
11
+ # Suppress HF spaces warning (internal library, not our code)
12
+ warnings.filterwarnings("ignore", category=FutureWarning, module="spaces")
13
+
14
  # ==================== MODEL DEFINITION ====================
15
  class LightweightCompressionNet(nn.Module):
16
  def __init__(self):
 
35
  features = features.view(features.size(0), -1)
36
  return self.head(features)
37
 
 
38
  # ==================== INFERENCE CLASS ====================
39
  class CompressionArtifactPredictor:
40
  def __init__(self, device: str = "cuda"):
 
51
  checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
52
  self.model.load_state_dict(checkpoint['model_state_dict'])
53
 
54
+ # FIXED: Add padding->center crop to handle arbitrary sizes
55
  self.preprocess = transforms.Compose([
56
  transforms.ToTensor(),
57
+ transforms.Pad(512, padding_mode='edge'), # Pad smaller images
58
+ transforms.CenterCrop(512), # Then crop to 512x512
59
  ])
60
 
61
  self.compression_formats = ['JPEG', 'WebP', 'AVIF', 'JXL']
 
70
  """Predict compression quality levels for all formats."""
71
  img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
72
 
73
+ # FIXED: Full precision, no autocast
74
  with torch.no_grad():
75
  predictions = self.model(img_tensor).squeeze(0).cpu().numpy()
76
 
 
105
 
106
  return results
107
 
 
108
  # ==================== GRADIO UI ====================
109
  def create_ui():
110
  predictor = CompressionArtifactPredictor()
 
117
  image = Image.fromarray(image)
118
 
119
  image = image.convert('RGB')
120
+ print(f"Processing image of size: {image.size}") # Debug log
121
+
122
  results = predictor.predict(image)
123
 
124
+ # FIXED: Dark mode compatible using CSS variables
125
  html_results = """
126
  <table style='width:100%; border-collapse: collapse; font-family: inherit;'>
127
+ <tr style='background: var(--block-label-background-fill, #f5f5f5);'>
128
+ <th style='padding:12px; text-align:left; border-bottom: 2px solid var(--border-color-primary, #ddd);'>Format</th>
129
+ <th style='padding:12px; text-align:center; border-bottom: 2px solid var(--border-color-primary, #ddd);'>Quality</th>
130
+ <th style='padding:12px; text-align:center; border-bottom: 2px solid var(--border-color-primary, #ddd);'>Assessment</th>
131
+ <th style='padding:12px; text-align:center; border-bottom: 2px solid var(--border-color-primary, #ddd);'>Accuracy</th>
132
  </tr>
133
  """
134
 
135
  for fmt, data in results.items():
136
  html_results += f"""
137
+ <tr style='border-bottom: 1px solid var(--border-color-primary, #eee);'>
138
  <td style='padding:12px; font-weight:500;'>{data['indicator']} {fmt}</td>
139
  <td style='padding:12px; text-align:center;'><strong>{data['quality_score']}/100</strong></td>
140
+ <td style='padding:12px; text-align:center;'>{data['category']}<br><small style='color: var(--body-text-color-subdued, #666);'>{data['desc']}</small></td>
141
  <td style='padding:12px; text-align:center;'>{data['accuracy']}%</td>
142
  </tr>
143
  """
 
188
  analyze_button = gr.Button("πŸ” Analyze Image Quality", variant="primary", size="lg")
189
 
190
  with gr.Column():
 
191
  results_output = gr.HTML(
192
  label="Format-Specific Quality Scores"
193
  )