Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import sys | |
| import time | |
| import threading | |
| import traceback | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import HfApi | |
| from netfm.data.datasets import load_and_prepare, PRETRAIN_DATASETS, HOLDOUT_DATASETS | |
| from netfm.models.encoder import GraphSAGEEncoder | |
| from netfm.pretrain.trainer import NetFMPretrainer | |
| from netfm.evaluate.pipeline import run_full_evaluation | |
| HF_REPO = os.environ.get("HF_REPO", "GitHunter/netfm-checkpoints") | |
| DATA_ROOT = "./data" | |
| CKPT_DIR = "./checkpoints" | |
| RESULTS_DIR = "./results" | |
| PRETRAIN_CONFIG = { | |
| "hidden_channels": 256, | |
| "out_channels": 128, | |
| "num_layers": 3, | |
| "dropout": 0.1, | |
| "lr": 1e-3, | |
| "weight_decay": 1e-4, | |
| "lambda_mask": 1.0, | |
| "lambda_link": 1.0, | |
| "lambda_subgraph": 0.5, | |
| "mask_ratio": 0.15, | |
| "edge_drop_ratio": 0.1, | |
| "epochs": 100, | |
| } | |
| log_lines: list[str] = [] | |
| training_status = "NOT STARTED" | |
| def log(msg: str) -> None: | |
| """Append a log line and print it.""" | |
| log_lines.append(msg) | |
| print(msg, flush=True) | |
| def upload_to_hub(local_path: str, repo_path: str) -> None: | |
| """Upload a file to the HuggingFace Hub.""" | |
| api = HfApi() | |
| try: | |
| api.upload_file( | |
| path_or_fileobj=local_path, | |
| path_in_repo=repo_path, | |
| repo_id=HF_REPO, | |
| repo_type="model", | |
| ) | |
| log(f"Uploaded {local_path} -> {repo_path}") | |
| except Exception as e: | |
| log(f"Upload failed for {repo_path}: {e}") | |
| def _run_training_inner() -> None: | |
| """Run the full training pipeline (inner function).""" | |
| global training_status | |
| training_status = "LOADING DATA" | |
| os.makedirs(CKPT_DIR, exist_ok=True) | |
| os.makedirs(RESULTS_DIR, exist_ok=True) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| log(f"Device: {device}") | |
| if device.type == "cuda": | |
| log(f"GPU: {torch.cuda.get_device_name(0)}") | |
| props = torch.cuda.get_device_properties(0) | |
| vram = getattr(props, "total_memory", getattr(props, "total_mem", 0)) | |
| log(f"VRAM: {vram / 1e9:.1f} GB") | |
| log("=" * 60) | |
| log("PHASE 1: Loading and preparing datasets") | |
| log("=" * 60) | |
| all_graphs = [] | |
| dataset_stats = {} | |
| for name in PRETRAIN_DATASETS: | |
| log(f"Loading {name}...") | |
| t0 = time.time() | |
| try: | |
| graphs = load_and_prepare(name, root=DATA_ROOT) | |
| for g in graphs: | |
| all_graphs.append(g.data) | |
| elapsed = time.time() - t0 | |
| dataset_stats[name] = { | |
| "num_graphs": len(graphs), | |
| "num_nodes": graphs[0].num_nodes, | |
| "num_edges": graphs[0].num_edges, | |
| "load_time_s": round(elapsed, 1), | |
| } | |
| log(f" -> {len(graphs)} graph(s), {graphs[0].num_nodes} nodes, {elapsed:.1f}s") | |
| except Exception as e: | |
| log(f" -> FAILED: {e}") | |
| traceback.print_exc() | |
| dataset_stats[name] = {"error": str(e)} | |
| stats_path = os.path.join(RESULTS_DIR, "dataset_stats.json") | |
| with open(stats_path, "w") as f: | |
| json.dump(dataset_stats, f, indent=2) | |
| upload_to_hub(stats_path, "dataset_stats.json") | |
| log(f"\nTotal pre-training graphs: {len(all_graphs)}") | |
| if len(all_graphs) == 0: | |
| training_status = "FAILED - No graphs loaded" | |
| return | |
| training_status = "PRE-TRAINING" | |
| log("=" * 60) | |
| log("PHASE 2: Pre-training") | |
| log("=" * 60) | |
| encoder = GraphSAGEEncoder( | |
| in_channels=6, | |
| hidden_channels=PRETRAIN_CONFIG["hidden_channels"], | |
| out_channels=PRETRAIN_CONFIG["out_channels"], | |
| num_layers=PRETRAIN_CONFIG["num_layers"], | |
| dropout=PRETRAIN_CONFIG["dropout"], | |
| ) | |
| log(f"Model parameters: {sum(p.numel() for p in encoder.parameters()):,}") | |
| trainer = NetFMPretrainer( | |
| encoder=encoder, | |
| device=device, | |
| lr=PRETRAIN_CONFIG["lr"], | |
| weight_decay=PRETRAIN_CONFIG["weight_decay"], | |
| lambda_mask=PRETRAIN_CONFIG["lambda_mask"], | |
| lambda_link=PRETRAIN_CONFIG["lambda_link"], | |
| lambda_subgraph=PRETRAIN_CONFIG["lambda_subgraph"], | |
| mask_ratio=PRETRAIN_CONFIG["mask_ratio"], | |
| edge_drop_ratio=PRETRAIN_CONFIG["edge_drop_ratio"], | |
| num_epochs=PRETRAIN_CONFIG["epochs"], | |
| ) | |
| history = trainer.train(all_graphs) | |
| ckpt_path = os.path.join(CKPT_DIR, "netfm_pretrained.pt") | |
| trainer.save(ckpt_path) | |
| config_path = os.path.join(CKPT_DIR, "config.json") | |
| with open(config_path, "w") as f: | |
| json.dump(PRETRAIN_CONFIG, f, indent=2) | |
| history_path = os.path.join(RESULTS_DIR, "pretrain_history.json") | |
| with open(history_path, "w") as f: | |
| json.dump(history, f, indent=2) | |
| upload_to_hub(ckpt_path, "netfm_pretrained.pt") | |
| upload_to_hub(config_path, "config.json") | |
| upload_to_hub(history_path, "pretrain_history.json") | |
| training_status = "EVALUATING" | |
| log("=" * 60) | |
| log("PHASE 3: Evaluation on held-out datasets") | |
| log("=" * 60) | |
| all_results = {} | |
| for name in HOLDOUT_DATASETS: | |
| log(f"Evaluating on {name}...") | |
| try: | |
| graphs = load_and_prepare(name, root=DATA_ROOT) | |
| data = graphs[0].data | |
| results = run_full_evaluation(encoder, data, device, name) | |
| all_results.update(results) | |
| log(f" Done: {list(results[name].keys())}") | |
| except Exception as e: | |
| log(f" FAILED: {e}") | |
| traceback.print_exc() | |
| all_results[name] = {"error": str(e)} | |
| results_path = os.path.join(RESULTS_DIR, "evaluation_results.json") | |
| with open(results_path, "w") as f: | |
| json.dump(all_results, f, indent=2, default=str) | |
| upload_to_hub(results_path, "evaluation_results.json") | |
| training_status = "COMPLETE" | |
| log("=" * 60) | |
| log("DONE! All results uploaded to HuggingFace Hub.") | |
| log("=" * 60) | |
| def run_training() -> None: | |
| """Wrapper that catches all exceptions from the training pipeline.""" | |
| global training_status | |
| try: | |
| _run_training_inner() | |
| except Exception as e: | |
| training_status = f"CRASHED: {e}" | |
| log(f"FATAL ERROR: {e}") | |
| log(traceback.format_exc()) | |
| def get_logs() -> str: | |
| """Return current training logs.""" | |
| header = f"Status: {training_status}\n{'=' * 60}\n" | |
| return header + "\n".join(log_lines[-200:]) | |
| def main() -> None: | |
| """Launch Gradio UI and start training in background.""" | |
| thread = threading.Thread(target=run_training, daemon=True) | |
| thread.start() | |
| with gr.Blocks(title="NetFM Training") as demo: | |
| gr.Markdown("# NetFM Training Monitor") | |
| gr.Markdown(f"Training on: **{torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}**") | |
| logs_box = gr.Textbox(label="Training Logs", lines=30, value=get_logs, every=5) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |
| if __name__ == "__main__": | |
| main() | |