|
|
|
|
|
"""final1.2.ipynb |
|
|
|
|
|
Automatically generated by Colab. |
|
|
|
|
|
Original file is located at |
|
|
https://colab.research.google.com/drive/1v6-6x7lqt6gr9VIauNVHIwjvIkewk8eT |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
"""## FINAL 1.2""" |
|
|
|
|
|
|
|
|
|
|
|
pip install torchmetrics lpips |
|
|
|
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torchvision.transforms import ToPILImage, ToTensor |
|
|
from torchvision.utils import make_grid |
|
|
from torchvision.io import write_video |
|
|
|
|
|
|
|
|
from pathlib import Path |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import random |
|
|
import json |
|
|
from IPython.display import Video |
|
|
|
|
|
|
|
|
tensor_to_image = ToPILImage() |
|
|
image_to_tensor = ToTensor() |
|
|
|
|
|
def get_img_dict(img_dir): |
|
|
img_files = [x for x in img_dir.iterdir() if x.name.endswith('.png') or x.name.endswith('.tiff')] |
|
|
img_files.sort() |
|
|
|
|
|
img_dict = {} |
|
|
for img_file in img_files: |
|
|
img_type = img_file.name.split('_')[0] |
|
|
if img_type not in img_dict: |
|
|
img_dict[img_type] = [] |
|
|
img_dict[img_type].append(img_file) |
|
|
return img_dict |
|
|
|
|
|
|
|
|
def get_sample_dict(sample_dir): |
|
|
|
|
|
camera_dirs = [x for x in sample_dir.iterdir() if 'camera' in x.name] |
|
|
camera_dirs.sort() |
|
|
|
|
|
sample_dict = {} |
|
|
|
|
|
for cam_dir in camera_dirs: |
|
|
cam_dict = {} |
|
|
cam_dict['scene'] = get_img_dict(cam_dir) |
|
|
|
|
|
obj_dirs = [x for x in cam_dir.iterdir() if 'obj_' in x.name] |
|
|
obj_dirs.sort() |
|
|
|
|
|
for obj_dir in obj_dirs: |
|
|
cam_dict[obj_dir.name] = get_img_dict(obj_dir) |
|
|
|
|
|
sample_dict[cam_dir.name] = cam_dict |
|
|
|
|
|
return sample_dict |
|
|
|
|
|
!wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/test_obj_descriptors.json |
|
|
|
|
|
!wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/train_obj_descriptors.json |
|
|
!wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/ex_vis.mp4 |
|
|
!wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/README.md |
|
|
!wget "https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/Notice%201%20-%20Unlimited_datasets.pdf" |
|
|
!wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/.gitattributes |
|
|
|
|
|
from huggingface_hub import HfApi, hf_hub_download |
|
|
import random, os |
|
|
api = HfApi() |
|
|
repo_id = "Amar-S/MOVi-MC-AC" |
|
|
|
|
|
files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") |
|
|
|
|
|
train_files = [f for f in files if f.startswith("train/") and not f.endswith(".json")] |
|
|
test_files = [f for f in files if f.startswith("test/") and not f.endswith(".json")] |
|
|
print(f"Found {len(train_files)} train files and {len(test_files)} test files.") |
|
|
|
|
|
import os |
|
|
import random |
|
|
import shutil |
|
|
from huggingface_hub import hf_hub_download |
|
|
os.makedirs("/content/data/train", exist_ok=True) |
|
|
os.makedirs("/content/data/test", exist_ok=True) |
|
|
|
|
|
subset_train = random.sample(train_files, int(len(train_files) * 0.005)) |
|
|
subset_test = random.sample(test_files, int(len(test_files) * 0.005)) |
|
|
|
|
|
for file in subset_train: |
|
|
out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file) |
|
|
dest_path = f"/content/data/train/{os.path.basename(file)}" |
|
|
shutil.copyfile(out_path, dest_path) |
|
|
|
|
|
for file in subset_test: |
|
|
out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file) |
|
|
dest_path = f"/content/data/test/{os.path.basename(file)}" |
|
|
shutil.copyfile(out_path, dest_path) |
|
|
|
|
|
import os |
|
|
|
|
|
|
|
|
train_dir = "data/train" |
|
|
for file in os.listdir(train_dir): |
|
|
if file.endswith(".tar.gz"): |
|
|
filepath = os.path.join(train_dir, file) |
|
|
!tar -xzf {filepath} -C {train_dir} |
|
|
|
|
|
|
|
|
test_dir = "data/test" |
|
|
for file in os.listdir(test_dir): |
|
|
if file.endswith(".tar.gz"): |
|
|
filepath = os.path.join(test_dir, file) |
|
|
!tar -xzf {filepath} -C {test_dir} |
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
from pathlib import Path |
|
|
root = Path('/content/data') |
|
|
deleted = 0 |
|
|
for archive in root.rglob('*.tar.gz'): |
|
|
try: |
|
|
archive.unlink() |
|
|
print(f"Deleted {archive}") |
|
|
deleted += 1 |
|
|
except Exception as e: |
|
|
print(f"Error deleting {archive}: {e}") |
|
|
print(f"Total deleted: {deleted}") |
|
|
|
|
|
pip install torchmetrics lpips |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure |
|
|
import lpips |
|
|
import matplotlib.pyplot as plt |
|
|
import torch |
|
|
|
|
|
def visualize_results(model, dataloader, device, num_samples=8): |
|
|
"""Visualize results with properly masked output (no background)""" |
|
|
model.eval() |
|
|
samples_shown = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in dataloader: |
|
|
if samples_shown >= num_samples: |
|
|
break |
|
|
|
|
|
rgb = batch['rgb'].to(device) |
|
|
modal_mask = batch['modal_mask'].to(device) |
|
|
amodal_mask = batch['amodal_mask'].to(device) |
|
|
gt_amodal_rgb = batch['amodal_rgb'].to(device) |
|
|
|
|
|
input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
|
|
pred = model(input_tensor) |
|
|
|
|
|
pred_masked = pred * amodal_mask |
|
|
gt_masked = gt_amodal_rgb * amodal_mask |
|
|
|
|
|
for i in range(rgb.shape[0]): |
|
|
if samples_shown >= num_samples: |
|
|
break |
|
|
|
|
|
fig, axes = plt.subplots(1, 6, figsize=(24, 4)) |
|
|
|
|
|
|
|
|
axes[0].imshow(rgb[i].cpu().permute(1, 2, 0)) |
|
|
axes[0].set_title('Scene RGB') |
|
|
axes[0].axis('off') |
|
|
|
|
|
|
|
|
axes[1].imshow(amodal_mask[i, 0].cpu(), cmap='gray') |
|
|
axes[1].set_title('Amodal Mask') |
|
|
axes[1].axis('off') |
|
|
|
|
|
|
|
|
axes[2].imshow(modal_mask[i, 0].cpu(), cmap='gray') |
|
|
axes[2].set_title('Modal Mask') |
|
|
axes[2].axis('off') |
|
|
|
|
|
|
|
|
axes[3].imshow(gt_masked[i].cpu().permute(1, 2, 0)) |
|
|
axes[3].set_title('GT Amodal RGB') |
|
|
axes[3].axis('off') |
|
|
|
|
|
|
|
|
axes[4].imshow(pred_masked[i].cpu().permute(1, 2, 0)) |
|
|
axes[4].set_title('Predicted Amodal RGB') |
|
|
axes[4].axis('off') |
|
|
|
|
|
|
|
|
diff = torch.abs(pred_masked[i] - gt_masked[i]).mean(dim=0) |
|
|
im = axes[5].imshow(diff.cpu(), cmap='hot') |
|
|
axes[5].set_title('Prediction Error') |
|
|
axes[5].axis('off') |
|
|
plt.colorbar(im, ax=axes[5]) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
samples_shown += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_metrics(model, dataloader, device): |
|
|
"""Compute evaluation metrics only within object regions""" |
|
|
model.eval() |
|
|
total_mse = 0 |
|
|
occluded_mse = 0 |
|
|
visible_mse = 0 |
|
|
total_pixels = 0 |
|
|
occluded_pixels = 0 |
|
|
visible_pixels = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in dataloader: |
|
|
rgb = batch['rgb'].to(device) |
|
|
modal_mask = batch['modal_mask'].to(device) |
|
|
amodal_mask = batch['amodal_mask'].to(device) |
|
|
occluded_mask = batch['occluded_mask'].to(device) |
|
|
gt_amodal_rgb = batch['amodal_rgb'].to(device) |
|
|
|
|
|
input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
|
|
pred = model(input_tensor) |
|
|
|
|
|
|
|
|
pred_masked = pred * amodal_mask |
|
|
gt_masked = gt_amodal_rgb * amodal_mask |
|
|
|
|
|
|
|
|
object_pixels = amodal_mask.sum() |
|
|
if object_pixels > 0: |
|
|
mse = F.mse_loss(pred_masked, gt_masked, reduction='sum') |
|
|
total_mse += mse.item() |
|
|
total_pixels += object_pixels.item() |
|
|
|
|
|
|
|
|
occluded_region = occluded_mask * amodal_mask |
|
|
occ_pixels = occluded_region.sum() |
|
|
if occ_pixels > 0: |
|
|
occ_mse = F.mse_loss(pred_masked * occluded_region, |
|
|
gt_masked * occluded_region, reduction='sum') |
|
|
occluded_mse += occ_mse.item() |
|
|
occluded_pixels += occ_pixels.item() |
|
|
|
|
|
|
|
|
visible_region = modal_mask * amodal_mask |
|
|
vis_pixels = visible_region.sum() |
|
|
if vis_pixels > 0: |
|
|
vis_mse = F.mse_loss(pred_masked * visible_region, |
|
|
gt_masked * visible_region, reduction='sum') |
|
|
visible_mse += vis_mse.item() |
|
|
visible_pixels += vis_pixels.item() |
|
|
|
|
|
return { |
|
|
'total_mse': total_mse / total_pixels if total_pixels > 0 else 0, |
|
|
'occluded_mse': occluded_mse / occluded_pixels if occluded_pixels > 0 else 0, |
|
|
'visible_mse': visible_mse / visible_pixels if visible_pixels > 0 else 0, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def calculate_metrics(model, dataloader, device): |
|
|
"""Computes PSNR, SSIM, LPIPS, and IoU between predictions and GT amodal RGBs.""" |
|
|
|
|
|
model.eval() |
|
|
psnr = PeakSignalNoiseRatio().to(device) |
|
|
ssim = StructuralSimilarityIndexMeasure().to(device) |
|
|
lpips_loss = lpips.LPIPS(net='alex').to(device) |
|
|
|
|
|
total_psnr, total_ssim, total_lpips = 0, 0, 0 |
|
|
total_iou = 0 |
|
|
count = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in dataloader: |
|
|
rgb = batch['rgb'].to(device) |
|
|
modal_mask = batch['modal_mask'].to(device) |
|
|
amodal_mask = batch['amodal_mask'].to(device) |
|
|
gt_amodal_rgb = batch['amodal_rgb'].to(device) |
|
|
|
|
|
input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
|
|
pred = model(input_tensor) |
|
|
|
|
|
pred_masked = pred * amodal_mask |
|
|
gt_masked = gt_amodal_rgb * amodal_mask |
|
|
|
|
|
for i in range(pred.shape[0]): |
|
|
pred_i = pred_masked[i].unsqueeze(0) |
|
|
gt_i = gt_masked[i].unsqueeze(0) |
|
|
|
|
|
|
|
|
if pred_i.shape[-1] < 64 or pred_i.shape[-2] < 64: |
|
|
continue |
|
|
|
|
|
total_psnr += psnr(pred_i, gt_i).item() |
|
|
total_ssim += ssim(pred_i, gt_i).item() |
|
|
total_lpips += lpips_loss(pred_i, gt_i).item() |
|
|
|
|
|
|
|
|
intersection = (amodal_mask[i] * (pred[i] > 0.5)).sum() |
|
|
union = ((amodal_mask[i] + (pred[i] > 0.5)) > 0).sum() |
|
|
if union > 0: |
|
|
iou = intersection.float() / union.float() |
|
|
total_iou += iou.item() |
|
|
|
|
|
count += 1 |
|
|
|
|
|
if count == 0: |
|
|
return {"psnr": 0, "ssim": 0, "lpips": 0, "miou": 0} |
|
|
|
|
|
return { |
|
|
"psnr": total_psnr / count, |
|
|
"ssim": total_ssim / count, |
|
|
"lpips": total_lpips / count, |
|
|
"miou": total_iou / count |
|
|
} |
|
|
|
|
|
pip install torchmetrics lpips |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure |
|
|
import lpips |
|
|
import matplotlib.pyplot as plt |
|
|
import torch |
|
|
|
|
|
def visualize_results(model, dataloader, device, num_samples=8): |
|
|
"""Visualize results with properly masked output (no background)""" |
|
|
model.eval() |
|
|
samples_shown = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in dataloader: |
|
|
if samples_shown >= num_samples: |
|
|
break |
|
|
|
|
|
rgb = batch['rgb'].to(device) |
|
|
modal_mask = batch['modal_mask'].to(device) |
|
|
amodal_mask = batch['amodal_mask'].to(device) |
|
|
gt_amodal_rgb = batch['amodal_rgb'].to(device) |
|
|
|
|
|
input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
|
|
pred = model(input_tensor) |
|
|
|
|
|
pred_masked = pred * amodal_mask |
|
|
gt_masked = gt_amodal_rgb * amodal_mask |
|
|
|
|
|
for i in range(rgb.shape[0]): |
|
|
if samples_shown >= num_samples: |
|
|
break |
|
|
|
|
|
fig, axes = plt.subplots(1, 6, figsize=(24, 4)) |
|
|
|
|
|
|
|
|
axes[0].imshow(rgb[i].cpu().permute(1, 2, 0)) |
|
|
axes[0].set_title('Scene RGB') |
|
|
axes[0].axis('off') |
|
|
|
|
|
|
|
|
axes[1].imshow(amodal_mask[i, 0].cpu(), cmap='gray') |
|
|
axes[1].set_title('Amodal Mask') |
|
|
axes[1].axis('off') |
|
|
|
|
|
|
|
|
axes[2].imshow(modal_mask[i, 0].cpu(), cmap='gray') |
|
|
axes[2].set_title('Modal Mask') |
|
|
axes[2].axis('off') |
|
|
|
|
|
|
|
|
axes[3].imshow(gt_masked[i].cpu().permute(1, 2, 0)) |
|
|
axes[3].set_title('GT Amodal RGB') |
|
|
axes[3].axis('off') |
|
|
|
|
|
|
|
|
axes[4].imshow(pred_masked[i].cpu().permute(1, 2, 0)) |
|
|
axes[4].set_title('Predicted Amodal RGB') |
|
|
axes[4].axis('off') |
|
|
|
|
|
|
|
|
diff = torch.abs(pred_masked[i] - gt_masked[i]).mean(dim=0) |
|
|
im = axes[5].imshow(diff.cpu(), cmap='hot') |
|
|
axes[5].set_title('Prediction Error') |
|
|
axes[5].axis('off') |
|
|
plt.colorbar(im, ax=axes[5]) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
samples_shown += 1 |
|
|
|
|
|
|
|
|
def evaluate_metrics(model, dataloader, device): |
|
|
"""Compute evaluation metrics only within object regions""" |
|
|
model.eval() |
|
|
total_mse = 0 |
|
|
occluded_mse = 0 |
|
|
visible_mse = 0 |
|
|
total_pixels = 0 |
|
|
occluded_pixels = 0 |
|
|
visible_pixels = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in dataloader: |
|
|
rgb = batch['rgb'].to(device) |
|
|
modal_mask = batch['modal_mask'].to(device) |
|
|
amodal_mask = batch['amodal_mask'].to(device) |
|
|
occluded_mask = batch['occluded_mask'].to(device) |
|
|
gt_amodal_rgb = batch['amodal_rgb'].to(device) |
|
|
|
|
|
input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
|
|
pred = model(input_tensor) |
|
|
|
|
|
|
|
|
pred_masked = pred * amodal_mask |
|
|
gt_masked = gt_amodal_rgb * amodal_mask |
|
|
|
|
|
|
|
|
object_pixels = amodal_mask.sum() |
|
|
if object_pixels > 0: |
|
|
mse = F.mse_loss(pred_masked, gt_masked, reduction='sum') |
|
|
total_mse += mse.item() |
|
|
total_pixels += object_pixels.item() |
|
|
|
|
|
|
|
|
occluded_region = occluded_mask * amodal_mask |
|
|
occ_pixels = occluded_region.sum() |
|
|
if occ_pixels > 0: |
|
|
occ_mse = F.mse_loss(pred_masked * occluded_region, |
|
|
gt_masked * occluded_region, reduction='sum') |
|
|
occluded_mse += occ_mse.item() |
|
|
occluded_pixels += occ_pixels.item() |
|
|
|
|
|
|
|
|
visible_region = modal_mask * amodal_mask |
|
|
vis_pixels = visible_region.sum() |
|
|
if vis_pixels > 0: |
|
|
vis_mse = F.mse_loss(pred_masked * visible_region, |
|
|
gt_masked * visible_region, reduction='sum') |
|
|
visible_mse += vis_mse.item() |
|
|
visible_pixels += vis_pixels.item() |
|
|
|
|
|
return { |
|
|
'total_mse': total_mse / total_pixels if total_pixels > 0 else 0, |
|
|
'occluded_mse': occluded_mse / occluded_pixels if occluded_pixels > 0 else 0, |
|
|
'visible_mse': visible_mse / visible_pixels if visible_pixels > 0 else 0, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def calculate_metrics(model, dataloader, device): |
|
|
"""Computes PSNR, SSIM, LPIPS, and IoU between predictions and GT amodal RGBs.""" |
|
|
|
|
|
model.eval() |
|
|
psnr = PeakSignalNoiseRatio().to(device) |
|
|
ssim = StructuralSimilarityIndexMeasure().to(device) |
|
|
lpips_loss = lpips.LPIPS(net='alex').to(device) |
|
|
|
|
|
total_psnr, total_ssim, total_lpips = 0, 0, 0 |
|
|
total_iou = 0 |
|
|
count = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in dataloader: |
|
|
rgb = batch['rgb'].to(device) |
|
|
modal_mask = batch['modal_mask'].to(device) |
|
|
amodal_mask = batch['amodal_mask'].to(device) |
|
|
gt_amodal_rgb = batch['amodal_rgb'].to(device) |
|
|
|
|
|
input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
|
|
pred = model(input_tensor) |
|
|
|
|
|
pred_masked = pred * amodal_mask |
|
|
gt_masked = gt_amodal_rgb * amodal_mask |
|
|
|
|
|
for i in range(pred.shape[0]): |
|
|
pred_i = pred_masked[i].unsqueeze(0) |
|
|
gt_i = gt_masked[i].unsqueeze(0) |
|
|
|
|
|
|
|
|
if pred_i.shape[-1] < 64 or pred_i.shape[-2] < 64: |
|
|
continue |
|
|
|
|
|
total_psnr += psnr(pred_i, gt_i).item() |
|
|
total_ssim += ssim(pred_i, gt_i).item() |
|
|
total_lpips += lpips_loss(pred_i, gt_i).item() |
|
|
|
|
|
|
|
|
intersection = (amodal_mask[i] * (pred[i] > 0.5)).sum() |
|
|
union = ((amodal_mask[i] + (pred[i] > 0.5)) > 0).sum() |
|
|
if union > 0: |
|
|
iou = intersection.float() / union.float() |
|
|
total_iou += iou.item() |
|
|
|
|
|
count += 1 |
|
|
|
|
|
if count == 0: |
|
|
return {"psnr": 0, "ssim": 0, "lpips": 0, "miou": 0} |
|
|
|
|
|
return { |
|
|
"psnr": total_psnr / count, |
|
|
"ssim": total_ssim / count, |
|
|
"lpips": total_lpips / count, |
|
|
"miou": total_iou / count |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torchvision import transforms |
|
|
from pathlib import Path |
|
|
from PIL import Image, ImageChops |
|
|
import numpy as np |
|
|
|
|
|
class ModalAmodalDataset(Dataset): |
|
|
def __init__(self, root_dir, split, img_size=(128, 128), max_samples=None, val_split=0.2, use_val_from_train=False): |
|
|
self.root_dir = Path(root_dir) |
|
|
self.img_size = img_size |
|
|
self.max_samples = max_samples |
|
|
self.val_split = val_split |
|
|
self.use_val_from_train = use_val_from_train |
|
|
self.split = split |
|
|
|
|
|
if split == 'val' and use_val_from_train: |
|
|
|
|
|
self.root_dir = self.root_dir / 'train' |
|
|
else: |
|
|
self.root_dir = self.root_dir / split |
|
|
|
|
|
self.samples = self._build_sample_index() |
|
|
|
|
|
self.rgb_transform = transforms.Compose([ |
|
|
transforms.Resize(img_size), |
|
|
transforms.ToTensor(), |
|
|
]) |
|
|
self.mask_transform = transforms.Compose([ |
|
|
transforms.Resize(img_size), |
|
|
transforms.ToTensor(), |
|
|
]) |
|
|
|
|
|
def _build_sample_index(self): |
|
|
samples = [] |
|
|
for scene_dir in self.root_dir.iterdir(): |
|
|
if not scene_dir.is_dir(): |
|
|
continue |
|
|
for camera_dir in scene_dir.iterdir(): |
|
|
if not camera_dir.name.startswith('camera_'): |
|
|
continue |
|
|
|
|
|
rgba_paths = sorted(camera_dir.glob('rgba_*.png')) |
|
|
seg_paths = sorted(camera_dir.glob('segmentation_*.png')) |
|
|
|
|
|
for obj_dir in camera_dir.iterdir(): |
|
|
if not obj_dir.name.startswith('obj_'): |
|
|
continue |
|
|
|
|
|
amodal_paths = sorted(obj_dir.glob('segmentation_*.png')) |
|
|
amodal_rgb_paths = sorted(obj_dir.glob('rgba_*.png')) |
|
|
|
|
|
if not (len(rgba_paths) == len(seg_paths) == len(amodal_paths) == len(amodal_rgb_paths)): |
|
|
continue |
|
|
|
|
|
for rgba_path, seg_path, amodal_path, amodal_rgb_path in zip( |
|
|
rgba_paths, seg_paths, amodal_paths, amodal_rgb_paths |
|
|
): |
|
|
samples.append({ |
|
|
'rgb_path': rgba_path, |
|
|
'seg_path': seg_path, |
|
|
'amodal_path': amodal_path, |
|
|
'amodal_rgb_path': amodal_rgb_path, |
|
|
'object_id': int(obj_dir.name.split('_')[1]), |
|
|
'scene': scene_dir.name, |
|
|
'camera': camera_dir.name |
|
|
}) |
|
|
|
|
|
|
|
|
if self.max_samples is not None and len(samples) > self.max_samples: |
|
|
|
|
|
import random |
|
|
random.seed(42) |
|
|
samples = random.sample(samples, self.max_samples) |
|
|
print(f"Dataset limited to {len(samples)} samples") |
|
|
|
|
|
|
|
|
if self.use_val_from_train: |
|
|
import random |
|
|
random.seed(42) |
|
|
random.shuffle(samples) |
|
|
|
|
|
val_size = int(len(samples) * self.val_split) |
|
|
if self.split == 'train': |
|
|
samples = samples[val_size:] |
|
|
print(f"Train split: {len(samples)} samples") |
|
|
elif self.split == 'val': |
|
|
samples = samples[:val_size] |
|
|
print(f"Validation split: {len(samples)} samples") |
|
|
|
|
|
return samples |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
sample = self.samples[idx] |
|
|
|
|
|
|
|
|
rgb = Image.open(sample['rgb_path']).convert('RGB') |
|
|
seg_map = np.array(Image.open(sample['seg_path'])) |
|
|
amodal_mask_img = Image.open(sample['amodal_path']).convert('L') |
|
|
amodal_rgb = Image.open(sample['amodal_rgb_path']).convert('RGB') |
|
|
|
|
|
|
|
|
modal_mask_np = (seg_map == sample['object_id']).astype(np.uint8) * 255 |
|
|
modal_mask_img = Image.fromarray(modal_mask_np, mode='L') |
|
|
|
|
|
|
|
|
rgb = self.rgb_transform(rgb) |
|
|
modal_mask = self.mask_transform(modal_mask_img) |
|
|
amodal_mask = self.mask_transform(amodal_mask_img) |
|
|
amodal_rgb = self.rgb_transform(amodal_rgb) |
|
|
|
|
|
|
|
|
occluded_mask = amodal_mask - modal_mask |
|
|
occluded_mask = torch.clamp(occluded_mask, 0, 1) |
|
|
|
|
|
return { |
|
|
'rgb': rgb, |
|
|
'modal_mask': modal_mask, |
|
|
'amodal_mask': amodal_mask, |
|
|
'occluded_mask': occluded_mask, |
|
|
'amodal_rgb': amodal_rgb, |
|
|
} |
|
|
|
|
|
|
|
|
class ImprovedUNet(nn.Module): |
|
|
|
|
|
def __init__(self, in_channels=5, out_channels=3): |
|
|
super().__init__() |
|
|
|
|
|
def conv_block(in_ch, out_ch, dropout=0.1): |
|
|
return nn.Sequential( |
|
|
nn.Conv2d(in_ch, out_ch, 3, padding=1), |
|
|
nn.BatchNorm2d(out_ch), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout2d(dropout), |
|
|
nn.Conv2d(out_ch, out_ch, 3, padding=1), |
|
|
nn.BatchNorm2d(out_ch), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
|
|
|
self.down1 = conv_block(in_channels, 64) |
|
|
self.pool1 = nn.MaxPool2d(2) |
|
|
self.down2 = conv_block(64, 128) |
|
|
self.pool2 = nn.MaxPool2d(2) |
|
|
self.down3 = conv_block(128, 256) |
|
|
self.pool3 = nn.MaxPool2d(2) |
|
|
self.down4 = conv_block(256, 512) |
|
|
self.pool4 = nn.MaxPool2d(2) |
|
|
|
|
|
|
|
|
self.middle = conv_block(512, 1024, dropout=0.2) |
|
|
|
|
|
|
|
|
self.up1 = nn.ConvTranspose2d(1024, 512, 2, stride=2) |
|
|
self.up_block1 = conv_block(1024, 512) |
|
|
self.up2 = nn.ConvTranspose2d(512, 256, 2, stride=2) |
|
|
self.up_block2 = conv_block(512, 256) |
|
|
self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2) |
|
|
self.up_block3 = conv_block(256, 128) |
|
|
self.up4 = nn.ConvTranspose2d(128, 64, 2, stride=2) |
|
|
self.up_block4 = conv_block(128, 64) |
|
|
|
|
|
self.final = nn.Conv2d(64, out_channels, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
d1 = self.down1(x) |
|
|
d2 = self.down2(self.pool1(d1)) |
|
|
d3 = self.down3(self.pool2(d2)) |
|
|
d4 = self.down4(self.pool3(d3)) |
|
|
|
|
|
|
|
|
m = self.middle(self.pool4(d4)) |
|
|
|
|
|
|
|
|
u1 = self.up_block1(torch.cat([self.up1(m), d4], dim=1)) |
|
|
u2 = self.up_block2(torch.cat([self.up2(u1), d3], dim=1)) |
|
|
u3 = self.up_block3(torch.cat([self.up3(u2), d2], dim=1)) |
|
|
u4 = self.up_block4(torch.cat([self.up4(u3), d1], dim=1)) |
|
|
|
|
|
return torch.sigmoid(self.final(u4)) |
|
|
|
|
|
class AmodalCompletionLoss(nn.Module): |
|
|
"""Loss that only considers object regions (ignores background)""" |
|
|
|
|
|
def __init__(self, occluded_weight=5.0, visible_weight=1.0): |
|
|
super().__init__() |
|
|
self.occluded_weight = occluded_weight |
|
|
self.visible_weight = visible_weight |
|
|
self.lpips_model = lpips.LPIPS(net='alex') |
|
|
|
|
|
def forward(self, pred, target, modal_mask, occluded_mask, amodal_mask): |
|
|
|
|
|
device = pred.device |
|
|
self.lpips_model = self.lpips_model.to(device) |
|
|
|
|
|
pred_masked = pred * amodal_mask |
|
|
target_masked = target * amodal_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
visible_region = modal_mask * amodal_mask |
|
|
if visible_region.sum() > 0: |
|
|
visible_loss = F.mse_loss(pred_masked * visible_region, target_masked * visible_region) |
|
|
else: |
|
|
visible_loss = torch.tensor(0.0).to(pred.device) |
|
|
|
|
|
|
|
|
occluded_region = occluded_mask * amodal_mask |
|
|
if occluded_region.sum() > 0: |
|
|
occluded_loss = F.mse_loss(pred_masked * occluded_region, target_masked * occluded_region) |
|
|
else: |
|
|
occluded_loss = torch.tensor(0.0).to(pred.device) |
|
|
|
|
|
|
|
|
perceptual_loss = self.lpips_model(pred_masked, target_masked).mean() |
|
|
|
|
|
|
|
|
boundary_mask = F.conv2d(amodal_mask, torch.ones(1,1,3,3).to(amodal_mask.device), padding=1) |
|
|
boundary_mask = ((boundary_mask > 0) & (boundary_mask < 9)).float() |
|
|
boundary_loss = F.mse_loss(pred_masked * boundary_mask, target_masked * boundary_mask) |
|
|
|
|
|
total_loss = (self.visible_weight * visible_loss + |
|
|
self.occluded_weight * occluded_loss + |
|
|
2.0 * boundary_loss) |
|
|
|
|
|
return total_loss, visible_loss, occluded_loss, boundary_loss |
|
|
|
|
|
|
|
|
def train_improved(model, dataloader, optimizer, device, num_epochs): |
|
|
model.train() |
|
|
criterion = AmodalCompletionLoss() |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
total_loss = 0 |
|
|
for i, batch in enumerate(dataloader): |
|
|
rgb = batch['rgb'].to(device) |
|
|
modal_mask = batch['modal_mask'].to(device) |
|
|
amodal_mask = batch['amodal_mask'].to(device) |
|
|
occluded_mask = batch['occluded_mask'].to(device) |
|
|
gt_amodal_rgb = batch['amodal_rgb'].to(device) |
|
|
|
|
|
input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
pred = model(input_tensor) |
|
|
|
|
|
loss, vis_loss, occ_loss, boundary_loss = criterion( |
|
|
pred, gt_amodal_rgb, modal_mask, occluded_mask, amodal_mask |
|
|
) |
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
total_loss += loss.item() |
|
|
|
|
|
if i % 16 == 0: |
|
|
print(f"Epoch [{epoch}/{num_epochs}] [{i}/{len(dataloader)}] " |
|
|
f"Total: {loss.item():.4f}, Visible: {vis_loss.item():.4f}, " |
|
|
f"Occluded: {occ_loss.item():.4f}, Boundary: {boundary_loss.item():.4f}") |
|
|
|
|
|
print(f"Epoch {epoch} Average Loss: {total_loss/len(dataloader):.4f}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
data_root = "data" |
|
|
|
|
|
|
|
|
train_dataset = ModalAmodalDataset( |
|
|
root_dir=data_root, |
|
|
split='train', |
|
|
img_size=(128, 128), |
|
|
max_samples=1000, |
|
|
val_split=0.2, |
|
|
use_val_from_train=True |
|
|
) |
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=16, |
|
|
shuffle=True, |
|
|
num_workers=2, |
|
|
pin_memory=True, |
|
|
drop_last=True |
|
|
) |
|
|
|
|
|
|
|
|
val_dataset = ModalAmodalDataset( |
|
|
root_dir=data_root, |
|
|
split='val', |
|
|
img_size=(128, 128), |
|
|
max_samples=1000, |
|
|
val_split=0.2, |
|
|
use_val_from_train=True |
|
|
) |
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=4, |
|
|
shuffle=True, |
|
|
num_workers=2, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
print(f"Training on {len(train_dataset)} samples, {len(train_loader)} batches per epoch") |
|
|
print(f"Validation on {len(val_dataset)} samples, {len(val_loader)} batches") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = ImprovedUNet().to(device) |
|
|
model.load_state_dict(torch.load('amodal_completion_model.pth', map_location=device)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = model.to(device) |
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*50) |
|
|
print("EVALUATION RESULTS") |
|
|
print("="*50) |
|
|
|
|
|
|
|
|
metrics = evaluate_metrics(model, val_loader, device) |
|
|
print(f"Overall MSE: {metrics['total_mse']:.6f}") |
|
|
print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}") |
|
|
print(f"Visible Region MSE: {metrics['visible_mse']:.6f}") |
|
|
print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}") |
|
|
|
|
|
|
|
|
print("\nGenerating visualizations...") |
|
|
visualize_results(model, val_loader, device, num_samples=8) |
|
|
|
|
|
|
|
|
image_metrics = calculate_metrics(model, val_loader, device) |
|
|
print(f"PSNR: {image_metrics['psnr']:.4f}") |
|
|
print(f"SSIM: {image_metrics['ssim']:.4f}") |
|
|
print(f"LPIPS: {image_metrics['lpips']:.4f}") |
|
|
print(f"mIoU (pred vs GT): {image_metrics['miou']:.4f}") |
|
|
|
|
|
|
|
|
data_root = "data" |
|
|
|
|
|
|
|
|
train_dataset = ModalAmodalDataset( |
|
|
root_dir=data_root, |
|
|
split='train', |
|
|
img_size=(128, 128), |
|
|
max_samples=1000, |
|
|
val_split=0.2, |
|
|
use_val_from_train=True |
|
|
) |
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=16, |
|
|
shuffle=True, |
|
|
num_workers=2, |
|
|
pin_memory=True, |
|
|
drop_last=True |
|
|
) |
|
|
|
|
|
|
|
|
val_dataset = ModalAmodalDataset( |
|
|
root_dir=data_root, |
|
|
split='val', |
|
|
img_size=(128, 128), |
|
|
max_samples=1000, |
|
|
val_split=0.2, |
|
|
use_val_from_train=True |
|
|
) |
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=4, |
|
|
shuffle=True, |
|
|
num_workers=2, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), 'amodal_completion_model.pth') |
|
|
|
|
|
|
|
|
|
|
|
test_dataset = ModalAmodalDataset( |
|
|
root_dir=data_root, |
|
|
split='test', |
|
|
img_size=(128, 128), |
|
|
max_samples=2000 |
|
|
) |
|
|
test_loader = DataLoader( |
|
|
test_dataset, |
|
|
batch_size=8, |
|
|
shuffle=True, |
|
|
num_workers=2, |
|
|
pin_memory=True, |
|
|
drop_last=True |
|
|
) |
|
|
|
|
|
print("EVALUATION RESULTS") |
|
|
print("="*50) |
|
|
|
|
|
|
|
|
metrics = evaluate_metrics(model, test_loader, device) |
|
|
print(f"Overall MSE: {metrics['total_mse']:.6f}") |
|
|
print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}") |
|
|
print(f"Visible Region MSE: {metrics['visible_mse']:.6f}") |
|
|
print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}") |
|
|
|
|
|
|
|
|
print("\nGenerating visualizations...") |
|
|
visualize_results(model, test_loader, device, num_samples=16) |
|
|
|
|
|
from google.colab import runtime |
|
|
runtime.unassign() |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
model = ImprovedUNet() |
|
|
torch.load('amodal_completion_model.pth', map_location=torch.device('cpu')) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
print("\n" + "="*50) |
|
|
print("EVALUATION RESULTS") |
|
|
print("="*50) |
|
|
|
|
|
|
|
|
metrics = evaluate_metrics(model, val_loader, device) |
|
|
print(f"Overall MSE: {metrics['total_mse']:.6f}") |
|
|
print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}") |
|
|
print(f"Visible Region MSE: {metrics['visible_mse']:.6f}") |
|
|
print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}") |
|
|
|
|
|
|
|
|
print("\nGenerating visualizations...") |
|
|
visualize_results(model, val_loader, device, num_samples=8) |
|
|
|
|
|
|
|
|
image_metrics = calculate_metrics(model, val_loader, device) |
|
|
print(f"PSNR: {image_metrics['psnr']:.4f}") |
|
|
print(f"SSIM: {image_metrics['ssim']:.4f}") |
|
|
print(f"LPIPS: {image_metrics['lpips']:.4f}") |
|
|
print(f"mIoU (pred vs GT): {image_metrics['miou']:.4f}") |
|
|
|
|
|
model = ImprovedUNet() |
|
|
model.eval() |