| import argparse |
| import json |
| from copy import deepcopy |
| from pathlib import Path |
| from typing import Optional, cast |
|
|
| import pandas as pd |
| from torch.utils.data import DataLoader, TensorDataset |
|
|
| from src import eval as e |
| from src.data import Dataset, Normalizer |
| from src.data.config import CONFIG_FILENAME, NORMALIZATION_DICT_FILENAME, OUTPUT_FOLDER |
| from src.eval.baseline_models import get_model_config |
| from src.galileo import Encoder, GalileoWrapper |
| from src.utils import config_dir, device |
|
|
| PARTITIONS = [ |
| "default", |
| "0.20x_train", |
| "0.05x_train", |
| "0.01x_train", |
| ] |
|
|
| RUNS_PER_SPLIT = {"LP": 5} |
|
|
| EVAL_MODES = ["KNN-5", "KNN-20", "LP"] |
|
|
| argparser = argparse.ArgumentParser() |
| argparser.add_argument("--model", type=str, default="mmearth_atto") |
| argparser.add_argument("--benchmark", type=str, default="eurosat") |
| argparser.add_argument("--eval_mode", type=str, default="FT") |
| argparser.add_argument("--output_folder", type=str, default="") |
| argparser.add_argument("--weights_path", type=str, default="/stage/data/RS_baseline_models") |
| argparser.add_argument("--batch_size", type=int, default=128) |
| argparser.add_argument("--num_workers", type=int, default=12) |
| argparser.add_argument("--sweep_norms", dest="sweep_norms", action="store_true") |
| argparser.add_argument("--patch_size", type=int, default=4) |
| argparser.add_argument( |
| "--pastis_filepath", type=str, default="/stage/data/presto_eval_sets/pastis" |
| ) |
| argparser.add_argument("--mados_filepath", type=str, default="/stage/data/presto_eval_sets/mados") |
| argparser.add_argument( |
| "--floods_filepath", type=str, default="/stage/data/presto_eval_sets/floods" |
| ) |
| argparser.add_argument( |
| "--breizhcrops_filepath", type=str, default="/stage/data/presto_eval_sets/breizhcrops" |
| ) |
| argparser.add_argument("--temporal_pooling", type=str, default="mean") |
|
|
| argparser.add_argument("--norm_dataset", type=str, default=None) |
| argparser.add_argument("--norm_std_multiplier", type=float, default=None) |
| argparser.add_argument("--lr", type=float, default=None) |
| argparser.add_argument("--run_id", type=int, default=None) |
| argparser.add_argument("--partition", type=str, default=None) |
|
|
| argparser.set_defaults(sweep_norms=False) |
| args = argparser.parse_args().__dict__ |
|
|
| if args["output_folder"] == "": |
| output_folder = OUTPUT_FOLDER |
| else: |
| output_folder = Path(args["output_folder"]) |
|
|
| weights_path = Path(args["weights_path"]) |
| if not weights_path.exists(): |
| raise ValueError(f"{weights_path} does not exist") |
|
|
| model_name = args["model"] |
| eval_mode = args["eval_mode"] |
| benchmark_name = args["benchmark"] |
| batch_size = args["batch_size"] |
| num_workers = args["num_workers"] |
| patch_size = args["patch_size"] |
| sweep_norms = args["sweep_norms"] |
| pastis_filepath = Path(args["pastis_filepath"]) |
| mados_filepath = Path(args["mados_filepath"]) |
| floods_filepath = Path(args["floods_filepath"]) |
| breizhcrops_filepath = Path(args["breizhcrops_filepath"]) |
| temporal_pooling = args["temporal_pooling"] |
| norm_dataset = args["norm_dataset"] |
| norm_std_multiplier = args["norm_std_multiplier"] |
| arg_lr = args["lr"] |
| arg_run_id = args["run_id"] |
| arg_partition = args["partition"] |
|
|
| if sweep_norms: |
| if norm_dataset is not None: |
| raise ValueError(f"Can't use norm_dataset {norm_dataset} if sweeping norms") |
| if norm_std_multiplier is not None: |
| raise ValueError(f"Can't use std_multiplier {norm_std_multiplier} if sweeping norms") |
| if (norm_dataset is not None) and (norm_dataset != "satlas"): |
| if norm_std_multiplier is None: |
| raise ValueError("If norm_dataset is not None, norm_std_multiplier must be passed") |
|
|
|
|
| BENCHMARKS = { |
| "eurosat": { |
| "class": e.GeobenchDataset, |
| "kwargs": {"dataset_config_file": "m-eurosat.json"}, |
| "config": "m-eurosat.json", |
| }, |
| "so2sat": { |
| "class": e.GeobenchDataset, |
| "kwargs": {"dataset_config_file": "m-so2sat.json"}, |
| "config": "m-so2sat.json", |
| }, |
| "brick-kiln": { |
| "class": e.GeobenchDataset, |
| "kwargs": {"dataset_config_file": "m-brick-kiln.json"}, |
| "config": "m-brick-kiln.json", |
| }, |
| "bigearthnet": { |
| "class": e.GeobenchDataset, |
| "kwargs": {"dataset_config_file": "m-bigearthnet.json"}, |
| "config": "m-bigearthnet.json", |
| }, |
| "cashew-plant": { |
| "class": e.GeobenchDataset, |
| "kwargs": {"dataset_config_file": "m-cashew-plant.json"}, |
| "config": "m-cashew-plant.json", |
| }, |
| "sa-crop-type": { |
| "class": e.GeobenchDataset, |
| "kwargs": {"dataset_config_file": "m-sa-crop-type.json"}, |
| "config": "m-sa-crop-type.json", |
| }, |
| "pastis": { |
| "class": e.PASTISDataset, |
| "kwargs": {"path_to_splits": pastis_filepath}, |
| "config": "pastis.json", |
| }, |
| "mados": { |
| "class": e.MADOSDataset, |
| "kwargs": {"path_to_splits": mados_filepath}, |
| "config": "mados.json", |
| }, |
| "floods": { |
| "class": e.Sen1Floods11Dataset, |
| "kwargs": {"path_to_splits": floods_filepath}, |
| "config": "sen1floods11.json", |
| }, |
| "cropharvest_val": {"config": "cropharvest.json"}, |
| "cropharvest_togo": {"config": "cropharvest.json"}, |
| "cropharvest_kenya": {"config": "cropharvest.json"}, |
| "cropharvest_brazil": {"config": "cropharvest.json"}, |
| "breizhcrops": { |
| "class": e.BreizhCropsDataset, |
| "kwargs": {"path_to_splits": breizhcrops_filepath}, |
| "config": "breizhcrops.json", |
| }, |
| } |
|
|
| if eval_mode not in EVAL_MODES: |
| raise ValueError(f"{eval_mode} not in {EVAL_MODES}") |
| if benchmark_name not in BENCHMARKS.keys(): |
| raise ValueError(f"{benchmark_name} not in {BENCHMARKS.keys()}") |
|
|
| model_name_for_savepath = model_name |
| savepath_prefix = f"{model_name_for_savepath}_{eval_mode}_{benchmark_name}_{patch_size}" |
| if benchmark_name == "pastis": |
| savepath_prefix = f"{savepath_prefix}_tp{temporal_pooling}" |
|
|
| if arg_partition is not None: |
| if arg_partition not in PARTITIONS: |
| raise ValueError(f"{arg_partition} not in PARTITIONS") |
| print(f"Replacing full partition sweep with {arg_partition}") |
| partitions_to_use = [arg_partition] |
| else: |
| partitions_to_use = deepcopy(PARTITIONS) |
|
|
| |
| partitions_to_use = partitions_to_use if benchmark_name != "breizhcrops" else ["default"] |
|
|
| if eval_mode == "LP": |
| if arg_run_id is None: |
| runs_to_use = list(range(RUNS_PER_SPLIT[eval_mode])) |
| else: |
| print(f"Replacing full run sweep with {arg_run_id}") |
| runs_to_use = [arg_run_id] |
| savepath_prefix = f"{savepath_prefix}_runid{arg_run_id}" |
|
|
| if arg_lr is None: |
| if eval_mode == "LP": |
| lrs_to_use = e.PROBING_LRs[eval_mode] |
| else: |
| lrs_to_use = e.FT_LRs |
| else: |
| print(f"Replacing full lr sweep with {arg_lr}") |
| lrs_to_use = [arg_lr] |
| savepath_prefix = f"{savepath_prefix}_lr{arg_lr}" |
|
|
| savepath = output_folder / f"{savepath_prefix}.csv" |
|
|
| if savepath.exists(): |
| results = pd.read_csv(savepath) |
| else: |
| results = None |
|
|
| if "cropharvest" not in benchmark_name: |
| benchmark = cast(dict, BENCHMARKS[benchmark_name]) |
| config_name: str = cast(str, benchmark["config"]) |
| with (Path("src/eval/configs_v2") / Path(config_name)).open("r") as f: |
| config = json.load(f) |
| do_pool = True if config["task_type"] == "cls" else False |
|
|
| |
| s1_or_s2 = "s1" if "s1" in config["band_info"].keys() else "s2" |
|
|
| try: |
| model_dict = get_model_config(model_name, weights_path, s1_or_s2) |
| if ("satmae" in model_name) and (config_name == "m-bigearthnet.json"): |
| print(f"Updating position embeddings for {model_name}") |
| |
| model_dict["args"]["img_size"] = 120 |
| encoder = model_dict["model"]( |
| **model_dict["args"], do_pool=do_pool, temporal_pooling=temporal_pooling |
| ).to(device) |
| except KeyError: |
| encoder = GalileoWrapper( |
| pretrained_path=weights_path, |
| patch_size=patch_size, |
| do_pool=do_pool, |
| add_layernorm_on_exit=False if eval_mode == "FT" else True, |
| ).to(device) |
|
|
| if benchmark_name == "mados": |
| |
| default_norm_strat = { |
| "stats": "dataset", |
| "type": "norm_no_clip", |
| "std_multiplier": 2.0, |
| } |
| elif s1_or_s2 == "s1": |
| |
| default_norm_strat = { |
| "stats": "OURS_S1", |
| "type": "norm_no_clip", |
| "std_multiplier": 2.0, |
| } |
| else: |
| |
| default_norm_strat = {"stats": "OURS", "type": "norm_no_clip", "std_multiplier": 2.0} |
|
|
| norms_for_model = e.get_all_norm_strats(model_name, s1_or_s2) |
| if sweep_norms: |
| norms_to_use = norms_for_model |
| else: |
| if norm_dataset is not None: |
| if norm_dataset == "satlas": |
| norms_to_use = [{"type": "satlas"}] |
| else: |
| norm_type, _ = e.norm_type_from_model_name(model_name) |
| if norm_std_multiplier is not None: |
| norms_to_use = [ |
| { |
| "type": norm_type, |
| "stats": norm_dataset, |
| "std_multiplier": norm_std_multiplier, |
| } |
| ] |
| else: |
| norms_to_use = [ |
| norm for norm in norms_for_model if norm.get("stats", "") == norm_dataset |
| ] |
|
|
| else: |
| |
| if "models" in config: |
| if model_name in config["models"]: |
| norms_to_use = [None] |
| else: |
| print(f"No norm strat; using default: {default_norm_strat}") |
| norms_to_use = [default_norm_strat] |
| else: |
| print(f"No norm strat; using default: {default_norm_strat}") |
| norms_to_use = [default_norm_strat] |
|
|
| for train_partition in partitions_to_use: |
| for norm_strat in norms_to_use: |
| print( |
| f"Running {train_partition} for {model_name}, {benchmark_name} with norm_strat {norm_strat}" |
| ) |
| loaders = e.get_loaders( |
| benchmark, |
| config, |
| model_name, |
| args["batch_size"], |
| args["num_workers"], |
| eval_mode, |
| train_partition=train_partition, |
| norm_ops=norm_strat, |
| ) |
| print(f"In eval, {len(loaders['train'])}") |
|
|
| if eval_mode in ["KNN-5", "KNN-20", "K-Means"]: |
| if config["task_type"] != "cls": |
| raise ValueError( |
| f"{eval_mode} not supported for {benchmark_name} of task type cls" |
| ) |
| if (results is not None) and ( |
| len( |
| results[ |
| (results["partition"] == train_partition) |
| & (results["norm_op"] == str(norm_strat)) |
| ] |
| ) |
| > 0 |
| ): |
| print(f"{train_partition}, {norm_strat} in results - skipping") |
| continue |
| train_embeddings, train_labels = e.get_embeddings( |
| data_loader=loaders["train"], model=encoder, device=device |
| ) |
| test_embeddings, test_labels = e.get_embeddings( |
| data_loader=loaders["test"], model=encoder, device=device |
| ) |
| test_result = e.run_knn( |
| eval_type=eval_mode, |
| train_embeddings=train_embeddings, |
| train_labels=train_labels, |
| test_embeddings=test_embeddings, |
| test_labels=test_labels, |
| num_classes=config["num_classes"], |
| is_multilabel=config["is_multilabel"], |
| device=device, |
| ) |
|
|
| val_embeddings, val_labels = e.get_embeddings( |
| data_loader=loaders["valid"], model=encoder, device=device |
| ) |
| val_result = e.run_knn( |
| eval_type=eval_mode, |
| train_embeddings=train_embeddings, |
| train_labels=train_labels, |
| test_embeddings=val_embeddings, |
| test_labels=val_labels, |
| num_classes=config["num_classes"], |
| is_multilabel=config["is_multilabel"], |
| device=device, |
| ) |
| new_df = pd.DataFrame( |
| { |
| "model_name": [model_name], |
| "benchmark": [benchmark_name], |
| "partition": [train_partition], |
| "test": [test_result], |
| "val": [val_result], |
| "norm_op": [str(norm_strat)], |
| } |
| ) |
| print(new_df) |
| if results is not None: |
| results = pd.concat([results, new_df], axis=0) |
| else: |
| results = new_df |
|
|
| results.to_csv(savepath, index=False) |
|
|
| elif eval_mode == "LP": |
| if (model_name == "anysat") and config["task_type"] == "seg": |
| train_subsample: Optional[float] = 1 / 16 |
| else: |
| train_subsample = None |
| t_e, t_l = e.get_embeddings( |
| data_loader=loaders["train"], |
| model=encoder, |
| device=device, |
| subsample_tokens=train_subsample, |
| ) |
| v_e, v_l = e.get_embeddings( |
| data_loader=loaders["valid"], model=encoder, device=device |
| ) |
| te_e, te_l = e.get_embeddings( |
| data_loader=loaders["test"], model=encoder, device=device |
| ) |
| embedding_loaders = { |
| "train": DataLoader( |
| TensorDataset(t_e, t_l), batch_size=batch_size, shuffle=True |
| ), |
| "valid": DataLoader( |
| TensorDataset(v_e, v_l), batch_size=batch_size, shuffle=False |
| ), |
| "test": DataLoader( |
| TensorDataset(te_e, te_l), batch_size=batch_size, shuffle=False |
| ), |
| } |
|
|
| for run_id in runs_to_use: |
| for lr in lrs_to_use: |
| if (results is not None) and ( |
| len( |
| results[ |
| (results["partition"] == train_partition) |
| & (results["lr"] == lr) |
| & (results["run_id"] == run_id) |
| & (results["norm_op"] == str(norm_strat)) |
| ] |
| ) |
| > 0 |
| ): |
| print(f"{train_partition}, {run_id}, {lr} in results - skipping") |
| continue |
| if config["task_type"] == "cls": |
| val, test = e.train_and_eval_probe_cls( |
| lr=lr, |
| config=config, |
| loaders=embedding_loaders, |
| in_features=encoder.dim, |
| device=device, |
| ) |
| elif config["task_type"] == "seg": |
| val, test = e.train_and_eval_probe_seg( |
| lr=lr, |
| config=config, |
| loaders=embedding_loaders, |
| in_features=encoder.dim, |
| grid_size=encoder.grid_size, |
| device=device, |
| ) |
| else: |
| raise ValueError( |
| f"task_type must be cls or seg, not {config['task_type']}" |
| ) |
|
|
| new_df = pd.DataFrame( |
| { |
| "model_name": [model_name], |
| "benchmark": [benchmark_name], |
| "partition": [train_partition], |
| "val": [val], |
| "test": [test], |
| "lr": [lr], |
| "run_id": [run_id], |
| "norm_op": [str(norm_strat)], |
| } |
| ) |
| print(new_df) |
| if results is not None: |
| results = pd.concat([results, new_df], axis=0) |
| else: |
| results = new_df |
|
|
| results.to_csv(savepath, index=False) |
|
|
| elif eval_mode == "FT": |
| cache_dir = output_folder / "ft_cache" |
| cache_dir.mkdir(exist_ok=True) |
| for run_id in runs_to_use: |
| for lr in lrs_to_use: |
| if (results is not None) and ( |
| len( |
| results[ |
| (results["partition"] == train_partition) |
| & (results["lr"] == lr) |
| & (results["run_id"] == run_id) |
| & (results["norm_op"] == str(norm_strat)) |
| ] |
| ) |
| > 0 |
| ): |
| print(f"{train_partition}, {run_id}, {lr} in results - skipping") |
| continue |
| if config["task_type"] == "cls": |
| val, test = e.finetune_and_eval_cls( |
| lr=lr, |
| config=config, |
| loaders=loaders, |
| encoder=encoder, |
| device=device, |
| cache_dir=cache_dir, |
| ) |
| elif config["task_type"] == "seg": |
| val, test = e.finetune_and_eval_seg( |
| lr=lr, |
| config=config, |
| loaders=loaders, |
| encoder=encoder, |
| device=device, |
| ) |
| new_df = pd.DataFrame( |
| { |
| "model_name": [model_name], |
| "benchmark": [benchmark_name], |
| "partition": [train_partition], |
| "val": [val], |
| "test": [test], |
| "lr": [lr], |
| "run_id": [run_id], |
| "norm_op": [str(norm_strat)], |
| } |
| ) |
| print(new_df) |
| if results is not None: |
| results = pd.concat([results, new_df], axis=0) |
| else: |
| results = new_df |
|
|
| results.to_csv(savepath, index=False) |
| else: |
| input_name_to_args = { |
| "cropharvest_val": {"country": "Togo", "eval_mode": "val"}, |
| "cropharvest_togo": {"country": "Togo", "eval_mode": "test"}, |
| "cropharvest_kenya": {"country": "Kenya", "eval_mode": "test"}, |
| "cropharvest_brazil": {"country": "Brazil", "eval_mode": "test"}, |
| } |
|
|
| if model_name == "presto": |
| model_name_to_save = "presto" |
| model_dict = get_model_config(model_name, weights_path, s1_or_s2="cropharvest") |
| encoder = model_dict["model"].load_pretrained().encoder.to(device) |
| val_task_ts = e.PrestoBinaryCropHarvestEval( |
| **input_name_to_args[benchmark_name], |
| normalizer=e.PrestoNormalizer(), |
| ) |
| elif model_name == "anysat": |
| model_name_to_save = "anysat" |
| |
| model_dict = get_model_config(model_name, weights_path, s1_or_s2="s2") |
| encoder = model_dict["model"](do_pool=True).to(device) |
| val_task_ts = e.AnySatBinaryCropHarvestEval( |
| **input_name_to_args[benchmark_name], |
| normalizer=e.AnySatNormalizer(), |
| ) |
| else: |
| encoder = Encoder.load_from_folder(weights_path).to(device) |
| |
| config_path = weights_path / CONFIG_FILENAME |
| with config_path.open("r") as f: |
| config = json.load(f) |
| ignore_band_groups = config["training"].get("ignore_band_groups", None) |
| print("Running with default normalization (OURS, std=2)") |
| val_task_ts = e.BinaryCropHarvestEval( |
| normalizer=Normalizer( |
| normalizing_dicts=Dataset.load_normalization_values( |
| config_dir / NORMALIZATION_DICT_FILENAME |
| ), |
| std_multiplier=2, |
| ), |
| **input_name_to_args[benchmark_name], |
| ignore_band_groups=ignore_band_groups, |
| ) |
|
|
| for partition in partitions_to_use: |
| partition_to_float = { |
| "default": None, |
| } |
| if (results is not None) and (len(results[results["partition"] == partition]) > 0): |
| print(f"{partition} in results - skipping") |
| continue |
| elif partition not in partition_to_float.keys(): |
| print(f"partition {partition} too small for cropharvest - skipping") |
| continue |
|
|
| output = val_task_ts.evaluate_model_on_task( |
| encoder, |
| model_modes=["KNNat20Classifier", "LogisticRegression"], |
| fraction=partition_to_float[partition], |
| ) |
| |
| output_keys = list(output.keys()) |
| k_key = [k for k in output_keys if "KNNat20" in k and "f1" in k and not k.endswith("_c")][ |
| 0 |
| ] |
| lr_key = [ |
| k |
| for k in output_keys |
| if "LogisticRegression" in k and "f1" in k and not k.endswith("_c") |
| ][0] |
| |
| new_df = pd.DataFrame( |
| { |
| "model_name": [model_name_to_save], |
| "benchmark": [benchmark_name], |
| "knn-20": [output[k_key]], |
| "lr": [output[lr_key]], |
| "partition": [partition], |
| } |
| ) |
| print(new_df) |
| if results is not None: |
| results = pd.concat([results, new_df], axis=0) |
| else: |
| results = new_df |
|
|
| results.to_csv(savepath, index=False) |
|
|