import json import os import sys import time import traceback import torch from huggingface_hub import HfApi from netfm.data.datasets import load_and_prepare, PRETRAIN_DATASETS, HOLDOUT_DATASETS from netfm.models.encoder import GraphSAGEEncoder from netfm.pretrain.trainer import NetFMPretrainer from netfm.evaluate.pipeline import run_full_evaluation HF_REPO = os.environ.get("HF_REPO", "GitHunter/netfm-checkpoints") DATA_ROOT = "./data" CKPT_DIR = "./checkpoints" RESULTS_DIR = "./results" PRETRAIN_CONFIG = { "hidden_channels": 256, "out_channels": 128, "num_layers": 3, "dropout": 0.1, "lr": 1e-3, "weight_decay": 1e-4, "lambda_mask": 1.0, "lambda_link": 1.0, "lambda_subgraph": 0.5, "mask_ratio": 0.15, "edge_drop_ratio": 0.1, "epochs": 100, } def upload_to_hub(local_path: str, repo_path: str) -> None: """Upload a file to the HuggingFace Hub.""" api = HfApi() try: api.upload_file( path_or_fileobj=local_path, path_in_repo=repo_path, repo_id=HF_REPO, repo_type="model", ) print(f"Uploaded {local_path} -> {repo_path}") except Exception as e: print(f"Upload failed for {repo_path}: {e}") def log(msg: str) -> None: """Print and flush a log message.""" print(msg, flush=True) def main() -> None: """Run full pre-training and evaluation pipeline.""" os.makedirs(CKPT_DIR, exist_ok=True) os.makedirs(RESULTS_DIR, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") log(f"Device: {device}") if device.type == "cuda": log(f"GPU: {torch.cuda.get_device_name(0)}") log(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB") log("=" * 60) log("PHASE 1: Loading and preparing datasets") log("=" * 60) all_graphs = [] dataset_stats = {} for name in PRETRAIN_DATASETS: log(f"Loading {name}...") t0 = time.time() try: graphs = load_and_prepare(name, root=DATA_ROOT) for g in graphs: all_graphs.append(g.data) elapsed = time.time() - t0 dataset_stats[name] = { "num_graphs": len(graphs), "num_nodes": graphs[0].num_nodes, "num_edges": graphs[0].num_edges, "load_time_s": round(elapsed, 1), } log(f" -> {len(graphs)} graph(s), {graphs[0].num_nodes} nodes, {elapsed:.1f}s") except Exception as e: log(f" -> FAILED: {e}") traceback.print_exc() dataset_stats[name] = {"error": str(e)} stats_path = os.path.join(RESULTS_DIR, "dataset_stats.json") with open(stats_path, "w") as f: json.dump(dataset_stats, f, indent=2) upload_to_hub(stats_path, "dataset_stats.json") log(f"\nTotal pre-training graphs: {len(all_graphs)}") if len(all_graphs) == 0: log("ERROR: No graphs loaded. Exiting.") sys.exit(1) log("=" * 60) log("PHASE 2: Pre-training") log("=" * 60) encoder = GraphSAGEEncoder( in_channels=6, hidden_channels=PRETRAIN_CONFIG["hidden_channels"], out_channels=PRETRAIN_CONFIG["out_channels"], num_layers=PRETRAIN_CONFIG["num_layers"], dropout=PRETRAIN_CONFIG["dropout"], ) log(f"Model parameters: {sum(p.numel() for p in encoder.parameters()):,}") trainer = NetFMPretrainer( encoder=encoder, device=device, lr=PRETRAIN_CONFIG["lr"], weight_decay=PRETRAIN_CONFIG["weight_decay"], lambda_mask=PRETRAIN_CONFIG["lambda_mask"], lambda_link=PRETRAIN_CONFIG["lambda_link"], lambda_subgraph=PRETRAIN_CONFIG["lambda_subgraph"], mask_ratio=PRETRAIN_CONFIG["mask_ratio"], edge_drop_ratio=PRETRAIN_CONFIG["edge_drop_ratio"], num_epochs=PRETRAIN_CONFIG["epochs"], ) history = trainer.train(all_graphs) ckpt_path = os.path.join(CKPT_DIR, "netfm_pretrained.pt") trainer.save(ckpt_path) config_path = os.path.join(CKPT_DIR, "config.json") with open(config_path, "w") as f: json.dump(PRETRAIN_CONFIG, f, indent=2) history_path = os.path.join(RESULTS_DIR, "pretrain_history.json") with open(history_path, "w") as f: json.dump(history, f, indent=2) upload_to_hub(ckpt_path, "netfm_pretrained.pt") upload_to_hub(config_path, "config.json") upload_to_hub(history_path, "pretrain_history.json") log("=" * 60) log("PHASE 3: Evaluation on held-out datasets") log("=" * 60) all_results = {} for name in HOLDOUT_DATASETS: log(f"Evaluating on {name}...") try: graphs = load_and_prepare(name, root=DATA_ROOT) data = graphs[0].data results = run_full_evaluation(encoder, data, device, name) all_results.update(results) log(f" Done: {list(results[name].keys())}") except Exception as e: log(f" FAILED: {e}") traceback.print_exc() all_results[name] = {"error": str(e)} results_path = os.path.join(RESULTS_DIR, "evaluation_results.json") with open(results_path, "w") as f: json.dump(all_results, f, indent=2, default=str) upload_to_hub(results_path, "evaluation_results.json") log("=" * 60) log("DONE! All results uploaded to HuggingFace Hub.") log("=" * 60) log(f"Checkpoint: {HF_REPO}") log(json.dumps(all_results, indent=2, default=str)) if __name__ == "__main__": main()