Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import sys | |
| import time | |
| import traceback | |
| import torch | |
| from huggingface_hub import HfApi | |
| from netfm.data.datasets import load_and_prepare, PRETRAIN_DATASETS, HOLDOUT_DATASETS | |
| from netfm.models.encoder import GraphSAGEEncoder | |
| from netfm.pretrain.trainer import NetFMPretrainer | |
| from netfm.evaluate.pipeline import run_full_evaluation | |
| HF_REPO = os.environ.get("HF_REPO", "GitHunter/netfm-checkpoints") | |
| DATA_ROOT = "./data" | |
| CKPT_DIR = "./checkpoints" | |
| RESULTS_DIR = "./results" | |
| PRETRAIN_CONFIG = { | |
| "hidden_channels": 256, | |
| "out_channels": 128, | |
| "num_layers": 3, | |
| "dropout": 0.1, | |
| "lr": 1e-3, | |
| "weight_decay": 1e-4, | |
| "lambda_mask": 1.0, | |
| "lambda_link": 1.0, | |
| "lambda_subgraph": 0.5, | |
| "mask_ratio": 0.15, | |
| "edge_drop_ratio": 0.1, | |
| "epochs": 100, | |
| } | |
| def upload_to_hub(local_path: str, repo_path: str) -> None: | |
| """Upload a file to the HuggingFace Hub.""" | |
| api = HfApi() | |
| try: | |
| api.upload_file( | |
| path_or_fileobj=local_path, | |
| path_in_repo=repo_path, | |
| repo_id=HF_REPO, | |
| repo_type="model", | |
| ) | |
| print(f"Uploaded {local_path} -> {repo_path}") | |
| except Exception as e: | |
| print(f"Upload failed for {repo_path}: {e}") | |
| def log(msg: str) -> None: | |
| """Print and flush a log message.""" | |
| print(msg, flush=True) | |
| def main() -> None: | |
| """Run full pre-training and evaluation pipeline.""" | |
| os.makedirs(CKPT_DIR, exist_ok=True) | |
| os.makedirs(RESULTS_DIR, exist_ok=True) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| log(f"Device: {device}") | |
| if device.type == "cuda": | |
| log(f"GPU: {torch.cuda.get_device_name(0)}") | |
| log(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB") | |
| log("=" * 60) | |
| log("PHASE 1: Loading and preparing datasets") | |
| log("=" * 60) | |
| all_graphs = [] | |
| dataset_stats = {} | |
| for name in PRETRAIN_DATASETS: | |
| log(f"Loading {name}...") | |
| t0 = time.time() | |
| try: | |
| graphs = load_and_prepare(name, root=DATA_ROOT) | |
| for g in graphs: | |
| all_graphs.append(g.data) | |
| elapsed = time.time() - t0 | |
| dataset_stats[name] = { | |
| "num_graphs": len(graphs), | |
| "num_nodes": graphs[0].num_nodes, | |
| "num_edges": graphs[0].num_edges, | |
| "load_time_s": round(elapsed, 1), | |
| } | |
| log(f" -> {len(graphs)} graph(s), {graphs[0].num_nodes} nodes, {elapsed:.1f}s") | |
| except Exception as e: | |
| log(f" -> FAILED: {e}") | |
| traceback.print_exc() | |
| dataset_stats[name] = {"error": str(e)} | |
| stats_path = os.path.join(RESULTS_DIR, "dataset_stats.json") | |
| with open(stats_path, "w") as f: | |
| json.dump(dataset_stats, f, indent=2) | |
| upload_to_hub(stats_path, "dataset_stats.json") | |
| log(f"\nTotal pre-training graphs: {len(all_graphs)}") | |
| if len(all_graphs) == 0: | |
| log("ERROR: No graphs loaded. Exiting.") | |
| sys.exit(1) | |
| log("=" * 60) | |
| log("PHASE 2: Pre-training") | |
| log("=" * 60) | |
| encoder = GraphSAGEEncoder( | |
| in_channels=6, | |
| hidden_channels=PRETRAIN_CONFIG["hidden_channels"], | |
| out_channels=PRETRAIN_CONFIG["out_channels"], | |
| num_layers=PRETRAIN_CONFIG["num_layers"], | |
| dropout=PRETRAIN_CONFIG["dropout"], | |
| ) | |
| log(f"Model parameters: {sum(p.numel() for p in encoder.parameters()):,}") | |
| trainer = NetFMPretrainer( | |
| encoder=encoder, | |
| device=device, | |
| lr=PRETRAIN_CONFIG["lr"], | |
| weight_decay=PRETRAIN_CONFIG["weight_decay"], | |
| lambda_mask=PRETRAIN_CONFIG["lambda_mask"], | |
| lambda_link=PRETRAIN_CONFIG["lambda_link"], | |
| lambda_subgraph=PRETRAIN_CONFIG["lambda_subgraph"], | |
| mask_ratio=PRETRAIN_CONFIG["mask_ratio"], | |
| edge_drop_ratio=PRETRAIN_CONFIG["edge_drop_ratio"], | |
| num_epochs=PRETRAIN_CONFIG["epochs"], | |
| ) | |
| history = trainer.train(all_graphs) | |
| ckpt_path = os.path.join(CKPT_DIR, "netfm_pretrained.pt") | |
| trainer.save(ckpt_path) | |
| config_path = os.path.join(CKPT_DIR, "config.json") | |
| with open(config_path, "w") as f: | |
| json.dump(PRETRAIN_CONFIG, f, indent=2) | |
| history_path = os.path.join(RESULTS_DIR, "pretrain_history.json") | |
| with open(history_path, "w") as f: | |
| json.dump(history, f, indent=2) | |
| upload_to_hub(ckpt_path, "netfm_pretrained.pt") | |
| upload_to_hub(config_path, "config.json") | |
| upload_to_hub(history_path, "pretrain_history.json") | |
| log("=" * 60) | |
| log("PHASE 3: Evaluation on held-out datasets") | |
| log("=" * 60) | |
| all_results = {} | |
| for name in HOLDOUT_DATASETS: | |
| log(f"Evaluating on {name}...") | |
| try: | |
| graphs = load_and_prepare(name, root=DATA_ROOT) | |
| data = graphs[0].data | |
| results = run_full_evaluation(encoder, data, device, name) | |
| all_results.update(results) | |
| log(f" Done: {list(results[name].keys())}") | |
| except Exception as e: | |
| log(f" FAILED: {e}") | |
| traceback.print_exc() | |
| all_results[name] = {"error": str(e)} | |
| results_path = os.path.join(RESULTS_DIR, "evaluation_results.json") | |
| with open(results_path, "w") as f: | |
| json.dump(all_results, f, indent=2, default=str) | |
| upload_to_hub(results_path, "evaluation_results.json") | |
| log("=" * 60) | |
| log("DONE! All results uploaded to HuggingFace Hub.") | |
| log("=" * 60) | |
| log(f"Checkpoint: {HF_REPO}") | |
| log(json.dumps(all_results, indent=2, default=str)) | |
| if __name__ == "__main__": | |
| main() | |