# -*- coding: utf-8 -*- """final1.2.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1v6-6x7lqt6gr9VIauNVHIwjvIkewk8eT """ """## FINAL 1.2""" pip install torchmetrics lpips # PyTorch, Torchvision import torch from torch import nn from torchvision.transforms import ToPILImage, ToTensor from torchvision.utils import make_grid from torchvision.io import write_video # Common from pathlib import Path from PIL import Image import numpy as np import matplotlib.pyplot as plt import random import json from IPython.display import Video # Utils from Torchvision tensor_to_image = ToPILImage() image_to_tensor = ToTensor() def get_img_dict(img_dir): img_files = [x for x in img_dir.iterdir() if x.name.endswith('.png') or x.name.endswith('.tiff')] img_files.sort() img_dict = {} for img_file in img_files: img_type = img_file.name.split('_')[0] if img_type not in img_dict: img_dict[img_type] = [] img_dict[img_type].append(img_file) return img_dict def get_sample_dict(sample_dir): camera_dirs = [x for x in sample_dir.iterdir() if 'camera' in x.name] camera_dirs.sort() sample_dict = {} for cam_dir in camera_dirs: cam_dict = {} cam_dict['scene'] = get_img_dict(cam_dir) obj_dirs = [x for x in cam_dir.iterdir() if 'obj_' in x.name] obj_dirs.sort() for obj_dir in obj_dirs: cam_dict[obj_dir.name] = get_img_dict(obj_dir) sample_dict[cam_dir.name] = cam_dict return sample_dict !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/test_obj_descriptors.json #Download Descriptors, Readme, etc. !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/train_obj_descriptors.json !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/ex_vis.mp4 !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/README.md !wget "https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/Notice%201%20-%20Unlimited_datasets.pdf" !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/.gitattributes #Test to see if you are on the right huggingface repo from huggingface_hub import HfApi, hf_hub_download import random, os api = HfApi() repo_id = "Amar-S/MOVi-MC-AC" # # List all files in the repo files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") # # Separate train and test files train_files = [f for f in files if f.startswith("train/") and not f.endswith(".json")] test_files = [f for f in files if f.startswith("test/") and not f.endswith(".json")] print(f"Found {len(train_files)} train files and {len(test_files)} test files.") #Download 4% of Train/Test files import os import random import shutil from huggingface_hub import hf_hub_download os.makedirs("/content/data/train", exist_ok=True) os.makedirs("/content/data/test", exist_ok=True) # # Sample 4% of each split (as you were doing) subset_train = random.sample(train_files, int(len(train_files) * 0.005)) subset_test = random.sample(test_files, int(len(test_files) * 0.005)) # # Download the training files (uncomment and fix) for file in subset_train: out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file) dest_path = f"/content/data/train/{os.path.basename(file)}" shutil.copyfile(out_path, dest_path) # COPY the actual file content instead of renaming symlink # # Download the test files for file in subset_test: out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file) dest_path = f"/content/data/test/{os.path.basename(file)}" shutil.copyfile(out_path, dest_path) # COPY the actual file content here as well import os # Untar all files in data/train train_dir = "data/train" for file in os.listdir(train_dir): if file.endswith(".tar.gz"): filepath = os.path.join(train_dir, file) !tar -xzf {filepath} -C {train_dir} # Untar all files in data/test test_dir = "data/test" for file in os.listdir(test_dir): if file.endswith(".tar.gz"): filepath = os.path.join(test_dir, file) !tar -xzf {filepath} -C {test_dir} import os from pathlib import Path root = Path('/content/data') # or wherever your files live deleted = 0 for archive in root.rglob('*.tar.gz'): try: archive.unlink() print(f"Deleted {archive}") deleted += 1 except Exception as e: print(f"Error deleting {archive}: {e}") print(f"Total deleted: {deleted}") pip install torchmetrics lpips import matplotlib.pyplot as plt from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure import lpips import matplotlib.pyplot as plt import torch def visualize_results(model, dataloader, device, num_samples=8): """Visualize results with properly masked output (no background)""" model.eval() samples_shown = 0 with torch.no_grad(): for batch in dataloader: if samples_shown >= num_samples: break rgb = batch['rgb'].to(device) modal_mask = batch['modal_mask'].to(device) amodal_mask = batch['amodal_mask'].to(device) gt_amodal_rgb = batch['amodal_rgb'].to(device) input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) pred = model(input_tensor) pred_masked = pred * amodal_mask # Remove background from prediction gt_masked = gt_amodal_rgb * amodal_mask # Ensure GT is also masked consistently for i in range(rgb.shape[0]): if samples_shown >= num_samples: break fig, axes = plt.subplots(1, 6, figsize=(24, 4)) # Scene RGB axes[0].imshow(rgb[i].cpu().permute(1, 2, 0)) axes[0].set_title('Scene RGB') axes[0].axis('off') # Amodal Mask axes[1].imshow(amodal_mask[i, 0].cpu(), cmap='gray') axes[1].set_title('Amodal Mask') axes[1].axis('off') # Modal Mask axes[2].imshow(modal_mask[i, 0].cpu(), cmap='gray') axes[2].set_title('Modal Mask') axes[2].axis('off') # Ground Truth Amodal RGB (masked) axes[3].imshow(gt_masked[i].cpu().permute(1, 2, 0)) axes[3].set_title('GT Amodal RGB') axes[3].axis('off') # Predicted Amodal RGB (masked) axes[4].imshow(pred_masked[i].cpu().permute(1, 2, 0)) axes[4].set_title('Predicted Amodal RGB') axes[4].axis('off') # Difference Heatmap diff = torch.abs(pred_masked[i] - gt_masked[i]).mean(dim=0) im = axes[5].imshow(diff.cpu(), cmap='hot') axes[5].set_title('Prediction Error') axes[5].axis('off') plt.colorbar(im, ax=axes[5]) plt.tight_layout() plt.show() samples_shown += 1 # STEP 4: Add this function for better evaluation: def evaluate_metrics(model, dataloader, device): """Compute evaluation metrics only within object regions""" model.eval() total_mse = 0 occluded_mse = 0 visible_mse = 0 total_pixels = 0 occluded_pixels = 0 visible_pixels = 0 with torch.no_grad(): for batch in dataloader: rgb = batch['rgb'].to(device) modal_mask = batch['modal_mask'].to(device) amodal_mask = batch['amodal_mask'].to(device) occluded_mask = batch['occluded_mask'].to(device) gt_amodal_rgb = batch['amodal_rgb'].to(device) input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) pred = model(input_tensor) # Mask both prediction and ground truth to object regions only pred_masked = pred * amodal_mask gt_masked = gt_amodal_rgb * amodal_mask # Overall MSE within object region object_pixels = amodal_mask.sum() if object_pixels > 0: mse = F.mse_loss(pred_masked, gt_masked, reduction='sum') total_mse += mse.item() total_pixels += object_pixels.item() # Occluded region MSE occluded_region = occluded_mask * amodal_mask occ_pixels = occluded_region.sum() if occ_pixels > 0: occ_mse = F.mse_loss(pred_masked * occluded_region, gt_masked * occluded_region, reduction='sum') occluded_mse += occ_mse.item() occluded_pixels += occ_pixels.item() # Visible region MSE visible_region = modal_mask * amodal_mask vis_pixels = visible_region.sum() if vis_pixels > 0: vis_mse = F.mse_loss(pred_masked * visible_region, gt_masked * visible_region, reduction='sum') visible_mse += vis_mse.item() visible_pixels += vis_pixels.item() return { 'total_mse': total_mse / total_pixels if total_pixels > 0 else 0, 'occluded_mse': occluded_mse / occluded_pixels if occluded_pixels > 0 else 0, 'visible_mse': visible_mse / visible_pixels if visible_pixels > 0 else 0, } def calculate_metrics(model, dataloader, device): """Computes PSNR, SSIM, LPIPS, and IoU between predictions and GT amodal RGBs.""" model.eval() psnr = PeakSignalNoiseRatio().to(device) ssim = StructuralSimilarityIndexMeasure().to(device) lpips_loss = lpips.LPIPS(net='alex').to(device) total_psnr, total_ssim, total_lpips = 0, 0, 0 total_iou = 0 count = 0 with torch.no_grad(): for batch in dataloader: rgb = batch['rgb'].to(device) modal_mask = batch['modal_mask'].to(device) amodal_mask = batch['amodal_mask'].to(device) gt_amodal_rgb = batch['amodal_rgb'].to(device) input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) pred = model(input_tensor) pred_masked = pred * amodal_mask gt_masked = gt_amodal_rgb * amodal_mask for i in range(pred.shape[0]): pred_i = pred_masked[i].unsqueeze(0) gt_i = gt_masked[i].unsqueeze(0) # Resize for LPIPS if necessary (it requires >= 64x64) if pred_i.shape[-1] < 64 or pred_i.shape[-2] < 64: continue total_psnr += psnr(pred_i, gt_i).item() total_ssim += ssim(pred_i, gt_i).item() total_lpips += lpips_loss(pred_i, gt_i).item() # mIoU between masks intersection = (amodal_mask[i] * (pred[i] > 0.5)).sum() union = ((amodal_mask[i] + (pred[i] > 0.5)) > 0).sum() if union > 0: iou = intersection.float() / union.float() total_iou += iou.item() count += 1 if count == 0: return {"psnr": 0, "ssim": 0, "lpips": 0, "miou": 0} return { "psnr": total_psnr / count, "ssim": total_ssim / count, "lpips": total_lpips / count, "miou": total_iou / count } pip install torchmetrics lpips import matplotlib.pyplot as plt from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure import lpips import matplotlib.pyplot as plt import torch def visualize_results(model, dataloader, device, num_samples=8): """Visualize results with properly masked output (no background)""" model.eval() samples_shown = 0 with torch.no_grad(): for batch in dataloader: if samples_shown >= num_samples: break rgb = batch['rgb'].to(device) modal_mask = batch['modal_mask'].to(device) amodal_mask = batch['amodal_mask'].to(device) gt_amodal_rgb = batch['amodal_rgb'].to(device) input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) pred = model(input_tensor) pred_masked = pred * amodal_mask # Remove background from prediction gt_masked = gt_amodal_rgb * amodal_mask # Ensure GT is also masked consistently for i in range(rgb.shape[0]): if samples_shown >= num_samples: break fig, axes = plt.subplots(1, 6, figsize=(24, 4)) # Scene RGB axes[0].imshow(rgb[i].cpu().permute(1, 2, 0)) axes[0].set_title('Scene RGB') axes[0].axis('off') # Amodal Mask axes[1].imshow(amodal_mask[i, 0].cpu(), cmap='gray') axes[1].set_title('Amodal Mask') axes[1].axis('off') # Modal Mask axes[2].imshow(modal_mask[i, 0].cpu(), cmap='gray') axes[2].set_title('Modal Mask') axes[2].axis('off') # Ground Truth Amodal RGB (masked) axes[3].imshow(gt_masked[i].cpu().permute(1, 2, 0)) axes[3].set_title('GT Amodal RGB') axes[3].axis('off') # Predicted Amodal RGB (masked) axes[4].imshow(pred_masked[i].cpu().permute(1, 2, 0)) axes[4].set_title('Predicted Amodal RGB') axes[4].axis('off') # Difference Heatmap diff = torch.abs(pred_masked[i] - gt_masked[i]).mean(dim=0) im = axes[5].imshow(diff.cpu(), cmap='hot') axes[5].set_title('Prediction Error') axes[5].axis('off') plt.colorbar(im, ax=axes[5]) plt.tight_layout() plt.show() samples_shown += 1 def evaluate_metrics(model, dataloader, device): """Compute evaluation metrics only within object regions""" model.eval() total_mse = 0 occluded_mse = 0 visible_mse = 0 total_pixels = 0 occluded_pixels = 0 visible_pixels = 0 with torch.no_grad(): for batch in dataloader: rgb = batch['rgb'].to(device) modal_mask = batch['modal_mask'].to(device) amodal_mask = batch['amodal_mask'].to(device) occluded_mask = batch['occluded_mask'].to(device) gt_amodal_rgb = batch['amodal_rgb'].to(device) input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) pred = model(input_tensor) # Mask both prediction and ground truth to object regions only pred_masked = pred * amodal_mask gt_masked = gt_amodal_rgb * amodal_mask # Overall MSE within object region object_pixels = amodal_mask.sum() if object_pixels > 0: mse = F.mse_loss(pred_masked, gt_masked, reduction='sum') total_mse += mse.item() total_pixels += object_pixels.item() # Occluded region MSE occluded_region = occluded_mask * amodal_mask occ_pixels = occluded_region.sum() if occ_pixels > 0: occ_mse = F.mse_loss(pred_masked * occluded_region, gt_masked * occluded_region, reduction='sum') occluded_mse += occ_mse.item() occluded_pixels += occ_pixels.item() # Visible region MSE visible_region = modal_mask * amodal_mask vis_pixels = visible_region.sum() if vis_pixels > 0: vis_mse = F.mse_loss(pred_masked * visible_region, gt_masked * visible_region, reduction='sum') visible_mse += vis_mse.item() visible_pixels += vis_pixels.item() return { 'total_mse': total_mse / total_pixels if total_pixels > 0 else 0, 'occluded_mse': occluded_mse / occluded_pixels if occluded_pixels > 0 else 0, 'visible_mse': visible_mse / visible_pixels if visible_pixels > 0 else 0, } def calculate_metrics(model, dataloader, device): """Computes PSNR, SSIM, LPIPS, and IoU between predictions and GT amodal RGBs.""" model.eval() psnr = PeakSignalNoiseRatio().to(device) ssim = StructuralSimilarityIndexMeasure().to(device) lpips_loss = lpips.LPIPS(net='alex').to(device) total_psnr, total_ssim, total_lpips = 0, 0, 0 total_iou = 0 count = 0 with torch.no_grad(): for batch in dataloader: rgb = batch['rgb'].to(device) modal_mask = batch['modal_mask'].to(device) amodal_mask = batch['amodal_mask'].to(device) gt_amodal_rgb = batch['amodal_rgb'].to(device) input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) pred = model(input_tensor) pred_masked = pred * amodal_mask gt_masked = gt_amodal_rgb * amodal_mask for i in range(pred.shape[0]): pred_i = pred_masked[i].unsqueeze(0) gt_i = gt_masked[i].unsqueeze(0) # Resize for LPIPS if necessary (it requires >= 64x64) if pred_i.shape[-1] < 64 or pred_i.shape[-2] < 64: continue total_psnr += psnr(pred_i, gt_i).item() total_ssim += ssim(pred_i, gt_i).item() total_lpips += lpips_loss(pred_i, gt_i).item() # mIoU between masks intersection = (amodal_mask[i] * (pred[i] > 0.5)).sum() union = ((amodal_mask[i] + (pred[i] > 0.5)) > 0).sum() if union > 0: iou = intersection.float() / union.float() total_iou += iou.item() count += 1 if count == 0: return {"psnr": 0, "ssim": 0, "lpips": 0, "miou": 0} return { "psnr": total_psnr / count, "ssim": total_ssim / count, "lpips": total_lpips / count, "miou": total_iou / count } import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms from pathlib import Path from PIL import Image, ImageChops import numpy as np class ModalAmodalDataset(Dataset): def __init__(self, root_dir, split, img_size=(128, 128), max_samples=None, val_split=0.2, use_val_from_train=False): self.root_dir = Path(root_dir) self.img_size = img_size self.max_samples = max_samples self.val_split = val_split self.use_val_from_train = use_val_from_train self.split = split if split == 'val' and use_val_from_train: # Load from train folder but use validation subset self.root_dir = self.root_dir / 'train' else: self.root_dir = self.root_dir / split self.samples = self._build_sample_index() self.rgb_transform = transforms.Compose([ transforms.Resize(img_size), transforms.ToTensor(), ]) self.mask_transform = transforms.Compose([ transforms.Resize(img_size), transforms.ToTensor(), ]) def _build_sample_index(self): samples = [] for scene_dir in self.root_dir.iterdir(): if not scene_dir.is_dir(): continue for camera_dir in scene_dir.iterdir(): if not camera_dir.name.startswith('camera_'): continue rgba_paths = sorted(camera_dir.glob('rgba_*.png')) seg_paths = sorted(camera_dir.glob('segmentation_*.png')) for obj_dir in camera_dir.iterdir(): if not obj_dir.name.startswith('obj_'): continue amodal_paths = sorted(obj_dir.glob('segmentation_*.png')) amodal_rgb_paths = sorted(obj_dir.glob('rgba_*.png')) if not (len(rgba_paths) == len(seg_paths) == len(amodal_paths) == len(amodal_rgb_paths)): continue for rgba_path, seg_path, amodal_path, amodal_rgb_path in zip( rgba_paths, seg_paths, amodal_paths, amodal_rgb_paths ): samples.append({ 'rgb_path': rgba_path, 'seg_path': seg_path, 'amodal_path': amodal_path, 'amodal_rgb_path': amodal_rgb_path, 'object_id': int(obj_dir.name.split('_')[1]), 'scene': scene_dir.name, 'camera': camera_dir.name }) # Limit dataset size if specified if self.max_samples is not None and len(samples) > self.max_samples: # Randomly sample to get diverse examples import random random.seed(42) # For reproducibility samples = random.sample(samples, self.max_samples) print(f"Dataset limited to {len(samples)} samples") # Create train/val split if using validation from train if self.use_val_from_train: import random random.seed(42) # Ensure reproducible splits random.shuffle(samples) val_size = int(len(samples) * self.val_split) if self.split == 'train': samples = samples[val_size:] # Use remaining samples for training print(f"Train split: {len(samples)} samples") elif self.split == 'val': samples = samples[:val_size] # Use first samples for validation print(f"Validation split: {len(samples)} samples") return samples def __len__(self): return len(self.samples) def __getitem__(self, idx): sample = self.samples[idx] # Load images rgb = Image.open(sample['rgb_path']).convert('RGB') seg_map = np.array(Image.open(sample['seg_path'])) amodal_mask_img = Image.open(sample['amodal_path']).convert('L') amodal_rgb = Image.open(sample['amodal_rgb_path']).convert('RGB') # Compute modal mask (visible part) modal_mask_np = (seg_map == sample['object_id']).astype(np.uint8) * 255 modal_mask_img = Image.fromarray(modal_mask_np, mode='L') # Transform images and masks rgb = self.rgb_transform(rgb) modal_mask = self.mask_transform(modal_mask_img) amodal_mask = self.mask_transform(amodal_mask_img) amodal_rgb = self.rgb_transform(amodal_rgb) # Create occluded mask (parts that are hidden) occluded_mask = amodal_mask - modal_mask occluded_mask = torch.clamp(occluded_mask, 0, 1) return { 'rgb': rgb, 'modal_mask': modal_mask, 'amodal_mask': amodal_mask, 'occluded_mask': occluded_mask, 'amodal_rgb': amodal_rgb, } class ImprovedUNet(nn.Module): def __init__(self, in_channels=5, out_channels=3): # RGB + modal_mask + amodal_mask super().__init__() def conv_block(in_ch, out_ch, dropout=0.1): return nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Dropout2d(dropout), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) # Encoder self.down1 = conv_block(in_channels, 64) self.pool1 = nn.MaxPool2d(2) self.down2 = conv_block(64, 128) self.pool2 = nn.MaxPool2d(2) self.down3 = conv_block(128, 256) self.pool3 = nn.MaxPool2d(2) self.down4 = conv_block(256, 512) self.pool4 = nn.MaxPool2d(2) # Bottleneck self.middle = conv_block(512, 1024, dropout=0.2) # Decoder self.up1 = nn.ConvTranspose2d(1024, 512, 2, stride=2) self.up_block1 = conv_block(1024, 512) self.up2 = nn.ConvTranspose2d(512, 256, 2, stride=2) self.up_block2 = conv_block(512, 256) self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2) self.up_block3 = conv_block(256, 128) self.up4 = nn.ConvTranspose2d(128, 64, 2, stride=2) self.up_block4 = conv_block(128, 64) self.final = nn.Conv2d(64, out_channels, 1) def forward(self, x): # Encoder d1 = self.down1(x) d2 = self.down2(self.pool1(d1)) d3 = self.down3(self.pool2(d2)) d4 = self.down4(self.pool3(d3)) # Bottleneck m = self.middle(self.pool4(d4)) # Decoder with skip connections u1 = self.up_block1(torch.cat([self.up1(m), d4], dim=1)) u2 = self.up_block2(torch.cat([self.up2(u1), d3], dim=1)) u3 = self.up_block3(torch.cat([self.up3(u2), d2], dim=1)) u4 = self.up_block4(torch.cat([self.up4(u3), d1], dim=1)) return torch.sigmoid(self.final(u4)) # Ensure output is in [0,1] class AmodalCompletionLoss(nn.Module): """Loss that only considers object regions (ignores background)""" def __init__(self, occluded_weight=5.0, visible_weight=1.0): super().__init__() self.occluded_weight = occluded_weight self.visible_weight = visible_weight self.lpips_model = lpips.LPIPS(net='alex') def forward(self, pred, target, modal_mask, occluded_mask, amodal_mask): # Only compute loss within the amodal mask (object region) device = pred.device self.lpips_model = self.lpips_model.to(device) pred_masked = pred * amodal_mask target_masked = target * amodal_mask # Loss on visible parts (within object) visible_region = modal_mask * amodal_mask if visible_region.sum() > 0: visible_loss = F.mse_loss(pred_masked * visible_region, target_masked * visible_region) else: visible_loss = torch.tensor(0.0).to(pred.device) # Loss on occluded parts (within object) occluded_region = occluded_mask * amodal_mask if occluded_region.sum() > 0: occluded_loss = F.mse_loss(pred_masked * occluded_region, target_masked * occluded_region) else: occluded_loss = torch.tensor(0.0).to(pred.device) perceptual_loss = self.lpips_model(pred_masked, target_masked).mean() # Boundary consistency within object boundary_mask = F.conv2d(amodal_mask, torch.ones(1,1,3,3).to(amodal_mask.device), padding=1) boundary_mask = ((boundary_mask > 0) & (boundary_mask < 9)).float() boundary_loss = F.mse_loss(pred_masked * boundary_mask, target_masked * boundary_mask) total_loss = (self.visible_weight * visible_loss + self.occluded_weight * occluded_loss + 2.0 * boundary_loss) return total_loss, visible_loss, occluded_loss, boundary_loss def train_improved(model, dataloader, optimizer, device, num_epochs): model.train() criterion = AmodalCompletionLoss() for epoch in range(num_epochs): total_loss = 0 for i, batch in enumerate(dataloader): rgb = batch['rgb'].to(device) modal_mask = batch['modal_mask'].to(device) amodal_mask = batch['amodal_mask'].to(device) occluded_mask = batch['occluded_mask'].to(device) gt_amodal_rgb = batch['amodal_rgb'].to(device) input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) optimizer.zero_grad() pred = model(input_tensor) loss, vis_loss, occ_loss, boundary_loss = criterion( pred, gt_amodal_rgb, modal_mask, occluded_mask, amodal_mask ) loss.backward() optimizer.step() total_loss += loss.item() if i % 16 == 0: print(f"Epoch [{epoch}/{num_epochs}] [{i}/{len(dataloader)}] " f"Total: {loss.item():.4f}, Visible: {vis_loss.item():.4f}, " f"Occluded: {occ_loss.item():.4f}, Boundary: {boundary_loss.item():.4f}") print(f"Epoch {epoch} Average Loss: {total_loss/len(dataloader):.4f}") # Usage if __name__ == "__main__": device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Dataset and DataLoader - REDUCED SIZE FOR FASTER TRAINING data_root = "data" # Create train dataset (80% of train folder) train_dataset = ModalAmodalDataset( root_dir=data_root, split='train', img_size=(128, 128), max_samples=1000, # Only use 1000 samples total before split val_split=0.2, # 20% for validation use_val_from_train=True # Create val split from train folder ) train_loader = DataLoader( train_dataset, batch_size=16, shuffle=True, num_workers=2, pin_memory=True, drop_last=True ) # Create validation dataset (20% of train folder) val_dataset = ModalAmodalDataset( root_dir=data_root, split='val', img_size=(128, 128), max_samples=1000, # Same max_samples to ensure proper split val_split=0.2, use_val_from_train=True # Create val split from train folder ) val_loader = DataLoader( val_dataset, batch_size=4, shuffle=True, num_workers=2, pin_memory=True ) print(f"Training on {len(train_dataset)} samples, {len(train_loader)} batches per epoch") print(f"Validation on {len(val_dataset)} samples, {len(val_loader)} batches") model = ImprovedUNet().to(device) model.load_state_dict(torch.load('amodal_completion_model.pth', map_location=device)) # Model and optimizer model = model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4) # Training #train_improved(model, train_loader, optimizer, device, num_epochs=10) # Evaluation and Visualization print("\n" + "="*50) print("EVALUATION RESULTS") print("="*50) # Compute metrics metrics = evaluate_metrics(model, val_loader, device) print(f"Overall MSE: {metrics['total_mse']:.6f}") print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}") print(f"Visible Region MSE: {metrics['visible_mse']:.6f}") print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}") # Visualize results print("\nGenerating visualizations...") visualize_results(model, val_loader, device, num_samples=8) # Compute metrics image_metrics = calculate_metrics(model, val_loader, device) print(f"PSNR: {image_metrics['psnr']:.4f}") print(f"SSIM: {image_metrics['ssim']:.4f}") print(f"LPIPS: {image_metrics['lpips']:.4f}") print(f"mIoU (pred vs GT): {image_metrics['miou']:.4f}") # Dataset and DataLoader - REDUCED SIZE FOR FASTER TRAINING data_root = "data" # Create train dataset (80% of train folder) train_dataset = ModalAmodalDataset( root_dir=data_root, split='train', img_size=(128, 128), max_samples=1000, # Only use 1000 samples total before split val_split=0.2, # 20% for validation use_val_from_train=True # Create val split from train folder ) train_loader = DataLoader( train_dataset, batch_size=16, shuffle=True, num_workers=2, pin_memory=True, drop_last=True ) # Create validation dataset (20% of train folder) val_dataset = ModalAmodalDataset( root_dir=data_root, split='val', img_size=(128, 128), max_samples=1000, # Same max_samples to ensure proper split val_split=0.2, use_val_from_train=True # Create val split from train folder ) val_loader = DataLoader( val_dataset, batch_size=4, shuffle=True, num_workers=2, pin_memory=True ) # Optional: Save model torch.save(model.state_dict(), 'amodal_completion_model.pth') # Evaluation and Visualization test_dataset = ModalAmodalDataset( root_dir=data_root, split='test', img_size=(128, 128), max_samples=2000 # Only use 1000 samples total before split ) test_loader = DataLoader( test_dataset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True, drop_last=True ) print("EVALUATION RESULTS") print("="*50) # Compute metrics metrics = evaluate_metrics(model, test_loader, device) print(f"Overall MSE: {metrics['total_mse']:.6f}") print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}") print(f"Visible Region MSE: {metrics['visible_mse']:.6f}") print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}") # Visualize results print("\nGenerating visualizations...") visualize_results(model, test_loader, device, num_samples=16) from google.colab import runtime runtime.unassign() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = ImprovedUNet() # replace with actual class name torch.load('amodal_completion_model.pth', map_location=torch.device('cpu')) model.to(device) model.eval() # Evaluation and Visualization print("\n" + "="*50) print("EVALUATION RESULTS") print("="*50) # Compute metrics metrics = evaluate_metrics(model, val_loader, device) print(f"Overall MSE: {metrics['total_mse']:.6f}") print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}") print(f"Visible Region MSE: {metrics['visible_mse']:.6f}") print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}") # Visualize results print("\nGenerating visualizations...") visualize_results(model, val_loader, device, num_samples=8) # Compute metrics image_metrics = calculate_metrics(model, val_loader, device) print(f"PSNR: {image_metrics['psnr']:.4f}") print(f"SSIM: {image_metrics['ssim']:.4f}") print(f"LPIPS: {image_metrics['lpips']:.4f}") print(f"mIoU (pred vs GT): {image_metrics['miou']:.4f}") model = ImprovedUNet() model.eval()