Spaces:
Sleeping
Sleeping
| """ | |
| export_graph_state.py | |
| βββββββββββββββββββββ | |
| Run this inside your Kaggle / Lightning AI notebook AFTER Cell 7 | |
| to export everything the Gradio app needs for full graph-lookup inference. | |
| Usage (paste into a new notebook cell after c07-graph): | |
| %run export_graph_state.py | |
| Or inline: | |
| exec(open('export_graph_state.py').read()) | |
| """ | |
| import os, torch | |
| from pathlib import Path | |
| EXPORT_DIR = Path(os.environ.get("EXPORT_DIR", "/kaggle/working/artifacts/aml-gnn")) | |
| EXPORT_DIR.mkdir(parents=True, exist_ok=True) | |
| # ββ 1. Best model weights βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # (already saved by c09-train if you ran it β just copy) | |
| model_src = EXPORT_DIR / "best_model.pt" | |
| if not model_src.exists() and best_state is not None: | |
| torch.save(best_state, model_src) | |
| print(f"Saved model weights β {model_src}") | |
| else: | |
| print(f"Model weights already at {model_src}") | |
| # ββ 2. Graph state (for graph-lookup inference) βββββββββββββββββββββββββββββββ | |
| graph_path = EXPORT_DIR / "graph_state.pt" | |
| torch.save({ | |
| "data": data, # PyG Data object (node_x, edge_index, edge_attr) | |
| "account_to_id": account_to_id, # dict: account_str β node_id int | |
| "edge_columns": list(edge_feat_df.columns) if 'edge_feat_df' in dir() else [], | |
| }, graph_path) | |
| print(f"Saved graph state β {graph_path}") | |
| print(f" Nodes : {data.x.shape[0]:,}") | |
| print(f" Edges : {data.edge_index.shape[1]:,}") | |
| # ββ 3. Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import json | |
| from datetime import datetime, timezone | |
| config_path = EXPORT_DIR / "config.json" | |
| cfg = { | |
| "model_class" : "EdgeGNN", | |
| "in_dim" : int(data.x.shape[1]), | |
| "edge_dim" : int(data.edge_attr.shape[1]), | |
| "hidden_dim" : int(HIDDEN_DIM), | |
| "dropout" : float(DROPOUT), | |
| "best_threshold" : float(best_thr), | |
| "use_focal_loss" : bool(USE_FOCAL_LOSS), | |
| "focal_gamma" : float(FOCAL_GAMMA), | |
| "sample_frac" : float(SAMPLE_FRAC), | |
| "timestamp_utc" : datetime.now(timezone.utc).isoformat(), | |
| } | |
| config_path.write_text(json.dumps(cfg, indent=2)) | |
| print(f"Saved config β {config_path}") | |
| print("\nβ Export complete. Copy the artifacts/ folder to your Gradio app directory.") | |
| print(f" {EXPORT_DIR}/") | |
| print(f" βββ best_model.pt") | |
| print(f" βββ graph_state.pt") | |
| print(f" βββ config.json") | |