| import numpy as np |
| import pandas as pd |
| import random |
| import torch |
| import time |
| import os |
| import copy |
| import json |
| import tifffile |
| import h3 |
| import setup |
|
|
| from sklearn.linear_model import RidgeCV |
| from sklearn.preprocessing import MinMaxScaler |
| from sklearn.metrics import average_precision_score |
|
|
| import utils |
| import models |
| import datasets |
|
|
| class EvaluatorSNT: |
| def __init__(self, train_params, eval_params): |
| self.train_params = train_params |
| self.eval_params = eval_params |
| with open('paths.json', 'r') as f: |
| paths = json.load(f) |
| D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True) |
| D = D.item() |
| self.loc_indices_per_species = D['loc_indices_per_species'] |
| self.labels_per_species = D['labels_per_species'] |
| self.taxa = D['taxa'] |
| self.obs_locs = D['obs_locs'] |
| self.obs_locs_idx = D['obs_locs_idx'] |
| |
| def get_labels(self, species): |
| species = str(species) |
| lat = [] |
| lon = [] |
| gt = [] |
| for hx in self.data: |
| cur_lat, cur_lon = h3.h3_to_geo(hx) |
| if species in self.data[hx]: |
| cur_label = int(len(self.data[hx][species]) > 0) |
| gt.append(cur_label) |
| lat.append(cur_lat) |
| lon.append(cur_lon) |
| lat = np.array(lat).astype(np.float32) |
| lon = np.array(lon).astype(np.float32) |
| obs_locs = np.vstack((lon, lat)).T |
| gt = np.array(gt).astype(np.float32) |
| return obs_locs, gt |
|
|
| def run_evaluation(self, model, enc): |
| results = {} |
|
|
| |
| np.random.seed(self.eval_params['seed']) |
| random.seed(self.eval_params['seed']) |
|
|
| |
| results['mean_average_precision'] = np.zeros((len(self.taxa)), dtype=np.float32) |
| |
| obs_locs = torch.from_numpy(self.obs_locs).to(self.eval_params['device']) |
| loc_feat = enc.encode(obs_locs) |
| |
| classes_of_interest = np.array([np.where(np.array(self.train_params['class_to_taxa']) == tt)[0] for tt in self.taxa]).squeeze() |
| classes_of_interest = torch.from_numpy(classes_of_interest) |
| |
| with torch.no_grad(): |
| loc_emb = model(loc_feat, return_feats=True) |
| wt = model.class_emb.weight[classes_of_interest, :] |
| pred_mtx = torch.matmul(loc_emb, wt.T).cpu().numpy() |
|
|
| split_rng = np.random.default_rng(self.eval_params['split_seed']) |
|
|
| for tt_id, tt in enumerate(self.taxa): |
| |
| cur_class_of_interest = np.where(self.taxa == tt)[0][0] |
| cur_loc_indices = np.array(self.loc_indices_per_species[cur_class_of_interest]) |
| cur_labels = np.array(self.labels_per_species[cur_class_of_interest]) |
| |
| |
| assert self.eval_params['split'] in ['all', 'val', 'test'] |
| if self.eval_params['split'] != 'all': |
| num_val = np.floor(len(cur_labels) * self.eval_params['val_frac']).astype(int) |
| idx_rand = split_rng.permutation(len(cur_labels)) |
| if self.eval_params['split'] == 'val': |
| idx_sel = idx_rand[:num_val] |
| elif self.eval_params['split'] == 'test': |
| idx_sel = idx_rand[num_val:] |
| cur_loc_indices = cur_loc_indices[idx_sel] |
| cur_labels = cur_labels[idx_sel] |
|
|
| |
| pred = pred_mtx[cur_loc_indices, tt_id] |
|
|
| |
| results['mean_average_precision'][tt_id] = average_precision_score((cur_labels > 0).astype(np.int32), pred) |
|
|
|
|
| valid_taxa = ~np.isnan(results['mean_average_precision']) |
|
|
| |
| results['per_species_average_precision_all'] = copy.deepcopy(results['mean_average_precision']) |
| per_species_average_precision_valid = results['per_species_average_precision_all'][valid_taxa] |
| results['mean_average_precision'] = per_species_average_precision_valid.mean() |
| results['num_eval_species_w_valid_ap'] = valid_taxa.sum() |
| results['num_eval_species_total'] = len(self.taxa) |
|
|
| return results |
|
|
| def report(self, results): |
| for field in ['mean_average_precision', 'num_eval_species_w_valid_ap', 'num_eval_species_total']: |
| print(f'{field}: {results[field]}') |
|
|
| class EvaluatorIUCN: |
|
|
| def __init__(self, train_params, eval_params): |
| self.train_params = train_params |
| self.eval_params = eval_params |
| with open('paths.json', 'r') as f: |
| paths = json.load(f) |
| with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: |
| self.data = json.load(f) |
| self.obs_locs = np.array(self.data['locs'], dtype=np.float32) |
| self.taxa = [int(tt) for tt in self.data['taxa_presence'].keys()] |
| |
| def run_evaluation(self, model, enc): |
| results = {} |
|
|
| results['per_species_average_precision_all'] = np.zeros(len(self.taxa), dtype=np.float32) |
| |
| obs_locs = torch.from_numpy(self.obs_locs).to(self.eval_params['device']) |
| loc_feat = enc.encode(obs_locs) |
| |
| |
| classes_of_interest = torch.from_numpy(np.array([np.where(np.array(self.train_params['class_to_taxa']) == tt)[0] for tt in self.taxa]).squeeze()) |
| with torch.no_grad(): |
| |
| loc_emb = model(loc_feat, return_feats=True) |
| wt = model.class_emb.weight[classes_of_interest, :] |
| pred_mtx = torch.matmul(loc_emb, wt.T) |
|
|
| for tt_id, tt in enumerate(self.taxa): |
| class_of_interest = np.where(np.array(self.train_params['class_to_taxa']) == tt)[0] |
|
|
| if len(class_of_interest) == 0: |
| |
| pred = None |
| else: |
| |
| pred = pred_mtx[:, tt_id] |
| |
| |
| if pred is None: |
| results['per_species_average_precision_all'][tt_id] = np.nan |
| else: |
| gt = np.zeros(obs_locs.shape[0], dtype=np.float32) |
| gt[self.data['taxa_presence'][str(tt)]] = 1.0 |
| |
| results['per_species_average_precision_all'][tt_id] = average_precision_score(gt, pred) |
| |
| valid_taxa = ~np.isnan(results['per_species_average_precision_all']) |
|
|
| |
| per_species_average_precision_valid = results['per_species_average_precision_all'][valid_taxa] |
| results['mean_average_precision'] = per_species_average_precision_valid.mean() |
| results['num_eval_species_w_valid_ap'] = valid_taxa.sum() |
| results['num_eval_species_total'] = len(self.taxa) |
| return results |
|
|
| def report(self, results): |
| for field in ['mean_average_precision', 'num_eval_species_w_valid_ap', 'num_eval_species_total']: |
| print(f'{field}: {results[field]}') |
|
|
| class EvaluatorGeoPrior: |
|
|
| def __init__(self, train_params, eval_params): |
| |
| self.train_params = train_params |
| self.eval_params = eval_params |
| with open('paths.json', 'r') as f: |
| paths = json.load(f) |
| |
| self.data = np.load(os.path.join(paths['geo_prior'], 'geo_prior_model_preds.npz')) |
| print('\n', self.data['probs'].shape[0], 'total test observations') |
| |
| meta = pd.read_csv(os.path.join(paths['geo_prior'], 'geo_prior_model_meta.csv')) |
| self.obs_locs = np.vstack((meta['longitude'].values, meta['latitude'].values)).T.astype(np.float32) |
| |
| self.taxon_map = self.find_mapping_between_models(self.data['model_to_taxa'], self.train_params['class_to_taxa']) |
| print(self.taxon_map.shape[0], 'out of', len(self.data['model_to_taxa']), 'taxa in both vision and geo models') |
|
|
| def find_mapping_between_models(self, vision_taxa, geo_taxa): |
| |
| |
| |
| taxon_map = np.ones((vision_taxa.shape[0], 2), dtype=np.int32)*-1 |
| taxon_map[:, 0] = np.arange(vision_taxa.shape[0]) |
| geo_taxa_arr = np.array(geo_taxa) |
| for tt_id, tt in enumerate(vision_taxa): |
| ind = np.where(geo_taxa_arr==tt)[0] |
| if len(ind) > 0: |
| taxon_map[tt_id, 1] = ind[0] |
| inds = np.where(taxon_map[:, 1]>-1)[0] |
| taxon_map = taxon_map[inds, :] |
| return taxon_map |
| |
| def convert_to_inat_vision_order(self, geo_pred_ip, vision_top_k_prob, vision_top_k_inds, vision_taxa, taxon_map): |
| |
| vision_pred = np.zeros((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32) |
| geo_pred = np.ones((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32) |
| vision_pred[np.arange(vision_pred.shape[0])[..., np.newaxis], vision_top_k_inds] = vision_top_k_prob |
|
|
| geo_pred[:, taxon_map[:, 0]] = geo_pred_ip[:, taxon_map[:, 1]] |
|
|
| return geo_pred, vision_pred |
| |
| def run_evaluation(self, model, enc): |
| results = {} |
|
|
| |
| batch_start = np.hstack((np.arange(0, self.data['probs'].shape[0], self.eval_params['batch_size']), self.data['probs'].shape[0])) |
| correct_pred = np.zeros(self.data['probs'].shape[0]) |
|
|
| print('\nbid\t w geo\t wo geo') |
| for bb_id, bb in enumerate(range(len(batch_start)-1)): |
| batch_inds = np.arange(batch_start[bb], batch_start[bb+1]) |
|
|
| vision_probs = self.data['probs'][batch_inds, :] |
| vision_inds = self.data['inds'][batch_inds, :] |
| gt = self.data['labels'][batch_inds] |
|
|
| obs_locs_batch = torch.from_numpy(self.obs_locs[batch_inds, :]).to(self.eval_params['device']) |
| loc_feat = enc.encode(obs_locs_batch) |
|
|
| with torch.no_grad(): |
| geo_pred = model(loc_feat).cpu().numpy() |
|
|
| geo_pred, vision_pred = self.convert_to_inat_vision_order(geo_pred, vision_probs, vision_inds, |
| self.data['model_to_taxa'], self.taxon_map) |
|
|
| comb_pred = np.argmax(vision_pred*geo_pred, 1) |
| comb_pred = (comb_pred==gt) |
| correct_pred[batch_inds] = comb_pred |
| |
| results['vision_only_top_1'] = float((self.data['inds'][:, -1] == self.data['labels']).mean()) |
| results['vision_geo_top_1'] = float(correct_pred.mean()) |
| return results |
| |
| def report(self, results): |
| print('\nOverall accuracy vision only model', round(results['vision_only_top_1'], 3)) |
| print('Overall accuracy of geo model ', round(results['vision_geo_top_1'], 3)) |
| print('Gain ', round(results['vision_geo_top_1'] - results['vision_only_top_1'], 3)) |
|
|
| class EvaluatorGeoFeature: |
|
|
| def __init__(self, train_params, eval_params): |
| self.train_params = train_params |
| self.eval_params = eval_params |
| with open('paths.json', 'r') as f: |
| paths = json.load(f) |
| self.data_path = paths['geo_feature'] |
| self.country_mask = tifffile.imread(os.path.join(paths['masks'], 'USA_MASK.tif')) == 1 |
| self.raster_names = ['ABOVE_GROUND_CARBON', 'ELEVATION', 'LEAF_AREA_INDEX', 'NON_TREE_VEGITATED', 'NOT_VEGITATED', 'POPULATION_DENSITY', 'SNOW_COVER', 'SOIL_MOISTURE', 'TREE_COVER'] |
| self.raster_names_log_transform = ['POPULATION_DENSITY'] |
|
|
| def load_raster(self, raster_name, log_transform=False): |
| raster = tifffile.imread(os.path.join(self.data_path, raster_name + '.tif')).astype(np.float32) |
| valid_mask = ~np.isnan(raster).copy() & self.country_mask |
| |
| if log_transform: |
| raster[valid_mask] = np.log1p(raster[valid_mask] - raster[valid_mask].min()) |
| |
| raster[valid_mask] -= raster[valid_mask].min() |
| raster[valid_mask] /= raster[valid_mask].max() |
| |
| return raster, valid_mask |
| |
| def get_split_labels(self, raster, split_ids, split_of_interest): |
| |
| inds_y, inds_x = np.where(split_ids==split_of_interest) |
| return raster[inds_y, inds_x] |
| |
| def get_split_feats(self, model, enc, split_ids, split_of_interest): |
| locs = utils.coord_grid(self.country_mask.shape, split_ids=split_ids, split_of_interest=split_of_interest) |
| locs = torch.from_numpy(locs).to(self.eval_params['device']) |
| locs_enc = enc.encode(locs) |
| with torch.no_grad(): |
| feats = model(locs_enc, return_feats=True).cpu().numpy() |
| return feats |
| |
| def run_evaluation(self, model, enc): |
| results = {} |
| for raster_name in self.raster_names: |
| do_log_transform = raster_name in self.raster_names_log_transform |
| raster, valid_mask = self.load_raster(raster_name, do_log_transform) |
| split_ids = utils.create_spatial_split(raster, valid_mask, cell_size=self.eval_params['cell_size']) |
| feats_train = self.get_split_feats(model, enc, split_ids=split_ids, split_of_interest=1) |
| feats_test = self.get_split_feats(model, enc, split_ids=split_ids, split_of_interest=2) |
| labels_train = self.get_split_labels(raster, split_ids, 1) |
| labels_test = self.get_split_labels(raster, split_ids, 2) |
| scaler = MinMaxScaler() |
| feats_train_scaled = scaler.fit_transform(feats_train) |
| feats_test_scaled = scaler.transform(feats_test) |
| clf = RidgeCV(alphas=(0.1, 1.0, 10.0), normalize=False, cv=10, fit_intercept=True, scoring='r2').fit(feats_train_scaled, labels_train) |
| train_score = clf.score(feats_train_scaled, labels_train) |
| test_score = clf.score(feats_test_scaled, labels_test) |
| results[f'train_r2_{raster_name}'] = float(train_score) |
| results[f'test_r2_{raster_name}'] = float(test_score) |
| results[f'alpha_{raster_name}'] = float(clf.alpha_) |
| return results |
| |
| def report(self, results): |
| report_fields = [x for x in results if 'test_r2' in x] |
| for field in report_fields: |
| print(f'{field}: {results[field]}') |
| print(np.mean([results[field] for field in report_fields])) |
|
|
| def launch_eval_run(overrides): |
|
|
| eval_params = setup.get_default_params_eval(overrides) |
|
|
| |
| eval_params['model_path'] = os.path.join(eval_params['exp_base'], eval_params['experiment_name'], eval_params['ckp_name']) |
| train_params = torch.load(eval_params['model_path'], map_location='cpu') |
| model = models.get_model(train_params['params']) |
| model.load_state_dict(train_params['state_dict'], strict=True) |
| model = model.to(eval_params['device']) |
| model.eval() |
|
|
| |
| if train_params['params']['input_enc'] in ['env', 'sin_cos_env']: |
| raster = datasets.load_env().to(eval_params['device']) |
| else: |
| raster = None |
| enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster) |
|
|
| t = time.time() |
| if eval_params['eval_type'] == 'snt': |
| eval_params['split'] = 'test' |
| eval_params['val_frac'] = 0.50 |
| eval_params['split_seed'] = 7499 |
| evaluator = EvaluatorSNT(train_params['params'], eval_params) |
| results = evaluator.run_evaluation(model, enc) |
| evaluator.report(results) |
| elif eval_params['eval_type'] == 'iucn': |
| evaluator = EvaluatorIUCN(train_params['params'], eval_params) |
| results = evaluator.run_evaluation(model, enc) |
| evaluator.report(results) |
| elif eval_params['eval_type'] == 'geo_prior': |
| evaluator = EvaluatorGeoPrior(train_params['params'], eval_params) |
| results = evaluator.run_evaluation(model, enc) |
| evaluator.report(results) |
| elif eval_params['eval_type'] == 'geo_feature': |
| evaluator = EvaluatorGeoFeature(train_params['params'], eval_params) |
| results = evaluator.run_evaluation(model, enc) |
| evaluator.report(results) |
| else: |
| raise NotImplementedError('Eval type not implemented.') |
| print(f'evaluation completed in {np.around((time.time()-t)/60, 1)} min') |
| return results |
|
|