| import torch |
| import numpy as np |
| import math |
| import datetime |
|
|
| class CoordEncoder: |
| |
| def __init__(self, input_enc, raster=None): |
| self.input_enc = input_enc |
| self.raster = raster |
|
|
| def encode(self, locs, normalize=True): |
| |
| if normalize: |
| locs = normalize_coords(locs) |
| if self.input_enc == 'sin_cos': |
| loc_feats = encode_loc(locs) |
| elif self.input_enc == 'env': |
| loc_feats = bilinear_interpolate(locs, self.raster) |
| elif self.input_enc == 'sin_cos_env': |
| loc_feats = encode_loc(locs) |
| context_feats = bilinear_interpolate(locs, self.raster) |
| loc_feats = torch.cat((loc_feats, context_feats), 1) |
| else: |
| raise NotImplementedError('Unknown input encoding.') |
| return loc_feats |
|
|
| def normalize_coords(locs): |
| |
| |
|
|
| locs[:,0] /= 180.0 |
| locs[:,1] /= 90.0 |
|
|
| return locs |
|
|
| def encode_loc(loc_ip, concat_dim=1): |
| |
| |
| feats = torch.cat((torch.sin(math.pi*loc_ip), torch.cos(math.pi*loc_ip)), concat_dim) |
| return feats |
|
|
| def bilinear_interpolate(loc_ip, data, remove_nans_raster=True): |
| |
| |
| |
| |
|
|
| assert data is not None |
|
|
| |
| loc = (loc_ip.clone() + 1) / 2.0 |
| loc[:,1] = 1 - loc[:,1] |
| |
|
|
| assert not torch.any(torch.isnan(loc)) |
| |
| if remove_nans_raster: |
| data[torch.isnan(data)] = 0.0 |
|
|
| |
| loc[:, 0] *= (data.shape[1]-1) |
| loc[:, 1] *= (data.shape[0]-1) |
|
|
| loc_int = torch.floor(loc).long() |
| xx = loc_int[:, 0] |
| yy = loc_int[:, 1] |
| xx_plus = xx + 1 |
| xx_plus[xx_plus > (data.shape[1]-1)] = data.shape[1]-1 |
| yy_plus = yy + 1 |
| yy_plus[yy_plus > (data.shape[0]-1)] = data.shape[0]-1 |
|
|
| loc_delta = loc - torch.floor(loc) |
| dx = loc_delta[:, 0].unsqueeze(1) |
| dy = loc_delta[:, 1].unsqueeze(1) |
|
|
| interp_val = data[yy, xx, :]*(1-dx)*(1-dy) + data[yy, xx_plus, :]*dx*(1-dy) + \ |
| data[yy_plus, xx, :]*(1-dx)*dy + data[yy_plus, xx_plus, :]*dx*dy |
|
|
| return interp_val |
|
|
| def rand_samples(batch_size, device, rand_type='uniform'): |
| |
|
|
| if rand_type == 'spherical': |
| rand_loc = torch.rand(batch_size, 2).to(device) |
| theta1 = 2.0*math.pi*rand_loc[:, 0] |
| theta2 = torch.acos(2.0*rand_loc[:, 1] - 1.0) |
| lat = 1.0 - 2.0*theta2/math.pi |
| lon = (theta1/math.pi) - 1.0 |
| rand_loc = torch.cat((lon.unsqueeze(1), lat.unsqueeze(1)), 1) |
|
|
| elif rand_type == 'uniform': |
| rand_loc = torch.rand(batch_size, 2).to(device)*2.0 - 1.0 |
|
|
| return rand_loc |
|
|
| def get_time_stamp(): |
| cur_time = str(datetime.datetime.now()) |
| date, time = cur_time.split(' ') |
| h, m, s = time.split(':') |
| s = s.split('.')[0] |
| time_stamp = '{}-{}-{}-{}'.format(date, h, m, s) |
| return time_stamp |
|
|
| def coord_grid(grid_size, split_ids=None, split_of_interest=None): |
| |
|
|
| feats = np.zeros((grid_size[0], grid_size[1], 2), dtype=np.float32) |
| mg = np.meshgrid(np.linspace(-180, 180, feats.shape[1]), np.linspace(90, -90, feats.shape[0])) |
| feats[:, :, 0] = mg[0] |
| feats[:, :, 1] = mg[1] |
| if split_ids is None or split_of_interest is None: |
| |
| |
| return feats.reshape(feats.shape[0]*feats.shape[1], 2) |
| else: |
| |
| ind_y, ind_x = np.where(split_ids==split_of_interest) |
|
|
| |
| return feats[ind_y, ind_x, :] |
| |
| def create_spatial_split(raster, mask, train_amt=1.0, cell_size=25): |
| |
| |
| |
| split_ids = np.ones((raster.shape[0], raster.shape[1])) |
| start = cell_size |
| for ii in np.arange(0, split_ids.shape[0], cell_size): |
| if start == 0: |
| start = cell_size |
| else: |
| start = 0 |
| for jj in np.arange(start, split_ids.shape[1], cell_size*2): |
| split_ids[ii:ii+cell_size, jj:jj+cell_size] = 2 |
| split_ids = split_ids*mask |
| if train_amt < 1.0: |
| |
| tr_y, tr_x = np.where(split_ids==1) |
| inds = np.random.choice(len(tr_y), int(len(tr_y)*(1.0-train_amt)), replace=False) |
| split_ids[tr_y[inds], tr_x[inds]] = 0 |
| return split_ids |
|
|