Image-to-Video
zzwustc's picture
Upload folder using huggingface_hub
ef296aa verified
raw
history blame
1.05 kB
import torch
def compute_metrics(gt, pred, time):
epe_all, epe_occ, epe_vis = get_epe(pred["flow"], gt["flow"], gt["alpha"])
iou = get_iou(gt["alpha"], pred["alpha"])
metrics = {
"epe_all": epe_all.cpu().numpy(),
"epe_occ": epe_occ.cpu().numpy(),
"epe_vis": epe_vis.cpu().numpy(),
"iou": iou.cpu().numpy(),
"time": time
}
return metrics
def get_epe(pred, label, vis):
diff = torch.norm(pred - label, p=2, dim=-1, keepdim=True)
epe_all = torch.mean(diff, dim=(1, 2, 3))
vis = vis[..., None]
epe_occ = torch.sum(diff * (1 - vis), dim=(1, 2, 3)) / torch.sum((1 - vis), dim=(1, 2, 3))
epe_vis = torch.sum((diff * vis), dim=(1, 2, 3)) / torch.sum(vis, dim=(1, 2, 3))
return epe_all, epe_occ, epe_vis
def get_iou(vis1, vis2):
occ1 = (1 - vis1).bool()
occ2 = (1 - vis2).bool()
intersection = (occ1 & occ2).float().sum(dim=[1, 2])
union = (occ1 | occ2).float().sum(dim=[1, 2])
iou = intersection / union
return iou