hiren05 commited on
Commit
035dcbf
·
verified ·
1 Parent(s): ac2e19e

Upload 2 files

Browse files
Files changed (2) hide show
  1. 2_2_2_2_2.py +1182 -0
  2. epoch_29.pth +3 -0
2_2_2_2_2.py ADDED
@@ -0,0 +1,1182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """2.2.2.2.2.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1igY4MKIJJTPHgEkdLFI_T5H6sLUoTaLr
8
+ """
9
+
10
+ #heat map video and metrics
11
+
12
+ """## CODE"""
13
+
14
+ pip install torchmetrics lpips
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch.utils.data import Dataset, DataLoader
20
+ from torchvision import transforms
21
+ from pathlib import Path
22
+ from PIL import Image
23
+ import numpy as np
24
+ import matplotlib.pyplot as plt
25
+ from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
26
+ from torchmetrics.image.fid import FrechetInceptionDistance
27
+ import lpips
28
+ import os
29
+ import random
30
+ import shutil
31
+ from huggingface_hub import HfApi, hf_hub_download
32
+ import tarfile
33
+ import json
34
+ import cv2
35
+ from tqdm import tqdm
36
+
37
+ def download_sequential_data(repo_id="Amar-S/MOVi-MC-AC", sample_ratio=0.01, base_dir="/content/data"):
38
+ """
39
+ Download data while preserving video sequences
40
+ """
41
+ api = HfApi()
42
+
43
+ # Create directories
44
+ os.makedirs(f"{base_dir}/train", exist_ok=True)
45
+ os.makedirs(f"{base_dir}/test", exist_ok=True)
46
+
47
+ # List all files in the repo
48
+ files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
49
+
50
+ # Separate train and test archives (each archive contains a complete scene sequence)
51
+ #train_files = [f for f in files if f.startswith("train/") and f.endswith(".tar.gz")]
52
+ test_files = [f for f in files if f.startswith("test/") and f.endswith(".tar.gz")]
53
+
54
+ #print(f"Found {len(train_files)} train archives and {len(test_files)} test archives.")
55
+
56
+ # Sample complete archives (not individual files) to preserve sequences
57
+ #subset_train = random.sample(train_files, max(1, int(len(train_files) * sample_ratio)))
58
+ subset_test = random.sample(test_files, max(1, int(len(test_files) * sample_ratio)))
59
+
60
+ #print(f"Downloading {len(subset_train)} train archives and {len(subset_test)} test archives...")
61
+
62
+ # Download training archives
63
+ # for file in subset_train:
64
+ # print(f"Downloading {file}...")
65
+ # out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file)
66
+ # dest_path = f"{base_dir}/train/{os.path.basename(file)}"
67
+ # shutil.copyfile(out_path, dest_path)
68
+
69
+ # Download test archives
70
+ for file in subset_test:
71
+ print(f"Downloading {file}...")
72
+ out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file)
73
+ dest_path = f"{base_dir}/test/{os.path.basename(file)}"
74
+ shutil.copyfile(out_path, dest_path)
75
+
76
+ # Extract all archives
77
+ extract_archives(f"{base_dir}/train")
78
+ extract_archives(f"{base_dir}/test")
79
+
80
+ print("Download and extraction complete!")
81
+
82
+ def extract_archives(directory):
83
+ """Extract all tar.gz files in a directory"""
84
+ for file in os.listdir(directory):
85
+ if file.endswith(".tar.gz"):
86
+ filepath = os.path.join(directory, file)
87
+ print(f"Extracting {filepath}...")
88
+ with tarfile.open(filepath, 'r:gz') as tar:
89
+ tar.extractall(path=directory)
90
+ # Remove the archive after extraction
91
+ os.remove(filepath)
92
+
93
+ download_sequential_data()
94
+ #extract_archives('/content/data/train')
95
+ extract_archives('/content/data/test')
96
+
97
+ def extract_archives(directory):
98
+ """Extract all tar.gz files in a directory"""
99
+ for file in os.listdir(directory):
100
+ if file.endswith(".tar.gz"):
101
+ filepath = os.path.join(directory, file)
102
+ print(f"Extracting {filepath}...")
103
+ with tarfile.open(filepath, 'r:gz') as tar:
104
+ print(filepath)
105
+ tar.extractall(path=directory)
106
+ # Remove the archive after extraction
107
+ os.remove(filepath)
108
+
109
+ #extract_archives('/content/data/train')
110
+ extract_archives('/content/data/test')
111
+
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+
141
+
142
+ class VideoAmodalDataset(Dataset):
143
+ def __init__(self, root_dir, split='train', seq_len=8, img_size=(256,256),
144
+ max_scenes=4, samples_per_scene=3, max_samples=None):
145
+ self.root_dir = Path(root_dir)
146
+ self.split = split
147
+ self.seq_len = seq_len
148
+ self.img_size = img_size
149
+ self.max_scenes = max_scenes
150
+ self.samples_per_scene = samples_per_scene
151
+
152
+ self.samples = self._build_sample_index(max_samples)
153
+
154
+ self.transform = transforms.Compose([
155
+ transforms.Resize(img_size),
156
+ transforms.ToTensor(),
157
+ ])
158
+
159
+ def _build_sample_index(self, max_samples):
160
+ samples = []
161
+ scene_paths = sorted((self.root_dir / self.split).glob('scene_*'))[:self.max_scenes]
162
+
163
+ for scene_path in scene_paths:
164
+ camera_paths = sorted(scene_path.glob('camera_*'))
165
+
166
+ for camera_path in camera_paths:
167
+ obj_paths = sorted(camera_path.glob('obj_*'))
168
+ selected_objs = random.sample(obj_paths, min(self.samples_per_scene, len(obj_paths)))
169
+
170
+ for obj_path in selected_objs:
171
+ rgba_files = sorted(camera_path.glob('rgba_*.png'))
172
+ frame_ids = [int(p.stem.split('_')[1]) for p in rgba_files]
173
+
174
+ # Create non-overlapping sequences
175
+ for i in range(0, len(frame_ids) - self.seq_len + 1, self.seq_len):
176
+ samples.append({
177
+ 'scene': scene_path.name,
178
+ 'camera': camera_path.name,
179
+ 'obj_folder': obj_path.name,
180
+ 'frame_ids': frame_ids[i:i+self.seq_len],
181
+ 'obj_id': int(obj_path.name.split('_')[1])
182
+ })
183
+
184
+ if max_samples and len(samples) >= max_samples:
185
+ return samples
186
+
187
+ return samples
188
+
189
+ def __getitem__(self, idx):
190
+ sample = self.samples[idx]
191
+ base_path = self.root_dir / self.split / sample['scene'] / sample['camera']
192
+ obj_path = base_path / sample['obj_folder']
193
+
194
+ rgb_frames = []
195
+ modal_mask_frames = []
196
+ amodal_mask_frames = []
197
+ amodal_rgb_frames = []
198
+
199
+ for fid in sample['frame_ids']:
200
+ fid_str = f"{fid:05d}"
201
+
202
+ try:
203
+ # Load scene RGB
204
+ rgb = Image.open(base_path / f'rgba_{fid_str}.png').convert('RGB')
205
+ rgb = self.transform(rgb)
206
+
207
+ # Load scene segmentation to compute modal mask
208
+ seg_map = np.array(Image.open(base_path / f'segmentation_{fid_str}.png'))
209
+ modal_mask_np = (seg_map == sample['obj_id']).astype(np.uint8) * 255
210
+ modal_mask = Image.fromarray(modal_mask_np, mode='L')
211
+ modal_mask = self.transform(modal_mask)
212
+
213
+ # Load amodal mask
214
+ amodal_mask = Image.open(obj_path / f'segmentation_{fid_str}.png').convert('L')
215
+ amodal_mask = self.transform(amodal_mask)
216
+
217
+ # Load target amodal RGB
218
+ amodal_rgb = Image.open(obj_path / f'rgba_{fid_str}.png').convert('RGB')
219
+ amodal_rgb = self.transform(amodal_rgb)
220
+
221
+ rgb_frames.append(rgb)
222
+ modal_mask_frames.append(modal_mask)
223
+ amodal_mask_frames.append(amodal_mask)
224
+ amodal_rgb_frames.append(amodal_rgb)
225
+
226
+ except Exception as e:
227
+ print(f"Error loading {base_path}/rgba_{fid_str}.png: {e}")
228
+ # Return empty tensors if loading fails
229
+ empty_rgb = torch.zeros(3, self.img_size[0], self.img_size[1])
230
+ empty_mask = torch.zeros(1, self.img_size[0], self.img_size[1])
231
+
232
+ return {
233
+ 'rgb_sequence': empty_rgb.unsqueeze(0).repeat(self.seq_len, 1, 1, 1),
234
+ 'modal_masks': empty_mask.unsqueeze(0).repeat(self.seq_len, 1, 1, 1),
235
+ 'amodal_masks': empty_mask.unsqueeze(0).repeat(self.seq_len, 1, 1, 1),
236
+ 'amodal_rgb_sequence': empty_rgb.unsqueeze(0).repeat(self.seq_len, 1, 1, 1),
237
+ 'scene': sample['scene'],
238
+ 'camera': sample['camera'],
239
+ 'object_id': sample['obj_id']
240
+ }
241
+
242
+ return {
243
+ 'rgb_sequence': torch.stack(rgb_frames), # Scene RGB
244
+ 'modal_masks': torch.stack(modal_mask_frames), # Modal masks (visible parts)
245
+ 'amodal_masks': torch.stack(amodal_mask_frames), # Amodal masks (complete shape)
246
+ 'amodal_rgb_sequence': torch.stack(amodal_rgb_frames), # Target: complete object RGB
247
+ 'scene': sample['scene'],
248
+ 'camera': sample['camera'],
249
+ 'object_id': sample['obj_id']
250
+ }
251
+
252
+ def __len__(self):
253
+ return len(self.samples)
254
+
255
+ import wandb
256
+
257
+ wandb.login()
258
+
259
+ # Add these imports to your existing imports
260
+ import numpy as np
261
+ from skimage.metrics import structural_similarity as ssim
262
+ from skimage.metrics import peak_signal_noise_ratio as psnr
263
+ import torch.nn.functional as F
264
+ from scipy import linalg
265
+ import matplotlib.pyplot as plt
266
+ import matplotlib.cm as cm
267
+ from torchvision.models import inception_v3
268
+ from torchvision.transforms import Resize, Normalize
269
+ import lpips
270
+
271
+ # Add this class for computing metrics
272
+ class VideoAmodalMetrics:
273
+ """Compute various metrics for video amodal completion"""
274
+
275
+ def __init__(self, device='cuda'):
276
+ self.device = device
277
+ # Initialize LPIPS model
278
+ self.lpips_model = lpips.LPIPS(net='alex').to(device)
279
+
280
+ # Initialize Inception model for FID
281
+ self.inception_model = inception_v3(pretrained=True, transform_input=False).to(device)
282
+ self.inception_model.eval()
283
+
284
+ # Preprocessing for Inception
285
+ self.inception_transform = torch.nn.Sequential(
286
+ Resize((299, 299)),
287
+ Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
288
+ )
289
+
290
+ def calculate_psnr(self, pred, target, mask=None):
291
+ """Calculate PSNR between prediction and target"""
292
+ if mask is not None:
293
+ # Only calculate PSNR in masked regions
294
+ pred_masked = pred * mask
295
+ target_masked = target * mask
296
+
297
+ # Convert to numpy and calculate PSNR for each frame
298
+ psnr_values = []
299
+ for i in range(pred.shape[0]): # Over batch or sequence
300
+ if pred.dim() == 5: # (B, C, N, H, W)
301
+ for j in range(pred.shape[2]): # Over frames
302
+ p = pred_masked[i, :, j].permute(1, 2, 0).cpu().numpy()
303
+ t = target_masked[i, :, j].permute(1, 2, 0).cpu().numpy()
304
+ m = mask[i, 0, j].cpu().numpy()
305
+
306
+ if m.sum() > 0: # Only if there are masked pixels
307
+ psnr_val = psnr(t, p, data_range=1.0)
308
+ psnr_values.append(psnr_val)
309
+ else: # (B, C, H, W)
310
+ p = pred_masked[i].permute(1, 2, 0).cpu().numpy()
311
+ t = target_masked[i].permute(1, 2, 0).cpu().numpy()
312
+ m = mask[i, 0].cpu().numpy()
313
+
314
+ if m.sum() > 0:
315
+ psnr_val = psnr(t, p, data_range=1.0)
316
+ psnr_values.append(psnr_val)
317
+ else:
318
+ # Calculate PSNR for entire image
319
+ mse = F.mse_loss(pred, target)
320
+ psnr_val = 20 * torch.log10(1.0 / torch.sqrt(mse))
321
+ return psnr_val.item()
322
+
323
+ return np.mean(psnr_values) if psnr_values else 0.0
324
+
325
+ def calculate_ssim(self, pred, target, mask=None):
326
+ """Calculate SSIM between prediction and target"""
327
+ ssim_values = []
328
+
329
+ for i in range(pred.shape[0]): # Over batch
330
+ if pred.dim() == 5: # (B, C, N, H, W)
331
+ for j in range(pred.shape[2]): # Over frames
332
+ p = pred[i, :, j].permute(1, 2, 0).cpu().numpy()
333
+ t = target[i, :, j].permute(1, 2, 0).cpu().numpy()
334
+
335
+ if mask is not None:
336
+ m = mask[i, 0, j].cpu().numpy()
337
+ if m.sum() == 0:
338
+ continue
339
+
340
+ ssim_val = ssim(t, p, data_range=1.0, channel_axis=2)
341
+ ssim_values.append(ssim_val)
342
+ else: # (B, C, H, W)
343
+ p = pred[i].permute(1, 2, 0).cpu().numpy()
344
+ t = target[i].permute(1, 2, 0).cpu().numpy()
345
+
346
+ if mask is not None:
347
+ m = mask[i, 0].cpu().numpy()
348
+ if m.sum() == 0:
349
+ continue
350
+
351
+ ssim_val = ssim(t, p, data_range=1.0, channel_axis=2)
352
+ ssim_values.append(ssim_val)
353
+
354
+ return np.mean(ssim_values) if ssim_values else 0.0
355
+
356
+ def calculate_lpips(self, pred, target, mask=None):
357
+ """Calculate LPIPS perceptual distance"""
358
+ # Ensure inputs are in [-1, 1] range for LPIPS
359
+ pred_norm = pred * 2.0 - 1.0
360
+ target_norm = target * 2.0 - 1.0
361
+
362
+ lpips_values = []
363
+
364
+ if pred.dim() == 5: # (B, C, N, H, W)
365
+ for i in range(pred.shape[0]):
366
+ for j in range(pred.shape[2]):
367
+ p = pred_norm[i, :, j].unsqueeze(0)
368
+ t = target_norm[i, :, j].unsqueeze(0)
369
+
370
+ with torch.no_grad():
371
+ lpips_val = self.lpips_model(p, t)
372
+ lpips_values.append(lpips_val.item())
373
+ else: # (B, C, H, W)
374
+ with torch.no_grad():
375
+ lpips_val = self.lpips_model(pred_norm, target_norm)
376
+ lpips_values.extend(lpips_val.cpu().numpy().tolist())
377
+
378
+ return np.mean(lpips_values) if lpips_values else 0.0
379
+
380
+ def calculate_iou(self, pred_mask, target_mask, threshold=0.5):
381
+ """Calculate IoU for binary masks"""
382
+ pred_binary = (pred_mask > threshold).float()
383
+ target_binary = (target_mask > threshold).float()
384
+
385
+ intersection = (pred_binary * target_binary).sum()
386
+ union = pred_binary.sum() + target_binary.sum() - intersection
387
+
388
+ iou = intersection / (union + 1e-8)
389
+ return iou.item()
390
+
391
+ def get_inception_features(self, images):
392
+ """Extract features from Inception model for FID calculation"""
393
+ with torch.no_grad():
394
+ # Preprocess images
395
+ images_preprocessed = self.inception_transform(images)
396
+
397
+ # Get features
398
+ features = self.inception_model(images_preprocessed)
399
+ return features.cpu().numpy()
400
+
401
+ def calculate_fid(self, pred, target):
402
+ """Calculate Fréchet Inception Distance"""
403
+ # Reshape if needed
404
+ if pred.dim() == 5: # (B, C, N, H, W) -> (B*N, C, H, W)
405
+ pred = pred.permute(0, 2, 1, 3, 4).reshape(-1, pred.shape[1], pred.shape[3], pred.shape[4])
406
+ target = target.permute(0, 2, 1, 3, 4).reshape(-1, target.shape[1], target.shape[3], target.shape[4])
407
+
408
+ # Get features
409
+ pred_features = self.get_inception_features(pred)
410
+ target_features = self.get_inception_features(target)
411
+
412
+ # Calculate statistics
413
+ mu1, sigma1 = pred_features.mean(axis=0), np.cov(pred_features, rowvar=False)
414
+ mu2, sigma2 = target_features.mean(axis=0), np.cov(target_features, rowvar=False)
415
+
416
+ # Calculate FID
417
+ diff = mu1 - mu2
418
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
419
+ if np.iscomplexobj(covmean):
420
+ covmean = covmean.real
421
+
422
+ fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
423
+ return fid
424
+
425
+ def calculate_all_metrics(self, pred, target, amodal_mask=None):
426
+ """Calculate all metrics at once"""
427
+ metrics = {}
428
+
429
+ metrics['psnr'] = self.calculate_psnr(pred, target, amodal_mask)
430
+ metrics['ssim'] = self.calculate_ssim(pred, target, amodal_mask)
431
+ metrics['lpips'] = self.calculate_lpips(pred, target, amodal_mask)
432
+
433
+ try:
434
+ metrics['fid'] = self.calculate_fid(pred, target)
435
+ except:
436
+ metrics['fid'] = 0.0
437
+
438
+ # IoU for masks (if available)
439
+ if amodal_mask is not None:
440
+ # Create predicted mask by thresholding prediction
441
+ pred_intensity = pred.mean(dim=1, keepdim=True) # Convert to grayscale
442
+ metrics['iou'] = self.calculate_iou(pred_intensity, amodal_mask)
443
+
444
+ return metrics
445
+
446
+ # Add this function to create error heatmaps
447
+ def create_error_heatmap(pred, target, mask=None):
448
+ """Create error heatmap between prediction and target"""
449
+ # Calculate per-pixel error
450
+ error = torch.abs(pred - target).mean(dim=0) # Average over color channels
451
+
452
+ if mask is not None:
453
+ error = error * mask.squeeze()
454
+
455
+ return error.cpu().numpy()
456
+
457
+ # Enhanced training function with metrics and wandb
458
+ def train_video_amodal_with_metrics():
459
+ # Initialize wandb
460
+ wandb.init(
461
+ project="video-amodal-completion",
462
+ config={
463
+ 'batch_size': 2,
464
+ 'seq_len': 6,
465
+ 'img_size': (256, 256),
466
+ 'num_epochs': 30,
467
+ 'learning_rate': 5e-5,
468
+ 'max_scenes': 2,
469
+ 'samples_per_scene': 2,
470
+ 'num_workers': 2,
471
+ 'grad_accum_steps': 4
472
+ }
473
+ )
474
+
475
+
476
+ #print(f"Loaded model from epoch {checkpoint['epoch']} with loss {checkpoint['train_loss']:.4f}")
477
+
478
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
479
+ torch.cuda.empty_cache()
480
+
481
+ config = wandb.config
482
+
483
+ # Initialize metrics calculator
484
+ metrics_calculator = VideoAmodalMetrics(device)
485
+
486
+ # Create datasets (your existing code)
487
+ train_dataset = VideoAmodalDataset(
488
+ root_dir='data',
489
+ split='train',
490
+ seq_len=config.seq_len,
491
+ img_size=config.img_size,
492
+ max_scenes=config.max_scenes,
493
+ samples_per_scene=config.samples_per_scene,
494
+ max_samples=100
495
+ )
496
+
497
+ val_dataset = VideoAmodalDataset(
498
+ root_dir='data',
499
+ split='test',
500
+ seq_len=config.seq_len,
501
+ img_size=config.img_size,
502
+ max_scenes=1,
503
+ samples_per_scene=1,
504
+ max_samples=10
505
+ )
506
+
507
+ # DataLoaders (your existing code)
508
+ train_loader = DataLoader(
509
+ train_dataset,
510
+ batch_size=config.batch_size,
511
+ shuffle=True,
512
+ num_workers=config.num_workers,
513
+ pin_memory=True
514
+ )
515
+
516
+ val_loader = DataLoader(
517
+ val_dataset,
518
+ batch_size=1,
519
+ shuffle=False,
520
+ num_workers=1
521
+ )
522
+
523
+ # Model (your existing code)
524
+ model = Video3DUNet(
525
+ in_channels=5,
526
+ out_channels=3,
527
+ sequence_length=config.seq_len
528
+ ).to(device)
529
+
530
+
531
+
532
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=1e-4)
533
+ criterion = VideoAmodalCompletionLoss()
534
+
535
+ # Training loop with metrics
536
+ for epoch in range(config.num_epochs):
537
+ model.train()
538
+ epoch_losses = []
539
+ epoch_metrics = {
540
+ 'train_psnr': [],
541
+ 'train_ssim': [],
542
+ 'train_lpips': [],
543
+ 'train_fid': [],
544
+ 'train_iou': []
545
+ }
546
+
547
+ for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
548
+ # Prepare inputs and targets (your existing code)
549
+ inputs = prepare_model_input(batch).to(device, non_blocking=True)
550
+ targets = prepare_model_target(batch).to(device, non_blocking=True)
551
+ modal_masks = batch['modal_masks'].to(device, non_blocking=True)
552
+ amodal_masks = batch['amodal_masks'].to(device, non_blocking=True)
553
+
554
+ # Forward pass (your existing code)
555
+ with torch.cuda.amp.autocast():
556
+ outputs = model(inputs)
557
+ loss, loss_dict = criterion(outputs, targets, modal_masks, amodal_masks)
558
+ loss = loss / config.grad_accum_steps
559
+
560
+ # Backward pass (your existing code)
561
+ loss.backward()
562
+
563
+ # Calculate metrics periodically
564
+ if i % 10 == 0:
565
+ with torch.no_grad():
566
+ amodal_masks_3d = amodal_masks.permute(0, 2, 1, 3, 4)
567
+ batch_metrics = metrics_calculator.calculate_all_metrics(
568
+ outputs, targets, amodal_masks_3d
569
+ )
570
+
571
+ for key, value in batch_metrics.items():
572
+ if f'train_{key}' in epoch_metrics:
573
+ epoch_metrics[f'train_{key}'].append(value)
574
+
575
+ # Gradient accumulation (your existing code)
576
+ if (i + 1) % config.grad_accum_steps == 0:
577
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
578
+ optimizer.step()
579
+ optimizer.zero_grad()
580
+ torch.cuda.empty_cache()
581
+
582
+ epoch_losses.append(loss_dict['total_loss'])
583
+
584
+ # Periodic logging with wandb
585
+ if i % 20 == 0:
586
+ log_dict = {
587
+ 'batch': epoch * len(train_loader) + i,
588
+ 'train_loss': loss_dict['total_loss'],
589
+ 'train_visible_loss': loss_dict['visible_loss'],
590
+ 'train_occluded_loss': loss_dict['occluded_loss'],
591
+ 'train_background_loss': loss_dict['background_loss'],
592
+ 'train_boundary_loss': loss_dict['boundary_loss']
593
+ }
594
+
595
+ # Add latest metrics if available
596
+ for key, values in epoch_metrics.items():
597
+ if values:
598
+ log_dict[key] = values[-1]
599
+
600
+ wandb.log(log_dict)
601
+
602
+ print(f"Batch {i}, Loss: {loss_dict['total_loss']:.4f}")
603
+ print(f" Visible: {loss_dict['visible_loss']:.4f}, "
604
+ f"Occluded: {loss_dict['occluded_loss']:.4f}, "
605
+ f"Background: {loss_dict['background_loss']:.4f}")
606
+
607
+ # Validation with metrics
608
+ model.eval()
609
+ val_losses = []
610
+ val_metrics = {
611
+ 'val_psnr': [],
612
+ 'val_ssim': [],
613
+ 'val_lpips': [],
614
+ 'val_fid': [],
615
+ 'val_iou': []
616
+ }
617
+
618
+ with torch.no_grad():
619
+ for batch in val_loader:
620
+ inputs = prepare_model_input(batch).to(device)
621
+ targets = prepare_model_target(batch).to(device)
622
+ modal_masks = batch['modal_masks'].to(device)
623
+ amodal_masks = batch['amodal_masks'].to(device)
624
+
625
+ outputs = model(inputs)
626
+ loss, loss_dict = criterion(outputs, targets, modal_masks, amodal_masks)
627
+ val_losses.append(loss_dict['total_loss'])
628
+
629
+ # Calculate validation metrics
630
+ amodal_masks_3d = amodal_masks.permute(0, 2, 1, 3, 4)
631
+ batch_metrics = metrics_calculator.calculate_all_metrics(
632
+ outputs, targets, amodal_masks_3d
633
+ )
634
+
635
+ for key, value in batch_metrics.items():
636
+ if f'val_{key}' in val_metrics:
637
+ val_metrics[f'val_{key}'].append(value)
638
+
639
+ # End of epoch logging
640
+ avg_train_loss = np.mean(epoch_losses)
641
+ avg_val_loss = np.mean(val_losses)
642
+
643
+ epoch_log = {
644
+ 'epoch': epoch,
645
+ 'avg_train_loss': avg_train_loss,
646
+ 'avg_val_loss': avg_val_loss
647
+ }
648
+
649
+ # Add averaged metrics
650
+ for key, values in {**epoch_metrics, **val_metrics}.items():
651
+ if values:
652
+ epoch_log[f'avg_{key}'] = np.mean(values)
653
+
654
+ wandb.log(epoch_log)
655
+
656
+ print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
657
+
658
+ # Log metrics
659
+ for key, values in val_metrics.items():
660
+ if values:
661
+ print(f" {key}: {np.mean(values):.4f}")
662
+
663
+ # Save checkpoint (your existing code)
664
+ torch.save({
665
+ 'epoch': epoch,
666
+ 'model_state_dict': model.state_dict(),
667
+ 'optimizer_state_dict': optimizer.state_dict(),
668
+ 'train_loss': avg_train_loss,
669
+ 'val_loss': avg_val_loss,
670
+ 'metrics': {key: np.mean(values) for key, values in val_metrics.items() if values}
671
+ }, f"epoch_{epoch}.pth")
672
+
673
+ wandb.finish()
674
+
675
+ # Enhanced GIF creation with error heatmap
676
+ def create_gif_with_error_heatmap(predictions, rgb_frames, gt_amodal_frames, amodal_masks,
677
+ output_path="amodal_completion_with_error.gif", duration=200):
678
+ """Create animated GIF with error heatmap"""
679
+ from PIL import Image
680
+ import numpy as np
681
+
682
+ frames = []
683
+ all_errors = []
684
+
685
+ # Calculate errors for all frames first to get consistent color scale
686
+ for i in range(len(predictions)):
687
+ pred_tensor = predictions[i]
688
+ gt_tensor = gt_amodal_frames[i]
689
+ mask_tensor = amodal_masks[i] if amodal_masks else None
690
+
691
+ error = create_error_heatmap(pred_tensor.unsqueeze(0), gt_tensor.unsqueeze(0),
692
+ mask_tensor.unsqueeze(0) if mask_tensor is not None else None)
693
+
694
+ all_errors.append(error)
695
+
696
+ # Get global error range for consistent coloring
697
+ max_error = max(error.max() for error in all_errors)
698
+ min_error = min(error.min() for error in all_errors)
699
+
700
+ for i in range(len(predictions)):
701
+ # Scene input
702
+ scene_rgb = (rgb_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
703
+
704
+ # Prediction output
705
+ pred_rgb = (np.clip(predictions[i].permute(1, 2, 0).numpy(), 0, 1) * 255).astype(np.uint8)
706
+
707
+ # Ground truth amodal
708
+ gt_rgb = (gt_amodal_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
709
+
710
+ # Error heatmap
711
+ # Error heatmap
712
+ error = all_errors[i]
713
+
714
+ # Normalize error to [0, 1] using global range
715
+ if max_error > min_error:
716
+ error_normalized = (error - min_error) / (max_error - min_error)
717
+ else:
718
+ error_normalized = error
719
+
720
+ # Ensure error is shape (H, W) before applying colormap
721
+ error_normalized = np.squeeze(error_normalized)
722
+ if error_normalized.ndim == 3:
723
+ error_normalized = error_normalized[0]
724
+
725
+ # Apply colormap
726
+ error_colored = cm.jet(error_normalized) # (H, W, 4)
727
+ error_rgb = (error_colored[:, :, :3] * 255).astype(np.uint8) # (H, W, 3)
728
+
729
+ # Now safe to concatenate
730
+ combined = np.concatenate([scene_rgb, pred_rgb, gt_rgb, error_rgb], axis=1)
731
+
732
+
733
+ # Add error scale text (simplified - you might want to add a proper colorbar)
734
+ from PIL import ImageDraw, ImageFont
735
+ img_pil = Image.fromarray(combined)
736
+ draw = ImageDraw.Draw(img_pil)
737
+
738
+ # Add text with error range
739
+ try:
740
+ font = ImageFont.load_default()
741
+ except:
742
+ font = None
743
+
744
+ text = f"Error: {min_error:.3f} - {max_error:.3f}"
745
+ draw.text((combined.shape[1] - 150, 10), text, fill=(255, 255, 255), font=font)
746
+
747
+ frames.append(img_pil)
748
+
749
+ # Save as animated GIF
750
+ frames[0].save(
751
+ output_path,
752
+ save_all=True,
753
+ append_images=frames[1:],
754
+ duration=duration,
755
+ loop=0
756
+ )
757
+
758
+ print(f"GIF with error heatmap saved to {output_path}")
759
+ print(f"Error range: {min_error:.4f} to {max_error:.4f}")
760
+
761
+ # Enhanced video generation with metrics
762
+ def load_model_and_generate_video_with_metrics(checkpoint_path, dataset, device,
763
+ output_path="amodal_completion.mp4", fps=8):
764
+ """Load trained model and generate video with metrics calculation"""
765
+ import cv2
766
+ from pathlib import Path
767
+
768
+ # Initialize metrics calculator
769
+ metrics_calculator = VideoAmodalMetrics(device)
770
+
771
+ # Load model (your existing code remains the same)
772
+ model = Video3DUNet(in_channels=5, out_channels=3, sequence_length=8).to(device)
773
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
774
+ model.load_state_dict(checkpoint['model_state_dict'])
775
+ model.eval()
776
+
777
+ print(f"Loaded model from epoch {checkpoint['epoch']} with loss {checkpoint['train_loss']:.4f}")
778
+
779
+ # Get a sample with 24 frames (your existing code)
780
+ sample = dataset[0]
781
+ seq_len = 8
782
+ total_frames = len(sample['rgb_sequence'])
783
+
784
+ print(f"Processing {total_frames} frames in windows of {seq_len}")
785
+
786
+ all_predictions = []
787
+ all_rgb = []
788
+ all_modal_masks = []
789
+ all_amodal_masks = []
790
+ all_metrics = []
791
+
792
+ with torch.no_grad():
793
+ # Process overlapping windows (your existing code)
794
+ for start_idx in range(0, total_frames - seq_len + 1, seq_len//2):
795
+ end_idx = min(start_idx + seq_len, total_frames)
796
+
797
+ # Create batch for this window
798
+ window_batch = {}
799
+ for key, value in sample.items():
800
+ if isinstance(value, torch.Tensor):
801
+ if value.dim() == 4:
802
+ window_batch[key] = value[start_idx:end_idx].unsqueeze(0)
803
+ else:
804
+ window_batch[key] = value.unsqueeze(0)
805
+ else:
806
+ window_batch[key] = [value]
807
+
808
+ # Get prediction for this window
809
+ inputs = prepare_model_input(window_batch).to(device)
810
+ pred = model(inputs)
811
+
812
+ # Mask to object region
813
+ amodal_mask = window_batch['amodal_masks'].permute(0, 2, 1, 3, 4).expand_as(pred).to(device)
814
+ pred_masked = pred * amodal_mask
815
+
816
+ # Calculate metrics for this window
817
+ target = prepare_model_target(window_batch).to(device)
818
+ window_metrics = metrics_calculator.calculate_all_metrics(pred, target, amodal_mask)
819
+ all_metrics.append(window_metrics)
820
+
821
+ # Store results (your existing code)
822
+ pred_frames = pred_masked.squeeze(0).permute(1, 0, 2, 3).cpu()
823
+
824
+ if start_idx == 0:
825
+ all_predictions.extend([pred_frames[i] for i in range(len(pred_frames))])
826
+ else:
827
+ overlap_frames = seq_len // 2
828
+ for i in range(overlap_frames):
829
+ if len(all_predictions) > start_idx + i:
830
+ all_predictions[start_idx + i] = (all_predictions[start_idx + i] + pred_frames[i]) / 2.0
831
+
832
+ for i in range(overlap_frames, len(pred_frames)):
833
+ if start_idx + i < total_frames:
834
+ all_predictions.append(pred_frames[i])
835
+
836
+ if start_idx == 0:
837
+ all_rgb = [sample['rgb_sequence'][i] for i in range(total_frames)]
838
+ all_modal_masks = [sample['modal_masks'][i] for i in range(total_frames)]
839
+ all_amodal_masks = [sample['amodal_masks'][i] for i in range(total_frames)]
840
+ all_gt_amodal = [sample['amodal_rgb_sequence'][i] for i in range(total_frames)]
841
+
842
+ # Print overall metrics
843
+ print("\nOverall Metrics:")
844
+ avg_metrics = {}
845
+ for key in all_metrics[0].keys():
846
+ avg_metrics[key] = np.mean([m[key] for m in all_metrics])
847
+ print(f" {key.upper()}: {avg_metrics[key]:.4f}")
848
+
849
+ # Your existing video creation code remains the same
850
+ all_predictions = all_predictions[:total_frames]
851
+ print(f"Generated {len(all_predictions)} prediction frames")
852
+
853
+ # Create video (your existing code)
854
+ height, width = all_predictions[0].shape[-2:]
855
+ video_width = width * 4
856
+ video_height = height
857
+
858
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
859
+ out = cv2.VideoWriter(output_path, fourcc, fps, (video_width, video_height))
860
+
861
+ for i in range(len(all_predictions)):
862
+ scene_rgb = all_rgb[i].permute(1, 2, 0).numpy()
863
+ modal_mask = all_modal_masks[i][0].numpy()
864
+ modal_mask_rgb = np.stack([modal_mask, modal_mask, modal_mask], axis=2)
865
+
866
+ pred_rgb = all_predictions[i].permute(1, 2, 0).numpy()
867
+ pred_rgb = np.clip(pred_rgb, 0, 1)
868
+
869
+ try:
870
+ gt_amodal = sample['amodal_rgb_sequence'][i].permute(1, 2, 0).numpy()
871
+ amodal_mask_np = all_amodal_masks[i][0].numpy()
872
+ gt_amodal_masked = gt_amodal * amodal_mask_np[:, :, None]
873
+ except:
874
+ gt_amodal_masked = np.zeros_like(pred_rgb)
875
+
876
+ combined_frame = np.concatenate([
877
+ scene_rgb,
878
+ modal_mask_rgb,
879
+ pred_rgb,
880
+ gt_amodal_masked
881
+ ], axis=1)
882
+
883
+ combined_frame_bgr = cv2.cvtColor((combined_frame * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
884
+ out.write(combined_frame_bgr)
885
+
886
+ if i % 5 == 0:
887
+ print(f"Processed frame {i+1}/{len(all_predictions)}")
888
+
889
+ out.release()
890
+ print(f"Video saved to {output_path}")
891
+
892
+ return all_predictions, all_rgb, all_gt_amodal, all_amodal_masks, avg_metrics
893
+
894
+ # Enhanced run function with all new features
895
+ def run_enhanced_video_generation():
896
+ """Run video generation with metrics and error visualization"""
897
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
898
+
899
+ # Load dataset
900
+ dataset = VideoAmodalDataset(
901
+ root_dir='data',
902
+ split='test',
903
+ seq_len=24,
904
+ img_size=(256, 256),
905
+ max_scenes=1,
906
+ samples_per_scene=1,
907
+ max_samples=1
908
+ )
909
+
910
+ # Generate video with metrics
911
+ checkpoint_path = "video_amodal_model_epoch_4.pth"
912
+ predictions, rgb_frames, gt_amodal_frames, amodal_masks, metrics = load_model_and_generate_video_with_metrics(
913
+ checkpoint_path,
914
+ dataset,
915
+ device,
916
+ output_path="amodal_completion_video_with_metrics.mp4",
917
+ fps=8
918
+ )
919
+
920
+ # Create enhanced GIF with error heatmap
921
+ create_gif_with_error_heatmap(
922
+ predictions,
923
+ rgb_frames,
924
+ gt_amodal_frames,
925
+ amodal_masks,
926
+ output_path="amodal_completion_with_error.gif",
927
+ duration=150
928
+ )
929
+
930
+ print("Enhanced video generation complete!")
931
+ return metrics
932
+
933
+ train_video_amodal_with_metrics()
934
+
935
+ # Simple way to run GIF generation from your trained model
936
+
937
+ import torch
938
+
939
+ def run_gif_generation():
940
+ """Simple function to generate GIFs from your trained model"""
941
+
942
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
943
+
944
+ # Create test dataset
945
+ dataset = VideoAmodalDataset(
946
+ root_dir='data',
947
+ split='test',
948
+ seq_len=24,
949
+ img_size=(256, 256),
950
+ max_scenes=50,
951
+ samples_per_scene=5,
952
+ max_samples=50
953
+ )
954
+
955
+ # Generate video with metrics and error heatmap GIF
956
+ checkpoint_path = "epoch_29.pth" # Change this to your checkpoint file name
957
+
958
+ predictions, rgb_frames, gt_amodal_frames, amodal_masks, metrics = load_model_and_generate_video_with_metrics(
959
+ checkpoint_path,
960
+ dataset,
961
+ device,
962
+ output_path="amodal_completion_video.mp4",
963
+ fps=6
964
+ )
965
+
966
+
967
+
968
+ # Create GIF with error heatmap
969
+ create_gif_with_error_heatmap(
970
+ predictions,
971
+ rgb_frames,
972
+ gt_amodal_frames,
973
+ amodal_masks,
974
+ output_path="amodal_completion_with_error.gif",
975
+ duration=150
976
+ )
977
+
978
+
979
+ print("GIF creation complete!")
980
+ print(f"Metrics: {metrics}")
981
+
982
+ # Just run this:
983
+ if __name__ == "__main__":
984
+ run_gif_generation()
985
+
986
+ import cv2
987
+
988
+ def draw_amodal_boundary(rgb_image, amodal_mask, color=(255, 0, 255)):
989
+ contours, _ = cv2.findContours(amodal_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
990
+ outlined = rgb_image.copy()
991
+ cv2.drawContours(outlined, contours, -1, color, thickness=2)
992
+ return outlined
993
+
994
+ # Enhanced GIF creation with proper error heatmap and colorbar
995
+ def create_gif_with_error_heatmap(predictions, rgb_frames, gt_amodal_frames, amodal_masks,
996
+ output_path="amodal_completion_with_error.gif", duration=240):
997
+ """Create animated GIF with proper error heatmap and colorbar"""
998
+ from PIL import Image, ImageDraw, ImageFont
999
+ import numpy as np
1000
+ import matplotlib.pyplot as plt
1001
+ import matplotlib.cm as cm
1002
+ from matplotlib.colors import Normalize
1003
+ import io
1004
+
1005
+ frames = []
1006
+ all_errors = []
1007
+
1008
+ # Calculate errors for all frames first to get consistent color scale
1009
+ for i in range(len(predictions)):
1010
+ pred_tensor = predictions[i]
1011
+ gt_tensor = gt_amodal_frames[i]
1012
+ mask_tensor = amodal_masks[i] if amodal_masks else None
1013
+
1014
+ error = create_error_heatmap(pred_tensor.unsqueeze(0), gt_tensor.unsqueeze(0),
1015
+ mask_tensor.unsqueeze(0) if mask_tensor is not None else None)
1016
+ all_errors.append(error)
1017
+
1018
+ # Get global error range for consistent coloring
1019
+ # Focus on masked regions only for better visualization
1020
+ masked_errors = []
1021
+ for i, error in enumerate(all_errors):
1022
+ if amodal_masks is not None:
1023
+ mask = amodal_masks[i][0].numpy()
1024
+ masked_error = error * mask
1025
+ masked_errors.extend(masked_error[masked_error > 0]) # Only non-zero masked regions
1026
+ else:
1027
+ masked_errors.extend(error.flatten())
1028
+
1029
+ if masked_errors:
1030
+ # Use percentiles for better visualization (removes outliers)
1031
+ min_error = np.percentile(masked_errors, 5) # 5th percentile
1032
+ max_error = np.percentile(masked_errors, 95) # 95th percentile
1033
+ else:
1034
+ min_error = min(error.min() for error in all_errors)
1035
+ max_error = max(error.max() for error in all_errors)
1036
+
1037
+ # Ensure we have a reasonable range
1038
+ if max_error - min_error < 1e-6:
1039
+ max_error = min_error + 1e-6
1040
+
1041
+ print(f"Error range for visualization: {min_error:.4f} to {max_error:.4f}")
1042
+
1043
+ # Create colorbar image
1044
+ def create_colorbar(height=256, width=30):
1045
+ # Create a vertical gradient
1046
+ gradient = np.linspace(1, 0, height).reshape(-1, 1)
1047
+ gradient = np.repeat(gradient, width, axis=1)
1048
+
1049
+ # Apply colormap (using 'hot' for red-yellow-white like your image)
1050
+ cmap = cm.get_cmap('hot')
1051
+ colorbar_colored = cmap(gradient)
1052
+ colorbar_rgb = (colorbar_colored[:, :, :3] * 255).astype(np.uint8)
1053
+
1054
+ # Convert to PIL Image
1055
+ colorbar_img = Image.fromarray(colorbar_rgb)
1056
+
1057
+ # Add scale labels
1058
+ fig, ax = plt.subplots(figsize=(1, 4))
1059
+ fig.patch.set_facecolor('black')
1060
+ ax.set_facecolor('black')
1061
+
1062
+ # Create colorbar
1063
+ norm = Normalize(vmin=min_error, vmax=max_error)
1064
+ sm = cm.ScalarMappable(norm=norm, cmap='hot')
1065
+ sm.set_array([])
1066
+
1067
+ cbar = plt.colorbar(sm, ax=ax, orientation='vertical', fraction=1.0)
1068
+ cbar.set_label('Prediction Error', color='white', fontsize=10)
1069
+ cbar.ax.tick_params(colors='white', labelsize=8)
1070
+
1071
+ # Remove the main axes
1072
+ ax.remove()
1073
+
1074
+ # Save to bytes
1075
+ buf = io.BytesIO()
1076
+ plt.savefig(buf, format='png', bbox_inches='tight',
1077
+ facecolor='black', edgecolor='none', dpi=100)
1078
+ buf.seek(0)
1079
+ colorbar_with_labels = Image.open(buf)
1080
+ plt.close()
1081
+
1082
+ return colorbar_with_labels
1083
+
1084
+ # Create colorbar once
1085
+ colorbar_img = create_colorbar()
1086
+ colorbar_width = colorbar_img.width
1087
+
1088
+ for i in range(len(predictions)):
1089
+ # Scene input
1090
+ scene_rgb = (rgb_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
1091
+
1092
+ # Prediction output
1093
+ pred_rgb = (np.clip(predictions[i].permute(1, 2, 0).numpy(), 0, 1) * 255).astype(np.uint8)
1094
+
1095
+ # Ground truth amodal
1096
+ gt_rgb = (gt_amodal_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
1097
+
1098
+ # Error heatmap
1099
+ error = all_errors[i]
1100
+
1101
+ # Apply mask to error if available
1102
+ if amodal_masks is not None:
1103
+ mask = amodal_masks[i][0].numpy()
1104
+ error = error * mask
1105
+
1106
+ # Ensure error is shape (H, W)
1107
+ error = np.squeeze(error)
1108
+ if error.ndim == 3:
1109
+ error = error[0]
1110
+
1111
+ # Normalize error using global range
1112
+ error_normalized = np.clip((error - min_error) / (max_error - min_error), 0, 1)
1113
+
1114
+ # Apply 'hot' colormap for red-yellow-white heatmap like your image
1115
+ cmap = cm.get_cmap('hot')
1116
+ error_colored = cmap(error_normalized) # (H, W, 4)
1117
+ error_rgb = (error_colored[:, :, :3] * 255).astype(np.uint8) # (H, W, 3)
1118
+
1119
+ # Set non-masked regions to black for better visualization
1120
+ if amodal_masks is not None:
1121
+ mask_3d = np.stack([mask, mask, mask], axis=2)
1122
+ error_rgb = error_rgb * mask_3d.astype(np.uint8)
1123
+
1124
+ # Concatenate all images
1125
+ highlighted_rgb = draw_amodal_boundary(scene_rgb, amodal_masks[i][0].cpu().numpy())
1126
+
1127
+
1128
+ combined = np.concatenate([highlighted_rgb, pred_rgb, gt_rgb, error_rgb], axis=1)
1129
+
1130
+ # Convert to PIL for adding colorbar
1131
+ img_pil = Image.fromarray(combined)
1132
+
1133
+ # Resize colorbar to match image height
1134
+ colorbar_resized = colorbar_img.resize((colorbar_width, img_pil.height))
1135
+
1136
+ # Create final image with colorbar
1137
+ final_width = img_pil.width + colorbar_width + 10 # 10px spacing
1138
+ final_img = Image.new('RGB', (final_width, img_pil.height), color='black')
1139
+
1140
+ # Paste main image and colorbar
1141
+ final_img.paste(img_pil, (0, 0))
1142
+ final_img.paste(colorbar_resized, (img_pil.width + 10, 0))
1143
+
1144
+ # Add frame number
1145
+ draw = ImageDraw.Draw(final_img)
1146
+ try:
1147
+ font = ImageFont.load_default()
1148
+ except:
1149
+ font = None
1150
+
1151
+ frame_text = f"Frame {i+1}/{len(predictions)}"
1152
+ draw.text((10, 10), frame_text, fill=(0, 0, 0), font=font)
1153
+
1154
+ frames.append(final_img)
1155
+
1156
+ # Save as animated GIF
1157
+ frames[0].save(
1158
+ output_path,
1159
+ save_all=True,
1160
+ append_images=frames[1:],
1161
+ duration=duration,
1162
+ loop=0
1163
+ )
1164
+
1165
+ print(f"GIF with proper error heatmap saved to {output_path}")
1166
+ print(f"Error range: {min_error:.4f} to {max_error:.4f}")
1167
+ print(f"Colorbar shows errors from low (black/red) to high (yellow/white)")
1168
+
1169
+ # Also update the error heatmap calculation to be more sensitive
1170
+ def create_error_heatmap(pred, target, mask=None):
1171
+ """Create error heatmap between prediction and target with enhanced sensitivity"""
1172
+ # Calculate per-pixel error (L2 norm across color channels)
1173
+ error = torch.sqrt(torch.sum((pred - target) ** 2, dim=1)) # L2 error per pixel
1174
+
1175
+ # Alternative: Use L1 error for different characteristics
1176
+ # error = torch.abs(pred - target).mean(dim=1) # L1 error
1177
+
1178
+ if mask is not None:
1179
+ error = error * mask.squeeze()
1180
+
1181
+ return error.cpu().numpy()
1182
+
epoch_29.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e600161d395086f90ad1d27abb9e2b676255c30391ce5f94acd6675e66c2ab7b
3
+ size 372700024