# -*- coding: utf-8 -*- """2.2.2.2.2.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1igY4MKIJJTPHgEkdLFI_T5H6sLUoTaLr """ #heat map video and metrics """## CODE""" pip install torchmetrics lpips import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms from pathlib import Path from PIL import Image import numpy as np import matplotlib.pyplot as plt from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure from torchmetrics.image.fid import FrechetInceptionDistance import lpips import os import random import shutil from huggingface_hub import HfApi, hf_hub_download import tarfile import json import cv2 from tqdm import tqdm def download_sequential_data(repo_id="Amar-S/MOVi-MC-AC", sample_ratio=0.01, base_dir="/content/data"): """ Download data while preserving video sequences """ api = HfApi() # Create directories os.makedirs(f"{base_dir}/train", exist_ok=True) os.makedirs(f"{base_dir}/test", exist_ok=True) # List all files in the repo files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") # Separate train and test archives (each archive contains a complete scene sequence) #train_files = [f for f in files if f.startswith("train/") and f.endswith(".tar.gz")] test_files = [f for f in files if f.startswith("test/") and f.endswith(".tar.gz")] #print(f"Found {len(train_files)} train archives and {len(test_files)} test archives.") # Sample complete archives (not individual files) to preserve sequences #subset_train = random.sample(train_files, max(1, int(len(train_files) * sample_ratio))) subset_test = random.sample(test_files, max(1, int(len(test_files) * sample_ratio))) #print(f"Downloading {len(subset_train)} train archives and {len(subset_test)} test archives...") # Download training archives # for file in subset_train: # print(f"Downloading {file}...") # out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file) # dest_path = f"{base_dir}/train/{os.path.basename(file)}" # shutil.copyfile(out_path, dest_path) # Download test archives for file in subset_test: print(f"Downloading {file}...") out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file) dest_path = f"{base_dir}/test/{os.path.basename(file)}" shutil.copyfile(out_path, dest_path) # Extract all archives extract_archives(f"{base_dir}/train") extract_archives(f"{base_dir}/test") print("Download and extraction complete!") def extract_archives(directory): """Extract all tar.gz files in a directory""" for file in os.listdir(directory): if file.endswith(".tar.gz"): filepath = os.path.join(directory, file) print(f"Extracting {filepath}...") with tarfile.open(filepath, 'r:gz') as tar: tar.extractall(path=directory) # Remove the archive after extraction os.remove(filepath) download_sequential_data() #extract_archives('/content/data/train') extract_archives('/content/data/test') def extract_archives(directory): """Extract all tar.gz files in a directory""" for file in os.listdir(directory): if file.endswith(".tar.gz"): filepath = os.path.join(directory, file) print(f"Extracting {filepath}...") with tarfile.open(filepath, 'r:gz') as tar: print(filepath) tar.extractall(path=directory) # Remove the archive after extraction os.remove(filepath) #extract_archives('/content/data/train') extract_archives('/content/data/test') class VideoAmodalDataset(Dataset): def __init__(self, root_dir, split='train', seq_len=8, img_size=(256,256), max_scenes=4, samples_per_scene=3, max_samples=None): self.root_dir = Path(root_dir) self.split = split self.seq_len = seq_len self.img_size = img_size self.max_scenes = max_scenes self.samples_per_scene = samples_per_scene self.samples = self._build_sample_index(max_samples) self.transform = transforms.Compose([ transforms.Resize(img_size), transforms.ToTensor(), ]) def _build_sample_index(self, max_samples): samples = [] scene_paths = sorted((self.root_dir / self.split).glob('scene_*'))[:self.max_scenes] for scene_path in scene_paths: camera_paths = sorted(scene_path.glob('camera_*')) for camera_path in camera_paths: obj_paths = sorted(camera_path.glob('obj_*')) selected_objs = random.sample(obj_paths, min(self.samples_per_scene, len(obj_paths))) for obj_path in selected_objs: rgba_files = sorted(camera_path.glob('rgba_*.png')) frame_ids = [int(p.stem.split('_')[1]) for p in rgba_files] # Create non-overlapping sequences for i in range(0, len(frame_ids) - self.seq_len + 1, self.seq_len): samples.append({ 'scene': scene_path.name, 'camera': camera_path.name, 'obj_folder': obj_path.name, 'frame_ids': frame_ids[i:i+self.seq_len], 'obj_id': int(obj_path.name.split('_')[1]) }) if max_samples and len(samples) >= max_samples: return samples return samples def __getitem__(self, idx): sample = self.samples[idx] base_path = self.root_dir / self.split / sample['scene'] / sample['camera'] obj_path = base_path / sample['obj_folder'] rgb_frames = [] modal_mask_frames = [] amodal_mask_frames = [] amodal_rgb_frames = [] for fid in sample['frame_ids']: fid_str = f"{fid:05d}" try: # Load scene RGB rgb = Image.open(base_path / f'rgba_{fid_str}.png').convert('RGB') rgb = self.transform(rgb) # Load scene segmentation to compute modal mask seg_map = np.array(Image.open(base_path / f'segmentation_{fid_str}.png')) modal_mask_np = (seg_map == sample['obj_id']).astype(np.uint8) * 255 modal_mask = Image.fromarray(modal_mask_np, mode='L') modal_mask = self.transform(modal_mask) # Load amodal mask amodal_mask = Image.open(obj_path / f'segmentation_{fid_str}.png').convert('L') amodal_mask = self.transform(amodal_mask) # Load target amodal RGB amodal_rgb = Image.open(obj_path / f'rgba_{fid_str}.png').convert('RGB') amodal_rgb = self.transform(amodal_rgb) rgb_frames.append(rgb) modal_mask_frames.append(modal_mask) amodal_mask_frames.append(amodal_mask) amodal_rgb_frames.append(amodal_rgb) except Exception as e: print(f"Error loading {base_path}/rgba_{fid_str}.png: {e}") # Return empty tensors if loading fails empty_rgb = torch.zeros(3, self.img_size[0], self.img_size[1]) empty_mask = torch.zeros(1, self.img_size[0], self.img_size[1]) return { 'rgb_sequence': empty_rgb.unsqueeze(0).repeat(self.seq_len, 1, 1, 1), 'modal_masks': empty_mask.unsqueeze(0).repeat(self.seq_len, 1, 1, 1), 'amodal_masks': empty_mask.unsqueeze(0).repeat(self.seq_len, 1, 1, 1), 'amodal_rgb_sequence': empty_rgb.unsqueeze(0).repeat(self.seq_len, 1, 1, 1), 'scene': sample['scene'], 'camera': sample['camera'], 'object_id': sample['obj_id'] } return { 'rgb_sequence': torch.stack(rgb_frames), # Scene RGB 'modal_masks': torch.stack(modal_mask_frames), # Modal masks (visible parts) 'amodal_masks': torch.stack(amodal_mask_frames), # Amodal masks (complete shape) 'amodal_rgb_sequence': torch.stack(amodal_rgb_frames), # Target: complete object RGB 'scene': sample['scene'], 'camera': sample['camera'], 'object_id': sample['obj_id'] } def __len__(self): return len(self.samples) import wandb wandb.login() # Add these imports to your existing imports import numpy as np from skimage.metrics import structural_similarity as ssim from skimage.metrics import peak_signal_noise_ratio as psnr import torch.nn.functional as F from scipy import linalg import matplotlib.pyplot as plt import matplotlib.cm as cm from torchvision.models import inception_v3 from torchvision.transforms import Resize, Normalize import lpips # Add this class for computing metrics class VideoAmodalMetrics: """Compute various metrics for video amodal completion""" def __init__(self, device='cuda'): self.device = device # Initialize LPIPS model self.lpips_model = lpips.LPIPS(net='alex').to(device) # Initialize Inception model for FID self.inception_model = inception_v3(pretrained=True, transform_input=False).to(device) self.inception_model.eval() # Preprocessing for Inception self.inception_transform = torch.nn.Sequential( Resize((299, 299)), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ) def calculate_psnr(self, pred, target, mask=None): """Calculate PSNR between prediction and target""" if mask is not None: # Only calculate PSNR in masked regions pred_masked = pred * mask target_masked = target * mask # Convert to numpy and calculate PSNR for each frame psnr_values = [] for i in range(pred.shape[0]): # Over batch or sequence if pred.dim() == 5: # (B, C, N, H, W) for j in range(pred.shape[2]): # Over frames p = pred_masked[i, :, j].permute(1, 2, 0).cpu().numpy() t = target_masked[i, :, j].permute(1, 2, 0).cpu().numpy() m = mask[i, 0, j].cpu().numpy() if m.sum() > 0: # Only if there are masked pixels psnr_val = psnr(t, p, data_range=1.0) psnr_values.append(psnr_val) else: # (B, C, H, W) p = pred_masked[i].permute(1, 2, 0).cpu().numpy() t = target_masked[i].permute(1, 2, 0).cpu().numpy() m = mask[i, 0].cpu().numpy() if m.sum() > 0: psnr_val = psnr(t, p, data_range=1.0) psnr_values.append(psnr_val) else: # Calculate PSNR for entire image mse = F.mse_loss(pred, target) psnr_val = 20 * torch.log10(1.0 / torch.sqrt(mse)) return psnr_val.item() return np.mean(psnr_values) if psnr_values else 0.0 def calculate_ssim(self, pred, target, mask=None): """Calculate SSIM between prediction and target""" ssim_values = [] for i in range(pred.shape[0]): # Over batch if pred.dim() == 5: # (B, C, N, H, W) for j in range(pred.shape[2]): # Over frames p = pred[i, :, j].permute(1, 2, 0).cpu().numpy() t = target[i, :, j].permute(1, 2, 0).cpu().numpy() if mask is not None: m = mask[i, 0, j].cpu().numpy() if m.sum() == 0: continue ssim_val = ssim(t, p, data_range=1.0, channel_axis=2) ssim_values.append(ssim_val) else: # (B, C, H, W) p = pred[i].permute(1, 2, 0).cpu().numpy() t = target[i].permute(1, 2, 0).cpu().numpy() if mask is not None: m = mask[i, 0].cpu().numpy() if m.sum() == 0: continue ssim_val = ssim(t, p, data_range=1.0, channel_axis=2) ssim_values.append(ssim_val) return np.mean(ssim_values) if ssim_values else 0.0 def calculate_lpips(self, pred, target, mask=None): """Calculate LPIPS perceptual distance""" # Ensure inputs are in [-1, 1] range for LPIPS pred_norm = pred * 2.0 - 1.0 target_norm = target * 2.0 - 1.0 lpips_values = [] if pred.dim() == 5: # (B, C, N, H, W) for i in range(pred.shape[0]): for j in range(pred.shape[2]): p = pred_norm[i, :, j].unsqueeze(0) t = target_norm[i, :, j].unsqueeze(0) with torch.no_grad(): lpips_val = self.lpips_model(p, t) lpips_values.append(lpips_val.item()) else: # (B, C, H, W) with torch.no_grad(): lpips_val = self.lpips_model(pred_norm, target_norm) lpips_values.extend(lpips_val.cpu().numpy().tolist()) return np.mean(lpips_values) if lpips_values else 0.0 def calculate_iou(self, pred_mask, target_mask, threshold=0.5): """Calculate IoU for binary masks""" pred_binary = (pred_mask > threshold).float() target_binary = (target_mask > threshold).float() intersection = (pred_binary * target_binary).sum() union = pred_binary.sum() + target_binary.sum() - intersection iou = intersection / (union + 1e-8) return iou.item() def get_inception_features(self, images): """Extract features from Inception model for FID calculation""" with torch.no_grad(): # Preprocess images images_preprocessed = self.inception_transform(images) # Get features features = self.inception_model(images_preprocessed) return features.cpu().numpy() def calculate_fid(self, pred, target): """Calculate Fréchet Inception Distance""" # Reshape if needed if pred.dim() == 5: # (B, C, N, H, W) -> (B*N, C, H, W) pred = pred.permute(0, 2, 1, 3, 4).reshape(-1, pred.shape[1], pred.shape[3], pred.shape[4]) target = target.permute(0, 2, 1, 3, 4).reshape(-1, target.shape[1], target.shape[3], target.shape[4]) # Get features pred_features = self.get_inception_features(pred) target_features = self.get_inception_features(target) # Calculate statistics mu1, sigma1 = pred_features.mean(axis=0), np.cov(pred_features, rowvar=False) mu2, sigma2 = target_features.mean(axis=0), np.cov(target_features, rowvar=False) # Calculate FID diff = mu1 - mu2 covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) if np.iscomplexobj(covmean): covmean = covmean.real fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean) return fid def calculate_all_metrics(self, pred, target, amodal_mask=None): """Calculate all metrics at once""" metrics = {} metrics['psnr'] = self.calculate_psnr(pred, target, amodal_mask) metrics['ssim'] = self.calculate_ssim(pred, target, amodal_mask) metrics['lpips'] = self.calculate_lpips(pred, target, amodal_mask) try: metrics['fid'] = self.calculate_fid(pred, target) except: metrics['fid'] = 0.0 # IoU for masks (if available) if amodal_mask is not None: # Create predicted mask by thresholding prediction pred_intensity = pred.mean(dim=1, keepdim=True) # Convert to grayscale metrics['iou'] = self.calculate_iou(pred_intensity, amodal_mask) return metrics # Add this function to create error heatmaps def create_error_heatmap(pred, target, mask=None): """Create error heatmap between prediction and target""" # Calculate per-pixel error error = torch.abs(pred - target).mean(dim=0) # Average over color channels if mask is not None: error = error * mask.squeeze() return error.cpu().numpy() # Enhanced training function with metrics and wandb def train_video_amodal_with_metrics(): # Initialize wandb wandb.init( project="video-amodal-completion", config={ 'batch_size': 2, 'seq_len': 6, 'img_size': (256, 256), 'num_epochs': 30, 'learning_rate': 5e-5, 'max_scenes': 2, 'samples_per_scene': 2, 'num_workers': 2, 'grad_accum_steps': 4 } ) #print(f"Loaded model from epoch {checkpoint['epoch']} with loss {checkpoint['train_loss']:.4f}") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') torch.cuda.empty_cache() config = wandb.config # Initialize metrics calculator metrics_calculator = VideoAmodalMetrics(device) # Create datasets (your existing code) train_dataset = VideoAmodalDataset( root_dir='data', split='train', seq_len=config.seq_len, img_size=config.img_size, max_scenes=config.max_scenes, samples_per_scene=config.samples_per_scene, max_samples=100 ) val_dataset = VideoAmodalDataset( root_dir='data', split='test', seq_len=config.seq_len, img_size=config.img_size, max_scenes=1, samples_per_scene=1, max_samples=10 ) # DataLoaders (your existing code) train_loader = DataLoader( train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True ) val_loader = DataLoader( val_dataset, batch_size=1, shuffle=False, num_workers=1 ) # Model (your existing code) model = Video3DUNet( in_channels=5, out_channels=3, sequence_length=config.seq_len ).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=1e-4) criterion = VideoAmodalCompletionLoss() # Training loop with metrics for epoch in range(config.num_epochs): model.train() epoch_losses = [] epoch_metrics = { 'train_psnr': [], 'train_ssim': [], 'train_lpips': [], 'train_fid': [], 'train_iou': [] } for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")): # Prepare inputs and targets (your existing code) inputs = prepare_model_input(batch).to(device, non_blocking=True) targets = prepare_model_target(batch).to(device, non_blocking=True) modal_masks = batch['modal_masks'].to(device, non_blocking=True) amodal_masks = batch['amodal_masks'].to(device, non_blocking=True) # Forward pass (your existing code) with torch.cuda.amp.autocast(): outputs = model(inputs) loss, loss_dict = criterion(outputs, targets, modal_masks, amodal_masks) loss = loss / config.grad_accum_steps # Backward pass (your existing code) loss.backward() # Calculate metrics periodically if i % 10 == 0: with torch.no_grad(): amodal_masks_3d = amodal_masks.permute(0, 2, 1, 3, 4) batch_metrics = metrics_calculator.calculate_all_metrics( outputs, targets, amodal_masks_3d ) for key, value in batch_metrics.items(): if f'train_{key}' in epoch_metrics: epoch_metrics[f'train_{key}'].append(value) # Gradient accumulation (your existing code) if (i + 1) % config.grad_accum_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() torch.cuda.empty_cache() epoch_losses.append(loss_dict['total_loss']) # Periodic logging with wandb if i % 20 == 0: log_dict = { 'batch': epoch * len(train_loader) + i, 'train_loss': loss_dict['total_loss'], 'train_visible_loss': loss_dict['visible_loss'], 'train_occluded_loss': loss_dict['occluded_loss'], 'train_background_loss': loss_dict['background_loss'], 'train_boundary_loss': loss_dict['boundary_loss'] } # Add latest metrics if available for key, values in epoch_metrics.items(): if values: log_dict[key] = values[-1] wandb.log(log_dict) print(f"Batch {i}, Loss: {loss_dict['total_loss']:.4f}") print(f" Visible: {loss_dict['visible_loss']:.4f}, " f"Occluded: {loss_dict['occluded_loss']:.4f}, " f"Background: {loss_dict['background_loss']:.4f}") # Validation with metrics model.eval() val_losses = [] val_metrics = { 'val_psnr': [], 'val_ssim': [], 'val_lpips': [], 'val_fid': [], 'val_iou': [] } with torch.no_grad(): for batch in val_loader: inputs = prepare_model_input(batch).to(device) targets = prepare_model_target(batch).to(device) modal_masks = batch['modal_masks'].to(device) amodal_masks = batch['amodal_masks'].to(device) outputs = model(inputs) loss, loss_dict = criterion(outputs, targets, modal_masks, amodal_masks) val_losses.append(loss_dict['total_loss']) # Calculate validation metrics amodal_masks_3d = amodal_masks.permute(0, 2, 1, 3, 4) batch_metrics = metrics_calculator.calculate_all_metrics( outputs, targets, amodal_masks_3d ) for key, value in batch_metrics.items(): if f'val_{key}' in val_metrics: val_metrics[f'val_{key}'].append(value) # End of epoch logging avg_train_loss = np.mean(epoch_losses) avg_val_loss = np.mean(val_losses) epoch_log = { 'epoch': epoch, 'avg_train_loss': avg_train_loss, 'avg_val_loss': avg_val_loss } # Add averaged metrics for key, values in {**epoch_metrics, **val_metrics}.items(): if values: epoch_log[f'avg_{key}'] = np.mean(values) wandb.log(epoch_log) print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}") # Log metrics for key, values in val_metrics.items(): if values: print(f" {key}: {np.mean(values):.4f}") # Save checkpoint (your existing code) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_loss': avg_train_loss, 'val_loss': avg_val_loss, 'metrics': {key: np.mean(values) for key, values in val_metrics.items() if values} }, f"epoch_{epoch}.pth") wandb.finish() # Enhanced GIF creation with error heatmap def create_gif_with_error_heatmap(predictions, rgb_frames, gt_amodal_frames, amodal_masks, output_path="amodal_completion_with_error.gif", duration=200): """Create animated GIF with error heatmap""" from PIL import Image import numpy as np frames = [] all_errors = [] # Calculate errors for all frames first to get consistent color scale for i in range(len(predictions)): pred_tensor = predictions[i] gt_tensor = gt_amodal_frames[i] mask_tensor = amodal_masks[i] if amodal_masks else None error = create_error_heatmap(pred_tensor.unsqueeze(0), gt_tensor.unsqueeze(0), mask_tensor.unsqueeze(0) if mask_tensor is not None else None) all_errors.append(error) # Get global error range for consistent coloring max_error = max(error.max() for error in all_errors) min_error = min(error.min() for error in all_errors) for i in range(len(predictions)): # Scene input scene_rgb = (rgb_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) # Prediction output pred_rgb = (np.clip(predictions[i].permute(1, 2, 0).numpy(), 0, 1) * 255).astype(np.uint8) # Ground truth amodal gt_rgb = (gt_amodal_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) # Error heatmap # Error heatmap error = all_errors[i] # Normalize error to [0, 1] using global range if max_error > min_error: error_normalized = (error - min_error) / (max_error - min_error) else: error_normalized = error # Ensure error is shape (H, W) before applying colormap error_normalized = np.squeeze(error_normalized) if error_normalized.ndim == 3: error_normalized = error_normalized[0] # Apply colormap error_colored = cm.jet(error_normalized) # (H, W, 4) error_rgb = (error_colored[:, :, :3] * 255).astype(np.uint8) # (H, W, 3) # Now safe to concatenate combined = np.concatenate([scene_rgb, pred_rgb, gt_rgb, error_rgb], axis=1) # Add error scale text (simplified - you might want to add a proper colorbar) from PIL import ImageDraw, ImageFont img_pil = Image.fromarray(combined) draw = ImageDraw.Draw(img_pil) # Add text with error range try: font = ImageFont.load_default() except: font = None text = f"Error: {min_error:.3f} - {max_error:.3f}" draw.text((combined.shape[1] - 150, 10), text, fill=(255, 255, 255), font=font) frames.append(img_pil) # Save as animated GIF frames[0].save( output_path, save_all=True, append_images=frames[1:], duration=duration, loop=0 ) print(f"GIF with error heatmap saved to {output_path}") print(f"Error range: {min_error:.4f} to {max_error:.4f}") # Enhanced video generation with metrics def load_model_and_generate_video_with_metrics(checkpoint_path, dataset, device, output_path="amodal_completion.mp4", fps=8): """Load trained model and generate video with metrics calculation""" import cv2 from pathlib import Path # Initialize metrics calculator metrics_calculator = VideoAmodalMetrics(device) # Load model (your existing code remains the same) model = Video3DUNet(in_channels=5, out_channels=3, sequence_length=8).to(device) checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) model.load_state_dict(checkpoint['model_state_dict']) model.eval() print(f"Loaded model from epoch {checkpoint['epoch']} with loss {checkpoint['train_loss']:.4f}") # Get a sample with 24 frames (your existing code) sample = dataset[0] seq_len = 8 total_frames = len(sample['rgb_sequence']) print(f"Processing {total_frames} frames in windows of {seq_len}") all_predictions = [] all_rgb = [] all_modal_masks = [] all_amodal_masks = [] all_metrics = [] with torch.no_grad(): # Process overlapping windows (your existing code) for start_idx in range(0, total_frames - seq_len + 1, seq_len//2): end_idx = min(start_idx + seq_len, total_frames) # Create batch for this window window_batch = {} for key, value in sample.items(): if isinstance(value, torch.Tensor): if value.dim() == 4: window_batch[key] = value[start_idx:end_idx].unsqueeze(0) else: window_batch[key] = value.unsqueeze(0) else: window_batch[key] = [value] # Get prediction for this window inputs = prepare_model_input(window_batch).to(device) pred = model(inputs) # Mask to object region amodal_mask = window_batch['amodal_masks'].permute(0, 2, 1, 3, 4).expand_as(pred).to(device) pred_masked = pred * amodal_mask # Calculate metrics for this window target = prepare_model_target(window_batch).to(device) window_metrics = metrics_calculator.calculate_all_metrics(pred, target, amodal_mask) all_metrics.append(window_metrics) # Store results (your existing code) pred_frames = pred_masked.squeeze(0).permute(1, 0, 2, 3).cpu() if start_idx == 0: all_predictions.extend([pred_frames[i] for i in range(len(pred_frames))]) else: overlap_frames = seq_len // 2 for i in range(overlap_frames): if len(all_predictions) > start_idx + i: all_predictions[start_idx + i] = (all_predictions[start_idx + i] + pred_frames[i]) / 2.0 for i in range(overlap_frames, len(pred_frames)): if start_idx + i < total_frames: all_predictions.append(pred_frames[i]) if start_idx == 0: all_rgb = [sample['rgb_sequence'][i] for i in range(total_frames)] all_modal_masks = [sample['modal_masks'][i] for i in range(total_frames)] all_amodal_masks = [sample['amodal_masks'][i] for i in range(total_frames)] all_gt_amodal = [sample['amodal_rgb_sequence'][i] for i in range(total_frames)] # Print overall metrics print("\nOverall Metrics:") avg_metrics = {} for key in all_metrics[0].keys(): avg_metrics[key] = np.mean([m[key] for m in all_metrics]) print(f" {key.upper()}: {avg_metrics[key]:.4f}") # Your existing video creation code remains the same all_predictions = all_predictions[:total_frames] print(f"Generated {len(all_predictions)} prediction frames") # Create video (your existing code) height, width = all_predictions[0].shape[-2:] video_width = width * 4 video_height = height fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (video_width, video_height)) for i in range(len(all_predictions)): scene_rgb = all_rgb[i].permute(1, 2, 0).numpy() modal_mask = all_modal_masks[i][0].numpy() modal_mask_rgb = np.stack([modal_mask, modal_mask, modal_mask], axis=2) pred_rgb = all_predictions[i].permute(1, 2, 0).numpy() pred_rgb = np.clip(pred_rgb, 0, 1) try: gt_amodal = sample['amodal_rgb_sequence'][i].permute(1, 2, 0).numpy() amodal_mask_np = all_amodal_masks[i][0].numpy() gt_amodal_masked = gt_amodal * amodal_mask_np[:, :, None] except: gt_amodal_masked = np.zeros_like(pred_rgb) combined_frame = np.concatenate([ scene_rgb, modal_mask_rgb, pred_rgb, gt_amodal_masked ], axis=1) combined_frame_bgr = cv2.cvtColor((combined_frame * 255).astype(np.uint8), cv2.COLOR_RGB2BGR) out.write(combined_frame_bgr) if i % 5 == 0: print(f"Processed frame {i+1}/{len(all_predictions)}") out.release() print(f"Video saved to {output_path}") return all_predictions, all_rgb, all_gt_amodal, all_amodal_masks, avg_metrics # Enhanced run function with all new features def run_enhanced_video_generation(): """Run video generation with metrics and error visualization""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load dataset dataset = VideoAmodalDataset( root_dir='data', split='test', seq_len=24, img_size=(256, 256), max_scenes=1, samples_per_scene=1, max_samples=1 ) # Generate video with metrics checkpoint_path = "video_amodal_model_epoch_4.pth" predictions, rgb_frames, gt_amodal_frames, amodal_masks, metrics = load_model_and_generate_video_with_metrics( checkpoint_path, dataset, device, output_path="amodal_completion_video_with_metrics.mp4", fps=8 ) # Create enhanced GIF with error heatmap create_gif_with_error_heatmap( predictions, rgb_frames, gt_amodal_frames, amodal_masks, output_path="amodal_completion_with_error.gif", duration=150 ) print("Enhanced video generation complete!") return metrics train_video_amodal_with_metrics() # Simple way to run GIF generation from your trained model import torch def run_gif_generation(): """Simple function to generate GIFs from your trained model""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Create test dataset dataset = VideoAmodalDataset( root_dir='data', split='test', seq_len=24, img_size=(256, 256), max_scenes=50, samples_per_scene=5, max_samples=50 ) # Generate video with metrics and error heatmap GIF checkpoint_path = "epoch_29.pth" # Change this to your checkpoint file name predictions, rgb_frames, gt_amodal_frames, amodal_masks, metrics = load_model_and_generate_video_with_metrics( checkpoint_path, dataset, device, output_path="amodal_completion_video.mp4", fps=6 ) # Create GIF with error heatmap create_gif_with_error_heatmap( predictions, rgb_frames, gt_amodal_frames, amodal_masks, output_path="amodal_completion_with_error.gif", duration=150 ) print("GIF creation complete!") print(f"Metrics: {metrics}") # Just run this: if __name__ == "__main__": run_gif_generation() import cv2 def draw_amodal_boundary(rgb_image, amodal_mask, color=(255, 0, 255)): contours, _ = cv2.findContours(amodal_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) outlined = rgb_image.copy() cv2.drawContours(outlined, contours, -1, color, thickness=2) return outlined # Enhanced GIF creation with proper error heatmap and colorbar def create_gif_with_error_heatmap(predictions, rgb_frames, gt_amodal_frames, amodal_masks, output_path="amodal_completion_with_error.gif", duration=240): """Create animated GIF with proper error heatmap and colorbar""" from PIL import Image, ImageDraw, ImageFont import numpy as np import matplotlib.pyplot as plt import matplotlib.cm as cm from matplotlib.colors import Normalize import io frames = [] all_errors = [] # Calculate errors for all frames first to get consistent color scale for i in range(len(predictions)): pred_tensor = predictions[i] gt_tensor = gt_amodal_frames[i] mask_tensor = amodal_masks[i] if amodal_masks else None error = create_error_heatmap(pred_tensor.unsqueeze(0), gt_tensor.unsqueeze(0), mask_tensor.unsqueeze(0) if mask_tensor is not None else None) all_errors.append(error) # Get global error range for consistent coloring # Focus on masked regions only for better visualization masked_errors = [] for i, error in enumerate(all_errors): if amodal_masks is not None: mask = amodal_masks[i][0].numpy() masked_error = error * mask masked_errors.extend(masked_error[masked_error > 0]) # Only non-zero masked regions else: masked_errors.extend(error.flatten()) if masked_errors: # Use percentiles for better visualization (removes outliers) min_error = np.percentile(masked_errors, 5) # 5th percentile max_error = np.percentile(masked_errors, 95) # 95th percentile else: min_error = min(error.min() for error in all_errors) max_error = max(error.max() for error in all_errors) # Ensure we have a reasonable range if max_error - min_error < 1e-6: max_error = min_error + 1e-6 print(f"Error range for visualization: {min_error:.4f} to {max_error:.4f}") # Create colorbar image def create_colorbar(height=256, width=30): # Create a vertical gradient gradient = np.linspace(1, 0, height).reshape(-1, 1) gradient = np.repeat(gradient, width, axis=1) # Apply colormap (using 'hot' for red-yellow-white like your image) cmap = cm.get_cmap('hot') colorbar_colored = cmap(gradient) colorbar_rgb = (colorbar_colored[:, :, :3] * 255).astype(np.uint8) # Convert to PIL Image colorbar_img = Image.fromarray(colorbar_rgb) # Add scale labels fig, ax = plt.subplots(figsize=(1, 4)) fig.patch.set_facecolor('black') ax.set_facecolor('black') # Create colorbar norm = Normalize(vmin=min_error, vmax=max_error) sm = cm.ScalarMappable(norm=norm, cmap='hot') sm.set_array([]) cbar = plt.colorbar(sm, ax=ax, orientation='vertical', fraction=1.0) cbar.set_label('Prediction Error', color='white', fontsize=10) cbar.ax.tick_params(colors='white', labelsize=8) # Remove the main axes ax.remove() # Save to bytes buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', facecolor='black', edgecolor='none', dpi=100) buf.seek(0) colorbar_with_labels = Image.open(buf) plt.close() return colorbar_with_labels # Create colorbar once colorbar_img = create_colorbar() colorbar_width = colorbar_img.width for i in range(len(predictions)): # Scene input scene_rgb = (rgb_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) # Prediction output pred_rgb = (np.clip(predictions[i].permute(1, 2, 0).numpy(), 0, 1) * 255).astype(np.uint8) # Ground truth amodal gt_rgb = (gt_amodal_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) # Error heatmap error = all_errors[i] # Apply mask to error if available if amodal_masks is not None: mask = amodal_masks[i][0].numpy() error = error * mask # Ensure error is shape (H, W) error = np.squeeze(error) if error.ndim == 3: error = error[0] # Normalize error using global range error_normalized = np.clip((error - min_error) / (max_error - min_error), 0, 1) # Apply 'hot' colormap for red-yellow-white heatmap like your image cmap = cm.get_cmap('hot') error_colored = cmap(error_normalized) # (H, W, 4) error_rgb = (error_colored[:, :, :3] * 255).astype(np.uint8) # (H, W, 3) # Set non-masked regions to black for better visualization if amodal_masks is not None: mask_3d = np.stack([mask, mask, mask], axis=2) error_rgb = error_rgb * mask_3d.astype(np.uint8) # Concatenate all images highlighted_rgb = draw_amodal_boundary(scene_rgb, amodal_masks[i][0].cpu().numpy()) combined = np.concatenate([highlighted_rgb, pred_rgb, gt_rgb, error_rgb], axis=1) # Convert to PIL for adding colorbar img_pil = Image.fromarray(combined) # Resize colorbar to match image height colorbar_resized = colorbar_img.resize((colorbar_width, img_pil.height)) # Create final image with colorbar final_width = img_pil.width + colorbar_width + 10 # 10px spacing final_img = Image.new('RGB', (final_width, img_pil.height), color='black') # Paste main image and colorbar final_img.paste(img_pil, (0, 0)) final_img.paste(colorbar_resized, (img_pil.width + 10, 0)) # Add frame number draw = ImageDraw.Draw(final_img) try: font = ImageFont.load_default() except: font = None frame_text = f"Frame {i+1}/{len(predictions)}" draw.text((10, 10), frame_text, fill=(0, 0, 0), font=font) frames.append(final_img) # Save as animated GIF frames[0].save( output_path, save_all=True, append_images=frames[1:], duration=duration, loop=0 ) print(f"GIF with proper error heatmap saved to {output_path}") print(f"Error range: {min_error:.4f} to {max_error:.4f}") print(f"Colorbar shows errors from low (black/red) to high (yellow/white)") # Also update the error heatmap calculation to be more sensitive def create_error_heatmap(pred, target, mask=None): """Create error heatmap between prediction and target with enhanced sensitivity""" # Calculate per-pixel error (L2 norm across color channels) error = torch.sqrt(torch.sum((pred - target) ** 2, dim=1)) # L2 error per pixel # Alternative: Use L1 error for different characteristics # error = torch.abs(pred - target).mean(dim=1) # L1 error if mask is not None: error = error * mask.squeeze() return error.cpu().numpy()