netfm-training / train.py
henribonamy's picture
Upload train.py with huggingface_hub
7080774 verified
import json
import os
import sys
import time
import traceback
import torch
from huggingface_hub import HfApi
from netfm.data.datasets import load_and_prepare, PRETRAIN_DATASETS, HOLDOUT_DATASETS
from netfm.models.encoder import GraphSAGEEncoder
from netfm.pretrain.trainer import NetFMPretrainer
from netfm.evaluate.pipeline import run_full_evaluation
HF_REPO = os.environ.get("HF_REPO", "GitHunter/netfm-checkpoints")
DATA_ROOT = "./data"
CKPT_DIR = "./checkpoints"
RESULTS_DIR = "./results"
PRETRAIN_CONFIG = {
"hidden_channels": 256,
"out_channels": 128,
"num_layers": 3,
"dropout": 0.1,
"lr": 1e-3,
"weight_decay": 1e-4,
"lambda_mask": 1.0,
"lambda_link": 1.0,
"lambda_subgraph": 0.5,
"mask_ratio": 0.15,
"edge_drop_ratio": 0.1,
"epochs": 100,
}
def upload_to_hub(local_path: str, repo_path: str) -> None:
"""Upload a file to the HuggingFace Hub."""
api = HfApi()
try:
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=repo_path,
repo_id=HF_REPO,
repo_type="model",
)
print(f"Uploaded {local_path} -> {repo_path}")
except Exception as e:
print(f"Upload failed for {repo_path}: {e}")
def log(msg: str) -> None:
"""Print and flush a log message."""
print(msg, flush=True)
def main() -> None:
"""Run full pre-training and evaluation pipeline."""
os.makedirs(CKPT_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(f"Device: {device}")
if device.type == "cuda":
log(f"GPU: {torch.cuda.get_device_name(0)}")
log(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
log("=" * 60)
log("PHASE 1: Loading and preparing datasets")
log("=" * 60)
all_graphs = []
dataset_stats = {}
for name in PRETRAIN_DATASETS:
log(f"Loading {name}...")
t0 = time.time()
try:
graphs = load_and_prepare(name, root=DATA_ROOT)
for g in graphs:
all_graphs.append(g.data)
elapsed = time.time() - t0
dataset_stats[name] = {
"num_graphs": len(graphs),
"num_nodes": graphs[0].num_nodes,
"num_edges": graphs[0].num_edges,
"load_time_s": round(elapsed, 1),
}
log(f" -> {len(graphs)} graph(s), {graphs[0].num_nodes} nodes, {elapsed:.1f}s")
except Exception as e:
log(f" -> FAILED: {e}")
traceback.print_exc()
dataset_stats[name] = {"error": str(e)}
stats_path = os.path.join(RESULTS_DIR, "dataset_stats.json")
with open(stats_path, "w") as f:
json.dump(dataset_stats, f, indent=2)
upload_to_hub(stats_path, "dataset_stats.json")
log(f"\nTotal pre-training graphs: {len(all_graphs)}")
if len(all_graphs) == 0:
log("ERROR: No graphs loaded. Exiting.")
sys.exit(1)
log("=" * 60)
log("PHASE 2: Pre-training")
log("=" * 60)
encoder = GraphSAGEEncoder(
in_channels=6,
hidden_channels=PRETRAIN_CONFIG["hidden_channels"],
out_channels=PRETRAIN_CONFIG["out_channels"],
num_layers=PRETRAIN_CONFIG["num_layers"],
dropout=PRETRAIN_CONFIG["dropout"],
)
log(f"Model parameters: {sum(p.numel() for p in encoder.parameters()):,}")
trainer = NetFMPretrainer(
encoder=encoder,
device=device,
lr=PRETRAIN_CONFIG["lr"],
weight_decay=PRETRAIN_CONFIG["weight_decay"],
lambda_mask=PRETRAIN_CONFIG["lambda_mask"],
lambda_link=PRETRAIN_CONFIG["lambda_link"],
lambda_subgraph=PRETRAIN_CONFIG["lambda_subgraph"],
mask_ratio=PRETRAIN_CONFIG["mask_ratio"],
edge_drop_ratio=PRETRAIN_CONFIG["edge_drop_ratio"],
num_epochs=PRETRAIN_CONFIG["epochs"],
)
history = trainer.train(all_graphs)
ckpt_path = os.path.join(CKPT_DIR, "netfm_pretrained.pt")
trainer.save(ckpt_path)
config_path = os.path.join(CKPT_DIR, "config.json")
with open(config_path, "w") as f:
json.dump(PRETRAIN_CONFIG, f, indent=2)
history_path = os.path.join(RESULTS_DIR, "pretrain_history.json")
with open(history_path, "w") as f:
json.dump(history, f, indent=2)
upload_to_hub(ckpt_path, "netfm_pretrained.pt")
upload_to_hub(config_path, "config.json")
upload_to_hub(history_path, "pretrain_history.json")
log("=" * 60)
log("PHASE 3: Evaluation on held-out datasets")
log("=" * 60)
all_results = {}
for name in HOLDOUT_DATASETS:
log(f"Evaluating on {name}...")
try:
graphs = load_and_prepare(name, root=DATA_ROOT)
data = graphs[0].data
results = run_full_evaluation(encoder, data, device, name)
all_results.update(results)
log(f" Done: {list(results[name].keys())}")
except Exception as e:
log(f" FAILED: {e}")
traceback.print_exc()
all_results[name] = {"error": str(e)}
results_path = os.path.join(RESULTS_DIR, "evaluation_results.json")
with open(results_path, "w") as f:
json.dump(all_results, f, indent=2, default=str)
upload_to_hub(results_path, "evaluation_results.json")
log("=" * 60)
log("DONE! All results uploaded to HuggingFace Hub.")
log("=" * 60)
log(f"Checkpoint: {HF_REPO}")
log(json.dumps(all_results, indent=2, default=str))
if __name__ == "__main__":
main()