Spaces:
Running
Running
| import torch | |
| import numpy as np | |
| import math | |
| import torch.nn.functional as F | |
| import utils.basic | |
| from typing import Tuple, Union | |
| def standardize_test_data(rgbs, trajs, visibs, valids, S_cap=600, only_first=False, seq_len=None): | |
| trajs = trajs.astype(np.float32) # S,N,2 | |
| visibs = visibs.astype(np.float32) # S,N | |
| valids = valids.astype(np.float32) # S,N | |
| visval_ok = np.sum(valids*visibs, axis=0) > 1 | |
| trajs = trajs[:,visval_ok] | |
| visibs = visibs[:,visval_ok] | |
| valids = valids[:,visval_ok] | |
| # fill in missing data | |
| N = trajs.shape[1] | |
| for ni in range(N): | |
| trajs[:,ni] = utils.misc.data_replace_with_nearest(trajs[:,ni], valids[:,ni]) | |
| # use cap or seq_len | |
| if seq_len is not None: | |
| S = min(len(rgbs), seq_len) | |
| else: | |
| S = len(rgbs) | |
| S = min(S, S_cap) | |
| if only_first: | |
| # we'll find the best frame to start on | |
| best_count = 0 | |
| best_ind = 0 | |
| for si in range(0,len(rgbs)-64): | |
| # try this slice | |
| visibs_ = visibs[si:min(si+S,len(rgbs)+1)] # S,N | |
| valids_ = valids[si:min(si+S,len(rgbs)+1)] # S,N | |
| visval_ok0 = (visibs_[0]*valids_[0]) > 0 # N | |
| visval_okA = np.sum(visibs_*valids_, axis=0) > 1 # N | |
| all_ok = visval_ok0 & visval_okA | |
| # print('- slicing %d to %d; sum(ok) %d' % (si, min(si+S,len(rgbs)+1), np.sum(all_ok))) | |
| if np.sum(all_ok) > best_count: | |
| best_count = np.sum(all_ok) | |
| best_ind = si | |
| si = best_ind | |
| rgbs = rgbs[si:si+S] | |
| trajs = trajs[si:si+S] | |
| visibs = visibs[si:si+S] | |
| valids = valids[si:si+S] | |
| vis_ok0 = visibs[0] > 0 # N | |
| trajs = trajs[:,vis_ok0] | |
| visibs = visibs[:,vis_ok0] | |
| valids = valids[:,vis_ok0] | |
| # print('- best_count', best_count, 'best_ind', best_ind) | |
| if seq_len is not None: | |
| rgbs = rgbs[:seq_len] | |
| trajs = trajs[:seq_len] | |
| valids = valids[:seq_len] | |
| # req two timesteps valid (after seqlen trim) | |
| visval_ok = np.sum(visibs*valids, axis=0) > 1 | |
| trajs = trajs[:,visval_ok] | |
| valids = valids[:,visval_ok] | |
| visibs = visibs[:,visval_ok] | |
| return rgbs, trajs, visibs, valids | |
| def get_2d_sincos_pos_embed( | |
| embed_dim: int, grid_size: Union[int, Tuple[int, int]] | |
| ) -> torch.Tensor: | |
| """ | |
| This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. | |
| It is a wrapper of get_2d_sincos_pos_embed_from_grid. | |
| Args: | |
| - embed_dim: The embedding dimension. | |
| - grid_size: The grid size. | |
| Returns: | |
| - pos_embed: The generated 2D positional embedding. | |
| """ | |
| if isinstance(grid_size, tuple): | |
| grid_size_h, grid_size_w = grid_size | |
| else: | |
| grid_size_h = grid_size_w = grid_size | |
| grid_h = torch.arange(grid_size_h, dtype=torch.float) | |
| grid_w = torch.arange(grid_size_w, dtype=torch.float) | |
| grid = torch.meshgrid(grid_w, grid_h, indexing="xy") | |
| grid = torch.stack(grid, dim=0) | |
| grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) | |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
| return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) | |
| def get_2d_sincos_pos_embed_from_grid( | |
| embed_dim: int, grid: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| This function generates a 2D positional embedding from a given grid using sine and cosine functions. | |
| Args: | |
| - embed_dim: The embedding dimension. | |
| - grid: The grid to generate the embedding from. | |
| Returns: | |
| - emb: The generated 2D positional embedding. | |
| """ | |
| assert embed_dim % 2 == 0 | |
| # use half of dimensions to encode grid_h | |
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
| emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) | |
| return emb | |
| def get_1d_sincos_pos_embed_from_grid( | |
| embed_dim: int, pos: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| This function generates a 1D positional embedding from a given grid using sine and cosine functions. | |
| Args: | |
| - embed_dim: The embedding dimension. | |
| - pos: The position to generate the embedding from. | |
| Returns: | |
| - emb: The generated 1D positional embedding. | |
| """ | |
| assert embed_dim % 2 == 0 | |
| omega = torch.arange(embed_dim // 2, dtype=torch.double) | |
| omega /= embed_dim / 2.0 | |
| omega = 1.0 / 10000**omega # (D/2,) | |
| pos = pos.reshape(-1) # (M,) | |
| out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
| emb_sin = torch.sin(out) # (M, D/2) | |
| emb_cos = torch.cos(out) # (M, D/2) | |
| emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) | |
| return emb[None].float() | |
| def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: | |
| """ | |
| This function generates a 2D positional embedding from given coordinates using sine and cosine functions. | |
| Args: | |
| - xy: The coordinates to generate the embedding from. | |
| - C: The size of the embedding. | |
| - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. | |
| Returns: | |
| - pe: The generated 2D positional embedding. | |
| """ | |
| B, N, D = xy.shape | |
| assert D == 2 | |
| x = xy[:, :, 0:1] | |
| y = xy[:, :, 1:2] | |
| div_term = ( | |
| torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) | |
| ).reshape(1, 1, int(C / 2)) | |
| pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) | |
| pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) | |
| pe_x[:, :, 0::2] = torch.sin(x * div_term) | |
| pe_x[:, :, 1::2] = torch.cos(x * div_term) | |
| pe_y[:, :, 0::2] = torch.sin(y * div_term) | |
| pe_y[:, :, 1::2] = torch.cos(y * div_term) | |
| pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) | |
| if cat_coords: | |
| pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) | |
| return pe | |
| # from datasets.dataset import mask2bbox | |
| # from pips2 | |
| def posemb_sincos_2d_xy(xy, C, temperature=10000, dtype=torch.float32, cat_coords=False): | |
| device = xy.device | |
| dtype = xy.dtype | |
| B, S, D = xy.shape | |
| assert(D==2) | |
| x = xy[:,:,0] | |
| y = xy[:,:,1] | |
| assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' | |
| omega = torch.arange(C // 4, device=device) / (C // 4 - 1) | |
| omega = 1. / (temperature ** omega) | |
| y = y.flatten()[:, None] * omega[None, :] | |
| x = x.flatten()[:, None] * omega[None, :] | |
| pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) | |
| pe = pe.reshape(B,S,C).type(dtype) | |
| if cat_coords: | |
| pe = torch.cat([pe, xy], dim=2) # B,N,C+2 | |
| return pe | |
| # # prevent circular imports | |
| # def mask2bbox(mask): | |
| # if mask.ndim == 3: | |
| # mask = mask[..., 0] | |
| # ys, xs = np.where(mask > 0.4) | |
| # if ys.size == 0 or xs.size==0: | |
| # return np.array((0, 0, 0, 0), dtype=int) | |
| # lt = np.array([np.min(xs), np.min(ys)]) | |
| # rb = np.array([np.max(xs), np.max(ys)]) + 1 | |
| # return np.concatenate([lt, rb]) | |
| # def get_stark_2d_embedding(H, W, C=64, device='cuda:0', temperature=10000, normalize=True): | |
| # scale = 2*math.pi | |
| # mask = torch.ones((1,H,W), dtype=torch.float32, device=device) | |
| # y_embed = mask.cumsum(1, dtype=torch.float32) # cumulative sum along axis 1 (h axis) --> (b, h, w) | |
| # x_embed = mask.cumsum(2, dtype=torch.float32) # cumulative sum along axis 2 (w axis) --> (b, h, w) | |
| # if normalize: | |
| # eps = 1e-6 | |
| # y_embed = y_embed / (y_embed[:, -1:, :] + eps) * scale # 2pi * (y / sigma(y)) | |
| # x_embed = x_embed / (x_embed[:, :, -1:] + eps) * scale # 2pi * (x / sigma(x)) | |
| # dim_t = torch.arange(C, dtype=torch.float32, device=device) # (0,1,2,...,d/2) | |
| # dim_t = temperature ** (2 * (dim_t // 2) / C) | |
| # pos_x = x_embed[:, :, :, None] / dim_t # (b,h,w,d/2) | |
| # pos_y = y_embed[:, :, :, None] / dim_t # (b,h,w,d/2) | |
| # pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) # (b,h,w,d/2) | |
| # pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) # (b,h,w,d/2) | |
| # pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) # (b,h,w,d) | |
| # return pos | |
| # def get_1d_embedding(x, C, cat_coords=False): | |
| # B, N, D = x.shape | |
| # assert(D==1) | |
| # div_term = (torch.arange(0, C, 2, device=x.device, dtype=torch.float32) * (10000.0 / C)).reshape(1, 1, int(C/2)) | |
| # pe_x = torch.zeros(B, N, C, device=x.device, dtype=torch.float32) | |
| # pe_x[:, :, 0::2] = torch.sin(x * div_term) | |
| # pe_x[:, :, 1::2] = torch.cos(x * div_term) | |
| # if cat_coords: | |
| # pe_x = torch.cat([pe, x], dim=2) # B,N,C*2+2 | |
| # return pe_x | |
| # def posemb_sincos_2d(h, w, dim, temperature=10000, dtype=torch.float32, device='cuda:0'): | |
| # y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij') | |
| # assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' | |
| # omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1) | |
| # omega = 1. / (temperature ** omega) | |
| # y = y.flatten()[:, None] * omega[None, :] | |
| # x = x.flatten()[:, None] * omega[None, :] | |
| # pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) # B,C,H,W | |
| # return pe.type(dtype) | |
| def iou(bbox1, bbox2): | |
| # bbox1, bbox2: [x1, y1, x2, y2] | |
| x1, y1, x2, y2 = bbox1 | |
| x1_, y1_, x2_, y2_ = bbox2 | |
| inter_x1 = max(x1, x1_) | |
| inter_y1 = max(y1, y1_) | |
| inter_x2 = min(x2, x2_) | |
| inter_y2 = min(y2, y2_) | |
| inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1) | |
| area1 = (x2 - x1) * (y2 - y1) | |
| area2 = (x2_ - x1_) * (y2_ - y1_) | |
| iou = inter_area / (area1 + area2 - inter_area) | |
| return iou | |
| # def get_2d_embedding(xy, C, cat_coords=False): | |
| # B, N, D = xy.shape | |
| # assert(D==2) | |
| # x = xy[:,:,0:1] | |
| # y = xy[:,:,1:2] | |
| # div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (10000.0 / C)).reshape(1, 1, int(C/2)) | |
| # pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) | |
| # pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) | |
| # pe_x[:, :, 0::2] = torch.sin(x * div_term) | |
| # pe_x[:, :, 1::2] = torch.cos(x * div_term) | |
| # pe_y[:, :, 0::2] = torch.sin(y * div_term) | |
| # pe_y[:, :, 1::2] = torch.cos(y * div_term) | |
| # pe = torch.cat([pe_x, pe_y], dim=2) # B,N,C*2 | |
| # if cat_coords: | |
| # pe = torch.cat([pe, xy], dim=2) # B,N,C*2+2 | |
| # return pe | |
| # def get_3d_embedding(xyz, C, cat_coords=False): | |
| # B, N, D = xyz.shape | |
| # assert(D==3) | |
| # x = xyz[:,:,0:1] | |
| # y = xyz[:,:,1:2] | |
| # z = xyz[:,:,2:3] | |
| # div_term = (torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (10000.0 / C)).reshape(1, 1, int(C/2)) | |
| # pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) | |
| # pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) | |
| # pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) | |
| # pe_x[:, :, 0::2] = torch.sin(x * div_term) | |
| # pe_x[:, :, 1::2] = torch.cos(x * div_term) | |
| # pe_y[:, :, 0::2] = torch.sin(y * div_term) | |
| # pe_y[:, :, 1::2] = torch.cos(y * div_term) | |
| # pe_z[:, :, 0::2] = torch.sin(z * div_term) | |
| # pe_z[:, :, 1::2] = torch.cos(z * div_term) | |
| # pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3 | |
| # if cat_coords: | |
| # pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3 | |
| # return pe | |
| class SimplePool(): | |
| def __init__(self, pool_size, version='pt', min_size=1): | |
| self.pool_size = pool_size | |
| self.version = version | |
| self.items = [] | |
| self.min_size = min_size | |
| if not (version=='pt' or version=='np'): | |
| print('version = %s; please choose pt or np') | |
| assert(False) # please choose pt or np | |
| def __len__(self): | |
| return len(self.items) | |
| def mean(self, min_size=None): | |
| if min_size is None: | |
| pool_size_thresh = self.min_size | |
| elif min_size=='half': | |
| pool_size_thresh = self.pool_size/2 | |
| else: | |
| pool_size_thresh = min_size | |
| if self.version=='np': | |
| if len(self.items) >= pool_size_thresh: | |
| return np.sum(self.items)/float(len(self.items)) | |
| else: | |
| return np.nan | |
| if self.version=='pt': | |
| if len(self.items) >= pool_size_thresh: | |
| return torch.sum(self.items)/float(len(self.items)) | |
| else: | |
| return torch.from_numpy(np.nan) | |
| def sample(self, with_replacement=True): | |
| idx = np.random.randint(len(self.items)) | |
| if with_replacement: | |
| return self.items[idx] | |
| else: | |
| return self.items.pop(idx) | |
| def fetch(self, num=None): | |
| if self.version=='pt': | |
| item_array = torch.stack(self.items) | |
| elif self.version=='np': | |
| item_array = np.stack(self.items) | |
| if num is not None: | |
| # there better be some items | |
| assert(len(self.items) >= num) | |
| # if there are not that many elements just return however many there are | |
| if len(self.items) < num: | |
| return item_array | |
| else: | |
| idxs = np.random.randint(len(self.items), size=num) | |
| return item_array[idxs] | |
| else: | |
| return item_array | |
| def is_full(self): | |
| full = len(self.items)==self.pool_size | |
| return full | |
| def empty(self): | |
| self.items = [] | |
| def have_min_size(self): | |
| return len(self.items) >= self.min_size | |
| def update(self, items): | |
| for item in items: | |
| if len(self.items) < self.pool_size: | |
| # the pool is not full, so let's add this in | |
| self.items.append(item) | |
| else: | |
| # the pool is full | |
| # pop from the front | |
| self.items.pop(0) | |
| # add to the back | |
| self.items.append(item) | |
| return self.items | |
| class SimpleHeap(): | |
| def __init__(self, pool_size, version='pt'): | |
| self.pool_size = pool_size | |
| self.version = version | |
| self.items = [] | |
| self.vals = [] | |
| if not (version=='pt' or version=='np'): | |
| print('version = %s; please choose pt or np') | |
| assert(False) # please choose pt or np | |
| def __len__(self): | |
| return len(self.items) | |
| def sample(self, random=True, with_replacement=True, semirandom=False): | |
| vals_arr = np.stack(self.vals) | |
| if random: | |
| ind = np.random.randint(len(self.items)) | |
| else: | |
| if semirandom and len(vals_arr)>1: | |
| # choose from the harder half | |
| inds = np.argsort(vals_arr) # ascending | |
| inds = inds[len(vals_arr)//2:] | |
| ind = np.random.choice(inds) | |
| else: | |
| # find the most valuable element | |
| ind = np.argmax(vals_arr) | |
| if with_replacement: | |
| return self.items[ind] | |
| else: | |
| item = self.items.pop(ind) | |
| val = self.vals.pop(ind) | |
| return item | |
| def fetch(self, num=None): | |
| if self.version=='pt': | |
| item_array = torch.stack(self.items) | |
| elif self.version=='np': | |
| item_array = np.stack(self.items) | |
| if num is not None: | |
| # there better be some items | |
| assert(len(self.items) >= num) | |
| # if there are not that many elements just return however many there are | |
| if len(self.items) < num: | |
| return item_array | |
| else: | |
| idxs = np.random.randint(len(self.items), size=num) | |
| return item_array[idxs] | |
| else: | |
| return item_array | |
| def is_full(self): | |
| full = len(self.items)==self.pool_size | |
| return full | |
| def empty(self): | |
| self.items = [] | |
| def update(self, vals, items): | |
| for val,item in zip(vals, items): | |
| if len(self.items) < self.pool_size: | |
| # the pool is not full, so let's add this in | |
| self.items.append(item) | |
| self.vals.append(val) | |
| else: | |
| # the pool is full | |
| # find our least-valuable element | |
| # and see if we should replace it | |
| vals_arr = np.stack(self.vals) | |
| ind = np.argmin(vals_arr) | |
| if vals_arr[ind] < val: | |
| # pop the min | |
| self.items.pop(ind) | |
| self.vals.pop(ind) | |
| # add to the back | |
| self.items.append(item) | |
| self.vals.append(val) | |
| return self.items | |
| def farthest_point_sample(xyz, npoint, include_ends=False, deterministic=False): | |
| """ | |
| Input: | |
| xyz: pointcloud data, [B, N, C], where C is probably 3 | |
| npoint: number of samples | |
| Return: | |
| inds: sampled pointcloud index, [B, npoint] | |
| """ | |
| device = xyz.device | |
| B, N, C = xyz.shape | |
| xyz = xyz.float() | |
| inds = torch.zeros(B, npoint, dtype=torch.long, device=device) | |
| distance = torch.ones((B, N), dtype=torch.float32, device=device) * 1e10 | |
| if deterministic: | |
| farthest = torch.randint(0, 1, (B,), dtype=torch.long, device=device) | |
| else: | |
| farthest = torch.randint(0, N, (B,), dtype=torch.long, device=device) | |
| batch_indices = torch.arange(B, dtype=torch.long, device=device) | |
| for i in range(npoint): | |
| if include_ends: | |
| if i==0: | |
| farthest = 0 | |
| elif i==1: | |
| farthest = N-1 | |
| inds[:, i] = farthest | |
| centroid = xyz[batch_indices, farthest, :].view(B, 1, C) | |
| dist = torch.sum((xyz - centroid) ** 2, -1) | |
| mask = dist < distance | |
| distance[mask] = dist[mask] | |
| farthest = torch.max(distance, -1)[1] | |
| if npoint > N: | |
| # if we need more samples, make them random | |
| distance += torch.randn_like(distance) | |
| return inds | |
| def farthest_point_sample_py(xyz, npoint, deterministic=False): | |
| N,C = xyz.shape | |
| inds = np.zeros(npoint, dtype=np.int32) | |
| distance = np.ones(N) * 1e10 | |
| if deterministic: | |
| farthest = 0 | |
| else: | |
| farthest = np.random.randint(0, N, dtype=np.int32) | |
| for i in range(npoint): | |
| inds[i] = farthest | |
| centroid = xyz[farthest, :].reshape(1,C) | |
| dist = np.sum((xyz - centroid) ** 2, -1) | |
| mask = dist < distance | |
| distance[mask] = dist[mask] | |
| farthest = np.argmax(distance, -1) | |
| if npoint > N: | |
| # if we need more samples, make them random | |
| distance += np.random.randn(*distance.shape) | |
| return inds | |
| def balanced_ce_loss(pred, gt, pos_weight=0.5, valid=None, dim=None, return_both=False, use_halfmask=False, H=64, W=64): | |
| # # pred and gt are the same shape | |
| # for (a,b) in zip(pred.size(), gt.size()): | |
| # if not a==b: | |
| # print('mismatch: pred, gt', pred.shape, gt.shape) | |
| # assert(a==b) # some shape mismatch! | |
| pred = pred.reshape(-1) | |
| gt = gt.reshape(-1) | |
| device = pred.device | |
| if valid is not None: | |
| valid = valid.reshape(-1) | |
| for (a,b) in zip(pred.size(), valid.size()): | |
| assert(a==b) # some shape mismatch! | |
| else: | |
| valid = torch.ones_like(gt) | |
| pos = (gt > 0.95).float() | |
| if use_halfmask: | |
| pos_wide = (gt >= 0.5).float() | |
| halfmask = (gt == 0.5).float() | |
| else: | |
| pos_wide = pos | |
| neg = (gt < 0.05).float() | |
| label = pos_wide*2.0 - 1.0 | |
| a = -label * pred | |
| b = F.relu(a) | |
| loss = b + torch.log(torch.exp(-b)+torch.exp(a-b)) | |
| if torch.sum(pos*valid)>0: | |
| pos_loss = loss[(pos*valid) > 0].mean() | |
| else: | |
| pos_loss = torch.tensor(0.0, requires_grad=True, device=device) | |
| if torch.sum(neg*valid)>0: | |
| neg_loss = loss[(neg*valid) > 0].mean() | |
| else: | |
| neg_loss = torch.tensor(0.0, requires_grad=True, device=device) | |
| balanced_loss = pos_weight*pos_loss + (1-pos_weight)*neg_loss | |
| return balanced_loss | |
| # pos_loss = utils.basic.reduce_masked_mean(loss, pos*valid, dim=dim) | |
| # neg_loss = utils.basic.reduce_masked_mean(loss, neg*valid, dim=dim) | |
| if use_halfmask: | |
| # here we will find the pixels which are already leaning positive, | |
| # and encourage them to be more positive | |
| B = loss.shape[0] | |
| loss_ = loss.reshape(B,-1) | |
| mask_ = halfmask.reshape(B,-1) * valid.reshape(B,-1) | |
| # to avoid the issue where spikes become spikier, | |
| # we will only apply this loss on batch els where we predicted zero positives | |
| pred_sig_ = torch.sigmoid(pred).reshape(B,-1) | |
| no_pred_ = torch.max(pred_sig_.round(), axis=1)[0] < 1 # B | |
| # and only on batch els where we have negatives available | |
| have_neg_ = torch.sum(neg, dim=1)>0 # B | |
| loss_ = loss_[no_pred_ & have_neg_] # N,H*W | |
| mask_ = mask_[no_pred_ & have_neg_] # N,H*W | |
| N = loss_.shape[0] | |
| if N > 0: | |
| # we want: | |
| # in the neg pixels, | |
| # set them to the max loss of the pos pixels, | |
| # so that they do not contribute to the min | |
| loss__ = loss_.reshape(-1) | |
| mask__ = mask_.reshape(-1) | |
| if torch.sum(mask__)>0: | |
| # print('loss_', loss_.shape, 'mask_', mask_.shape, 'loss__', loss__.shape, 'mask__', mask__.shape) | |
| mloss__ = loss__.detach() | |
| mloss__[mask__==0] = torch.max(loss__[mask__==1]) | |
| mloss_ = mloss__.reshape(N,H*W) | |
| # now, in each batch el, take a tiny region around the argmin, so we can boost this region | |
| minloss_mask_ = torch.zeros_like(mloss_).scatter(1,mloss_.argmin(1,True),value=1) | |
| minloss_mask_ = utils.improc.dilate2d(minloss_mask_.view(N,1,H,W), times=3).reshape(N,H*W) | |
| loss__ = loss_.reshape(-1) | |
| minloss_mask__ = minloss_mask_.reshape(-1) | |
| half_loss = loss__[minloss_mask__>0].mean() | |
| # print('N', N, 'half_loss', half_loss) | |
| pos_loss = pos_loss + half_loss | |
| # if False: | |
| # min_pos = 8 | |
| # # only apply the loss when we have some negatives available, | |
| # # otherwise it's a whole "ignore" frame, which may mean | |
| # # we are unsure if the target is even there | |
| # if torch.sum(mask__==0) > 0: # negatives available | |
| # # only apply the loss when the halfmask is larger area than | |
| # # min_pos (the number of pixels we want to boost), | |
| # # so that indexing will work | |
| # if torch.all(torch.sum(mask_==1, dim=1) >= min_pos): # topk indexing will work | |
| # # in the pixels we will not use, | |
| # # set them to the max of the pixels we may use, | |
| # # so that they do not contribute to the min | |
| # loss__[mask__==0] = torch.max(loss__[mask__==1]) | |
| # loss_ = loss__.reshape(B,-1) | |
| # half_loss = torch.mean(torch.topk(loss_, min_pos, dim=1, largest=False)[0], dim=1) # B | |
| # have_neg = (torch.sum(neg, dim=1)>0).float() # B | |
| # pos_loss = pos_loss + half_loss*have_neg | |
| # half_loss = [] | |
| # for b in range(B): | |
| # loss_b = loss_[b] | |
| # mask_b = mask_[b] | |
| # if torch.sum(mask_b): | |
| # inds = torch.nonzero(mask_b).reshape(-1) | |
| # half_loss.append(torch.min(loss_b[inds])) | |
| # if len(half_loss): | |
| # # # half_loss_ = half_loss.reshape(-1) | |
| # # half_loss = torch.min(half_loss, dim=1)[0] # B | |
| # # half_loss = torch.mean(torch.topk(half_loss, 4, dim=1, largest=False)[0], dim=1) # B | |
| # pos_loss = pos_loss + torch.stack(half_loss).mean() | |
| if return_both: | |
| return pos_loss, neg_loss | |
| balanced_loss = pos_weight*pos_loss + (1-pos_weight)*neg_loss | |
| return balanced_loss | |
| def dice_loss(pred, gt): | |
| # gt has ignores at 0.5 | |
| # pred and gt are the same shape | |
| for (a,b) in zip(pred.size(), gt.size()): | |
| assert(a==b) # some shape mismatch! | |
| prob = pred.sigmoid() | |
| # flatten everything except batch | |
| prob = prob.flatten(1) | |
| gt = gt.flatten(1) | |
| pos = (gt > 0.95).float() | |
| neg = (gt < 0.05).float() | |
| valid = (pos+neg).float().clamp(0,1) | |
| numerator = 2 * (prob * pos * valid).sum(1) | |
| denominator = (prob*valid).sum(1) + (pos*valid).sum(1) | |
| loss = 1 - (numerator + 1) / (denominator + 1) | |
| return loss | |
| def sigmoid_focal_loss(pred, gt, alpha=0.25, gamma=2):#, use_halfmask=False): | |
| # gt has ignores at 0.5 | |
| # pred and gt are the same shape | |
| for (a,b) in zip(pred.size(), gt.size()): | |
| assert(a==b) # some shape mismatch! | |
| # flatten everything except batch | |
| pred = pred.flatten(1) | |
| gt = gt.flatten(1) | |
| pos = (gt > 0.95).float() | |
| neg = (gt < 0.05).float() | |
| # if use_halfmask: | |
| # pos_wide = (gt >= 0.5).float() | |
| # halfmask = (gt == 0.5).float() | |
| # else: | |
| # pos_wide = pos | |
| valid = (pos+neg).float().clamp(0,1) | |
| prob = pred.sigmoid() | |
| ce_loss = F.binary_cross_entropy_with_logits(pred, pos, reduction="none") | |
| p_t = prob * pos + (1 - prob) * (1 - pos) | |
| loss = ce_loss * ((1 - p_t) ** gamma) | |
| if alpha >= 0: | |
| alpha_t = alpha * pos + (1 - alpha) * (1 - pos) | |
| loss = alpha_t * loss | |
| loss = (loss*valid).sum(1) / (1 + valid.sum(1)) | |
| return loss | |
| # def dice_loss(inputs, targets, normalizer=1): | |
| # inputs = inputs.sigmoid() | |
| # inputs = inputs.flatten(1) | |
| # numerator = 2 * (inputs * targets).sum(1) | |
| # denominator = inputs.sum(-1) + targets.sum(-1) | |
| # loss = 1 - (numerator + 1) / (denominator + 1) | |
| # return loss.sum() / normalizer | |
| # def sigmoid_focal_loss(inputs, targets, normalizer=1, alpha=0.25, gamma=2): | |
| # prob = inputs.sigmoid() | |
| # ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") | |
| # p_t = prob * targets + (1 - prob) * (1 - targets) | |
| # loss = ce_loss * ((1 - p_t) ** gamma) | |
| # if alpha >= 0: | |
| # alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | |
| # loss = alpha_t * loss | |
| # return loss.mean(1).sum() / normalizer | |
| def data_replace_with_nearest(xys, valids): | |
| # replace invalid xys with nearby ones | |
| invalid_idx = np.where(valids==0)[0] | |
| valid_idx = np.where(valids==1)[0] | |
| for idx in invalid_idx: | |
| nearest = valid_idx[np.argmin(np.abs(valid_idx - idx))] | |
| xys[idx] = xys[nearest] | |
| return xys | |
| def data_get_traj_from_masks(masks): | |
| if masks.ndim==4: | |
| masks = masks[...,0] | |
| S, H, W = masks.shape | |
| masks = (masks > 0.1).astype(np.float32) | |
| fills = np.zeros((S)) | |
| xy_means = np.zeros((S,2)) | |
| xy_rands = np.zeros((S,2)) | |
| valids = np.zeros((S)) | |
| for si, mask in enumerate(masks): | |
| if np.sum(mask) > 0: | |
| ys, xs = np.where(mask) | |
| inds = np.random.permutation(len(xs)) | |
| xs, ys = xs[inds], ys[inds] | |
| x0, x1 = np.min(xs), np.max(xs)+1 | |
| y0, y1 = np.min(ys), np.max(ys)+1 | |
| # if (x1-x0)>0 and (y1-y0)>0: | |
| xy_means[si] = np.array([xs.mean(), ys.mean()]) | |
| xy_rands[si] = np.array([xs[0], ys[0]]) | |
| valids[si] = 1 | |
| crop = mask[y0:y1, x0:x1] | |
| fill = np.mean(crop) | |
| fills[si] = fill | |
| # print('fills', fills) | |
| return xy_means, xy_rands, valids, fills | |
| def data_zoom(zoom, xys, visibs, rgbs, valids=None, masks=None, masks2=None, masks3=None, masks4=None): | |
| S, H, W, C = rgbs.shape | |
| S,N,D = xys.shape | |
| _, H, W, C = rgbs.shape | |
| assert(C==3) | |
| crop_W = int(W//zoom) | |
| crop_H = int(H//zoom) | |
| if np.random.rand() < 0.25: # follow-crop | |
| # start with xy traj | |
| # smooth_xys = xys.copy() | |
| smooth_xys = xys[:,np.random.randint(N)].reshape(S,1,2) | |
| # make it inbounds | |
| smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) | |
| # smooth it out, to remove info about the traj, and simulate camera motion | |
| for _ in range(S*3): | |
| for ii in range(S): | |
| if ii==0: | |
| smooth_xys[ii] = (smooth_xys[ii] + smooth_xys[ii+1])/2.0 | |
| elif ii==S-1: | |
| smooth_xys[ii] = (smooth_xys[ii-1] + smooth_xys[ii])/2.0 | |
| else: | |
| smooth_xys[ii] = (smooth_xys[ii-1] + smooth_xys[ii] + smooth_xys[ii+1])/3.0 | |
| else: # static (no-hint) crop | |
| # zero-vel on random available coordinate | |
| if valids is not None: | |
| visval = visibs*valids # S,N | |
| visval = np.sum(visval, axis=1) # S | |
| else: | |
| visval = np.sum(visibs, axis=1) # S | |
| anchor_inds = np.nonzero(visval >= np.mean(visval))[0] | |
| ind = anchor_inds[np.random.randint(len(anchor_inds))] | |
| # print('ind', ind) | |
| smooth_xys = xys[ind:ind+1].repeat(S,axis=0) | |
| smooth_xys = smooth_xys.mean(axis=1, keepdims=True) | |
| # xmid = np.random.randint(crop_W//2, W-crop_W//2) | |
| # ymid = np.random.randint(crop_H//2, H-crop_H//2) | |
| # smooth_xys = np.stack([xmid, ymid], axis=-1).reshape(1,1,2).repeat(S, axis=0) # S,1,2 | |
| smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) | |
| if np.random.rand() < 0.5: | |
| # add a random alternate trajectory, to help push us off center | |
| alt_xys = np.random.randint(-crop_H//8, crop_H//8, (S,1,2)) | |
| for _ in range(4): # smooth out | |
| for ii in range(S): | |
| if ii==0: | |
| alt_xys[ii] = (alt_xys[ii] + alt_xys[ii+1])/2.0 | |
| elif ii==S-1: | |
| alt_xys[ii] = (alt_xys[ii-1] + alt_xys[ii])/2.0 | |
| else: | |
| alt_xys[ii] = (alt_xys[ii-1] + alt_xys[ii] + alt_xys[ii+1])/3.0 | |
| smooth_xys = smooth_xys + alt_xys | |
| smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) | |
| rgbs_crop = [] | |
| if masks is not None: | |
| masks_crop = [] | |
| if masks2 is not None: | |
| masks2_crop = [] | |
| if masks3 is not None: | |
| masks3_crop = [] | |
| if masks4 is not None: | |
| masks4_crop = [] | |
| offsets = [] | |
| for si in range(S): | |
| xy_mid = smooth_xys[si].squeeze(0).round().astype(np.int32) # 2 | |
| xmid, ymid = xy_mid[0], xy_mid[1] | |
| x0, x1 = np.clip(xmid-crop_W//2, 0, W), np.clip(xmid+crop_W//2, 0, W) | |
| y0, y1 = np.clip(ymid-crop_H//2, 0, H), np.clip(ymid+crop_H//2, 0, H) | |
| offset = np.array([x0, y0]).reshape(1,2) | |
| rgbs_crop.append(rgbs[si,y0:y1,x0:x1]) | |
| if masks is not None: | |
| masks_crop.append(masks[si,y0:y1,x0:x1]) | |
| if masks2 is not None: | |
| masks2_crop.append(masks2[si,y0:y1,x0:x1]) | |
| if masks3 is not None: | |
| masks3_crop.append(masks3[si,y0:y1,x0:x1]) | |
| if masks4 is not None: | |
| masks4_crop.append(masks4[si,y0:y1,x0:x1]) | |
| xys[si] -= offset | |
| offsets.append(offset) | |
| rgbs = np.stack(rgbs_crop, axis=0) | |
| if masks is not None: | |
| masks = np.stack(masks_crop, axis=0) | |
| if masks2 is not None: | |
| masks2 = np.stack(masks2_crop, axis=0) | |
| if masks3 is not None: | |
| masks3 = np.stack(masks3_crop, axis=0) | |
| if masks4 is not None: | |
| masks4 = np.stack(masks4_crop, axis=0) | |
| # update visibility annotations | |
| for si in range(S): | |
| oob_inds = np.logical_or( | |
| np.logical_or(xys[si,:,0] < 0, xys[si,:,0] > crop_W-1), | |
| np.logical_or(xys[si,:,1] < 0, xys[si,:,1] > crop_H-1)) | |
| visibs[si,oob_inds] = 0 | |
| # if masks4 is not None: | |
| # return xys, visibs, valids, rgbs, masks, masks2, masks3, masks4 | |
| # if masks3 is not None: | |
| # return xys, visibs, valids, rgbs, masks, masks2, masks3 | |
| # if masks2 is not None: | |
| # return xys, visibs, valids, rgbs, masks, masks2 | |
| # if masks is not None: | |
| # return xys, visibs, valids, rgbs, masks | |
| # else: | |
| # return xys, visibs, valids, rgbs | |
| if valids is not None: | |
| return xys, visibs, rgbs, valids | |
| else: | |
| return xys, visibs, rgbs | |
| def data_zoom_bbox(zoom, bboxes, visibs, rgbs):#, valids=None): | |
| S, H, W, C = rgbs.shape | |
| _, H, W, C = rgbs.shape | |
| assert(C==3) | |
| crop_W = int(W//zoom) | |
| crop_H = int(H//zoom) | |
| xys = bboxes[:,0:2]*0.5 + bboxes[:,2:4]*0.5 | |
| if np.random.rand() < 0.25: # follow-crop | |
| # start with xy traj | |
| smooth_xys = xys.copy() | |
| # make it inbounds | |
| smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) | |
| # smooth it out, to remove info about the traj, and simulate camera motion | |
| for _ in range(S*3): | |
| for ii in range(1,S-1): | |
| smooth_xys[ii] = (smooth_xys[ii-1] + smooth_xys[ii] + smooth_xys[ii+1])/3.0 | |
| else: # static (no-hint) crop | |
| # zero-vel on random available coordinate | |
| anchor_inds = np.nonzero(visibs.reshape(-1)>0.5)[0] | |
| ind = anchor_inds[np.random.randint(len(anchor_inds))] | |
| smooth_xys = xys[ind:ind+1].repeat(S,axis=0) | |
| # xmid = np.random.randint(crop_W//2, W-crop_W//2) | |
| # ymid = np.random.randint(crop_H//2, H-crop_H//2) | |
| # smooth_xys = np.stack([xmid, ymid], axis=-1).reshape(1,1,2).repeat(S, axis=0) # S,1,2 | |
| smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) | |
| # print('xys', xys) | |
| # print('smooth_xys', smooth_xys) | |
| if np.random.rand() < 0.5: | |
| # add a random alternate trajectory, to help push us off center | |
| alt_xys = np.random.randint(-crop_H//8, crop_H//8, (S,2)) | |
| for _ in range(3): | |
| for ii in range(1,S-1): | |
| alt_xys[ii] = (alt_xys[ii-1] + alt_xys[ii] + alt_xys[ii+1])/3.0 | |
| smooth_xys = smooth_xys + alt_xys | |
| smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) | |
| rgbs_crop = [] | |
| offsets = [] | |
| for si in range(S): | |
| xy_mid = smooth_xys[si].round().astype(np.int32) | |
| xmid, ymid = xy_mid[0], xy_mid[1] | |
| x0, x1 = np.clip(xmid-crop_W//2, 0, W), np.clip(xmid+crop_W//2, 0, W) | |
| y0, y1 = np.clip(ymid-crop_H//2, 0, H), np.clip(ymid+crop_H//2, 0, H) | |
| offset = np.array([x0, y0]).reshape(2) | |
| rgbs_crop.append(rgbs[si,y0:y1,x0:x1]) | |
| xys[si] -= offset | |
| bboxes[si,0:2] -= offset | |
| bboxes[si,2:4] -= offset | |
| offsets.append(offset) | |
| rgbs = np.stack(rgbs_crop, axis=0) | |
| # update visibility annotations | |
| for si in range(S): | |
| # avoid 1px edge | |
| oob_inds = np.logical_or( | |
| np.logical_or(xys[si,0] < 1, xys[si,0] > W-2), | |
| np.logical_or(xys[si,1] < 1, xys[si,1] > H-2)) | |
| visibs[si,oob_inds] = 0 | |
| # clamp to image bounds | |
| xys0 = np.minimum(np.maximum(bboxes[:,0:2], np.zeros((2,), dtype=int)), np.array([W, H]) - 1) # S,2 | |
| xys1 = np.minimum(np.maximum(bboxes[:,2:4], np.zeros((2,), dtype=int)), np.array([W, H]) - 1) # S,2 | |
| bboxes = np.concatenate([xys0, xys1], axis=1) | |
| return bboxes, visibs, rgbs | |
| def data_pad_if_necessary(rgbs, masks, masks2=None): | |
| S,H,W,C = rgbs.shape | |
| mask_areas = (masks > 0).reshape(S,-1).sum(axis=1) | |
| mask_areas_norm = mask_areas / np.max(mask_areas) | |
| visibs = mask_areas_norm | |
| bboxes = np.stack([mask2bbox(mask) for mask in masks]) | |
| whs = bboxes[:,2:4] - bboxes[:,0:2] | |
| whs = whs[visibs > 0.5] | |
| # print('mean wh', np.mean(whs[:,0]), np.mean(whs[:,1])) | |
| if np.mean(whs[:,0]) >= W/2: | |
| # print('padding w') | |
| pad = ((0,0),(0,0),(W//4,W//4),(0,0)) | |
| rgbs = np.pad(rgbs, pad, mode="constant") | |
| masks = np.pad(masks, pad[:3], mode="constant") | |
| if masks2 is not None: | |
| masks2 = np.pad(masks2, pad[:3], mode="constant") | |
| # print('rgbs', rgbs.shape) | |
| # print('masks', masks.shape) | |
| if np.mean(whs[:,1]) >= H/2: | |
| # print('padding h') | |
| pad = ((0,0),(H//4,H//4),(0,0),(0,0)) | |
| rgbs = np.pad(rgbs, pad, mode="constant") | |
| masks = np.pad(masks, pad[:3], mode="constant") | |
| if masks2 is not None: | |
| masks2 = np.pad(masks2, pad[:3], mode="constant", constant_values=0.5) | |
| if masks2 is not None: | |
| return rgbs, masks, masks2 | |
| return rgbs, masks | |
| def data_pad_if_necessary_b(rgbs, bboxes, visibs): | |
| S,H,W,C = rgbs.shape | |
| whs = bboxes[:,2:4] - bboxes[:,0:2] | |
| whs = whs[visibs > 0.5] | |
| if np.mean(whs[:,0]) >= W/2: | |
| pad = ((0,0),(0,0),(W//4,W//4),(0,0)) | |
| rgbs = np.pad(rgbs, pad, mode="constant") | |
| bboxes[:,0] += W//4 | |
| bboxes[:,2] += W//4 | |
| if np.mean(whs[:,1]) >= H/2: | |
| pad = ((0,0),(H//4,H//4),(0,0),(0,0)) | |
| rgbs = np.pad(rgbs, pad, mode="constant") | |
| bboxes[:,1] += H//4 | |
| bboxes[:,3] += H//4 | |
| return rgbs, bboxes | |
| def posenc(x, min_deg, max_deg): | |
| """Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1]. | |
| Instead of computing [sin(x), cos(x)], we use the trig identity | |
| cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]). | |
| Args: | |
| x: torch.Tensor, variables to be encoded. Note that x should be in [-pi, pi]. | |
| min_deg: int, the minimum (inclusive) degree of the encoding. | |
| max_deg: int, the maximum (exclusive) degree of the encoding. | |
| legacy_posenc_order: bool, keep the same ordering as the original tf code. | |
| Returns: | |
| encoded: torch.Tensor, encoded variables. | |
| """ | |
| if min_deg == max_deg: | |
| return x | |
| scales = torch.tensor( | |
| [2**i for i in range(min_deg, max_deg)], dtype=x.dtype, device=x.device | |
| ) | |
| xb = (x[..., None, :] * scales[:, None]).reshape(list(x.shape[:-1]) + [-1]) | |
| four_feat = torch.sin(torch.cat([xb, xb + 0.5 * torch.pi], dim=-1)) | |
| return torch.cat([x] + [four_feat], dim=-1) | |