hamxaameer commited on
Commit
d3dc4bf
Β·
verified Β·
1 Parent(s): fb717b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +533 -220
app.py CHANGED
@@ -1,277 +1,590 @@
1
  """
2
- Hugging Face Spaces App for MAE Image Reconstruction
3
- Entry point for Hugging Face deployment
4
  """
5
 
6
  import gradio as gr
7
  import torch
8
  import torch.nn as nn
9
- import numpy as np
10
- from PIL import Image
11
  from torchvision import transforms
12
- import os
 
 
 
13
 
14
- # Import local modules
15
- from mae_model import create_mae_model
16
- from metrics import calculate_psnr, calculate_ssim, denormalize_for_metrics
17
 
 
 
 
18
 
19
- class MAEInference:
20
- """MAE Inference wrapper for Hugging Face Spaces."""
 
 
 
 
 
 
21
 
22
- def __init__(self):
23
- # Use CPU for Hugging Face free tier, GPU if available
24
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
- print(f"Running on: {self.device}")
26
-
27
- # Create model
28
- self.model = create_mae_model(
29
- img_size=224,
30
- patch_size=16,
31
- encoder_embed_dim=768,
32
- encoder_depth=12,
33
- encoder_num_heads=12,
34
- decoder_embed_dim=384,
35
- decoder_depth=12,
36
- decoder_num_heads=6,
37
- mask_ratio=0.75
38
- )
39
-
40
- # Load checkpoint
41
- self._load_weights()
42
-
43
- self.model = self.model.to(self.device)
44
- self.model.eval()
45
-
46
- # Image transforms
47
- self.transform = transforms.Compose([
48
- transforms.Resize((224, 224)),
49
- transforms.ToTensor(),
50
- transforms.Normalize(
51
- mean=[0.485, 0.456, 0.406],
52
- std=[0.229, 0.224, 0.225]
53
- )
54
- ])
55
 
56
- def _load_weights(self):
57
- """Load model weights from various possible locations."""
58
- # Possible checkpoint locations
59
- checkpoint_paths = [
60
- "checkpoint_best.pth", # Same directory (HF Spaces)
61
- "mae_checkpoint.pth", # Alternative name
62
- "model/checkpoint_best.pth", # Model subdirectory
63
- "/kaggle/working/checkpoint_best.pth", # Kaggle
64
- ]
65
 
66
- for path in checkpoint_paths:
67
- if os.path.exists(path):
68
- try:
69
- checkpoint = torch.load(path, map_location=self.device)
70
- if 'model_state_dict' in checkpoint:
71
- self.model.load_state_dict(checkpoint['model_state_dict'])
72
- else:
73
- self.model.load_state_dict(checkpoint)
74
- print(f"βœ“ Loaded weights from: {path}")
75
- return
76
- except Exception as e:
77
- print(f"Failed to load {path}: {e}")
78
- continue
79
 
