File size: 2,731 Bytes
a29c713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
export_graph_state.py
─────────────────────
Run this inside your Kaggle / Lightning AI notebook AFTER Cell 7
to export everything the Gradio app needs for full graph-lookup inference.

Usage (paste into a new notebook cell after c07-graph):
    %run export_graph_state.py

Or inline:
    exec(open('export_graph_state.py').read())
"""

import os, torch
from pathlib import Path

EXPORT_DIR = Path(os.environ.get("EXPORT_DIR", "/kaggle/working/artifacts/aml-gnn"))
EXPORT_DIR.mkdir(parents=True, exist_ok=True)

# ── 1. Best model weights ─────────────────────────────────────────────────────
#    (already saved by c09-train if you ran it β€” just copy)
model_src = EXPORT_DIR / "best_model.pt"
if not model_src.exists() and best_state is not None:
    torch.save(best_state, model_src)
    print(f"Saved model weights β†’ {model_src}")
else:
    print(f"Model weights already at {model_src}")

# ── 2. Graph state (for graph-lookup inference) ───────────────────────────────
graph_path = EXPORT_DIR / "graph_state.pt"
torch.save({
    "data":          data,            # PyG Data object (node_x, edge_index, edge_attr)
    "account_to_id": account_to_id,  # dict: account_str β†’ node_id int
    "edge_columns":  list(edge_feat_df.columns) if 'edge_feat_df' in dir() else [],
}, graph_path)
print(f"Saved graph state β†’ {graph_path}")
print(f"  Nodes : {data.x.shape[0]:,}")
print(f"  Edges : {data.edge_index.shape[1]:,}")

# ── 3. Config ─────────────────────────────────────────────────────────────────
import json
from datetime import datetime, timezone
config_path = EXPORT_DIR / "config.json"
cfg = {
    "model_class"    : "EdgeGNN",
    "in_dim"         : int(data.x.shape[1]),
    "edge_dim"       : int(data.edge_attr.shape[1]),
    "hidden_dim"     : int(HIDDEN_DIM),
    "dropout"        : float(DROPOUT),
    "best_threshold" : float(best_thr),
    "use_focal_loss" : bool(USE_FOCAL_LOSS),
    "focal_gamma"    : float(FOCAL_GAMMA),
    "sample_frac"    : float(SAMPLE_FRAC),
    "timestamp_utc"  : datetime.now(timezone.utc).isoformat(),
}
config_path.write_text(json.dumps(cfg, indent=2))
print(f"Saved config β†’ {config_path}")

print("\nβœ…  Export complete. Copy the artifacts/ folder to your Gradio app directory.")
print(f"   {EXPORT_DIR}/")
print(f"   β”œβ”€β”€ best_model.pt")
print(f"   β”œβ”€β”€ graph_state.pt")
print(f"   └── config.json")