Spaces:
Sleeping
Sleeping
| import os | |
| import wandb | |
| from src.utils.paths import get_path | |
| api = wandb.Api() | |
| def get_run_by_name(name): | |
| runs = api.runs( | |
| path="fcc_ml/svj_clustering", | |
| filters={"display_name": {"$eq": name.strip()}} | |
| ) | |
| runs = api.runs( | |
| path="fcc_ml/svj_clustering", | |
| filters={"display_name": {"$eq": name.strip()}} | |
| ) | |
| if runs.length != 1: | |
| return None | |
| return runs[0] | |
| def get_steps_from_file(fname): | |
| # fname looks like "/work/gkrzmanc/jetclustering/results/train/lgatr_CONT_ds_cap_1000_2025_01_21_19_41_51/step_22000_epoch_1467.ckpt" -> extract 22000 | |
| return int(fname.split("/")[-1].split("_")[1]) | |
| def get_run_initial_steps(run): | |
| if not run.config["load_model_weights"]: | |
| return 0 | |
| else: | |
| run_name_1 = run.config["load_model_weights"].split("/")[-2] | |
| run_1 = get_run_by_name(run_name_1) | |
| if run_1 is None: raise Exception("Run doesn't exist: " + run_name_1) | |
| return get_run_initial_steps(run_1) + get_steps_from_file(run.config["load_model_weights"]) | |
| def extract_relative_path(run_path): | |
| # just return everything after train/.. - run_path looks like /a/b/c/d/train/e/f | |
| return get_path("train/" + run_path.split("train/")[-1], type="results", fallback=True) | |
| #return "train/" + run_path.split("train/")[-1] | |
| def get_run_step_direct(run_path, step): | |
| # get the step of the run directly | |
| p = extract_relative_path(run_path) | |
| print("Run-path:", p) | |
| lst = os.listdir(p) | |
| lst = [x for x in lst if x.endswith(".ckpt")] # files are of format step_x_epoch_y.ckpt | |
| steps = [int(x.split("_")[1]) for x in lst] | |
| if step not in steps: | |
| print("Available steps:", steps) | |
| raise Exception("Step not found in run") | |
| full_path = os.path.join(p, [x for x in lst if int(x.split("_")[1]) == step][0]) | |
| # return everything after "train/" | |
| return "train/" + full_path.split("train/")[-1] | |
| def get_run_step_ckpt(run, step, steps_from_zero): | |
| if not run.config["load_model_weights"] or steps_from_zero: | |
| return get_run_step_direct(run.config["run_path"], step), run | |
| else: | |
| run_name_1 = run.config["load_model_weights"].split("/")[-2] | |
| run_1 = get_run_by_name(run_name_1) | |
| if run_1 is None: raise Exception("Run doesn't exist: " + run_name_1) | |
| steps = get_run_initial_steps(run) | |
| if step > steps: | |
| print("Step", step, "is in run", run.name) | |
| return get_run_step_direct(run_1.config["run_path"], step - steps), run_1 | |
| else: | |
| return get_run_step_ckpt(run_1, step) | |
| args_to_update = ["validation_steps", "start_lr", "lr_scheduler", "optimizer", "embed_as_vectors", "epsilon", | |
| "min_samples", "min_cluster_size", "spatial_part_only", "scalars_oc", "lorentz_norm", "beta_type", | |
| "coord_loss_weight", "repul_loss_weight", "attr_loss_weight", "gt_radius", "loss", "num_steps", | |
| "num_epochs", "hidden_s_channels", "hidden_mv_channels", "n_heads", "internal_dim", | |
| "num_blocks", "network_config", "data_config", "no_pid"] | |
| def update_args(args, run): | |
| for arg in args_to_update: | |
| if arg in ["min_samples", "min_cluster_size", "epsilon"]: | |
| print("Skipping setting clustering args") | |
| continue | |
| if arg not in run.config: | |
| print("Skipping setting", arg) | |
| continue | |
| print("Setting", arg, run.config[arg]) | |
| setattr(args, arg, run.config[arg]) | |
| print("Loaded args from run", run.name) | |
| args.parent_run = run.name | |
| return args | |