80
- print("⚠ No checkpoint found - using random weights for demo")
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- def denormalize(self, tensor):
83
- """Denormalize tensor for display."""
84
- mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
85
- std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
86
-
87
- if tensor.device.type != 'cpu':
88
- tensor = tensor.cpu()
89
-
90
- tensor = tensor * std + mean
91
- tensor = torch.clamp(tensor, 0, 1)
92
- return tensor
93
-
94
- def create_masked_image(self, image_tensor, mask_indices, patch_size=16):
95
- """Create visualization of masked image."""
96
- img = self.denormalize(image_tensor.clone())
97
- num_patches_per_side = 224 // patch_size
98
-
99
- for idx in mask_indices:
100
- row = idx.item() // num_patches_per_side
101
- col = idx.item() % num_patches_per_side
102
- img[:,
103
- row * patch_size:(row + 1) * patch_size,
104
- col * patch_size:(col + 1) * patch_size] = 0.5
105
-
106
- return (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
107
 
108
- @torch.no_grad()
109
- def reconstruct(self, image, mask_ratio=0.75):
110
- """Reconstruct image using MAE."""
111
- if image is None:
112
- return None, None, None, "Please upload an image."
113
-
114
- # Preprocess
115
- if isinstance(image, np.ndarray):
116
- image = Image.fromarray(image)
117
- if image.mode != 'RGB':
118
- image = image.convert('RGB')
119
-
120
- input_tensor = self.transform(image).unsqueeze(0).to(self.device)
121
-
122
- # Forward pass
123
- pred, target, mask_indices = self.model(input_tensor, mask_ratio)
124
-
125
- # Unpatchify
126
- reconstructed = self.model.unpatchify(pred)
127
-
128
- # Create visualizations
129
- masked_img = self.create_masked_image(
130
- input_tensor[0].cpu(),
131
- mask_indices[0].cpu()
132
- )
133
-
134
- recon_img = self.denormalize(reconstructed[0].cpu())
135
- recon_img = (recon_img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
136
-
137
- original_img = self.denormalize(input_tensor[0].cpu())
138
- original_img = (original_img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- # Calculate metrics
141
- pred_denorm = denormalize_for_metrics(reconstructed)
142
- target_denorm = denormalize_for_metrics(input_tensor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- psnr = calculate_psnr(pred_denorm, target_denorm)
145
- ssim = calculate_ssim(pred_denorm, target_denorm)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- metrics_text = f"""
148
- ### πŸ“Š Reconstruction Metrics
149
- | Metric | Value |
150
- |--------|-------|
151
- | **PSNR** | {psnr:.2f} dB |
152
- | **SSIM** | {ssim:.4f} |
153
- | **Mask Ratio** | {mask_ratio*100:.0f}% |
154
- | **Visible Patches** | {int((1-mask_ratio)*196)} / 196 |
155
- | **Masked Patches** | {int(mask_ratio*196)} / 196 |
156
- """
157
 
158
- return original_img, masked_img, recon_img, metrics_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
 
161
- # Initialize model globally (loaded once when app starts)
162
- print("Initializing MAE model...")
163
- mae = MAEInference()
164
 
 
 
 
 
 
 
 
 
165
 
166
- def process_image(input_image, mask_ratio):
167
- """Main processing function for Gradio."""
168
- if input_image is None:
169
- return None, None, None, "⬆️ Please upload an image to get started."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- mask_ratio = max(0.1, min(0.95, mask_ratio))
172
- return mae.reconstruct(input_image, mask_ratio)
 
 
 
 
 
 
 
173
 
174
 
175
- # Create Gradio interface
176
- with gr.Blocks(
177
- title="MAE Image Reconstruction",
178
- theme=gr.themes.Soft(),
179
- css="""
180
- .gradio-container { max-width: 1200px !important; }
181
- .output-image { border-radius: 8px; }
182
- """
183
- ) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
- gr.Markdown("""
186
- # 🎭 Masked Autoencoder (MAE) Image Reconstruction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
- Upload any image to see how the MAE reconstructs it from only **25% visible patches**.
189
- The model learns powerful visual representations by predicting masked regions.
 
190
 
191
- > **Try adjusting the mask ratio** to see how the reconstruction quality changes!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  """)
193
 
194
  with gr.Row():
195
  with gr.Column(scale=1):
196
- input_image = gr.Image(
197
- label="πŸ“€ Upload Image",
198
- type="pil",
199
- height=280
200
- )
201
 
202
  mask_ratio_slider = gr.Slider(
203
- minimum=0.1,
204
- maximum=0.95,
205
- value=0.75,
206
- step=0.05,
207
- label="🎚️ Masking Ratio",
208
- info="Percentage of patches to mask (default: 75%)"
209
  )
210
 
211
- reconstruct_btn = gr.Button(
212
- "πŸ”„ Reconstruct Image",
213
- variant="primary",
214
- size="lg"
215
- )
216
 
217
- metrics_output = gr.Markdown(
218
- value="⬆️ Upload an image and click **Reconstruct** to see metrics."
219
- )
 
 
 
 
 
 
 
220
 
221
  with gr.Column(scale=2):
 
222
  with gr.Row():
223
- original_output = gr.Image(
224
- label="Original (224Γ—224)",
225
- height=224,
226
- show_download_button=True
227
- )
228
- masked_output = gr.Image(
229
- label="Masked Input",
230
- height=224,
231
- show_download_button=True
232
- )
233
- reconstructed_output = gr.Image(
234
- label="Reconstruction",
235
- height=224,
236
- show_download_button=True
237
- )
 
 
 
 
 
 
 
 
 
 
238
 
239
  # Event handlers
240
  reconstruct_btn.click(
241
- fn=process_image,
242
  inputs=[input_image, mask_ratio_slider],
243
  outputs=[original_output, masked_output, reconstructed_output, metrics_output]
244
  )
245
 
246
- mask_ratio_slider.change(
247
- fn=process_image,
248
  inputs=[input_image, mask_ratio_slider],
249
  outputs=[original_output, masked_output, reconstructed_output, metrics_output]
250
  )
251
 
252
- input_image.change(
253
- fn=process_image,
254
- inputs=[input_image, mask_ratio_slider],
255
- outputs=[original_output, masked_output, reconstructed_output, metrics_output]
256
  )
257
-
258
- gr.Markdown("""
259
- ---
260
- ### πŸ”¬ How MAE Works
261
-
262
- 1. **Masking**: Randomly mask ~75% of image patches
263
- 2. **Encoding**: Process only visible patches through ViT encoder
264
- 3. **Decoding**: Reconstruct full image using a lightweight decoder
265
-
266
- **Model Architecture:**
267
- - **Encoder**: ViT-Base (768 dim, 12 layers) β€” 86M params
268
- - **Decoder**: ViT-Small (384 dim, 12 layers) β€” 22M params
269
-
270
- πŸ“„ [Original Paper](https://arxiv.org/abs/2111.06377) |
271
- πŸ”— [GitHub](https://github.com/facebookresearch/mae)
272
- """)
273
 
274
 
275
- # Launch app
276
  if __name__ == "__main__":
277
- demo.launch()
 
1
  """
2
+ 🎭 Masked Autoencoder (MAE) - HuggingFace Spaces App
3
+ Beautiful UI for image reconstruction with detailed metrics
4
  """
5
 
6
  import gradio as gr
7
  import torch
8
  import torch.nn as nn
9
+ import torch.nn.functional as F
 
10
  from torchvision import transforms
11
+ from PIL import Image
12
+ import numpy as np
13
+ from einops import rearrange
14
+ import math
15
 
 
 
 
16
 
17
+ # ============================================================================
18
+ # MODEL ARCHITECTURE
19
+ # ============================================================================
20
 
21
+ class PatchEmbed(nn.Module):
22
+ """Convert image to patches and embed them."""
23
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
24
+ super().__init__()
25
+ self.img_size = img_size
26
+ self.patch_size = patch_size
27
+ self.num_patches = (img_size // patch_size) ** 2
28
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
29
 
30
+ def forward(self, x):
31
+ x = self.proj(x)
32
+ x = x.flatten(2).transpose(1, 2)
33
+ return x
34
+
35
+
36
+ class Attention(nn.Module):
37
+ """Multi-head self-attention."""
38
+ def __init__(self, dim, num_heads=8):
39
+ super().__init__()
40
+ self.num_heads = num_heads
41
+ head_dim = dim // num_heads
42
+ self.scale = head_dim ** -0.5
43
+ self.qkv = nn.Linear(dim, dim * 3)
44
+ self.proj = nn.Linear(dim, dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ def forward(self, x):
47
+ B, N, C = x.shape
48
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
49
+ q, k, v = qkv[0], qkv[1], qkv[2]
 
 
 
 
 
50
 
51
+ attn = (q @ k.transpose(-2, -1)) * self.scale
52
+ attn = attn.softmax(dim=-1)
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
55
+ x = self.proj(x)
56
+ return x
57
+
58
+
59
+ class MLP(nn.Module):
60
+ """Feedforward network."""
61
+ def __init__(self, in_features, hidden_features=None):
62
+ super().__init__()
63
+ hidden_features = hidden_features or in_features * 4
64
+ self.fc1 = nn.Linear(in_features, hidden_features)
65
+ self.act = nn.GELU()
66
+ self.fc2 = nn.Linear(hidden_features, in_features)
67
 
68
+ def forward(self, x):
69
+ x = self.fc1(x)
70
+ x = self.act(x)
71
+ x = self.fc2(x)
72
+ return x
73
+
74
+
75
+ class TransformerBlock(nn.Module):
76
+ """Transformer block with attention and MLP."""
77
+ def __init__(self, dim, num_heads, mlp_ratio=4.0):
78
+ super().__init__()
79
+ self.norm1 = nn.LayerNorm(dim)
80
+ self.attn = Attention(dim, num_heads=num_heads)
81
+ self.norm2 = nn.LayerNorm(dim)
82
+ self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio))
 
 
 
 
 
 
 
 
 
 
83
 
84
+ def forward(self, x):
85
+ x = x + self.attn(self.norm1(x))
86
+ x = x + self.mlp(self.norm2(x))
87
+ return x
88
+
89
+
90
+ def get_2d_sincos_pos_embed(embed_dim, grid_size):
91
+ """Generate 2D sinusoidal positional embeddings."""
92
+ grid_h = np.arange(grid_size, dtype=np.float32)
93
+ grid_w = np.arange(grid_size, dtype=np.float32)
94
+ grid = np.meshgrid(grid_w, grid_h)
95
+ grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
96
+
97
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
98
+ return pos_embed
99
+
100
+
101
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
102
+ assert embed_dim % 2 == 0
103
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
104
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
105
+ emb = np.concatenate([emb_h, emb_w], axis=1)
106
+ return emb
107
+
108
+
109
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
110
+ assert embed_dim % 2 == 0
111
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
112
+ omega /= embed_dim / 2.
113
+ omega = 1. / 10000**omega
114
+ pos = pos.reshape(-1)
115
+ out = np.einsum('m,d->md', pos, omega)
116
+ emb_sin = np.sin(out)
117
+ emb_cos = np.cos(out)
118
+ emb = np.concatenate([emb_sin, emb_cos], axis=1)
119
+ return emb
120
+
121
+
122
+ class ViTEncoder(nn.Module):
123
+ """Vision Transformer Encoder (ViT-Base)."""
124
+ def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768,
125
+ depth=12, num_heads=12, mlp_ratio=4.0):
126
+ super().__init__()
127
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
128
+ self.num_patches = self.patch_embed.num_patches
129
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
130
+ self.blocks = nn.ModuleList([
131
+ TransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(depth)
132
+ ])
133
+ self.norm = nn.LayerNorm(embed_dim)
134
+ self._init_weights()
135
 
136
+ def _init_weights(self):
137
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.num_patches ** 0.5))
138
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
139
+
140
+ def forward(self, x, visible_indices):
141
+ x = self.patch_embed(x)
142
+ x = x + self.pos_embed
143
+ x = torch.gather(x, dim=1, index=visible_indices.unsqueeze(-1).expand(-1, -1, x.shape[-1]))
144
+ for block in self.blocks:
145
+ x = block(x)
146
+ x = self.norm(x)
147
+ return x
148
+
149
+
150
+ class ViTDecoder(nn.Module):
151
+ """Vision Transformer Decoder (ViT-Small)."""
152
+ def __init__(self, img_size=224, patch_size=16, embed_dim=384, depth=12,
153
+ num_heads=6, mlp_ratio=4.0, encoder_embed_dim=768):
154
+ super().__init__()
155
+ self.patch_size = patch_size
156
+ self.num_patches = (img_size // patch_size) ** 2
157
+ self.encoder_to_decoder = nn.Linear(encoder_embed_dim, embed_dim)
158
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
159
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
160
+ self.blocks = nn.ModuleList([
161
+ TransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(depth)
162
+ ])
163
+ self.norm = nn.LayerNorm(embed_dim)
164
+ self.pred = nn.Linear(embed_dim, patch_size ** 2 * 3)
165
+ self._init_weights()
166
 
167
+ def _init_weights(self):
168
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.num_patches ** 0.5))
169
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
170
+ nn.init.normal_(self.mask_token, std=0.02)
171
+
172
+ def forward(self, x, visible_indices, mask_indices):
173
+ B, num_visible, _ = x.shape
174
+ num_masked = mask_indices.shape[1]
175
+ x = self.encoder_to_decoder(x)
176
+ mask_tokens = self.mask_token.expand(B, num_masked, -1).to(dtype=x.dtype)
177
+ full_tokens = torch.zeros(B, self.num_patches, x.shape[-1], device=x.device, dtype=x.dtype)
178
+ visible_indices_expanded = visible_indices.unsqueeze(-1).expand(-1, -1, x.shape[-1])
179
+ full_tokens.scatter_(1, visible_indices_expanded, x)
180
+ mask_indices_expanded = mask_indices.unsqueeze(-1).expand(-1, -1, x.shape[-1])
181
+ full_tokens.scatter_(1, mask_indices_expanded, mask_tokens)
182
+ full_tokens = full_tokens + self.pos_embed.to(dtype=x.dtype)
183
+ for block in self.blocks:
184
+ full_tokens = block(full_tokens)
185
+ full_tokens = self.norm(full_tokens)
186
+ pred = self.pred(full_tokens)
187
+ return pred
188
+
189
+
190
+ class MaskedAutoencoder(nn.Module):
191
+ """Masked Autoencoder for Self-Supervised Learning."""
192
+ def __init__(self, img_size=224, patch_size=16, in_channels=3,
193
+ encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12,
194
+ decoder_embed_dim=384, decoder_depth=12, decoder_num_heads=6,
195
+ mlp_ratio=4.0, mask_ratio=0.75):
196
+ super().__init__()
197
+ self.img_size = img_size
198
+ self.patch_size = patch_size
199
+ self.num_patches = (img_size // patch_size) ** 2
200
+ self.mask_ratio = mask_ratio
201
 
202
+ self.encoder = ViTEncoder(img_size, patch_size, in_channels, encoder_embed_dim,
203
+ encoder_depth, encoder_num_heads, mlp_ratio)
204
+ self.decoder = ViTDecoder(img_size, patch_size, decoder_embed_dim, decoder_depth,
205
+ decoder_num_heads, mlp_ratio, encoder_embed_dim)
 
 
 
 
 
 
206
 
207
+ def patchify(self, imgs):
208
+ p = self.patch_size
209
+ x = rearrange(imgs, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
210
+ return x
211
+
212
+ def unpatchify(self, x):
213
+ p = self.patch_size
214
+ h = w = self.img_size // p
215
+ x = rearrange(x, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=h, w=w, p1=p, p2=p, c=3)
216
+ return x
217
+
218
+ def random_masking(self, batch_size, device, mask_ratio=None):
219
+ if mask_ratio is None:
220
+ mask_ratio = self.mask_ratio
221
+ num_patches = self.num_patches
222
+ num_visible = int(num_patches * (1 - mask_ratio))
223
+ noise = torch.rand(batch_size, num_patches, device=device)
224
+ ids_shuffle = torch.argsort(noise, dim=1)
225
+ visible_indices = torch.sort(ids_shuffle[:, :num_visible], dim=1)[0]
226
+ mask_indices = torch.sort(ids_shuffle[:, num_visible:], dim=1)[0]
227
+ return visible_indices, mask_indices
228
+
229
+ def forward(self, imgs, mask_ratio=None):
230
+ B = imgs.shape[0]
231
+ device = imgs.device
232
+ visible_indices, mask_indices = self.random_masking(B, device, mask_ratio)
233
+ latent = self.encoder(imgs, visible_indices)
234
+ pred = self.decoder(latent, visible_indices, mask_indices)
235
+ target = self.patchify(imgs)
236
+ return pred, target, mask_indices
237
+
238
+ def forward_loss(self, imgs, mask_ratio=None):
239
+ pred, target, mask_indices = self.forward(imgs, mask_ratio)
240
+ B = imgs.shape[0]
241
+ mask_indices_expanded = mask_indices.unsqueeze(-1).expand(-1, -1, pred.shape[-1])
242
+ pred_masked = torch.gather(pred, dim=1, index=mask_indices_expanded)
243
+ target_masked = torch.gather(target, dim=1, index=mask_indices_expanded)
244
+ loss = F.mse_loss(pred_masked, target_masked)
245
+ return loss, pred, target, mask_indices
246
 
247
 
248
+ # ============================================================================
249
+ # METRICS
250
+ # ============================================================================
251
 
252
+ def gaussian_kernel(size=11, sigma=1.5, channels=3, device='cpu'):
253
+ """Create Gaussian kernel for SSIM calculation."""
254
+ x = torch.arange(size, device=device).float() - size // 2
255
+ gauss_1d = torch.exp(-x ** 2 / (2 * sigma ** 2))
256
+ gauss_1d = gauss_1d / gauss_1d.sum()
257
+ gauss_2d = gauss_1d.unsqueeze(1) @ gauss_1d.unsqueeze(0)
258
+ kernel = gauss_2d.unsqueeze(0).unsqueeze(0).repeat(channels, 1, 1, 1)
259
+ return kernel
260
 
261
+
262
+ def calculate_psnr(pred, target, max_val=1.0):
263
+ """Calculate Peak Signal-to-Noise Ratio."""
264
+ mse = F.mse_loss(pred, target, reduction='mean')
265
+ if mse == 0:
266
+ return float('inf')
267
+ psnr = 20 * math.log10(max_val) - 10 * torch.log10(mse)
268
+ return psnr.item()
269
+
270
+
271
+ def calculate_ssim(pred, target, window_size=11, sigma=1.5, data_range=1.0):
272
+ """Calculate Structural Similarity Index."""
273
+ device = pred.device
274
+ channels = pred.shape[1]
275
+
276
+ C1 = (0.01 * data_range) ** 2
277
+ C2 = (0.03 * data_range) ** 2
278
+
279
+ kernel = gaussian_kernel(window_size, sigma, channels, device)
280
+
281
+ mu_pred = F.conv2d(pred, kernel, padding=window_size // 2, groups=channels)
282
+ mu_target = F.conv2d(target, kernel, padding=window_size // 2, groups=channels)
283
+
284
+ mu_pred_sq = mu_pred ** 2
285
+ mu_target_sq = mu_target ** 2
286
+ mu_pred_target = mu_pred * mu_target
287
 
288
+ sigma_pred_sq = F.conv2d(pred ** 2, kernel, padding=window_size // 2, groups=channels) - mu_pred_sq
289
+ sigma_target_sq = F.conv2d(target ** 2, kernel, padding=window_size // 2, groups=channels) - mu_target_sq
290
+ sigma_pred_target = F.conv2d(pred * target, kernel, padding=window_size // 2, groups=channels) - mu_pred_target
291
+
292
+ numerator = (2 * mu_pred_target + C1) * (2 * sigma_pred_target + C2)
293
+ denominator = (mu_pred_sq + mu_target_sq + C1) * (sigma_pred_sq + sigma_target_sq + C2)
294
+
295
+ ssim_map = numerator / denominator
296
+ return ssim_map.mean().item()
297
 
298
 
299
+ def denormalize_for_metrics(tensor):
300
+ """Denormalize tensor for metric calculation."""
301
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(tensor.device)
302
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(tensor.device)
303
+ tensor = tensor * std + mean
304
+ return torch.clamp(tensor, 0, 1)
305
+
306
+
307
+ # ============================================================================
308
+ # LOAD MODEL
309
+ # ============================================================================
310
+
311
+ print("πŸš€ Loading MAE model...")
312
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
313
+ print(f" Device: {device}")
314
+
315
+ checkpoint = torch.load('mae_model_weights.pth', map_location=device)
316
+ config = checkpoint['config']
317
+
318
+ model = MaskedAutoencoder(
319
+ img_size=config['img_size'],
320
+ patch_size=config['patch_size'],
321
+ encoder_embed_dim=config['encoder_embed_dim'],
322
+ encoder_depth=config['encoder_depth'],
323
+ encoder_num_heads=config['encoder_num_heads'],
324
+ decoder_embed_dim=config['decoder_embed_dim'],
325
+ decoder_depth=config['decoder_depth'],
326
+ decoder_num_heads=config['decoder_num_heads'],
327
+ mask_ratio=config['mask_ratio']
328
+ )
329
+ model.load_state_dict(checkpoint['model_state_dict'])
330
+ model.to(device)
331
+ model.eval()
332
+ print("βœ… Model loaded successfully!")
333
+
334
+
335
+ # ============================================================================
336
+ # IMAGE PROCESSING
337
+ # ============================================================================
338
+
339
+ transform = transforms.Compose([
340
+ transforms.Resize((224, 224)),
341
+ transforms.ToTensor(),
342
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
343
+ ])
344
+
345
+
346
+ def denormalize(tensor):
347
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(tensor.device)
348
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(tensor.device)
349
+ return torch.clamp(tensor * std + mean, 0, 1)
350
+
351
+
352
+ def create_masked_vis(image_tensor, mask_indices, patch_size=16):
353
+ """Create visualization of masked image with gray patches."""
354
+ img = denormalize(image_tensor.clone())
355
+ num_patches_per_side = 224 // patch_size
356
+ for idx in mask_indices:
357
+ row = idx.item() // num_patches_per_side
358
+ col = idx.item() % num_patches_per_side
359
+ img[:, row * patch_size:(row + 1) * patch_size,
360
+ col * patch_size:(col + 1) * patch_size] = 0.5
361
+ return (img.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
362
+
363
+
364
+ # ============================================================================
365
+ # INFERENCE FUNCTION
366
+ # ============================================================================
367
+
368
+ @torch.no_grad()
369
+ def reconstruct_image(input_image, mask_ratio_percent):
370
+ if input_image is None:
371
+ return None, None, None, "⚠️ Please upload an image first."
372
 
373
+ # Convert percentage to ratio
374
+ mask_ratio = mask_ratio_percent / 100.0
375
+ mask_ratio = max(0.01, min(0.99, mask_ratio)) # Clamp between 1% and 99%
376
+
377
+ # Convert to PIL if needed
378
+ if isinstance(input_image, np.ndarray):
379
+ input_image = Image.fromarray(input_image)
380
+ if input_image.mode != 'RGB':
381
+ input_image = input_image.convert('RGB')
382
+
383
+ # Process image
384
+ input_tensor = transform(input_image).unsqueeze(0).to(device)
385
+
386
+ # Forward pass with loss
387
+ loss, pred, target, mask_indices = model.forward_loss(input_tensor, mask_ratio)
388
+ reconstructed = model.unpatchify(pred)
389
+
390
+ # Original image
391
+ original_img = denormalize(input_tensor[0].cpu())
392
+ original_img = (original_img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
393
+
394
+ # Masked image
395
+ masked_img = create_masked_vis(input_tensor[0].cpu(), mask_indices[0].cpu())
396
 
397
+ # Reconstructed image
398
+ recon_img = denormalize(reconstructed[0].cpu())
399
+ recon_img = (recon_img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
400
 
401
+ # Calculate metrics
402
+ pred_denorm = denormalize_for_metrics(reconstructed)
403
+ target_denorm = denormalize_for_metrics(input_tensor)
404
+ psnr = calculate_psnr(pred_denorm, target_denorm)
405
+ ssim = calculate_ssim(pred_denorm, target_denorm)
406
+
407
+ # Determine quality rating
408
+ if psnr >= 30 and ssim >= 0.85:
409
+ quality = "🎯 Excellent"
410
+ quality_color = "#10b981"
411
+ elif psnr >= 25 and ssim >= 0.75:
412
+ quality = "βœ… Good"
413
+ quality_color = "#3b82f6"
414
+ elif psnr >= 20 and ssim >= 0.65:
415
+ quality = "⚑ Fair"
416
+ quality_color = "#f59e0b"
417
+ else:
418
+ quality = "πŸ”§ Needs Improvement"
419
+ quality_color = "#ef4444"
420
+
421
+ # Create detailed metrics text
422
+ metrics_text = f"""
423
+ ## πŸ“Š Reconstruction Quality: <span style="color: {quality_color}; font-weight: bold;">{quality}</span>
424
+
425
+ ### 🎯 Detailed Metrics
426
+
427
+ | Metric | Value | Description |
428
+ |--------|-------|-------------|
429
+ | **MSE Loss** | `{loss.item():.6f}` | Mean Squared Error (Lower is better) |
430
+ | **PSNR** | `{psnr:.2f} dB` | Peak Signal-to-Noise Ratio (Higher is better) |
431
+ | **SSIM** | `{ssim:.4f}` | Structural Similarity (Closer to 1 is better) |
432
+
433
+ ### 🎭 Masking Configuration
434
+
435
+ | Parameter | Value |
436
+ |-----------|-------|
437
+ | **Masking Ratio** | {mask_ratio*100:.1f}% |
438
+ | **Masked Patches** | {mask_indices.shape[1]} / 196 patches |
439
+ | **Visible Patches** | {196 - mask_indices.shape[1]} / 196 patches |
440
+ | **Patch Size** | 16Γ—16 pixels |
441
+
442
+ ### πŸ—οΈ Model Architecture
443
+
444
+ - **Encoder**: ViT-Base (768d, 12 layers, 12 heads) ~ 86M parameters
445
+ - **Decoder**: ViT-Small (384d, 12 layers, 6 heads) ~ 22M parameters
446
+ - **Total Parameters**: ~108M
447
+ - **Training Dataset**: TinyImageNet
448
+
449
+ ### πŸ’‘ Quality Guidelines
450
+
451
+ - **Excellent** (PSNR β‰₯ 30 dB, SSIM β‰₯ 0.85): Near-perfect reconstruction
452
+ - **Good** (PSNR β‰₯ 25 dB, SSIM β‰₯ 0.75): High-quality reconstruction
453
+ - **Fair** (PSNR β‰₯ 20 dB, SSIM β‰₯ 0.65): Acceptable reconstruction
454
+ - **Needs Improvement** (Below thresholds): Challenging conditions
455
+
456
+ ---
457
+
458
+ πŸ’‘ **Tip**: Lower masking ratios (10-50%) produce better reconstructions. Higher ratios (70-95%) test the model's limits!
459
+ """
460
+
461
+ return original_img, masked_img, recon_img, metrics_text
462
+
463
+
464
+ # ============================================================================
465
+ # GRADIO INTERFACE
466
+ # ============================================================================
467
+
468
+ # Custom CSS for beautiful UI
469
+ custom_css = """
470
+ #title {
471
+ text-align: center;
472
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
473
+ -webkit-background-clip: text;
474
+ -webkit-text-fill-color: transparent;
475
+ font-size: 3em;
476
+ font-weight: bold;
477
+ margin-bottom: 0.5em;
478
+ }
479
+
480
+ #subtitle {
481
+ text-align: center;
482
+ color: #6b7280;
483
+ font-size: 1.2em;
484
+ margin-bottom: 2em;
485
+ }
486
+
487
+ .gradio-container {
488
+ max-width: 1400px;
489
+ margin: auto;
490
+ }
491
+
492
+ #image-output img {
493
+ border-radius: 12px;
494
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
495
+ }
496
+
497
+ #metrics-box {
498
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
499
+ border-radius: 12px;
500
+ padding: 20px;
501
+ }
502
+ """
503
+
504
+ # Create Gradio interface
505
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="MAE Image Reconstruction") as demo:
506
+ gr.HTML("""
507
+ <h1 id="title">🎭 Masked Autoencoder (MAE)</h1>
508
+ <p id="subtitle">Self-Supervised Image Reconstruction with Vision Transformers</p>
509
  """)
510
 
511
  with gr.Row():
512
  with gr.Column(scale=1):
513
+ gr.Markdown("### πŸ“€ Upload & Configure")
514
+ input_image = gr.Image(label="Upload Image", type="pil", height=300)
 
 
 
515
 
516
  mask_ratio_slider = gr.Slider(
517
+ minimum=1,
518
+ maximum=99,
519
+ value=75,
520
+ step=1,
521
+ label="🎭 Masking Ratio (%)",
522
+ info="Percentage of image patches to hide (1% = easy, 99% = extremely hard)"
523
  )
524
 
525
+ with gr.Row():
526
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
527
+ reconstruct_btn = gr.Button("πŸ”„ Reconstruct", variant="primary", size="lg")
 
 
528
 
529
+ gr.Markdown("""
530
+ ### ℹ️ How It Works
531
+
532
+ 1. **Upload** any image
533
+ 2. **Adjust** the masking ratio
534
+ 3. **Click** Reconstruct
535
+ 4. **View** the results & metrics
536
+
537
+ The model randomly masks patches of your image and reconstructs the full image from only the visible parts!
538
+ """)
539
 
540
  with gr.Column(scale=2):
541
+ gr.Markdown("### πŸ–ΌοΈ Reconstruction Results")
542
  with gr.Row():
543
+ original_output = gr.Image(label="πŸ“· Original (224Γ—224)", elem_id="image-output")
544
+ masked_output = gr.Image(label="🎭 Masked Input", elem_id="image-output")
545
+ reconstructed_output = gr.Image(label="✨ Reconstruction", elem_id="image-output")
546
+
547
+ gr.Markdown("### πŸ“Š Quality Metrics & Analysis")
548
+ metrics_output = gr.Markdown(value="Upload an image and click **Reconstruct** to see detailed metrics.", elem_id="metrics-box")
549
+
550
+ gr.Markdown("""
551
+ ---
552
+ ### 🎯 Try These Examples:
553
+
554
+ - **Easy (10-30% masking)**: Clear reconstruction, tests basic capability
555
+ - **Medium (40-60% masking)**: Balanced challenge, realistic scenarios
556
+ - **Hard (70-85% masking)**: Significant challenge, impressive results
557
+ - **Extreme (90-99% masking)**: Model's absolute limits
558
+
559
+ ### πŸ”¬ About MAE
560
+
561
+ Masked Autoencoders (MAE) are self-supervised learning models that learn visual representations by reconstructing masked images. This implementation uses:
562
+ - **Asymmetric Encoder-Decoder**: Efficient processing of visible patches
563
+ - **ViT Architecture**: Transformer-based vision understanding
564
+ - **High Masking Ratio**: Learns robust features from limited information
565
+
566
+ πŸ“„ **Paper**: [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) (He et al., 2021)
567
+ """)
568
 
569
  # Event handlers
570
  reconstruct_btn.click(
571
+ fn=reconstruct_image,
572
  inputs=[input_image, mask_ratio_slider],
573
  outputs=[original_output, masked_output, reconstructed_output, metrics_output]
574
  )
575
 
576
+ mask_ratio_slider.release(
577
+ fn=reconstruct_image,
578
  inputs=[input_image, mask_ratio_slider],
579
  outputs=[original_output, masked_output, reconstructed_output, metrics_output]
580
  )
581
 
582
+ clear_btn.click(
583
+ fn=lambda: (None, None, None, None, "Upload an image to begin."),
584
+ outputs=[input_image, original_output, masked_output, reconstructed_output, metrics_output]
 
585
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
 
587
 
588
+ # Launch
589
  if __name__ == "__main__":
590
+ demo.launch(share=False)