Wildfire-FM / experiments /raw_reference /task_scripts /run_event_analog_taskmodel_seeded.py
yx21e's picture
Initial FireWx-FM artifact release
80ef3b2 verified
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Dict, List, Tuple
import os
for _p in os.environ.get("WILDFIRE_FM_EXTRA_PYTHONPATH", "").split(os.pathsep):
if _p and _p not in sys.path:
sys.path.insert(0, _p)
import faiss
import hnswlib
import numpy as np
import pandas as pd
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
DROP_COLUMNS = {
"Event_ID",
"Incid_Name",
"incident_name_norm",
"wfigs_name",
"Ig_Date",
"weather_date",
"BurnBndAc",
"target_log_burn_acres",
}
CATEGORICAL_COLUMNS = ["Incid_Type", "state_abbr", "county_name", "wfigs_match_type"]
def rmse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
return float(np.sqrt(np.mean((np.asarray(y_true) - np.asarray(y_pred)) ** 2)))
def mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
denom = np.clip(np.asarray(y_true, dtype=np.float64), 1e-6, None)
frac = np.abs(np.asarray(y_true, dtype=np.float64) - np.asarray(y_pred, dtype=np.float64)) / denom
return float(np.mean(frac))
def r2_score_manual(y_true: np.ndarray, y_pred: np.ndarray) -> float:
y_true = np.asarray(y_true, dtype=np.float64)
y_pred = np.asarray(y_pred, dtype=np.float64)
ss_res = float(np.sum((y_true - y_pred) ** 2))
ss_tot = float(np.sum((y_true - y_true.mean()) ** 2))
return float(1.0 - ss_res / ss_tot) if ss_tot > 0 else 0.0
def spearman_corr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
a = pd.Series(np.asarray(y_true, dtype=np.float64))
b = pd.Series(np.asarray(y_pred, dtype=np.float64))
value = a.corr(b, method="spearman")
return float(value) if pd.notna(value) else 0.0
def build_splits(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
ordered = df.sort_values("Ig_Date").reset_index(drop=True)
n = len(ordered)
train_end = max(int(round(n * 0.6)), 1)
val_end = max(int(round(n * 0.8)), train_end + 1)
val_end = min(val_end, n - 1) if n >= 3 else n
train = ordered.iloc[:train_end].copy()
val = ordered.iloc[train_end:val_end].copy()
test = ordered.iloc[val_end:].copy()
if len(val) == 0 and len(test) > 1:
val = test.iloc[:1].copy()
test = test.iloc[1:].copy()
return train, val, test
def feature_columns(df: pd.DataFrame, feature_profile: str = "all") -> Tuple[List[str], List[str]]:
categorical = [c for c in CATEGORICAL_COLUMNS if c in df.columns]
numeric = []
for col in df.columns:
if col in DROP_COLUMNS or col in categorical:
continue
if pd.api.types.is_numeric_dtype(df[col]):
numeric.append(col)
if feature_profile == "weather_fm":
numeric = [c for c in numeric if c.startswith("weather_")]
categorical = []
return numeric, categorical
def make_preprocessor(numeric_cols: List[str], categorical_cols: List[str]) -> ColumnTransformer:
return ColumnTransformer(
transformers=[
(
"num",
Pipeline(
steps=[
("impute", SimpleImputer(strategy="median")),
("scale", StandardScaler()),
]
),
numeric_cols,
),
(
"cat",
Pipeline(
steps=[
("impute", SimpleImputer(strategy="most_frequent")),
("onehot", OneHotEncoder(handle_unknown="ignore")),
]
),
categorical_cols,
),
],
remainder="drop",
)
def to_dense_float32(x) -> np.ndarray:
if hasattr(x, "toarray"):
x = x.toarray()
return np.asarray(x, dtype=np.float32)
def weighted_prediction(sim: np.ndarray, targets: np.ndarray) -> float:
weights = np.maximum((np.asarray(sim, dtype=np.float64) + 1.0) / 2.0, 1e-6)
return float(np.sum(weights * targets) / np.sum(weights))
def graded_relevance(query_target: float, retrieved_targets: np.ndarray) -> np.ndarray:
delta = np.abs(np.asarray(retrieved_targets, dtype=np.float64) - float(query_target))
return np.select([delta <= 0.5, delta <= 1.0, delta <= 1.5], [3.0, 2.0, 1.0], default=0.0)
def dcg(relevance: np.ndarray) -> float:
rel = np.asarray(relevance, dtype=np.float64)
if rel.size == 0:
return 0.0
discounts = 1.0 / np.log2(np.arange(rel.size, dtype=np.float64) + 2.0)
return float(np.sum(rel * discounts))
def ndcg_at_k(relevance: np.ndarray, ideal_relevance: np.ndarray, k: int) -> float:
rel = np.asarray(relevance, dtype=np.float64)[:k]
ideal = np.asarray(ideal_relevance, dtype=np.float64)[:k]
denom = dcg(ideal)
return float(dcg(rel) / denom) if denom > 0 else 0.0
def score_backend(
name: str,
query_vec: np.ndarray,
library_vec: np.ndarray,
query_df: pd.DataFrame,
library_df: pd.DataFrame,
k: int,
mode: str,
) -> Tuple[Dict[str, float], pd.DataFrame]:
target_lib = library_df["target_log_burn_acres"].to_numpy(dtype=np.float64)
rows = []
preds = []
ndcg5 = []
ndcg10 = []
hit1 = []
hit5 = []
hit10 = []
best_abs_delta = []
k_eff = min(int(k), int(library_vec.shape[0]))
if name == "cosine_exact":
sim_all = cosine_similarity(query_vec, library_vec)
knn_idx = np.argsort(-sim_all, axis=1)[:, :k_eff]
knn_sim = np.take_along_axis(sim_all, knn_idx, axis=1)
else:
library_norm = library_vec / np.clip(np.linalg.norm(library_vec, axis=1, keepdims=True), 1e-12, None)
query_norm = query_vec / np.clip(np.linalg.norm(query_vec, axis=1, keepdims=True), 1e-12, None)
if name == "faiss_flat_ip":
index = faiss.IndexFlatIP(library_norm.shape[1])
index.add(library_norm.astype(np.float32))
knn_sim, knn_idx = index.search(query_norm.astype(np.float32), k_eff)
elif name == "hnsw_cosine":
index = hnswlib.Index(space="cosine", dim=library_norm.shape[1])
index.init_index(max_elements=library_norm.shape[0], ef_construction=100, M=16)
index.add_items(library_norm.astype(np.float32), np.arange(library_norm.shape[0]))
index.set_ef(max(50, k_eff))
knn_idx, dist = index.knn_query(query_norm.astype(np.float32), k=k_eff)
knn_sim = 1.0 - dist
else:
raise ValueError(name)
for i in range(query_df.shape[0]):
order = knn_idx[i]
top_sim = knn_sim[i]
top_targets = target_lib[order]
query_target = float(query_df.iloc[i]["target_log_burn_acres"])
relevance = graded_relevance(query_target, top_targets)
ideal_relevance = np.sort(graded_relevance(query_target, target_lib))[::-1]
abs_delta = np.abs(top_targets - float(query_df.iloc[i]["target_log_burn_acres"]))
ndcg5.append(ndcg_at_k(relevance, ideal_relevance, 5))
ndcg10.append(ndcg_at_k(relevance, ideal_relevance, 10))
hit1.append(float(relevance[:1].max() >= 2.0))
hit5.append(float(relevance[: min(5, k_eff)].max() >= 2.0))
hit10.append(float(relevance[: min(10, k_eff)].max() >= 2.0))
best_abs_delta.append(float(abs_delta.min()))
pred = float(np.mean(top_targets)) if mode == "mean" else weighted_prediction(top_sim, top_targets)
preds.append(pred)
rows.append(
{
"query_event_id": query_df.iloc[i]["Event_ID"],
"true_log_burn_acres": float(query_df.iloc[i]["target_log_burn_acres"]),
"pred_log_burn_acres": pred,
"backend": name,
"k": k,
"effective_k": k_eff,
"mode": mode,
"top_relevance": relevance.tolist(),
"best_abs_log_delta": float(abs_delta.min()),
}
)
pred_arr = np.asarray(preds, dtype=np.float64)
true_log = query_df["target_log_burn_acres"].to_numpy(dtype=np.float64)
true_acres = query_df["BurnBndAc"].to_numpy(dtype=np.float64)
pred_acres = np.exp(pred_arr)
metrics = {
"count": int(len(query_df)),
"log_mae": float(np.mean(np.abs(true_log - pred_arr))),
"log_rmse": rmse(true_log, pred_arr),
"log_r2": r2_score_manual(true_log, pred_arr),
"log_spearman": spearman_corr(true_log, pred_arr),
"log_median_ae": float(np.median(np.abs(true_log - pred_arr))),
"acres_mae": float(np.mean(np.abs(true_acres - pred_acres))),
"acres_rmse": rmse(true_acres, pred_acres),
"acres_median_ae": float(np.median(np.abs(true_acres - pred_acres))),
"acres_mape": mape(true_acres, pred_acres),
"ndcg_at_5": float(np.mean(ndcg5)) if ndcg5 else 0.0,
"ndcg_at_10": float(np.mean(ndcg10)) if ndcg10 else 0.0,
"hit_at_1_tol1": float(np.mean(hit1)) if hit1 else 0.0,
"hit_at_5_tol1": float(np.mean(hit5)) if hit5 else 0.0,
"hit_at_10_tol1": float(np.mean(hit10)) if hit10 else 0.0,
"mean_best_abs_log_delta_at_k": float(np.mean(best_abs_delta)) if best_abs_delta else 0.0,
}
return metrics, pd.DataFrame(rows)
def target_weight_vectors(train_vec: np.ndarray, val_vec: np.ndarray, test_vec: np.ndarray, target: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
x = np.asarray(train_vec, dtype=np.float64)
y = np.asarray(target, dtype=np.float64)
y = y - y.mean()
x_centered = x - x.mean(axis=0, keepdims=True)
denom = np.clip(np.sqrt(np.sum(x_centered**2, axis=0)) * np.sqrt(np.sum(y**2)), 1e-12, None)
corr = np.abs(np.sum(x_centered * y[:, None], axis=0) / denom)
corr = np.nan_to_num(corr, nan=0.0, posinf=0.0, neginf=0.0)
if float(corr.max()) > 0:
corr = corr / float(corr.max())
weights = (0.25 + corr).astype(np.float32)
return train_vec * weights, val_vec * weights, test_vec * weights
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--event-table", type=Path, required=True)
parser.add_argument("--output-dir", type=Path, required=True)
parser.add_argument("--selection-metric", choices=("log_mae", "ndcg_at_10"), default="ndcg_at_10")
parser.add_argument("--feature-profile", choices=("all", "weather_fm"), default="all")
parser.add_argument("--fm-family", type=str, default="")
parser.add_argument("--seed", type=int, default=7)
args = parser.parse_args()
df = pd.read_csv(args.event_table)
df["Ig_Date"] = pd.to_datetime(df["Ig_Date"])
train_df, val_df, test_df = build_splits(df)
numeric_cols, categorical_cols = feature_columns(df, feature_profile=args.feature_profile)
if not numeric_cols and not categorical_cols:
raise SystemExit(f"No usable features found for profile={args.feature_profile}")
x_cols = numeric_cols + categorical_cols
pre = make_preprocessor(numeric_cols, categorical_cols)
train_vec = to_dense_float32(pre.fit_transform(train_df[x_cols]))
val_vec = to_dense_float32(pre.transform(val_df[x_cols]))
test_vec = to_dense_float32(pre.transform(test_df[x_cols]))
weighted_train_vec, weighted_val_vec, weighted_test_vec = target_weight_vectors(
train_vec,
val_vec,
test_vec,
train_df["target_log_burn_acres"].to_numpy(dtype=np.float64),
)
vector_variants = {
"standard": (train_vec, val_vec, test_vec),
"target_weighted": (weighted_train_vec, weighted_val_vec, weighted_test_vec),
}
candidate_validation: List[Dict[str, object]] = []
best = None
best_score = None
best_val_rows = None
best_test_rows = None
for variant, (lib_vec, v_vec, _) in vector_variants.items():
for backend in ["cosine_exact", "faiss_flat_ip", "hnsw_cosine"]:
for k in [1, 3, 5, 10]:
for mode in ["mean", "weighted"]:
val_metrics, val_rows = score_backend(backend, v_vec, lib_vec, val_df, train_df, k, mode)
candidate_validation.append({"variant": variant, "backend": backend, "k": k, "mode": mode, "val_metrics": val_metrics})
score = float(val_metrics[args.selection_metric])
better = score > best_score if args.selection_metric == "ndcg_at_10" and best_score is not None else score < best_score if best_score is not None else True
if better:
best_score = score
best = {"variant": variant, "backend": backend, "k": k, "mode": mode}
best_val_rows = val_rows
assert best is not None
best_train_vec, _, best_test_vec = vector_variants[str(best["variant"])]
test_metrics, test_rows = score_backend(best["backend"], best_test_vec, best_train_vec, test_df, train_df, int(best["k"]), str(best["mode"]))
best_test_rows = test_rows
args.output_dir.mkdir(parents=True, exist_ok=True)
if best_val_rows is not None:
best_val_rows.to_csv(args.output_dir / "val_retrieval_examples.csv", index=False)
if best_test_rows is not None:
best_test_rows.to_csv(args.output_dir / "test_retrieval_examples.csv", index=False)
summary = {
"task_id": "wildfire_analog_retrieval_taskmodels",
"task_form": "event_level_retrieval_with_induced_outcome_error",
"event_table": str(args.event_table),
"output_dir": str(args.output_dir),
"feature_profile": args.feature_profile,
"seed": int(args.seed),
"split_sizes": {
"train": int(len(train_df)),
"val": int(len(val_df)),
"test": int(len(test_df)),
},
"feature_columns": {"numeric": numeric_cols, "categorical": categorical_cols},
"candidate_validation": candidate_validation,
"selected_retrieval": best,
"selection_metric": args.selection_metric,
"test_metrics": test_metrics,
"model_family": "popular_open_source_retrieval_backends_with_train_only_target_weighting",
"fm_family": (args.fm_family or "weather_fm_derived_features") if args.feature_profile == "weather_fm" else None,
}
(args.output_dir / "summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
print(json.dumps(summary, indent=2))
if __name__ == "__main__":
main()