| import os |
| import numpy as np |
| import json |
| import pandas as pd |
| from calendar import monthrange |
| import torch |
| import utils |
|
|
| class LocationDataset(torch.utils.data.Dataset): |
| def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device): |
|
|
| |
| self.input_enc = input_enc |
| if self.input_enc in ['env', 'sin_cos_env']: |
| raster = load_env() |
| else: |
| raster = None |
| self.enc = utils.CoordEncoder(input_enc, raster) |
| |
| |
| self.locs = locs |
| self.loc_feats = self.enc.encode(self.locs) |
| self.labels = labels |
| self.classes = classes |
| self.class_to_taxa = class_to_taxa |
| |
| |
| self.num_classes = len(np.unique(labels)) |
| self.input_dim = self.loc_feats.shape[1] |
|
|
| if self.enc.raster is not None: |
| self.enc.raster = self.enc.raster.to(device) |
|
|
| def __len__(self): |
| return self.loc_feats.shape[0] |
|
|
| def __getitem__(self, index): |
| loc_feat = self.loc_feats[index, :] |
| loc = self.locs[index, :] |
| class_id = self.labels[index] |
| return loc_feat, loc, class_id |
|
|
| def load_env(): |
| with open('paths.json', 'r') as f: |
| paths = json.load(f) |
| raster = load_context_feats(os.path.join(paths['env'],'bioclim_elevation_scaled.npy')) |
| return raster |
|
|
| def load_context_feats(data_path): |
| context_feats = np.load(data_path).astype(np.float32) |
| context_feats = torch.from_numpy(context_feats) |
| return context_feats |
|
|
| def load_inat_data(ip_file, taxa_of_interest=None): |
|
|
| print('\nLoading ' + ip_file) |
| data = pd.read_csv(ip_file) |
|
|
| |
| num_obs = data.shape[0] |
| data = data[((data['latitude'] <= 90) & (data['latitude'] >= -90) & (data['longitude'] <= 180) & (data['longitude'] >= -180) )] |
| if (num_obs - data.shape[0]) > 0: |
| print(num_obs - data.shape[0], 'items filtered due to invalid locations') |
|
|
| if 'accuracy' in data.columns: |
| data.drop(['accuracy'], axis=1, inplace=True) |
|
|
| if 'positional_accuracy' in data.columns: |
| data.drop(['positional_accuracy'], axis=1, inplace=True) |
| |
| if 'geoprivacy' in data.columns: |
| data.drop(['geoprivacy'], axis=1, inplace=True) |
|
|
| if 'observed_on' in data.columns: |
| data.rename(columns = {'observed_on':'date'}, inplace=True) |
|
|
| num_obs_orig = data.shape[0] |
| data = data.dropna() |
| size_diff = num_obs_orig - data.shape[0] |
| if size_diff > 0: |
| print(size_diff, 'observation(s) with a NaN entry out of' , num_obs_orig, 'removed') |
| |
| |
| if taxa_of_interest is not None: |
| num_obs_orig = data.shape[0] |
| data = data[data['taxon_id'].isin(taxa_of_interest)] |
| print(num_obs_orig - data.shape[0], 'observation(s) out of' , num_obs_orig, 'from different taxa removed') |
|
|
| print('Number of unique classes {}'.format(np.unique(data['taxon_id'].values).shape[0])) |
|
|
| locs = np.vstack((data['longitude'].values, data['latitude'].values)).T.astype(np.float32) |
| taxa = data['taxon_id'].values.astype(np.int) |
|
|
| if 'user_id' in data.columns: |
| users = data['user_id'].values.astype(np.int) |
| _, users = np.unique(users, return_inverse=True) |
| elif 'observer_id' in data.columns: |
| users = data['observer_id'].values.astype(np.int) |
| _, users = np.unique(users, return_inverse=True) |
| else: |
| users = np.ones(taxa.shape[0], dtype=np.int)*-1 |
|
|
| |
| years = np.array([int(d_str[:4]) for d_str in data['date'].values]) |
| months = np.array([int(d_str[5:7]) for d_str in data['date'].values]) |
| days = np.array([int(d_str[8:10]) for d_str in data['date'].values]) |
| days_per_month = np.cumsum([0] + [monthrange(2018, mm)[1] for mm in range(1, 12)]) |
| dates = days_per_month[months-1] + days-1 |
| dates = np.round((dates) / 364.0, 4).astype(np.float32) |
| if 'id' in data.columns: |
| obs_ids = data['id'].values |
| elif 'observation_uuid' in data.columns: |
| obs_ids = data['observation_uuid'].values |
| |
| return locs, taxa, users, dates, years, obs_ids |
|
|
| def choose_aux_species(current_species, num_aux_species, aux_species_seed): |
| if num_aux_species == 0: |
| return [] |
| with open('paths.json', 'r') as f: |
| paths = json.load(f) |
| data_dir = paths['train'] |
| taxa_file = os.path.join(data_dir, 'geo_prior_train_meta.json') |
| with open(taxa_file, 'r') as f: |
| inat_large_metadata = json.load(f) |
| aux_species_candidates = [x['taxon_id'] for x in inat_large_metadata] |
| aux_species_candidates = np.setdiff1d(aux_species_candidates, current_species) |
| print(f'choosing {num_aux_species} species to add from {len(aux_species_candidates)} candidates') |
| rng = np.random.default_rng(aux_species_seed) |
| idx_rand_aux_species = rng.permutation(len(aux_species_candidates)) |
| aux_species = list(aux_species_candidates[idx_rand_aux_species[:num_aux_species]]) |
| return aux_species |
|
|
| def get_taxa_of_interest(species_set='all', num_aux_species=0, aux_species_seed=123, taxa_file_snt=None): |
| if species_set == 'all': |
| return None |
| if species_set == 'snt_birds': |
| assert taxa_file_snt is not None |
| with open(taxa_file_snt, 'r') as f: |
| taxa_subsets = json.load(f) |
| taxa_of_interest = list(taxa_subsets['snt_birds']) |
| else: |
| raise NotImplementedError |
| |
| aux_species = choose_aux_species(taxa_of_interest, num_aux_species, aux_species_seed) |
| taxa_of_interest.extend(aux_species) |
| return taxa_of_interest |
|
|
| def get_idx_subsample_observations(labels, hard_cap=-1, hard_cap_seed=123): |
| if hard_cap == -1: |
| return np.arange(len(labels)) |
| print(f'subsampling (up to) {hard_cap} per class for the training set') |
| class_counts = {id: 0 for id in np.unique(labels)} |
| ss_rng = np.random.default_rng(hard_cap_seed) |
| idx_rand = ss_rng.permutation(len(labels)) |
| idx_ss = [] |
| for i in idx_rand: |
| class_id = labels[i] |
| if class_counts[class_id] < hard_cap: |
| idx_ss.append(i) |
| class_counts[class_id] += 1 |
| idx_ss = np.sort(idx_ss) |
| print(f'final training set size: {len(idx_ss)}') |
| return idx_ss |
|
|
| def get_train_data(params): |
| with open('paths.json', 'r') as f: |
| paths = json.load(f) |
| data_dir = paths['train'] |
| obs_file = os.path.join(data_dir, 'geo_prior_train.csv') |
| taxa_file = os.path.join(data_dir, 'geo_prior_train_meta.json') |
| taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json') |
|
|
| taxa_of_interest = get_taxa_of_interest(params['species_set'], params['num_aux_species'], params['aux_species_seed'], taxa_file_snt) |
|
|
| locs, labels, _, _, _, _ = load_inat_data(obs_file, taxa_of_interest) |
| unique_taxa, class_ids = np.unique(labels, return_inverse=True) |
| class_to_taxa = unique_taxa.tolist() |
|
|
| |
| class_info_file = json.load(open(taxa_file, 'r')) |
| class_names_file = [cc['latin_name'] for cc in class_info_file] |
| taxa_ids_file = [cc['taxon_id'] for cc in class_info_file] |
| classes = dict(zip(taxa_ids_file, class_names_file)) |
| |
| idx_ss = get_idx_subsample_observations(labels, params['hard_cap_num_per_class'], params['hard_cap_seed']) |
|
|
| locs = torch.from_numpy(np.array(locs)[idx_ss]) |
|
|
| labels = torch.from_numpy(np.array(class_ids)[idx_ss]) |
|
|
| ds = LocationDataset(locs, labels, classes, class_to_taxa, params['input_enc'], params['device']) |
|
|
| return ds |
|
|