hamxaameer commited on
Commit
214cb22
Β·
verified Β·
1 Parent(s): 0a2c2db

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +277 -0
app.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Spaces App for MAE Image Reconstruction
3
+ Entry point for Hugging Face deployment
4
+ """
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import torch.nn as nn
9
+ import numpy as np
10
+ from PIL import Image
11
+ from torchvision import transforms
12
+ import os
13
+
14
+ # Import local modules
15
+ from mae_model import create_mae_model
16
+ from metrics import calculate_psnr, calculate_ssim, denormalize_for_metrics
17
+
18
+
19
+ class MAEInference:
20
+ """MAE Inference wrapper for Hugging Face Spaces."""
21
+
22
+ def __init__(self):
23
+ # Use CPU for Hugging Face free tier, GPU if available
24
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
+ print(f"Running on: {self.device}")
26
+
27
+ # Create model
28
+ self.model = create_mae_model(
29
+ img_size=224,
30
+ patch_size=16,
31
+ encoder_embed_dim=768,
32
+ encoder_depth=12,
33
+ encoder_num_heads=12,
34
+ decoder_embed_dim=384,
35
+ decoder_depth=12,
36
+ decoder_num_heads=6,
37
+ mask_ratio=0.75
38
+ )
39
+
40
+ # Load checkpoint
41
+ self._load_weights()
42
+
43
+ self.model = self.model.to(self.device)
44
+ self.model.eval()
45
+
46
+ # Image transforms
47
+ self.transform = transforms.Compose([
48
+ transforms.Resize((224, 224)),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize(
51
+ mean=[0.485, 0.456, 0.406],
52
+ std=[0.229, 0.224, 0.225]
53
+ )
54
+ ])
55
+
56
+ def _load_weights(self):
57
+ """Load model weights from various possible locations."""
58
+ # Possible checkpoint locations
59
+ checkpoint_paths = [
60
+ "checkpoint_best.pth", # Same directory (HF Spaces)
61
+ "mae_checkpoint.pth", # Alternative name
62
+ "model/checkpoint_best.pth", # Model subdirectory
63
+ "/kaggle/working/checkpoint_best.pth", # Kaggle
64
+ ]
65
+
66
+ for path in checkpoint_paths:
67
+ if os.path.exists(path):
68
+ try:
69
+ checkpoint = torch.load(path, map_location=self.device)
70
+ if 'model_state_dict' in checkpoint:
71
+ self.model.load_state_dict(checkpoint['model_state_dict'])
72
+ else:
73
+ self.model.load_state_dict(checkpoint)
74
+ print(f"βœ“ Loaded weights from: {path}")
75
+ return
76
+ except Exception as e:
77
+ print(f"Failed to load {path}: {e}")
78
+ continue
79
+
80
+ print("⚠ No checkpoint found - using random weights for demo")
81
+
82
+ def denormalize(self, tensor):
83
+ """Denormalize tensor for display."""
84
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
85
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
86
+
87
+ if tensor.device.type != 'cpu':
88
+ tensor = tensor.cpu()
89
+
90
+ tensor = tensor * std + mean
91
+ tensor = torch.clamp(tensor, 0, 1)
92
+ return tensor
93
+
94
+ def create_masked_image(self, image_tensor, mask_indices, patch_size=16):
95
+ """Create visualization of masked image."""
96
+ img = self.denormalize(image_tensor.clone())
97
+ num_patches_per_side = 224 // patch_size
98
+
99
+ for idx in mask_indices:
100
+ row = idx.item() // num_patches_per_side
101
+ col = idx.item() % num_patches_per_side
102
+ img[:,
103
+ row * patch_size:(row + 1) * patch_size,
104
+ col * patch_size:(col + 1) * patch_size] = 0.5
105
+
106
+ return (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
107
+
108
+ @torch.no_grad()
109
+ def reconstruct(self, image, mask_ratio=0.75):
110
+ """Reconstruct image using MAE."""
111
+ if image is None:
112
+ return None, None, None, "Please upload an image."
113
+
114
+ # Preprocess
115
+ if isinstance(image, np.ndarray):
116
+ image = Image.fromarray(image)
117
+ if image.mode != 'RGB':
118
+ image = image.convert('RGB')
119
+
120
+ input_tensor = self.transform(image).unsqueeze(0).to(self.device)
121
+
122
+ # Forward pass
123
+ pred, target, mask_indices = self.model(input_tensor, mask_ratio)
124
+
125
+ # Unpatchify
126
+ reconstructed = self.model.unpatchify(pred)
127
+
128
+ # Create visualizations
129
+ masked_img = self.create_masked_image(
130
+ input_tensor[0].cpu(),
131
+ mask_indices[0].cpu()
132
+ )
133
+
134
+ recon_img = self.denormalize(reconstructed[0].cpu())
135
+ recon_img = (recon_img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
136
+
137
+ original_img = self.denormalize(input_tensor[0].cpu())
138
+ original_img = (original_img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
139
+
140
+ # Calculate metrics
141
+ pred_denorm = denormalize_for_metrics(reconstructed)
142
+ target_denorm = denormalize_for_metrics(input_tensor)
143
+
144
+ psnr = calculate_psnr(pred_denorm, target_denorm)
145
+ ssim = calculate_ssim(pred_denorm, target_denorm)
146
+
147
+ metrics_text = f"""
148
+ ### πŸ“Š Reconstruction Metrics
149
+ | Metric | Value |
150
+ |--------|-------|
151
+ | **PSNR** | {psnr:.2f} dB |
152
+ | **SSIM** | {ssim:.4f} |
153
+ | **Mask Ratio** | {mask_ratio*100:.0f}% |
154
+ | **Visible Patches** | {int((1-mask_ratio)*196)} / 196 |
155
+ | **Masked Patches** | {int(mask_ratio*196)} / 196 |
156
+ """
157
+
158
+ return original_img, masked_img, recon_img, metrics_text
159
+
160
+
161
+ # Initialize model globally (loaded once when app starts)
162
+ print("Initializing MAE model...")
163
+ mae = MAEInference()
164
+
165
+
166
+ def process_image(input_image, mask_ratio):
167
+ """Main processing function for Gradio."""
168
+ if input_image is None:
169
+ return None, None, None, "⬆️ Please upload an image to get started."
170
+
171
+ mask_ratio = max(0.1, min(0.95, mask_ratio))
172
+ return mae.reconstruct(input_image, mask_ratio)
173
+
174
+
175
+ # Create Gradio interface
176
+ with gr.Blocks(
177
+ title="MAE Image Reconstruction",
178
+ theme=gr.themes.Soft(),
179
+ css="""
180
+ .gradio-container { max-width: 1200px !important; }
181
+ .output-image { border-radius: 8px; }
182
+ """
183
+ ) as demo:
184
+
185
+ gr.Markdown("""
186
+ # 🎭 Masked Autoencoder (MAE) Image Reconstruction
187
+
188
+ Upload any image to see how the MAE reconstructs it from only **25% visible patches**.
189
+ The model learns powerful visual representations by predicting masked regions.
190
+
191
+ > **Try adjusting the mask ratio** to see how the reconstruction quality changes!
192
+ """)
193
+
194
+ with gr.Row():
195
+ with gr.Column(scale=1):
196
+ input_image = gr.Image(
197
+ label="πŸ“€ Upload Image",
198
+ type="pil",
199
+ height=280
200
+ )
201
+
202
+ mask_ratio_slider = gr.Slider(
203
+ minimum=0.1,
204
+ maximum=0.95,
205
+ value=0.75,
206
+ step=0.05,
207
+ label="🎚️ Masking Ratio",
208
+ info="Percentage of patches to mask (default: 75%)"
209
+ )
210
+
211
+ reconstruct_btn = gr.Button(
212
+ "πŸ”„ Reconstruct Image",
213
+ variant="primary",
214
+ size="lg"
215
+ )
216
+
217
+ metrics_output = gr.Markdown(
218
+ value="⬆️ Upload an image and click **Reconstruct** to see metrics."
219
+ )
220
+
221
+ with gr.Column(scale=2):
222
+ with gr.Row():
223
+ original_output = gr.Image(
224
+ label="Original (224Γ—224)",
225
+ height=224,
226
+ show_download_button=True
227
+ )
228
+ masked_output = gr.Image(
229
+ label="Masked Input",
230
+ height=224,
231
+ show_download_button=True
232
+ )
233
+ reconstructed_output = gr.Image(
234
+ label="Reconstruction",
235
+ height=224,
236
+ show_download_button=True
237
+ )
238
+
239
+ # Event handlers
240
+ reconstruct_btn.click(
241
+ fn=process_image,
242
+ inputs=[input_image, mask_ratio_slider],
243
+ outputs=[original_output, masked_output, reconstructed_output, metrics_output]
244
+ )
245
+
246
+ mask_ratio_slider.change(
247
+ fn=process_image,
248
+ inputs=[input_image, mask_ratio_slider],
249
+ outputs=[original_output, masked_output, reconstructed_output, metrics_output]
250
+ )
251
+
252
+ input_image.change(
253
+ fn=process_image,
254
+ inputs=[input_image, mask_ratio_slider],
255
+ outputs=[original_output, masked_output, reconstructed_output, metrics_output]
256
+ )
257
+
258
+ gr.Markdown("""
259
+ ---
260
+ ### πŸ”¬ How MAE Works
261
+
262
+ 1. **Masking**: Randomly mask ~75% of image patches
263
+ 2. **Encoding**: Process only visible patches through ViT encoder
264
+ 3. **Decoding**: Reconstruct full image using a lightweight decoder
265
+
266
+ **Model Architecture:**
267
+ - **Encoder**: ViT-Base (768 dim, 12 layers) β€” 86M params
268
+ - **Decoder**: ViT-Small (384 dim, 12 layers) β€” 22M params
269
+
270
+ πŸ“„ [Original Paper](https://arxiv.org/abs/2111.06377) |
271
+ πŸ”— [GitHub](https://github.com/facebookresearch/mae)
272
+ """)
273
+
274
+
275
+ # Launch app
276
+ if __name__ == "__main__":
277
+ demo.launch()