aml_app / export_graph_state.py
waltertaya's picture
Added files
a29c713 verified
"""
export_graph_state.py
─────────────────────
Run this inside your Kaggle / Lightning AI notebook AFTER Cell 7
to export everything the Gradio app needs for full graph-lookup inference.
Usage (paste into a new notebook cell after c07-graph):
%run export_graph_state.py
Or inline:
exec(open('export_graph_state.py').read())
"""
import os, torch
from pathlib import Path
EXPORT_DIR = Path(os.environ.get("EXPORT_DIR", "/kaggle/working/artifacts/aml-gnn"))
EXPORT_DIR.mkdir(parents=True, exist_ok=True)
# ── 1. Best model weights ─────────────────────────────────────────────────────
# (already saved by c09-train if you ran it β€” just copy)
model_src = EXPORT_DIR / "best_model.pt"
if not model_src.exists() and best_state is not None:
torch.save(best_state, model_src)
print(f"Saved model weights β†’ {model_src}")
else:
print(f"Model weights already at {model_src}")
# ── 2. Graph state (for graph-lookup inference) ───────────────────────────────
graph_path = EXPORT_DIR / "graph_state.pt"
torch.save({
"data": data, # PyG Data object (node_x, edge_index, edge_attr)
"account_to_id": account_to_id, # dict: account_str β†’ node_id int
"edge_columns": list(edge_feat_df.columns) if 'edge_feat_df' in dir() else [],
}, graph_path)
print(f"Saved graph state β†’ {graph_path}")
print(f" Nodes : {data.x.shape[0]:,}")
print(f" Edges : {data.edge_index.shape[1]:,}")
# ── 3. Config ─────────────────────────────────────────────────────────────────
import json
from datetime import datetime, timezone
config_path = EXPORT_DIR / "config.json"
cfg = {
"model_class" : "EdgeGNN",
"in_dim" : int(data.x.shape[1]),
"edge_dim" : int(data.edge_attr.shape[1]),
"hidden_dim" : int(HIDDEN_DIM),
"dropout" : float(DROPOUT),
"best_threshold" : float(best_thr),
"use_focal_loss" : bool(USE_FOCAL_LOSS),
"focal_gamma" : float(FOCAL_GAMMA),
"sample_frac" : float(SAMPLE_FRAC),
"timestamp_utc" : datetime.now(timezone.utc).isoformat(),
}
config_path.write_text(json.dumps(cfg, indent=2))
print(f"Saved config β†’ {config_path}")
print("\nβœ… Export complete. Copy the artifacts/ folder to your Gradio app directory.")
print(f" {EXPORT_DIR}/")
print(f" β”œβ”€β”€ best_model.pt")
print(f" β”œβ”€β”€ graph_state.pt")
print(f" └── config.json")