bfm-finetuned-large / eval_finetuned_large.py
YoshimuraHiroto's picture
Upload eval_finetuned_large.py with huggingface_hub
4ebca3d verified
"""
Re-evaluate BFM Large model AFTER GeoLifeCLEF finetuning on BioCube 28 species.
This script:
1. Builds the BFM Large base model
2. Wraps with BFMRaw (same as finetuning)
3. Loads finetuned checkpoint
4. Evaluates on BioCube test data using paper metrics (MAE, RMSE, F1, Sorensen)
Usage:
conda run -n bfm python eval_finetuned_large.py 2>&1 | tee eval_finetuned_large.log
"""
import argparse
import json
import os
import sys
import time
import warnings
from collections import defaultdict
from pathlib import Path
import numpy as np
import torch
warnings.filterwarnings("ignore")
PROJECT_ROOT = Path(__file__).resolve().parent
BFM_MODEL_DIR = PROJECT_ROOT / "bfm-model"
sys.path.insert(0, str(BFM_MODEL_DIR))
SAFETENSORS_PATH = PROJECT_ROOT / "bfm-pretrained" / "bfm-pretrain-large.safetensors"
FINETUNE_CHECKPOINT = PROJECT_ROOT / "outputs_finetune_large" / "checkpoints" / "best_checkpoint.pth"
STATS_PATH = BFM_MODEL_DIR / "batch_statistics" / "monthly_batches_stats_splitted_channels.json"
TEST_DATA_DIR = PROJECT_ROOT / "dataset" / "test"
RESULTS_DIR = PROJECT_ROOT / "env_sdm_results"
CHECKPOINT_DIR = RESULTS_DIR / "finetuned_eval_checkpoints"
# Number of GeoLifeCLEF species the model was finetuned on
NUM_FINETUNE_SPECIES = 500
# ─── Model config (Large) ───
MODEL_CONFIG = {
"embed_dim": 512, "depth": 10, "patch_size": 8,
"swin_backbone_size": "large", "num_heads": 16, "head_dim": 64,
"H": 160, "W": 280, "num_latent_tokens": 8,
"perceiver_latents": 16100, "T": 2,
}
SWIN_LARGE_CONFIG = {
"swin_encoder_depths": (2, 2, 2), "swin_encoder_num_heads": (8, 16, 32),
"swin_decoder_depths": (2, 2, 2), "swin_decoder_num_heads": (32, 16, 8),
"swin_window_size": (1, 4, 5), "swin_mlp_ratio": 4.0,
"swin_qkv_bias": True, "swin_drop_rate": 0.0,
"swin_attn_drop_rate": 0.0, "swin_drop_path_rate": 0.1,
"use_lora": False,
}
# ─── Variable definitions ───
SPECIES_VARS = [
"1340361", "1340503", "1536449", "1898286", "1920506", "2430567",
"2431885", "2433433", "2434779", "2435240", "2435261", "2437394",
"2441454", "2473958", "2491534", "2891770", "3034825", "4408498",
"5218786", "5219073", "5219173", "5219219", "5844449", "8002952",
"8077224", "8894817", "8909809", "9809229",
]
SURFACE_VARS = ["t2m", "msl", "slt", "z", "u10", "v10", "lsm"]
EDAPHIC_VARS = ["swvl1", "swvl2", "stl1", "stl2"]
ATMOS_VARS = ["z", "t", "u", "v", "q"]
CLIMATE_VARS = [
"smlt", "tp", "csfr", "avg_sdswrf", "avg_snswrf", "avg_snlwrf",
"avg_tprate", "avg_sdswrfcs", "sd", "t2m", "d2m",
]
VEGETATION_VARS = ["NDVI"]
LAND_VARS = ["Land"]
AGRICULTURE_VARS = ["Agriculture", "Arable", "Cropland"]
FOREST_VARS = ["Forest"]
REDLIST_VARS = ["RLI"]
MISC_VARS = ["avg_slhtf", "avg_pevr"]
ATMOS_LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
VARIABLE_GROUPS = [
"surface_variables", "edaphic_variables", "atmospheric_variables",
"climate_variables", "species_variables", "vegetation_variables",
"land_variables", "agriculture_variables", "forest_variables",
"redlist_variables", "misc_variables",
]
# Paper Table 9 reference values
PAPER_TABLE9_SPECIES_MAE = {
"1340361": 0.000064, "1340503": 0.064850, "1536449": 0.034896,
"1898286": 0.359776, "1920506": 0.272244, "2430567": 0.001872,
"2431885": 0.012577, "2433433": 0.001001, "2434779": 0.000018,
"2435240": 0.005042, "2435261": 0.000074, "2437394": 0.000220,
"2441454": 0.005205, "2473958": 0.051274, "2491534": 0.294113,
"2891770": 0.046475, "3034825": 0.023135, "4408498": 0.016373,
"5218786": 0.007713, "5219073": 0.000831, "5219173": 0.006481,
"5219219": 0.000296, "5844449": 0.006944, "8002952": 0.015838,
"8077224": 0.361090, "8894817": 0.000336, "8909809": 0.006676,
"9809229": 0.745430,
}
def build_base_model():
"""Build BFM Large model."""
from bfm_model.bfm.model import BFM
model = BFM(
surface_vars=SURFACE_VARS, edaphic_vars=EDAPHIC_VARS,
atmos_vars=ATMOS_VARS, climate_vars=CLIMATE_VARS,
species_vars=SPECIES_VARS, vegetation_vars=VEGETATION_VARS,
land_vars=LAND_VARS, agriculture_vars=AGRICULTURE_VARS,
forest_vars=FOREST_VARS, redlist_vars=REDLIST_VARS,
misc_vars=MISC_VARS, atmos_levels=ATMOS_LEVELS,
species_num=len(SPECIES_VARS),
H=MODEL_CONFIG["H"], W=MODEL_CONFIG["W"],
num_latent_tokens=MODEL_CONFIG["num_latent_tokens"],
backbone_type="swin", patch_size=MODEL_CONFIG["patch_size"],
embed_dim=MODEL_CONFIG["embed_dim"],
num_heads=MODEL_CONFIG["num_heads"],
head_dim=MODEL_CONFIG["head_dim"],
depth=MODEL_CONFIG["depth"],
perceiver_latents=MODEL_CONFIG["perceiver_latents"],
batch_size=1, td_learning=True, use_mask="no",
**SWIN_LARGE_CONFIG,
)
return model
def load_and_process_batch(fpath, stats, patch_size=8):
"""Load a raw .pt batch file and process into Batch namedtuple."""
from datetime import datetime
from bfm_model.bfm.dataloader_monthly import (
Batch, Metadata, crop_variables, extract_atmospheric_levels, normalize_keys,
)
from bfm_model.bfm.scaler import _rescale_recursive, dimensions_to_keep_monthly
data = torch.load(fpath, map_location="cpu", weights_only=False)
latitudes = data["batch_metadata"]["latitudes"]
longitudes = data["batch_metadata"]["longitudes"]
timestamps = data["batch_metadata"]["timestamp"]
pressure_levels = data["batch_metadata"]["pressure_levels"]
species_list = data["batch_metadata"]["species_list"]
H, W = len(latitudes), len(longitudes)
new_H = (H // patch_size) * patch_size
new_W = (W // patch_size) * patch_size
raw_species = {}
if "species_variables" in data:
for sp_id, sp_data in data["species_variables"].items():
raw_species[sp_id] = sp_data[:, :new_H, :new_W].clone()
_rescale_recursive(
data, stats,
dimensions_to_keep_by_key=dimensions_to_keep_monthly,
mode="normalize", direction="scaled",
)
surface_vars = crop_variables(data["surface_variables"], new_H, new_W)
edaphic_vars = crop_variables(data["edaphic_variables"], new_H, new_W)
atmospheric_vars = crop_variables(data["atmospheric_variables"], new_H, new_W)
climate_vars = crop_variables(data["climate_variables"], new_H, new_W)
species_vars = crop_variables(data["species_variables"], new_H, new_W)
land_vars = crop_variables(data["land_variables"], new_H, new_W)
agriculture_vars = crop_variables(data["agriculture_variables"], new_H, new_W)
forest_vars = crop_variables(data["forest_variables"], new_H, new_W)
vegetation_vars = crop_variables(
data["vegetation_variables"], new_H, new_W, handle_nans=True, nan_mode="zero"
)
redlist_vars = crop_variables(data["redlist_variables"], new_H, new_W)
misc_vars = crop_variables(data["misc_variables"], new_H, new_W)
latitude_var = torch.tensor(latitudes[:new_H])
longitude_var = torch.tensor(longitudes[:new_W])
dt_format = "%Y-%m-%d %H:%M:%S"
start = datetime.strptime(timestamps[0], dt_format)
end = datetime.strptime(timestamps[1], dt_format)
lead_months = (end.year - start.year) * 12 + (end.month - start.month) + 1
atmospheric_vars = extract_atmospheric_levels(
atmospheric_vars, pressure_levels, ATMOS_LEVELS, level_dim=1
)
metadata = Metadata(
latitudes=latitude_var, longitudes=longitude_var,
timestamp=timestamps, lead_time=lead_months,
pressure_levels=pressure_levels, species_list=species_list,
)
species_vars = normalize_keys(species_vars)
batch = Batch(
batch_metadata=metadata,
surface_variables=surface_vars, edaphic_variables=edaphic_vars,
atmospheric_variables=atmospheric_vars, climate_variables=climate_vars,
species_variables=species_vars, vegetation_variables=vegetation_vars,
land_variables=land_vars, agriculture_variables=agriculture_vars,
forest_variables=forest_vars, redlist_variables=redlist_vars,
misc_variables=misc_vars,
)
return batch, raw_species
def add_batch_dim(batch):
"""Add batch dimension (B=1)."""
from bfm_model.bfm.dataloader_monthly import Batch, Metadata
def add_dim(var_dict):
return {k: v.unsqueeze(0) for k, v in var_dict.items()}
new_md = Metadata(
latitudes=batch.batch_metadata.latitudes.unsqueeze(0),
longitudes=batch.batch_metadata.longitudes.unsqueeze(0),
timestamp=[batch.batch_metadata.timestamp],
lead_time=batch.batch_metadata.lead_time,
pressure_levels=batch.batch_metadata.pressure_levels,
species_list=batch.batch_metadata.species_list,
)
return Batch(
batch_metadata=new_md,
surface_variables=add_dim(batch.surface_variables),
edaphic_variables=add_dim(batch.edaphic_variables),
atmospheric_variables=add_dim(batch.atmospheric_variables),
climate_variables=add_dim(batch.climate_variables),
species_variables=add_dim(batch.species_variables),
vegetation_variables=add_dim(batch.vegetation_variables),
land_variables=add_dim(batch.land_variables),
agriculture_variables=add_dim(batch.agriculture_variables),
forest_variables=add_dim(batch.forest_variables),
redlist_variables=add_dim(batch.redlist_variables),
misc_variables=add_dim(batch.misc_variables),
)
def ensure_f32(batch):
"""Convert float64 tensors to float32."""
from bfm_model.bfm.dataloader_monthly import Batch
def conv(d):
return {k: v.float() if v.dtype == torch.float64 else v for k, v in d.items()}
return Batch(
batch_metadata=batch.batch_metadata,
surface_variables=conv(batch.surface_variables),
edaphic_variables=conv(batch.edaphic_variables),
atmospheric_variables=conv(batch.atmospheric_variables),
climate_variables=conv(batch.climate_variables),
species_variables=conv(batch.species_variables),
vegetation_variables=conv(batch.vegetation_variables),
land_variables=conv(batch.land_variables),
agriculture_variables=conv(batch.agriculture_variables),
forest_variables=conv(batch.forest_variables),
redlist_variables=conv(batch.redlist_variables),
misc_variables=conv(batch.misc_variables),
)
# ─── Paper Metric Functions ───
def compute_mae_eq18(pred_dict, gt_dict, var_names):
per_var_mae = {}
for vname in var_names:
if vname not in pred_dict or vname not in gt_dict:
continue
p = pred_dict[vname][0].cpu().float()
g = gt_dict[vname][1].float()
mae = float(torch.mean(torch.abs(p - g)).item())
per_var_mae[vname] = mae
avg_mae = float(np.mean(list(per_var_mae.values()))) if per_var_mae else None
return per_var_mae, avg_mae
def compute_rmse_eq19(pred_dict, gt_dict, var_names):
per_var_rmse = {}
for vname in var_names:
if vname not in pred_dict or vname not in gt_dict:
continue
p = pred_dict[vname][0].cpu().float()
g = gt_dict[vname][1].float()
rmse = float(torch.sqrt(torch.mean((p - g) ** 2)).item())
per_var_rmse[vname] = rmse
avg_rmse = float(np.mean(list(per_var_rmse.values()))) if per_var_rmse else None
return per_var_rmse, avg_rmse
def compute_f1_eq20(species_pred, species_gt, threshold=0.0):
sp_ids = sorted(set(species_pred.keys()) & set(species_gt.keys()))
if not sp_ids:
return None
pred_stack = torch.stack([species_pred[sp][0].cpu().float() for sp in sp_ids])
gt_stack = torch.stack([species_gt[sp][1].float() for sp in sp_ids])
pred_bin = (pred_stack > threshold).int()
gt_bin = (gt_stack > 0).int()
tp = (pred_bin * gt_bin).sum(dim=0).float()
fp = (pred_bin * (1 - gt_bin)).sum(dim=0).float()
fn = ((1 - pred_bin) * gt_bin).sum(dim=0).float()
denom = tp + (fp + fn) / 2.0
f1_map = torch.where(denom > 0, tp / denom, torch.ones_like(tp))
return float(f1_map.mean().item())
def compute_sorensen_eq22(species_pred, species_gt, threshold=0.0):
sp_ids = sorted(set(species_pred.keys()) & set(species_gt.keys()))
if not sp_ids:
return None, None, None
pred_stack = torch.stack([species_pred[sp][0].cpu().float() for sp in sp_ids])
gt_stack = torch.stack([species_gt[sp][1].float() for sp in sp_ids])
pred_bin = (pred_stack > threshold).int()
gt_bin = (gt_stack > 0).int()
c = (pred_bin * gt_bin).sum(dim=0).float()
b = ((1 - pred_bin) * gt_bin).sum(dim=0).float()
d = (pred_bin * (1 - gt_bin)).sum(dim=0).float()
denom = 2 * c + b + d
sorensen_map = torch.where(denom > 0, 2 * c / denom, torch.full_like(c, float("nan")))
valid = sorensen_map[~torch.isnan(sorensen_map)]
if len(valid) > 0:
return float(valid.mean().item()), float(valid.std().item()), sorensen_map.numpy()
return None, None, None
def load_checkpoint_results():
results = {}
if CHECKPOINT_DIR.exists():
for f in sorted(CHECKPOINT_DIR.glob("pair_*.json")):
with open(f) as fh:
data = json.load(fh)
results[data["pair_index"]] = data
return results
def save_checkpoint_result(pair_idx, result):
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
path = CHECKPOINT_DIR / f"pair_{pair_idx:04d}.json"
with open(path, "w") as f:
json.dump(result, f, indent=2, default=str)
def main():
parser = argparse.ArgumentParser(description="Evaluate Finetuned BFM Large on BioCube")
parser.add_argument("--max-pairs", type=int, default=None)
parser.add_argument("--resume", action="store_true")
parser.add_argument("--species-threshold", type=float, default=0.0)
parser.add_argument("--checkpoint", type=str, default=None,
help="Path to finetuned checkpoint (default: outputs_finetune_large/checkpoints/best_checkpoint.pth)")
args = parser.parse_args()
ckpt_path = Path(args.checkpoint) if args.checkpoint else FINETUNE_CHECKPOINT
print("=" * 70)
print("Finetuned BFM Large - BioCube 28 Species Re-evaluation")
print("=" * 70)
print(f"Finetuned checkpoint: {ckpt_path}")
print(f"Species threshold: {args.species_threshold}")
print("=" * 70)
assert SAFETENSORS_PATH.exists(), f"Pretrained weights not found: {SAFETENSORS_PATH}"
assert ckpt_path.exists(), f"Finetuned checkpoint not found: {ckpt_path}"
assert STATS_PATH.exists(), f"Stats not found: {STATS_PATH}"
assert TEST_DATA_DIR.exists(), f"Test data not found: {TEST_DATA_DIR}"
with open(STATS_PATH) as f:
stats = json.load(f)
existing_results = {}
if args.resume:
existing_results = load_checkpoint_results()
print(f"Resume: found {len(existing_results)} completed pairs")
# 1. Build base model and load pretrained weights
print("\nBuilding BFM Large model...")
from safetensors.torch import load_file
base_model = build_base_model()
total_params = sum(p.numel() for p in base_model.parameters())
print(f"Base model parameters: {total_params / 1e6:.1f}M")
print(f"Loading pretrained weights from {SAFETENSORS_PATH.name}...")
state = load_file(str(SAFETENSORS_PATH), device="cpu")
base_model.load_state_dict(state, strict=False)
# 2. Wrap with BFMRaw and load finetuned weights
print("Wrapping with BFMRaw...")
from bfm_finetune.bfm_mod import BFMRaw
model = BFMRaw(base_model=base_model, n_species=NUM_FINETUNE_SPECIES, mode="eval")
print(f"Loading finetuned weights from {ckpt_path.name}...")
ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=True)
missing, unexpected = model.load_state_dict(ckpt["model_state_dict"], strict=False)
print(f" Finetuned checkpoint: epoch={ckpt.get('epoch', '?')}, loss={ckpt.get('loss', '?')}")
print(f" Missing keys: {len(missing)}, Unexpected: {len(unexpected)}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = model.to(device)
model.eval()
# The finetuned model processes species through BFMRaw, which uses a different
# encoder/decoder than the base BFM. For evaluation on BioCube test data,
# we need to:
# a) Use the full BFM pipeline for non-species variables (the base_model handles these)
# b) For species, use the finetuned BFMRaw path
#
# However, BFMRaw replaces the base model's encoder/decoder with nn.Identity(),
# so the full BFM forward pass no longer works for non-species variables.
# We evaluate ONLY species variables through the finetuned model.
#
# For a fair comparison, we also build a separate pretrained-only BFM model
# for non-species variables.
print("\nBuilding separate pretrained BFM for non-species evaluation...")
pretrained_model = build_base_model()
pretrained_model.load_state_dict(state, strict=False)
pretrained_model = pretrained_model.to(device)
pretrained_model.eval()
# List test files
test_files = sorted(
[f for f in TEST_DATA_DIR.iterdir() if f.suffix == ".pt"],
key=lambda x: x.name,
)
total_pairs = len(test_files) - 1
max_pairs = args.max_pairs if args.max_pairs else total_pairs
eval_pairs = min(max_pairs, total_pairs)
print(f"\nTest files: {len(test_files)}, pairs: {total_pairs}, evaluating: {eval_pairs}")
# ─── Accumulators ───
all_var_mae = defaultdict(lambda: defaultdict(list))
all_var_rmse = defaultdict(lambda: defaultdict(list))
all_species_mae = defaultdict(list)
all_f1_scores = []
all_sorensen_means = []
pair_details = []
skipped = 0
start_time = time.time()
for idx, result in sorted(existing_results.items()):
if idx >= eval_pairs:
continue
for gname, gdata in result.get("per_group_mae", {}).items():
for vname, mae_val in gdata.items():
all_var_mae[gname][vname].append(mae_val)
for gname, gdata in result.get("per_group_rmse", {}).items():
for vname, rmse_val in gdata.items():
all_var_rmse[gname][vname].append(rmse_val)
for sp_id, mae_val in result.get("per_species_mae", {}).items():
all_species_mae[sp_id].append(mae_val)
if result.get("f1") is not None:
all_f1_scores.append(result["f1"])
if result.get("sorensen_mean") is not None:
all_sorensen_means.append(result["sorensen_mean"])
pair_details.append(result)
# ─── Evaluation loop ───
for i in range(eval_pairs):
if i in existing_results:
continue
elapsed_total = time.time() - start_time
remaining = (elapsed_total / max(len(pair_details), 1)) * (eval_pairs - len(pair_details))
print(f"\n--- Pair {i + 1}/{eval_pairs} "
f"[done={len(pair_details)}, ETA={remaining/60:.0f}m] ---")
print(f" {test_files[i].name} -> {test_files[i + 1].name}")
try:
t0 = time.time()
x_batch, _ = load_and_process_batch(str(test_files[i]), stats, MODEL_CONFIG["patch_size"])
y_batch, _ = load_and_process_batch(str(test_files[i + 1]), stats, MODEL_CONFIG["patch_size"])
# ─── Pretrained model: full forward pass for non-species metrics ───
x_batched = add_batch_dim(x_batch)
from bfm_model.bfm.dataloader_monthly import batch_to_device
x_batched_dev = batch_to_device(x_batched, device)
x_batched_dev = ensure_f32(x_batched_dev)
t_fwd = time.time()
with torch.no_grad():
output_pretrained = pretrained_model(
x_batched_dev, lead_time=x_batch.batch_metadata.lead_time, batch_size=1
)
fwd_time_pretrained = time.time() - t_fwd
# ─── Finetuned model: species-only forward pass ───
# Prepare species input for BFMRaw: needs {"species_distribution": [B, T, C, H, W]}
species_gt_dict = y_batch.species_variables
species_x_dict = x_batch.species_variables
# Stack all 28 BioCube species into [T, 28, H, W] tensor
sp_ids_in_order = sorted(species_x_dict.keys())
# The finetuned model handles 500 species but we only have 28.
# We need to place our 28 species into the first 28 channels of the 500-channel input.
H, W = x_batch.species_variables[sp_ids_in_order[0]].shape[1:]
T = 2
n_sp = NUM_FINETUNE_SPECIES
species_tensor = torch.zeros(T, n_sp, H, W)
for idx_sp, sp_id in enumerate(sp_ids_in_order):
if idx_sp < n_sp:
species_tensor[:, idx_sp, :, :] = species_x_dict[sp_id]
species_batch = {
"species_distribution": species_tensor.unsqueeze(0).to(device), # [1, T, 500, H, W]
}
t_fwd2 = time.time()
with torch.no_grad():
species_output = model(species_batch)
fwd_time_ft = time.time() - t_fwd2
# species_output shape: (recon, encoded, feats) in eval mode
species_pred_tensor = species_output[0] # [B, 1, 500, H_out, W_out]
print(f" Forward (pretrained): {fwd_time_pretrained:.1f}s, (finetuned): {fwd_time_ft:.1f}s")
pair_result = {
"pair_index": i,
"input_file": test_files[i].name,
"target_file": test_files[i + 1].name,
"forward_time_pretrained_s": round(fwd_time_pretrained, 1),
"forward_time_finetuned_s": round(fwd_time_ft, 1),
"per_group_mae": {},
"per_group_rmse": {},
"per_species_mae": {},
"f1": None,
"sorensen_mean": None,
"sorensen_std": None,
}
# ─── Non-species metrics (from pretrained model) ───
for gname in VARIABLE_GROUPS:
if gname == "species_variables":
continue
pred_dict = output_pretrained.get(gname, {})
gt_dict = getattr(y_batch, gname, {})
if not pred_dict or not gt_dict:
continue
var_names = list(pred_dict.keys())
per_var_mae, _ = compute_mae_eq18(pred_dict, gt_dict, var_names)
per_var_rmse, _ = compute_rmse_eq19(pred_dict, gt_dict, var_names)
pair_result["per_group_mae"][gname] = per_var_mae
pair_result["per_group_rmse"][gname] = per_var_rmse
for vname, mae_val in per_var_mae.items():
all_var_mae[gname][vname].append(mae_val)
for vname, rmse_val in per_var_rmse.items():
all_var_rmse[gname][vname].append(rmse_val)
# ─── Species metrics (from finetuned model) ───
# Reconstruct per-species predictions from the finetuned output
sp_pred_for_metrics = {}
sp_gt_for_metrics = {}
species_mae_results = {}
for idx_sp, sp_id in enumerate(sp_ids_in_order):
if idx_sp >= n_sp:
break
# Predicted: [B, 1, n_sp, H_out, W_out] -> get species idx_sp
pred_sp = species_pred_tensor[0, 0, idx_sp, :, :] # [H_out, W_out]
# Ground truth: [T, H, W]
gt_sp = species_gt_dict[sp_id] # [T, H, W], we want T=1
# Resize prediction to match GT spatial dims if needed
gt_h, gt_w = gt_sp.shape[1], gt_sp.shape[2]
pred_h, pred_w = pred_sp.shape[0], pred_sp.shape[1]
if pred_h != gt_h or pred_w != gt_w:
pred_sp = torch.nn.functional.interpolate(
pred_sp.unsqueeze(0).unsqueeze(0).cpu().float(),
size=(gt_h, gt_w), mode="bilinear", align_corners=False,
).squeeze()
# Package for metric functions: pred as [1, H, W], gt as [T, H, W]
sp_pred_for_metrics[sp_id] = pred_sp.unsqueeze(0).cpu().float()
sp_gt_for_metrics[sp_id] = gt_sp.float()
# Per-species MAE
mae = float(torch.mean(torch.abs(pred_sp.cpu().float() - gt_sp[1].float())).item())
species_mae_results[sp_id] = mae
all_species_mae[sp_id].append(mae)
pair_result["per_species_mae"] = species_mae_results
pair_result["per_group_mae"]["species_variables"] = species_mae_results
pair_result["per_group_rmse"]["species_variables"] = {}
for sp_id in sp_ids_in_order:
if sp_id in sp_pred_for_metrics:
p = sp_pred_for_metrics[sp_id][0]
g = sp_gt_for_metrics[sp_id][1]
rmse = float(torch.sqrt(torch.mean((p - g) ** 2)).item())
pair_result["per_group_rmse"]["species_variables"][sp_id] = rmse
all_var_rmse["species_variables"][sp_id].append(rmse)
all_var_mae["species_variables"][sp_id].append(species_mae_results[sp_id])
# F1 score
f1 = compute_f1_eq20(sp_pred_for_metrics, sp_gt_for_metrics, threshold=args.species_threshold)
pair_result["f1"] = f1
if f1 is not None:
all_f1_scores.append(f1)
print(f" F1 (finetuned): {f1:.4f}")
# Sorensen
sor_mean, sor_std, _ = compute_sorensen_eq22(
sp_pred_for_metrics, sp_gt_for_metrics, threshold=args.species_threshold
)
pair_result["sorensen_mean"] = sor_mean
pair_result["sorensen_std"] = sor_std
if sor_mean is not None:
all_sorensen_means.append(sor_mean)
print(f" Sorensen (finetuned): {sor_mean:.4f} +/- {sor_std:.4f}")
sp_mae_vals = list(species_mae_results.values())
if sp_mae_vals:
print(f" Species MAE (finetuned, norm): {np.mean(sp_mae_vals):.6f}")
pair_details.append(pair_result)
save_checkpoint_result(i, pair_result)
total_time = time.time() - t0
print(f" Total pair time: {total_time:.1f}s")
except Exception as e:
print(f" ERROR: {e}")
import traceback
traceback.print_exc()
skipped += 1
# ═══════════════════════════════════════════════════════════
# AGGREGATE RESULTS
# ═══════════════════════════════════════════════════════════
total_elapsed = time.time() - start_time
print(f"\n{'=' * 70}")
print(f"AGGREGATE RESULTS ({len(pair_details)}/{eval_pairs} pairs, {total_elapsed/60:.1f} min)")
print(f"{'=' * 70}")
if not pair_details:
print("No successful evaluations!")
return
# Per-variable-group MAE/RMSE
print(f"\n--- MAE (Eq.18) & RMSE (Eq.19) per variable group ---")
print(f"{'Group':<25s} {'MAE (norm)':>12s} {'RMSE (norm)':>12s} {'N vars':>6s}")
print("-" * 60)
group_summary = {}
for gname in VARIABLE_GROUPS:
if gname not in all_var_mae:
continue
var_maes = {v: np.mean(vals) for v, vals in all_var_mae[gname].items()}
var_rmses = {v: np.mean(vals) for v, vals in all_var_rmse[gname].items()} if gname in all_var_rmse else {}
group_mae = np.mean(list(var_maes.values()))
group_rmse = np.mean(list(var_rmses.values())) if var_rmses else 0.0
print(f"{gname:<25s} {group_mae:>12.6f} {group_rmse:>12.6f} {len(var_maes):>6d}")
group_summary[gname] = {
"mae_normalized": float(group_mae),
"rmse_normalized": float(group_rmse),
"per_variable_mae": {v: float(m) for v, m in var_maes.items()},
"per_variable_rmse": {v: float(r) for v, r in var_rmses.items()},
}
# Per-species MAE comparison
print(f"\n--- Per-species MAE (normalized) vs Paper Table 9 ---")
print(f"{'Species':<12s} {'Finetuned':>10s} {'Paper':>10s} {'Ratio':>8s}")
print("-" * 46)
species_comparison = {}
for sp_id in SPECIES_VARS:
if sp_id in all_species_mae:
our_mae = np.mean(all_species_mae[sp_id])
paper_mae = PAPER_TABLE9_SPECIES_MAE.get(sp_id)
ratio = our_mae / paper_mae if paper_mae and paper_mae > 0 else None
ratio_str = f"{ratio:.2f}x" if ratio else "N/A"
print(f"{sp_id:<12s} {our_mae:>10.6f} {paper_mae:>10.6f} {ratio_str:>8s}")
species_comparison[sp_id] = {
"our_mae": float(our_mae),
"paper_mae": float(paper_mae) if paper_mae else None,
"ratio": float(ratio) if ratio else None,
}
if species_comparison:
our_mean = np.mean([v["our_mae"] for v in species_comparison.values()])
paper_mean = np.mean([v["paper_mae"] for v in species_comparison.values() if v["paper_mae"]])
print("-" * 46)
print(f"{'MEAN':<12s} {our_mean:>10.6f} {paper_mean:>10.6f} {our_mean/paper_mean:.2f}x")
# F1
print(f"\n--- F1 Score (Eq.20) ---")
if all_f1_scores:
f1_mean = np.mean(all_f1_scores)
f1_std = np.std(all_f1_scores)
print(f" F1 (finetuned) = {f1_mean:.4f} +/- {f1_std:.4f} (N={len(all_f1_scores)} pairs)")
print(f" Paper Table 1 (finetuned, 500 sp): F1 = 0.9964")
else:
f1_mean = None
print(" No F1 scores computed")
# Sorensen
print(f"\n--- Sorensen Similarity (Eq.22) ---")
if all_sorensen_means:
sor_mean = np.mean(all_sorensen_means)
sor_std = np.std(all_sorensen_means)
print(f" S = {sor_mean:.4f} +/- {sor_std:.4f} (N={len(all_sorensen_means)} pairs)")
print(f" Paper Section 4.1: S = 0.31")
else:
sor_mean = None
print(" No Sorensen scores computed")
# Comparison table
print(f"\n{'=' * 80}")
print("COMPARISON: Finetuned Large vs Pretrained-only Large vs Paper")
print(f"{'=' * 80}")
print(f"{'Metric':<30s} {'Finetuned':>12s} {'Pretrained':>12s} {'Paper':>12s}")
print("-" * 80)
if all_f1_scores:
print(f"{'F1 (Eq.20)':<30s} {np.mean(all_f1_scores):>12.4f} {'N/A':>12s} {'0.9964':>12s}")
if all_sorensen_means:
print(f"{'Sorensen S (Eq.22)':<30s} {np.mean(all_sorensen_means):>12.4f} {'N/A':>12s} {'0.31':>12s}")
if species_comparison:
print(f"{'Mean species MAE (norm)':<30s} {our_mean:>12.6f} {'N/A':>12s} {'0.083602':>12s}")
# Save results
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
results = {
"model": "BFM-Finetuned-Large-GeoLifeCLEF",
"config": MODEL_CONFIG,
"finetune_checkpoint": str(ckpt_path),
"finetune_epoch": ckpt.get("epoch"),
"finetune_loss": float(ckpt.get("loss", 0)),
"n_pairs_evaluated": len(pair_details),
"n_pairs_total": total_pairs,
"skipped": skipped,
"total_time_minutes": round(total_elapsed / 60, 1),
"species_threshold": args.species_threshold,
"aggregate": {
"f1_mean": float(np.mean(all_f1_scores)) if all_f1_scores else None,
"f1_std": float(np.std(all_f1_scores)) if all_f1_scores else None,
"sorensen_mean": float(np.mean(all_sorensen_means)) if all_sorensen_means else None,
"sorensen_std": float(np.std(all_sorensen_means)) if all_sorensen_means else None,
},
"per_group_summary": group_summary,
"species_comparison_with_table9": species_comparison,
"paper_reference": {
"table1_f1": 0.9964,
"section41_sorensen": 0.31,
},
}
results_path = RESULTS_DIR / "bfm_finetuned_eval_large_results.json"
with open(results_path, "w") as f:
json.dump(results, f, indent=2, default=str)
print(f"\nResults saved to {results_path}")
if __name__ == "__main__":
main()