1.2 / final1_2.py
hiren05's picture
Upload 2 files
7db50ad verified
# -*- coding: utf-8 -*-
"""final1.2.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1v6-6x7lqt6gr9VIauNVHIwjvIkewk8eT
"""
"""## FINAL 1.2"""
pip install torchmetrics lpips
# PyTorch, Torchvision
import torch
from torch import nn
from torchvision.transforms import ToPILImage, ToTensor
from torchvision.utils import make_grid
from torchvision.io import write_video
# Common
from pathlib import Path
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import random
import json
from IPython.display import Video
# Utils from Torchvision
tensor_to_image = ToPILImage()
image_to_tensor = ToTensor()
def get_img_dict(img_dir):
img_files = [x for x in img_dir.iterdir() if x.name.endswith('.png') or x.name.endswith('.tiff')]
img_files.sort()
img_dict = {}
for img_file in img_files:
img_type = img_file.name.split('_')[0]
if img_type not in img_dict:
img_dict[img_type] = []
img_dict[img_type].append(img_file)
return img_dict
def get_sample_dict(sample_dir):
camera_dirs = [x for x in sample_dir.iterdir() if 'camera' in x.name]
camera_dirs.sort()
sample_dict = {}
for cam_dir in camera_dirs:
cam_dict = {}
cam_dict['scene'] = get_img_dict(cam_dir)
obj_dirs = [x for x in cam_dir.iterdir() if 'obj_' in x.name]
obj_dirs.sort()
for obj_dir in obj_dirs:
cam_dict[obj_dir.name] = get_img_dict(obj_dir)
sample_dict[cam_dir.name] = cam_dict
return sample_dict
!wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/test_obj_descriptors.json
#Download Descriptors, Readme, etc.
!wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/train_obj_descriptors.json
!wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/ex_vis.mp4
!wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/README.md
!wget "https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/Notice%201%20-%20Unlimited_datasets.pdf"
!wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/.gitattributes
#Test to see if you are on the right huggingface repo
from huggingface_hub import HfApi, hf_hub_download
import random, os
api = HfApi()
repo_id = "Amar-S/MOVi-MC-AC"
# # List all files in the repo
files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
# # Separate train and test files
train_files = [f for f in files if f.startswith("train/") and not f.endswith(".json")]
test_files = [f for f in files if f.startswith("test/") and not f.endswith(".json")]
print(f"Found {len(train_files)} train files and {len(test_files)} test files.")
#Download 4% of Train/Test files
import os
import random
import shutil
from huggingface_hub import hf_hub_download
os.makedirs("/content/data/train", exist_ok=True)
os.makedirs("/content/data/test", exist_ok=True)
# # Sample 4% of each split (as you were doing)
subset_train = random.sample(train_files, int(len(train_files) * 0.005))
subset_test = random.sample(test_files, int(len(test_files) * 0.005))
# # Download the training files (uncomment and fix)
for file in subset_train:
out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file)
dest_path = f"/content/data/train/{os.path.basename(file)}"
shutil.copyfile(out_path, dest_path) # COPY the actual file content instead of renaming symlink
# # Download the test files
for file in subset_test:
out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file)
dest_path = f"/content/data/test/{os.path.basename(file)}"
shutil.copyfile(out_path, dest_path) # COPY the actual file content here as well
import os
# Untar all files in data/train
train_dir = "data/train"
for file in os.listdir(train_dir):
if file.endswith(".tar.gz"):
filepath = os.path.join(train_dir, file)
!tar -xzf {filepath} -C {train_dir}
# Untar all files in data/test
test_dir = "data/test"
for file in os.listdir(test_dir):
if file.endswith(".tar.gz"):
filepath = os.path.join(test_dir, file)
!tar -xzf {filepath} -C {test_dir}
import os
from pathlib import Path
root = Path('/content/data') # or wherever your files live
deleted = 0
for archive in root.rglob('*.tar.gz'):
try:
archive.unlink()
print(f"Deleted {archive}")
deleted += 1
except Exception as e:
print(f"Error deleting {archive}: {e}")
print(f"Total deleted: {deleted}")
pip install torchmetrics lpips
import matplotlib.pyplot as plt
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
import lpips
import matplotlib.pyplot as plt
import torch
def visualize_results(model, dataloader, device, num_samples=8):
"""Visualize results with properly masked output (no background)"""
model.eval()
samples_shown = 0
with torch.no_grad():
for batch in dataloader:
if samples_shown >= num_samples:
break
rgb = batch['rgb'].to(device)
modal_mask = batch['modal_mask'].to(device)
amodal_mask = batch['amodal_mask'].to(device)
gt_amodal_rgb = batch['amodal_rgb'].to(device)
input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1)
pred = model(input_tensor)
pred_masked = pred * amodal_mask # Remove background from prediction
gt_masked = gt_amodal_rgb * amodal_mask # Ensure GT is also masked consistently
for i in range(rgb.shape[0]):
if samples_shown >= num_samples:
break
fig, axes = plt.subplots(1, 6, figsize=(24, 4))
# Scene RGB
axes[0].imshow(rgb[i].cpu().permute(1, 2, 0))
axes[0].set_title('Scene RGB')
axes[0].axis('off')
# Amodal Mask
axes[1].imshow(amodal_mask[i, 0].cpu(), cmap='gray')
axes[1].set_title('Amodal Mask')
axes[1].axis('off')
# Modal Mask
axes[2].imshow(modal_mask[i, 0].cpu(), cmap='gray')
axes[2].set_title('Modal Mask')
axes[2].axis('off')
# Ground Truth Amodal RGB (masked)
axes[3].imshow(gt_masked[i].cpu().permute(1, 2, 0))
axes[3].set_title('GT Amodal RGB')
axes[3].axis('off')
# Predicted Amodal RGB (masked)
axes[4].imshow(pred_masked[i].cpu().permute(1, 2, 0))
axes[4].set_title('Predicted Amodal RGB')
axes[4].axis('off')
# Difference Heatmap
diff = torch.abs(pred_masked[i] - gt_masked[i]).mean(dim=0)
im = axes[5].imshow(diff.cpu(), cmap='hot')
axes[5].set_title('Prediction Error')
axes[5].axis('off')
plt.colorbar(im, ax=axes[5])
plt.tight_layout()
plt.show()
samples_shown += 1
# STEP 4: Add this function for better evaluation:
def evaluate_metrics(model, dataloader, device):
"""Compute evaluation metrics only within object regions"""
model.eval()
total_mse = 0
occluded_mse = 0
visible_mse = 0
total_pixels = 0
occluded_pixels = 0
visible_pixels = 0
with torch.no_grad():
for batch in dataloader:
rgb = batch['rgb'].to(device)
modal_mask = batch['modal_mask'].to(device)
amodal_mask = batch['amodal_mask'].to(device)
occluded_mask = batch['occluded_mask'].to(device)
gt_amodal_rgb = batch['amodal_rgb'].to(device)
input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1)
pred = model(input_tensor)
# Mask both prediction and ground truth to object regions only
pred_masked = pred * amodal_mask
gt_masked = gt_amodal_rgb * amodal_mask
# Overall MSE within object region
object_pixels = amodal_mask.sum()
if object_pixels > 0:
mse = F.mse_loss(pred_masked, gt_masked, reduction='sum')
total_mse += mse.item()
total_pixels += object_pixels.item()
# Occluded region MSE
occluded_region = occluded_mask * amodal_mask
occ_pixels = occluded_region.sum()
if occ_pixels > 0:
occ_mse = F.mse_loss(pred_masked * occluded_region,
gt_masked * occluded_region, reduction='sum')
occluded_mse += occ_mse.item()
occluded_pixels += occ_pixels.item()
# Visible region MSE
visible_region = modal_mask * amodal_mask
vis_pixels = visible_region.sum()
if vis_pixels > 0:
vis_mse = F.mse_loss(pred_masked * visible_region,
gt_masked * visible_region, reduction='sum')
visible_mse += vis_mse.item()
visible_pixels += vis_pixels.item()
return {
'total_mse': total_mse / total_pixels if total_pixels > 0 else 0,
'occluded_mse': occluded_mse / occluded_pixels if occluded_pixels > 0 else 0,
'visible_mse': visible_mse / visible_pixels if visible_pixels > 0 else 0,
}
def calculate_metrics(model, dataloader, device):
"""Computes PSNR, SSIM, LPIPS, and IoU between predictions and GT amodal RGBs."""
model.eval()
psnr = PeakSignalNoiseRatio().to(device)
ssim = StructuralSimilarityIndexMeasure().to(device)
lpips_loss = lpips.LPIPS(net='alex').to(device)
total_psnr, total_ssim, total_lpips = 0, 0, 0
total_iou = 0
count = 0
with torch.no_grad():
for batch in dataloader:
rgb = batch['rgb'].to(device)
modal_mask = batch['modal_mask'].to(device)
amodal_mask = batch['amodal_mask'].to(device)
gt_amodal_rgb = batch['amodal_rgb'].to(device)
input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1)
pred = model(input_tensor)
pred_masked = pred * amodal_mask
gt_masked = gt_amodal_rgb * amodal_mask
for i in range(pred.shape[0]):
pred_i = pred_masked[i].unsqueeze(0)
gt_i = gt_masked[i].unsqueeze(0)
# Resize for LPIPS if necessary (it requires >= 64x64)
if pred_i.shape[-1] < 64 or pred_i.shape[-2] < 64:
continue
total_psnr += psnr(pred_i, gt_i).item()
total_ssim += ssim(pred_i, gt_i).item()
total_lpips += lpips_loss(pred_i, gt_i).item()
# mIoU between masks
intersection = (amodal_mask[i] * (pred[i] > 0.5)).sum()
union = ((amodal_mask[i] + (pred[i] > 0.5)) > 0).sum()
if union > 0:
iou = intersection.float() / union.float()
total_iou += iou.item()
count += 1
if count == 0:
return {"psnr": 0, "ssim": 0, "lpips": 0, "miou": 0}
return {
"psnr": total_psnr / count,
"ssim": total_ssim / count,
"lpips": total_lpips / count,
"miou": total_iou / count
}
pip install torchmetrics lpips
import matplotlib.pyplot as plt
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
import lpips
import matplotlib.pyplot as plt
import torch
def visualize_results(model, dataloader, device, num_samples=8):
"""Visualize results with properly masked output (no background)"""
model.eval()
samples_shown = 0
with torch.no_grad():
for batch in dataloader:
if samples_shown >= num_samples:
break
rgb = batch['rgb'].to(device)
modal_mask = batch['modal_mask'].to(device)
amodal_mask = batch['amodal_mask'].to(device)
gt_amodal_rgb = batch['amodal_rgb'].to(device)
input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1)
pred = model(input_tensor)
pred_masked = pred * amodal_mask # Remove background from prediction
gt_masked = gt_amodal_rgb * amodal_mask # Ensure GT is also masked consistently
for i in range(rgb.shape[0]):
if samples_shown >= num_samples:
break
fig, axes = plt.subplots(1, 6, figsize=(24, 4))
# Scene RGB
axes[0].imshow(rgb[i].cpu().permute(1, 2, 0))
axes[0].set_title('Scene RGB')
axes[0].axis('off')
# Amodal Mask
axes[1].imshow(amodal_mask[i, 0].cpu(), cmap='gray')
axes[1].set_title('Amodal Mask')
axes[1].axis('off')
# Modal Mask
axes[2].imshow(modal_mask[i, 0].cpu(), cmap='gray')
axes[2].set_title('Modal Mask')
axes[2].axis('off')
# Ground Truth Amodal RGB (masked)
axes[3].imshow(gt_masked[i].cpu().permute(1, 2, 0))
axes[3].set_title('GT Amodal RGB')
axes[3].axis('off')
# Predicted Amodal RGB (masked)
axes[4].imshow(pred_masked[i].cpu().permute(1, 2, 0))
axes[4].set_title('Predicted Amodal RGB')
axes[4].axis('off')
# Difference Heatmap
diff = torch.abs(pred_masked[i] - gt_masked[i]).mean(dim=0)
im = axes[5].imshow(diff.cpu(), cmap='hot')
axes[5].set_title('Prediction Error')
axes[5].axis('off')
plt.colorbar(im, ax=axes[5])
plt.tight_layout()
plt.show()
samples_shown += 1
def evaluate_metrics(model, dataloader, device):
"""Compute evaluation metrics only within object regions"""
model.eval()
total_mse = 0
occluded_mse = 0
visible_mse = 0
total_pixels = 0
occluded_pixels = 0
visible_pixels = 0
with torch.no_grad():
for batch in dataloader:
rgb = batch['rgb'].to(device)
modal_mask = batch['modal_mask'].to(device)
amodal_mask = batch['amodal_mask'].to(device)
occluded_mask = batch['occluded_mask'].to(device)
gt_amodal_rgb = batch['amodal_rgb'].to(device)
input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1)
pred = model(input_tensor)
# Mask both prediction and ground truth to object regions only
pred_masked = pred * amodal_mask
gt_masked = gt_amodal_rgb * amodal_mask
# Overall MSE within object region
object_pixels = amodal_mask.sum()
if object_pixels > 0:
mse = F.mse_loss(pred_masked, gt_masked, reduction='sum')
total_mse += mse.item()
total_pixels += object_pixels.item()
# Occluded region MSE
occluded_region = occluded_mask * amodal_mask
occ_pixels = occluded_region.sum()
if occ_pixels > 0:
occ_mse = F.mse_loss(pred_masked * occluded_region,
gt_masked * occluded_region, reduction='sum')
occluded_mse += occ_mse.item()
occluded_pixels += occ_pixels.item()
# Visible region MSE
visible_region = modal_mask * amodal_mask
vis_pixels = visible_region.sum()
if vis_pixels > 0:
vis_mse = F.mse_loss(pred_masked * visible_region,
gt_masked * visible_region, reduction='sum')
visible_mse += vis_mse.item()
visible_pixels += vis_pixels.item()
return {
'total_mse': total_mse / total_pixels if total_pixels > 0 else 0,
'occluded_mse': occluded_mse / occluded_pixels if occluded_pixels > 0 else 0,
'visible_mse': visible_mse / visible_pixels if visible_pixels > 0 else 0,
}
def calculate_metrics(model, dataloader, device):
"""Computes PSNR, SSIM, LPIPS, and IoU between predictions and GT amodal RGBs."""
model.eval()
psnr = PeakSignalNoiseRatio().to(device)
ssim = StructuralSimilarityIndexMeasure().to(device)
lpips_loss = lpips.LPIPS(net='alex').to(device)
total_psnr, total_ssim, total_lpips = 0, 0, 0
total_iou = 0
count = 0
with torch.no_grad():
for batch in dataloader:
rgb = batch['rgb'].to(device)
modal_mask = batch['modal_mask'].to(device)
amodal_mask = batch['amodal_mask'].to(device)
gt_amodal_rgb = batch['amodal_rgb'].to(device)
input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1)
pred = model(input_tensor)
pred_masked = pred * amodal_mask
gt_masked = gt_amodal_rgb * amodal_mask
for i in range(pred.shape[0]):
pred_i = pred_masked[i].unsqueeze(0)
gt_i = gt_masked[i].unsqueeze(0)
# Resize for LPIPS if necessary (it requires >= 64x64)
if pred_i.shape[-1] < 64 or pred_i.shape[-2] < 64:
continue
total_psnr += psnr(pred_i, gt_i).item()
total_ssim += ssim(pred_i, gt_i).item()
total_lpips += lpips_loss(pred_i, gt_i).item()
# mIoU between masks
intersection = (amodal_mask[i] * (pred[i] > 0.5)).sum()
union = ((amodal_mask[i] + (pred[i] > 0.5)) > 0).sum()
if union > 0:
iou = intersection.float() / union.float()
total_iou += iou.item()
count += 1
if count == 0:
return {"psnr": 0, "ssim": 0, "lpips": 0, "miou": 0}
return {
"psnr": total_psnr / count,
"ssim": total_ssim / count,
"lpips": total_lpips / count,
"miou": total_iou / count
}
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pathlib import Path
from PIL import Image, ImageChops
import numpy as np
class ModalAmodalDataset(Dataset):
def __init__(self, root_dir, split, img_size=(128, 128), max_samples=None, val_split=0.2, use_val_from_train=False):
self.root_dir = Path(root_dir)
self.img_size = img_size
self.max_samples = max_samples
self.val_split = val_split
self.use_val_from_train = use_val_from_train
self.split = split
if split == 'val' and use_val_from_train:
# Load from train folder but use validation subset
self.root_dir = self.root_dir / 'train'
else:
self.root_dir = self.root_dir / split
self.samples = self._build_sample_index()
self.rgb_transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
])
self.mask_transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
])
def _build_sample_index(self):
samples = []
for scene_dir in self.root_dir.iterdir():
if not scene_dir.is_dir():
continue
for camera_dir in scene_dir.iterdir():
if not camera_dir.name.startswith('camera_'):
continue
rgba_paths = sorted(camera_dir.glob('rgba_*.png'))
seg_paths = sorted(camera_dir.glob('segmentation_*.png'))
for obj_dir in camera_dir.iterdir():
if not obj_dir.name.startswith('obj_'):
continue
amodal_paths = sorted(obj_dir.glob('segmentation_*.png'))
amodal_rgb_paths = sorted(obj_dir.glob('rgba_*.png'))
if not (len(rgba_paths) == len(seg_paths) == len(amodal_paths) == len(amodal_rgb_paths)):
continue
for rgba_path, seg_path, amodal_path, amodal_rgb_path in zip(
rgba_paths, seg_paths, amodal_paths, amodal_rgb_paths
):
samples.append({
'rgb_path': rgba_path,
'seg_path': seg_path,
'amodal_path': amodal_path,
'amodal_rgb_path': amodal_rgb_path,
'object_id': int(obj_dir.name.split('_')[1]),
'scene': scene_dir.name,
'camera': camera_dir.name
})
# Limit dataset size if specified
if self.max_samples is not None and len(samples) > self.max_samples:
# Randomly sample to get diverse examples
import random
random.seed(42) # For reproducibility
samples = random.sample(samples, self.max_samples)
print(f"Dataset limited to {len(samples)} samples")
# Create train/val split if using validation from train
if self.use_val_from_train:
import random
random.seed(42) # Ensure reproducible splits
random.shuffle(samples)
val_size = int(len(samples) * self.val_split)
if self.split == 'train':
samples = samples[val_size:] # Use remaining samples for training
print(f"Train split: {len(samples)} samples")
elif self.split == 'val':
samples = samples[:val_size] # Use first samples for validation
print(f"Validation split: {len(samples)} samples")
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
# Load images
rgb = Image.open(sample['rgb_path']).convert('RGB')
seg_map = np.array(Image.open(sample['seg_path']))
amodal_mask_img = Image.open(sample['amodal_path']).convert('L')
amodal_rgb = Image.open(sample['amodal_rgb_path']).convert('RGB')
# Compute modal mask (visible part)
modal_mask_np = (seg_map == sample['object_id']).astype(np.uint8) * 255
modal_mask_img = Image.fromarray(modal_mask_np, mode='L')
# Transform images and masks
rgb = self.rgb_transform(rgb)
modal_mask = self.mask_transform(modal_mask_img)
amodal_mask = self.mask_transform(amodal_mask_img)
amodal_rgb = self.rgb_transform(amodal_rgb)
# Create occluded mask (parts that are hidden)
occluded_mask = amodal_mask - modal_mask
occluded_mask = torch.clamp(occluded_mask, 0, 1)
return {
'rgb': rgb,
'modal_mask': modal_mask,
'amodal_mask': amodal_mask,
'occluded_mask': occluded_mask,
'amodal_rgb': amodal_rgb,
}
class ImprovedUNet(nn.Module):
def __init__(self, in_channels=5, out_channels=3): # RGB + modal_mask + amodal_mask
super().__init__()
def conv_block(in_ch, out_ch, dropout=0.1):
return nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Dropout2d(dropout),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
# Encoder
self.down1 = conv_block(in_channels, 64)
self.pool1 = nn.MaxPool2d(2)
self.down2 = conv_block(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.down3 = conv_block(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.down4 = conv_block(256, 512)
self.pool4 = nn.MaxPool2d(2)
# Bottleneck
self.middle = conv_block(512, 1024, dropout=0.2)
# Decoder
self.up1 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.up_block1 = conv_block(1024, 512)
self.up2 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.up_block2 = conv_block(512, 256)
self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.up_block3 = conv_block(256, 128)
self.up4 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.up_block4 = conv_block(128, 64)
self.final = nn.Conv2d(64, out_channels, 1)
def forward(self, x):
# Encoder
d1 = self.down1(x)
d2 = self.down2(self.pool1(d1))
d3 = self.down3(self.pool2(d2))
d4 = self.down4(self.pool3(d3))
# Bottleneck
m = self.middle(self.pool4(d4))
# Decoder with skip connections
u1 = self.up_block1(torch.cat([self.up1(m), d4], dim=1))
u2 = self.up_block2(torch.cat([self.up2(u1), d3], dim=1))
u3 = self.up_block3(torch.cat([self.up3(u2), d2], dim=1))
u4 = self.up_block4(torch.cat([self.up4(u3), d1], dim=1))
return torch.sigmoid(self.final(u4)) # Ensure output is in [0,1]
class AmodalCompletionLoss(nn.Module):
"""Loss that only considers object regions (ignores background)"""
def __init__(self, occluded_weight=5.0, visible_weight=1.0):
super().__init__()
self.occluded_weight = occluded_weight
self.visible_weight = visible_weight
self.lpips_model = lpips.LPIPS(net='alex')
def forward(self, pred, target, modal_mask, occluded_mask, amodal_mask):
# Only compute loss within the amodal mask (object region)
device = pred.device
self.lpips_model = self.lpips_model.to(device)
pred_masked = pred * amodal_mask
target_masked = target * amodal_mask
# Loss on visible parts (within object)
visible_region = modal_mask * amodal_mask
if visible_region.sum() > 0:
visible_loss = F.mse_loss(pred_masked * visible_region, target_masked * visible_region)
else:
visible_loss = torch.tensor(0.0).to(pred.device)
# Loss on occluded parts (within object)
occluded_region = occluded_mask * amodal_mask
if occluded_region.sum() > 0:
occluded_loss = F.mse_loss(pred_masked * occluded_region, target_masked * occluded_region)
else:
occluded_loss = torch.tensor(0.0).to(pred.device)
perceptual_loss = self.lpips_model(pred_masked, target_masked).mean()
# Boundary consistency within object
boundary_mask = F.conv2d(amodal_mask, torch.ones(1,1,3,3).to(amodal_mask.device), padding=1)
boundary_mask = ((boundary_mask > 0) & (boundary_mask < 9)).float()
boundary_loss = F.mse_loss(pred_masked * boundary_mask, target_masked * boundary_mask)
total_loss = (self.visible_weight * visible_loss +
self.occluded_weight * occluded_loss +
2.0 * boundary_loss)
return total_loss, visible_loss, occluded_loss, boundary_loss
def train_improved(model, dataloader, optimizer, device, num_epochs):
model.train()
criterion = AmodalCompletionLoss()
for epoch in range(num_epochs):
total_loss = 0
for i, batch in enumerate(dataloader):
rgb = batch['rgb'].to(device)
modal_mask = batch['modal_mask'].to(device)
amodal_mask = batch['amodal_mask'].to(device)
occluded_mask = batch['occluded_mask'].to(device)
gt_amodal_rgb = batch['amodal_rgb'].to(device)
input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1)
optimizer.zero_grad()
pred = model(input_tensor)
loss, vis_loss, occ_loss, boundary_loss = criterion(
pred, gt_amodal_rgb, modal_mask, occluded_mask, amodal_mask
)
loss.backward()
optimizer.step()
total_loss += loss.item()
if i % 16 == 0:
print(f"Epoch [{epoch}/{num_epochs}] [{i}/{len(dataloader)}] "
f"Total: {loss.item():.4f}, Visible: {vis_loss.item():.4f}, "
f"Occluded: {occ_loss.item():.4f}, Boundary: {boundary_loss.item():.4f}")
print(f"Epoch {epoch} Average Loss: {total_loss/len(dataloader):.4f}")
# Usage
if __name__ == "__main__":
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Dataset and DataLoader - REDUCED SIZE FOR FASTER TRAINING
data_root = "data"
# Create train dataset (80% of train folder)
train_dataset = ModalAmodalDataset(
root_dir=data_root,
split='train',
img_size=(128, 128),
max_samples=1000, # Only use 1000 samples total before split
val_split=0.2, # 20% for validation
use_val_from_train=True # Create val split from train folder
)
train_loader = DataLoader(
train_dataset,
batch_size=16,
shuffle=True,
num_workers=2,
pin_memory=True,
drop_last=True
)
# Create validation dataset (20% of train folder)
val_dataset = ModalAmodalDataset(
root_dir=data_root,
split='val',
img_size=(128, 128),
max_samples=1000, # Same max_samples to ensure proper split
val_split=0.2,
use_val_from_train=True # Create val split from train folder
)
val_loader = DataLoader(
val_dataset,
batch_size=4,
shuffle=True,
num_workers=2,
pin_memory=True
)
print(f"Training on {len(train_dataset)} samples, {len(train_loader)} batches per epoch")
print(f"Validation on {len(val_dataset)} samples, {len(val_loader)} batches")
model = ImprovedUNet().to(device)
model.load_state_dict(torch.load('amodal_completion_model.pth', map_location=device))
# Model and optimizer
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
# Training
#train_improved(model, train_loader, optimizer, device, num_epochs=10)
# Evaluation and Visualization
print("\n" + "="*50)
print("EVALUATION RESULTS")
print("="*50)
# Compute metrics
metrics = evaluate_metrics(model, val_loader, device)
print(f"Overall MSE: {metrics['total_mse']:.6f}")
print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}")
print(f"Visible Region MSE: {metrics['visible_mse']:.6f}")
print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}")
# Visualize results
print("\nGenerating visualizations...")
visualize_results(model, val_loader, device, num_samples=8)
# Compute metrics
image_metrics = calculate_metrics(model, val_loader, device)
print(f"PSNR: {image_metrics['psnr']:.4f}")
print(f"SSIM: {image_metrics['ssim']:.4f}")
print(f"LPIPS: {image_metrics['lpips']:.4f}")
print(f"mIoU (pred vs GT): {image_metrics['miou']:.4f}")
# Dataset and DataLoader - REDUCED SIZE FOR FASTER TRAINING
data_root = "data"
# Create train dataset (80% of train folder)
train_dataset = ModalAmodalDataset(
root_dir=data_root,
split='train',
img_size=(128, 128),
max_samples=1000, # Only use 1000 samples total before split
val_split=0.2, # 20% for validation
use_val_from_train=True # Create val split from train folder
)
train_loader = DataLoader(
train_dataset,
batch_size=16,
shuffle=True,
num_workers=2,
pin_memory=True,
drop_last=True
)
# Create validation dataset (20% of train folder)
val_dataset = ModalAmodalDataset(
root_dir=data_root,
split='val',
img_size=(128, 128),
max_samples=1000, # Same max_samples to ensure proper split
val_split=0.2,
use_val_from_train=True # Create val split from train folder
)
val_loader = DataLoader(
val_dataset,
batch_size=4,
shuffle=True,
num_workers=2,
pin_memory=True
)
# Optional: Save model
torch.save(model.state_dict(), 'amodal_completion_model.pth')
# Evaluation and Visualization
test_dataset = ModalAmodalDataset(
root_dir=data_root,
split='test',
img_size=(128, 128),
max_samples=2000 # Only use 1000 samples total before split
)
test_loader = DataLoader(
test_dataset,
batch_size=8,
shuffle=True,
num_workers=2,
pin_memory=True,
drop_last=True
)
print("EVALUATION RESULTS")
print("="*50)
# Compute metrics
metrics = evaluate_metrics(model, test_loader, device)
print(f"Overall MSE: {metrics['total_mse']:.6f}")
print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}")
print(f"Visible Region MSE: {metrics['visible_mse']:.6f}")
print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}")
# Visualize results
print("\nGenerating visualizations...")
visualize_results(model, test_loader, device, num_samples=16)
from google.colab import runtime
runtime.unassign()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ImprovedUNet() # replace with actual class name
torch.load('amodal_completion_model.pth', map_location=torch.device('cpu'))
model.to(device)
model.eval()
# Evaluation and Visualization
print("\n" + "="*50)
print("EVALUATION RESULTS")
print("="*50)
# Compute metrics
metrics = evaluate_metrics(model, val_loader, device)
print(f"Overall MSE: {metrics['total_mse']:.6f}")
print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}")
print(f"Visible Region MSE: {metrics['visible_mse']:.6f}")
print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}")
# Visualize results
print("\nGenerating visualizations...")
visualize_results(model, val_loader, device, num_samples=8)
# Compute metrics
image_metrics = calculate_metrics(model, val_loader, device)
print(f"PSNR: {image_metrics['psnr']:.4f}")
print(f"SSIM: {image_metrics['ssim']:.4f}")
print(f"LPIPS: {image_metrics['lpips']:.4f}")
print(f"mIoU (pred vs GT): {image_metrics['miou']:.4f}")
model = ImprovedUNet()
model.eval()