| """ |
| Re-evaluate BFM Large model AFTER GeoLifeCLEF finetuning on BioCube 28 species. |
| |
| This script: |
| 1. Builds the BFM Large base model |
| 2. Wraps with BFMRaw (same as finetuning) |
| 3. Loads finetuned checkpoint |
| 4. Evaluates on BioCube test data using paper metrics (MAE, RMSE, F1, Sorensen) |
| |
| Usage: |
| conda run -n bfm python eval_finetuned_large.py 2>&1 | tee eval_finetuned_large.log |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import time |
| import warnings |
| from collections import defaultdict |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
| warnings.filterwarnings("ignore") |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parent |
| BFM_MODEL_DIR = PROJECT_ROOT / "bfm-model" |
| sys.path.insert(0, str(BFM_MODEL_DIR)) |
|
|
| SAFETENSORS_PATH = PROJECT_ROOT / "bfm-pretrained" / "bfm-pretrain-large.safetensors" |
| FINETUNE_CHECKPOINT = PROJECT_ROOT / "outputs_finetune_large" / "checkpoints" / "best_checkpoint.pth" |
| STATS_PATH = BFM_MODEL_DIR / "batch_statistics" / "monthly_batches_stats_splitted_channels.json" |
| TEST_DATA_DIR = PROJECT_ROOT / "dataset" / "test" |
| RESULTS_DIR = PROJECT_ROOT / "env_sdm_results" |
| CHECKPOINT_DIR = RESULTS_DIR / "finetuned_eval_checkpoints" |
|
|
| |
| NUM_FINETUNE_SPECIES = 500 |
|
|
| |
| MODEL_CONFIG = { |
| "embed_dim": 512, "depth": 10, "patch_size": 8, |
| "swin_backbone_size": "large", "num_heads": 16, "head_dim": 64, |
| "H": 160, "W": 280, "num_latent_tokens": 8, |
| "perceiver_latents": 16100, "T": 2, |
| } |
|
|
| SWIN_LARGE_CONFIG = { |
| "swin_encoder_depths": (2, 2, 2), "swin_encoder_num_heads": (8, 16, 32), |
| "swin_decoder_depths": (2, 2, 2), "swin_decoder_num_heads": (32, 16, 8), |
| "swin_window_size": (1, 4, 5), "swin_mlp_ratio": 4.0, |
| "swin_qkv_bias": True, "swin_drop_rate": 0.0, |
| "swin_attn_drop_rate": 0.0, "swin_drop_path_rate": 0.1, |
| "use_lora": False, |
| } |
|
|
| |
| SPECIES_VARS = [ |
| "1340361", "1340503", "1536449", "1898286", "1920506", "2430567", |
| "2431885", "2433433", "2434779", "2435240", "2435261", "2437394", |
| "2441454", "2473958", "2491534", "2891770", "3034825", "4408498", |
| "5218786", "5219073", "5219173", "5219219", "5844449", "8002952", |
| "8077224", "8894817", "8909809", "9809229", |
| ] |
| SURFACE_VARS = ["t2m", "msl", "slt", "z", "u10", "v10", "lsm"] |
| EDAPHIC_VARS = ["swvl1", "swvl2", "stl1", "stl2"] |
| ATMOS_VARS = ["z", "t", "u", "v", "q"] |
| CLIMATE_VARS = [ |
| "smlt", "tp", "csfr", "avg_sdswrf", "avg_snswrf", "avg_snlwrf", |
| "avg_tprate", "avg_sdswrfcs", "sd", "t2m", "d2m", |
| ] |
| VEGETATION_VARS = ["NDVI"] |
| LAND_VARS = ["Land"] |
| AGRICULTURE_VARS = ["Agriculture", "Arable", "Cropland"] |
| FOREST_VARS = ["Forest"] |
| REDLIST_VARS = ["RLI"] |
| MISC_VARS = ["avg_slhtf", "avg_pevr"] |
| ATMOS_LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50] |
|
|
| VARIABLE_GROUPS = [ |
| "surface_variables", "edaphic_variables", "atmospheric_variables", |
| "climate_variables", "species_variables", "vegetation_variables", |
| "land_variables", "agriculture_variables", "forest_variables", |
| "redlist_variables", "misc_variables", |
| ] |
|
|
| |
| PAPER_TABLE9_SPECIES_MAE = { |
| "1340361": 0.000064, "1340503": 0.064850, "1536449": 0.034896, |
| "1898286": 0.359776, "1920506": 0.272244, "2430567": 0.001872, |
| "2431885": 0.012577, "2433433": 0.001001, "2434779": 0.000018, |
| "2435240": 0.005042, "2435261": 0.000074, "2437394": 0.000220, |
| "2441454": 0.005205, "2473958": 0.051274, "2491534": 0.294113, |
| "2891770": 0.046475, "3034825": 0.023135, "4408498": 0.016373, |
| "5218786": 0.007713, "5219073": 0.000831, "5219173": 0.006481, |
| "5219219": 0.000296, "5844449": 0.006944, "8002952": 0.015838, |
| "8077224": 0.361090, "8894817": 0.000336, "8909809": 0.006676, |
| "9809229": 0.745430, |
| } |
|
|
|
|
| def build_base_model(): |
| """Build BFM Large model.""" |
| from bfm_model.bfm.model import BFM |
| model = BFM( |
| surface_vars=SURFACE_VARS, edaphic_vars=EDAPHIC_VARS, |
| atmos_vars=ATMOS_VARS, climate_vars=CLIMATE_VARS, |
| species_vars=SPECIES_VARS, vegetation_vars=VEGETATION_VARS, |
| land_vars=LAND_VARS, agriculture_vars=AGRICULTURE_VARS, |
| forest_vars=FOREST_VARS, redlist_vars=REDLIST_VARS, |
| misc_vars=MISC_VARS, atmos_levels=ATMOS_LEVELS, |
| species_num=len(SPECIES_VARS), |
| H=MODEL_CONFIG["H"], W=MODEL_CONFIG["W"], |
| num_latent_tokens=MODEL_CONFIG["num_latent_tokens"], |
| backbone_type="swin", patch_size=MODEL_CONFIG["patch_size"], |
| embed_dim=MODEL_CONFIG["embed_dim"], |
| num_heads=MODEL_CONFIG["num_heads"], |
| head_dim=MODEL_CONFIG["head_dim"], |
| depth=MODEL_CONFIG["depth"], |
| perceiver_latents=MODEL_CONFIG["perceiver_latents"], |
| batch_size=1, td_learning=True, use_mask="no", |
| **SWIN_LARGE_CONFIG, |
| ) |
| return model |
|
|
|
|
| def load_and_process_batch(fpath, stats, patch_size=8): |
| """Load a raw .pt batch file and process into Batch namedtuple.""" |
| from datetime import datetime |
| from bfm_model.bfm.dataloader_monthly import ( |
| Batch, Metadata, crop_variables, extract_atmospheric_levels, normalize_keys, |
| ) |
| from bfm_model.bfm.scaler import _rescale_recursive, dimensions_to_keep_monthly |
|
|
| data = torch.load(fpath, map_location="cpu", weights_only=False) |
|
|
| latitudes = data["batch_metadata"]["latitudes"] |
| longitudes = data["batch_metadata"]["longitudes"] |
| timestamps = data["batch_metadata"]["timestamp"] |
| pressure_levels = data["batch_metadata"]["pressure_levels"] |
| species_list = data["batch_metadata"]["species_list"] |
|
|
| H, W = len(latitudes), len(longitudes) |
| new_H = (H // patch_size) * patch_size |
| new_W = (W // patch_size) * patch_size |
|
|
| raw_species = {} |
| if "species_variables" in data: |
| for sp_id, sp_data in data["species_variables"].items(): |
| raw_species[sp_id] = sp_data[:, :new_H, :new_W].clone() |
|
|
| _rescale_recursive( |
| data, stats, |
| dimensions_to_keep_by_key=dimensions_to_keep_monthly, |
| mode="normalize", direction="scaled", |
| ) |
|
|
| surface_vars = crop_variables(data["surface_variables"], new_H, new_W) |
| edaphic_vars = crop_variables(data["edaphic_variables"], new_H, new_W) |
| atmospheric_vars = crop_variables(data["atmospheric_variables"], new_H, new_W) |
| climate_vars = crop_variables(data["climate_variables"], new_H, new_W) |
| species_vars = crop_variables(data["species_variables"], new_H, new_W) |
| land_vars = crop_variables(data["land_variables"], new_H, new_W) |
| agriculture_vars = crop_variables(data["agriculture_variables"], new_H, new_W) |
| forest_vars = crop_variables(data["forest_variables"], new_H, new_W) |
| vegetation_vars = crop_variables( |
| data["vegetation_variables"], new_H, new_W, handle_nans=True, nan_mode="zero" |
| ) |
| redlist_vars = crop_variables(data["redlist_variables"], new_H, new_W) |
| misc_vars = crop_variables(data["misc_variables"], new_H, new_W) |
|
|
| latitude_var = torch.tensor(latitudes[:new_H]) |
| longitude_var = torch.tensor(longitudes[:new_W]) |
|
|
| dt_format = "%Y-%m-%d %H:%M:%S" |
| start = datetime.strptime(timestamps[0], dt_format) |
| end = datetime.strptime(timestamps[1], dt_format) |
| lead_months = (end.year - start.year) * 12 + (end.month - start.month) + 1 |
|
|
| atmospheric_vars = extract_atmospheric_levels( |
| atmospheric_vars, pressure_levels, ATMOS_LEVELS, level_dim=1 |
| ) |
|
|
| metadata = Metadata( |
| latitudes=latitude_var, longitudes=longitude_var, |
| timestamp=timestamps, lead_time=lead_months, |
| pressure_levels=pressure_levels, species_list=species_list, |
| ) |
| species_vars = normalize_keys(species_vars) |
|
|
| batch = Batch( |
| batch_metadata=metadata, |
| surface_variables=surface_vars, edaphic_variables=edaphic_vars, |
| atmospheric_variables=atmospheric_vars, climate_variables=climate_vars, |
| species_variables=species_vars, vegetation_variables=vegetation_vars, |
| land_variables=land_vars, agriculture_variables=agriculture_vars, |
| forest_variables=forest_vars, redlist_variables=redlist_vars, |
| misc_variables=misc_vars, |
| ) |
| return batch, raw_species |
|
|
|
|
| def add_batch_dim(batch): |
| """Add batch dimension (B=1).""" |
| from bfm_model.bfm.dataloader_monthly import Batch, Metadata |
|
|
| def add_dim(var_dict): |
| return {k: v.unsqueeze(0) for k, v in var_dict.items()} |
|
|
| new_md = Metadata( |
| latitudes=batch.batch_metadata.latitudes.unsqueeze(0), |
| longitudes=batch.batch_metadata.longitudes.unsqueeze(0), |
| timestamp=[batch.batch_metadata.timestamp], |
| lead_time=batch.batch_metadata.lead_time, |
| pressure_levels=batch.batch_metadata.pressure_levels, |
| species_list=batch.batch_metadata.species_list, |
| ) |
| return Batch( |
| batch_metadata=new_md, |
| surface_variables=add_dim(batch.surface_variables), |
| edaphic_variables=add_dim(batch.edaphic_variables), |
| atmospheric_variables=add_dim(batch.atmospheric_variables), |
| climate_variables=add_dim(batch.climate_variables), |
| species_variables=add_dim(batch.species_variables), |
| vegetation_variables=add_dim(batch.vegetation_variables), |
| land_variables=add_dim(batch.land_variables), |
| agriculture_variables=add_dim(batch.agriculture_variables), |
| forest_variables=add_dim(batch.forest_variables), |
| redlist_variables=add_dim(batch.redlist_variables), |
| misc_variables=add_dim(batch.misc_variables), |
| ) |
|
|
|
|
| def ensure_f32(batch): |
| """Convert float64 tensors to float32.""" |
| from bfm_model.bfm.dataloader_monthly import Batch |
|
|
| def conv(d): |
| return {k: v.float() if v.dtype == torch.float64 else v for k, v in d.items()} |
|
|
| return Batch( |
| batch_metadata=batch.batch_metadata, |
| surface_variables=conv(batch.surface_variables), |
| edaphic_variables=conv(batch.edaphic_variables), |
| atmospheric_variables=conv(batch.atmospheric_variables), |
| climate_variables=conv(batch.climate_variables), |
| species_variables=conv(batch.species_variables), |
| vegetation_variables=conv(batch.vegetation_variables), |
| land_variables=conv(batch.land_variables), |
| agriculture_variables=conv(batch.agriculture_variables), |
| forest_variables=conv(batch.forest_variables), |
| redlist_variables=conv(batch.redlist_variables), |
| misc_variables=conv(batch.misc_variables), |
| ) |
|
|
|
|
| |
|
|
| def compute_mae_eq18(pred_dict, gt_dict, var_names): |
| per_var_mae = {} |
| for vname in var_names: |
| if vname not in pred_dict or vname not in gt_dict: |
| continue |
| p = pred_dict[vname][0].cpu().float() |
| g = gt_dict[vname][1].float() |
| mae = float(torch.mean(torch.abs(p - g)).item()) |
| per_var_mae[vname] = mae |
| avg_mae = float(np.mean(list(per_var_mae.values()))) if per_var_mae else None |
| return per_var_mae, avg_mae |
|
|
|
|
| def compute_rmse_eq19(pred_dict, gt_dict, var_names): |
| per_var_rmse = {} |
| for vname in var_names: |
| if vname not in pred_dict or vname not in gt_dict: |
| continue |
| p = pred_dict[vname][0].cpu().float() |
| g = gt_dict[vname][1].float() |
| rmse = float(torch.sqrt(torch.mean((p - g) ** 2)).item()) |
| per_var_rmse[vname] = rmse |
| avg_rmse = float(np.mean(list(per_var_rmse.values()))) if per_var_rmse else None |
| return per_var_rmse, avg_rmse |
|
|
|
|
| def compute_f1_eq20(species_pred, species_gt, threshold=0.0): |
| sp_ids = sorted(set(species_pred.keys()) & set(species_gt.keys())) |
| if not sp_ids: |
| return None |
| pred_stack = torch.stack([species_pred[sp][0].cpu().float() for sp in sp_ids]) |
| gt_stack = torch.stack([species_gt[sp][1].float() for sp in sp_ids]) |
| pred_bin = (pred_stack > threshold).int() |
| gt_bin = (gt_stack > 0).int() |
| tp = (pred_bin * gt_bin).sum(dim=0).float() |
| fp = (pred_bin * (1 - gt_bin)).sum(dim=0).float() |
| fn = ((1 - pred_bin) * gt_bin).sum(dim=0).float() |
| denom = tp + (fp + fn) / 2.0 |
| f1_map = torch.where(denom > 0, tp / denom, torch.ones_like(tp)) |
| return float(f1_map.mean().item()) |
|
|
|
|
| def compute_sorensen_eq22(species_pred, species_gt, threshold=0.0): |
| sp_ids = sorted(set(species_pred.keys()) & set(species_gt.keys())) |
| if not sp_ids: |
| return None, None, None |
| pred_stack = torch.stack([species_pred[sp][0].cpu().float() for sp in sp_ids]) |
| gt_stack = torch.stack([species_gt[sp][1].float() for sp in sp_ids]) |
| pred_bin = (pred_stack > threshold).int() |
| gt_bin = (gt_stack > 0).int() |
| c = (pred_bin * gt_bin).sum(dim=0).float() |
| b = ((1 - pred_bin) * gt_bin).sum(dim=0).float() |
| d = (pred_bin * (1 - gt_bin)).sum(dim=0).float() |
| denom = 2 * c + b + d |
| sorensen_map = torch.where(denom > 0, 2 * c / denom, torch.full_like(c, float("nan"))) |
| valid = sorensen_map[~torch.isnan(sorensen_map)] |
| if len(valid) > 0: |
| return float(valid.mean().item()), float(valid.std().item()), sorensen_map.numpy() |
| return None, None, None |
|
|
|
|
| def load_checkpoint_results(): |
| results = {} |
| if CHECKPOINT_DIR.exists(): |
| for f in sorted(CHECKPOINT_DIR.glob("pair_*.json")): |
| with open(f) as fh: |
| data = json.load(fh) |
| results[data["pair_index"]] = data |
| return results |
|
|
|
|
| def save_checkpoint_result(pair_idx, result): |
| CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) |
| path = CHECKPOINT_DIR / f"pair_{pair_idx:04d}.json" |
| with open(path, "w") as f: |
| json.dump(result, f, indent=2, default=str) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Evaluate Finetuned BFM Large on BioCube") |
| parser.add_argument("--max-pairs", type=int, default=None) |
| parser.add_argument("--resume", action="store_true") |
| parser.add_argument("--species-threshold", type=float, default=0.0) |
| parser.add_argument("--checkpoint", type=str, default=None, |
| help="Path to finetuned checkpoint (default: outputs_finetune_large/checkpoints/best_checkpoint.pth)") |
| args = parser.parse_args() |
|
|
| ckpt_path = Path(args.checkpoint) if args.checkpoint else FINETUNE_CHECKPOINT |
|
|
| print("=" * 70) |
| print("Finetuned BFM Large - BioCube 28 Species Re-evaluation") |
| print("=" * 70) |
| print(f"Finetuned checkpoint: {ckpt_path}") |
| print(f"Species threshold: {args.species_threshold}") |
| print("=" * 70) |
|
|
| assert SAFETENSORS_PATH.exists(), f"Pretrained weights not found: {SAFETENSORS_PATH}" |
| assert ckpt_path.exists(), f"Finetuned checkpoint not found: {ckpt_path}" |
| assert STATS_PATH.exists(), f"Stats not found: {STATS_PATH}" |
| assert TEST_DATA_DIR.exists(), f"Test data not found: {TEST_DATA_DIR}" |
|
|
| with open(STATS_PATH) as f: |
| stats = json.load(f) |
|
|
| existing_results = {} |
| if args.resume: |
| existing_results = load_checkpoint_results() |
| print(f"Resume: found {len(existing_results)} completed pairs") |
|
|
| |
| print("\nBuilding BFM Large model...") |
| from safetensors.torch import load_file |
| base_model = build_base_model() |
| total_params = sum(p.numel() for p in base_model.parameters()) |
| print(f"Base model parameters: {total_params / 1e6:.1f}M") |
|
|
| print(f"Loading pretrained weights from {SAFETENSORS_PATH.name}...") |
| state = load_file(str(SAFETENSORS_PATH), device="cpu") |
| base_model.load_state_dict(state, strict=False) |
|
|
| |
| print("Wrapping with BFMRaw...") |
| from bfm_finetune.bfm_mod import BFMRaw |
| model = BFMRaw(base_model=base_model, n_species=NUM_FINETUNE_SPECIES, mode="eval") |
|
|
| print(f"Loading finetuned weights from {ckpt_path.name}...") |
| ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=True) |
| missing, unexpected = model.load_state_dict(ckpt["model_state_dict"], strict=False) |
| print(f" Finetuned checkpoint: epoch={ckpt.get('epoch', '?')}, loss={ckpt.get('loss', '?')}") |
| print(f" Missing keys: {len(missing)}, Unexpected: {len(unexpected)}") |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
| model = model.to(device) |
| model.eval() |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| print("\nBuilding separate pretrained BFM for non-species evaluation...") |
| pretrained_model = build_base_model() |
| pretrained_model.load_state_dict(state, strict=False) |
| pretrained_model = pretrained_model.to(device) |
| pretrained_model.eval() |
|
|
| |
| test_files = sorted( |
| [f for f in TEST_DATA_DIR.iterdir() if f.suffix == ".pt"], |
| key=lambda x: x.name, |
| ) |
| total_pairs = len(test_files) - 1 |
| max_pairs = args.max_pairs if args.max_pairs else total_pairs |
| eval_pairs = min(max_pairs, total_pairs) |
| print(f"\nTest files: {len(test_files)}, pairs: {total_pairs}, evaluating: {eval_pairs}") |
|
|
| |
| all_var_mae = defaultdict(lambda: defaultdict(list)) |
| all_var_rmse = defaultdict(lambda: defaultdict(list)) |
| all_species_mae = defaultdict(list) |
| all_f1_scores = [] |
| all_sorensen_means = [] |
| pair_details = [] |
| skipped = 0 |
| start_time = time.time() |
|
|
| for idx, result in sorted(existing_results.items()): |
| if idx >= eval_pairs: |
| continue |
| for gname, gdata in result.get("per_group_mae", {}).items(): |
| for vname, mae_val in gdata.items(): |
| all_var_mae[gname][vname].append(mae_val) |
| for gname, gdata in result.get("per_group_rmse", {}).items(): |
| for vname, rmse_val in gdata.items(): |
| all_var_rmse[gname][vname].append(rmse_val) |
| for sp_id, mae_val in result.get("per_species_mae", {}).items(): |
| all_species_mae[sp_id].append(mae_val) |
| if result.get("f1") is not None: |
| all_f1_scores.append(result["f1"]) |
| if result.get("sorensen_mean") is not None: |
| all_sorensen_means.append(result["sorensen_mean"]) |
| pair_details.append(result) |
|
|
| |
| for i in range(eval_pairs): |
| if i in existing_results: |
| continue |
|
|
| elapsed_total = time.time() - start_time |
| remaining = (elapsed_total / max(len(pair_details), 1)) * (eval_pairs - len(pair_details)) |
| print(f"\n--- Pair {i + 1}/{eval_pairs} " |
| f"[done={len(pair_details)}, ETA={remaining/60:.0f}m] ---") |
| print(f" {test_files[i].name} -> {test_files[i + 1].name}") |
|
|
| try: |
| t0 = time.time() |
|
|
| x_batch, _ = load_and_process_batch(str(test_files[i]), stats, MODEL_CONFIG["patch_size"]) |
| y_batch, _ = load_and_process_batch(str(test_files[i + 1]), stats, MODEL_CONFIG["patch_size"]) |
|
|
| |
| x_batched = add_batch_dim(x_batch) |
| from bfm_model.bfm.dataloader_monthly import batch_to_device |
| x_batched_dev = batch_to_device(x_batched, device) |
| x_batched_dev = ensure_f32(x_batched_dev) |
|
|
| t_fwd = time.time() |
| with torch.no_grad(): |
| output_pretrained = pretrained_model( |
| x_batched_dev, lead_time=x_batch.batch_metadata.lead_time, batch_size=1 |
| ) |
| fwd_time_pretrained = time.time() - t_fwd |
|
|
| |
| |
| species_gt_dict = y_batch.species_variables |
| species_x_dict = x_batch.species_variables |
|
|
| |
| sp_ids_in_order = sorted(species_x_dict.keys()) |
| |
| |
| H, W = x_batch.species_variables[sp_ids_in_order[0]].shape[1:] |
| T = 2 |
| n_sp = NUM_FINETUNE_SPECIES |
| species_tensor = torch.zeros(T, n_sp, H, W) |
| for idx_sp, sp_id in enumerate(sp_ids_in_order): |
| if idx_sp < n_sp: |
| species_tensor[:, idx_sp, :, :] = species_x_dict[sp_id] |
|
|
| species_batch = { |
| "species_distribution": species_tensor.unsqueeze(0).to(device), |
| } |
|
|
| t_fwd2 = time.time() |
| with torch.no_grad(): |
| species_output = model(species_batch) |
| fwd_time_ft = time.time() - t_fwd2 |
| |
| species_pred_tensor = species_output[0] |
|
|
| print(f" Forward (pretrained): {fwd_time_pretrained:.1f}s, (finetuned): {fwd_time_ft:.1f}s") |
|
|
| pair_result = { |
| "pair_index": i, |
| "input_file": test_files[i].name, |
| "target_file": test_files[i + 1].name, |
| "forward_time_pretrained_s": round(fwd_time_pretrained, 1), |
| "forward_time_finetuned_s": round(fwd_time_ft, 1), |
| "per_group_mae": {}, |
| "per_group_rmse": {}, |
| "per_species_mae": {}, |
| "f1": None, |
| "sorensen_mean": None, |
| "sorensen_std": None, |
| } |
|
|
| |
| for gname in VARIABLE_GROUPS: |
| if gname == "species_variables": |
| continue |
| pred_dict = output_pretrained.get(gname, {}) |
| gt_dict = getattr(y_batch, gname, {}) |
| if not pred_dict or not gt_dict: |
| continue |
| var_names = list(pred_dict.keys()) |
| per_var_mae, _ = compute_mae_eq18(pred_dict, gt_dict, var_names) |
| per_var_rmse, _ = compute_rmse_eq19(pred_dict, gt_dict, var_names) |
| pair_result["per_group_mae"][gname] = per_var_mae |
| pair_result["per_group_rmse"][gname] = per_var_rmse |
| for vname, mae_val in per_var_mae.items(): |
| all_var_mae[gname][vname].append(mae_val) |
| for vname, rmse_val in per_var_rmse.items(): |
| all_var_rmse[gname][vname].append(rmse_val) |
|
|
| |
| |
| sp_pred_for_metrics = {} |
| sp_gt_for_metrics = {} |
| species_mae_results = {} |
|
|
| for idx_sp, sp_id in enumerate(sp_ids_in_order): |
| if idx_sp >= n_sp: |
| break |
| |
| pred_sp = species_pred_tensor[0, 0, idx_sp, :, :] |
|
|
| |
| gt_sp = species_gt_dict[sp_id] |
|
|
| |
| gt_h, gt_w = gt_sp.shape[1], gt_sp.shape[2] |
| pred_h, pred_w = pred_sp.shape[0], pred_sp.shape[1] |
| if pred_h != gt_h or pred_w != gt_w: |
| pred_sp = torch.nn.functional.interpolate( |
| pred_sp.unsqueeze(0).unsqueeze(0).cpu().float(), |
| size=(gt_h, gt_w), mode="bilinear", align_corners=False, |
| ).squeeze() |
|
|
| |
| sp_pred_for_metrics[sp_id] = pred_sp.unsqueeze(0).cpu().float() |
| sp_gt_for_metrics[sp_id] = gt_sp.float() |
|
|
| |
| mae = float(torch.mean(torch.abs(pred_sp.cpu().float() - gt_sp[1].float())).item()) |
| species_mae_results[sp_id] = mae |
| all_species_mae[sp_id].append(mae) |
|
|
| pair_result["per_species_mae"] = species_mae_results |
| pair_result["per_group_mae"]["species_variables"] = species_mae_results |
| pair_result["per_group_rmse"]["species_variables"] = {} |
| for sp_id in sp_ids_in_order: |
| if sp_id in sp_pred_for_metrics: |
| p = sp_pred_for_metrics[sp_id][0] |
| g = sp_gt_for_metrics[sp_id][1] |
| rmse = float(torch.sqrt(torch.mean((p - g) ** 2)).item()) |
| pair_result["per_group_rmse"]["species_variables"][sp_id] = rmse |
| all_var_rmse["species_variables"][sp_id].append(rmse) |
| all_var_mae["species_variables"][sp_id].append(species_mae_results[sp_id]) |
|
|
| |
| f1 = compute_f1_eq20(sp_pred_for_metrics, sp_gt_for_metrics, threshold=args.species_threshold) |
| pair_result["f1"] = f1 |
| if f1 is not None: |
| all_f1_scores.append(f1) |
| print(f" F1 (finetuned): {f1:.4f}") |
|
|
| |
| sor_mean, sor_std, _ = compute_sorensen_eq22( |
| sp_pred_for_metrics, sp_gt_for_metrics, threshold=args.species_threshold |
| ) |
| pair_result["sorensen_mean"] = sor_mean |
| pair_result["sorensen_std"] = sor_std |
| if sor_mean is not None: |
| all_sorensen_means.append(sor_mean) |
| print(f" Sorensen (finetuned): {sor_mean:.4f} +/- {sor_std:.4f}") |
|
|
| sp_mae_vals = list(species_mae_results.values()) |
| if sp_mae_vals: |
| print(f" Species MAE (finetuned, norm): {np.mean(sp_mae_vals):.6f}") |
|
|
| pair_details.append(pair_result) |
| save_checkpoint_result(i, pair_result) |
|
|
| total_time = time.time() - t0 |
| print(f" Total pair time: {total_time:.1f}s") |
|
|
| except Exception as e: |
| print(f" ERROR: {e}") |
| import traceback |
| traceback.print_exc() |
| skipped += 1 |
|
|
| |
| |
| |
| total_elapsed = time.time() - start_time |
| print(f"\n{'=' * 70}") |
| print(f"AGGREGATE RESULTS ({len(pair_details)}/{eval_pairs} pairs, {total_elapsed/60:.1f} min)") |
| print(f"{'=' * 70}") |
|
|
| if not pair_details: |
| print("No successful evaluations!") |
| return |
|
|
| |
| print(f"\n--- MAE (Eq.18) & RMSE (Eq.19) per variable group ---") |
| print(f"{'Group':<25s} {'MAE (norm)':>12s} {'RMSE (norm)':>12s} {'N vars':>6s}") |
| print("-" * 60) |
|
|
| group_summary = {} |
| for gname in VARIABLE_GROUPS: |
| if gname not in all_var_mae: |
| continue |
| var_maes = {v: np.mean(vals) for v, vals in all_var_mae[gname].items()} |
| var_rmses = {v: np.mean(vals) for v, vals in all_var_rmse[gname].items()} if gname in all_var_rmse else {} |
| group_mae = np.mean(list(var_maes.values())) |
| group_rmse = np.mean(list(var_rmses.values())) if var_rmses else 0.0 |
| print(f"{gname:<25s} {group_mae:>12.6f} {group_rmse:>12.6f} {len(var_maes):>6d}") |
| group_summary[gname] = { |
| "mae_normalized": float(group_mae), |
| "rmse_normalized": float(group_rmse), |
| "per_variable_mae": {v: float(m) for v, m in var_maes.items()}, |
| "per_variable_rmse": {v: float(r) for v, r in var_rmses.items()}, |
| } |
|
|
| |
| print(f"\n--- Per-species MAE (normalized) vs Paper Table 9 ---") |
| print(f"{'Species':<12s} {'Finetuned':>10s} {'Paper':>10s} {'Ratio':>8s}") |
| print("-" * 46) |
|
|
| species_comparison = {} |
| for sp_id in SPECIES_VARS: |
| if sp_id in all_species_mae: |
| our_mae = np.mean(all_species_mae[sp_id]) |
| paper_mae = PAPER_TABLE9_SPECIES_MAE.get(sp_id) |
| ratio = our_mae / paper_mae if paper_mae and paper_mae > 0 else None |
| ratio_str = f"{ratio:.2f}x" if ratio else "N/A" |
| print(f"{sp_id:<12s} {our_mae:>10.6f} {paper_mae:>10.6f} {ratio_str:>8s}") |
| species_comparison[sp_id] = { |
| "our_mae": float(our_mae), |
| "paper_mae": float(paper_mae) if paper_mae else None, |
| "ratio": float(ratio) if ratio else None, |
| } |
|
|
| if species_comparison: |
| our_mean = np.mean([v["our_mae"] for v in species_comparison.values()]) |
| paper_mean = np.mean([v["paper_mae"] for v in species_comparison.values() if v["paper_mae"]]) |
| print("-" * 46) |
| print(f"{'MEAN':<12s} {our_mean:>10.6f} {paper_mean:>10.6f} {our_mean/paper_mean:.2f}x") |
|
|
| |
| print(f"\n--- F1 Score (Eq.20) ---") |
| if all_f1_scores: |
| f1_mean = np.mean(all_f1_scores) |
| f1_std = np.std(all_f1_scores) |
| print(f" F1 (finetuned) = {f1_mean:.4f} +/- {f1_std:.4f} (N={len(all_f1_scores)} pairs)") |
| print(f" Paper Table 1 (finetuned, 500 sp): F1 = 0.9964") |
| else: |
| f1_mean = None |
| print(" No F1 scores computed") |
|
|
| |
| print(f"\n--- Sorensen Similarity (Eq.22) ---") |
| if all_sorensen_means: |
| sor_mean = np.mean(all_sorensen_means) |
| sor_std = np.std(all_sorensen_means) |
| print(f" S = {sor_mean:.4f} +/- {sor_std:.4f} (N={len(all_sorensen_means)} pairs)") |
| print(f" Paper Section 4.1: S = 0.31") |
| else: |
| sor_mean = None |
| print(" No Sorensen scores computed") |
|
|
| |
| print(f"\n{'=' * 80}") |
| print("COMPARISON: Finetuned Large vs Pretrained-only Large vs Paper") |
| print(f"{'=' * 80}") |
| print(f"{'Metric':<30s} {'Finetuned':>12s} {'Pretrained':>12s} {'Paper':>12s}") |
| print("-" * 80) |
| if all_f1_scores: |
| print(f"{'F1 (Eq.20)':<30s} {np.mean(all_f1_scores):>12.4f} {'N/A':>12s} {'0.9964':>12s}") |
| if all_sorensen_means: |
| print(f"{'Sorensen S (Eq.22)':<30s} {np.mean(all_sorensen_means):>12.4f} {'N/A':>12s} {'0.31':>12s}") |
| if species_comparison: |
| print(f"{'Mean species MAE (norm)':<30s} {our_mean:>12.6f} {'N/A':>12s} {'0.083602':>12s}") |
|
|
| |
| RESULTS_DIR.mkdir(parents=True, exist_ok=True) |
| results = { |
| "model": "BFM-Finetuned-Large-GeoLifeCLEF", |
| "config": MODEL_CONFIG, |
| "finetune_checkpoint": str(ckpt_path), |
| "finetune_epoch": ckpt.get("epoch"), |
| "finetune_loss": float(ckpt.get("loss", 0)), |
| "n_pairs_evaluated": len(pair_details), |
| "n_pairs_total": total_pairs, |
| "skipped": skipped, |
| "total_time_minutes": round(total_elapsed / 60, 1), |
| "species_threshold": args.species_threshold, |
| "aggregate": { |
| "f1_mean": float(np.mean(all_f1_scores)) if all_f1_scores else None, |
| "f1_std": float(np.std(all_f1_scores)) if all_f1_scores else None, |
| "sorensen_mean": float(np.mean(all_sorensen_means)) if all_sorensen_means else None, |
| "sorensen_std": float(np.std(all_sorensen_means)) if all_sorensen_means else None, |
| }, |
| "per_group_summary": group_summary, |
| "species_comparison_with_table9": species_comparison, |
| "paper_reference": { |
| "table1_f1": 0.9964, |
| "section41_sorensen": 0.31, |
| }, |
| } |
|
|
| results_path = RESULTS_DIR / "bfm_finetuned_eval_large_results.json" |
| with open(results_path, "w") as f: |
| json.dump(results, f, indent=2, default=str) |
| print(f"\nResults saved to {results_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|