| |
| import os |
| import yaml |
| import sys |
|
|
| def main(): |
| |
| master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") |
| master_port = int(os.getenv("MASTER_PORT", 29500)) |
| |
| |
| try: |
| num_nodes = int(os.getenv("SENSECORE_PYTORCH_NNODES", 1)) |
| except (ValueError, TypeError): |
| num_nodes = 1 |
| |
| try: |
| gpus_per_node = int(os.getenv("SENSECORE_ACCELERATE_DEVICE_COUNT", 1)) |
| except (ValueError, TypeError): |
| gpus_per_node = 1 |
| |
| try: |
| node_rank = int(os.getenv("SENSECORE_PYTORCH_NODE_RANK", 0)) |
| except (ValueError, TypeError): |
| node_rank = 0 |
| |
| |
| num_processes = num_nodes * gpus_per_node |
| |
| |
| config = { |
| "compute_environment": "LOCAL_MACHINE", |
| "distributed_type": "DEEPSPEED", |
| "deepspeed_config": { |
| "deepspeed_config_file": "configs/ds_config.json", |
| "zero3_init_flag": True, |
| "deepspeed_multinode_launcher": "standard", |
| "deepspeed_hostfile": '/mnt/jfzn/msj/flash-linear-attention/legacy/training/hostfile.txt', |
| }, |
| "machine_rank": node_rank, |
| "main_process_ip": master_addr, |
| "main_process_port": master_port, |
| "main_training_function": "main", |
| "num_machines": num_nodes, |
| "num_processes": num_processes, |
| "same_network": True, |
| "use_cpu": False, |
| "rdzv_backend": "c10d", |
| "tpu_env": [], |
| "tpu_use_cluster": False, |
| "tpu_use_sudo": False, |
| } |
| |
| |
| print("Generated Configuration:") |
| print(f" Master: {master_addr}:{master_port}") |
| print(f" Number of nodes: {num_nodes}") |
| print(f" GPUs per node: {gpus_per_node}") |
| print(f" Total processes: {num_processes}") |
| print(f" Node rank: {node_rank}") |
| |
| |
| os.makedirs("configs", exist_ok=True) |
| |
| |
| output_file = "/mnt/jfzn/msj/flash-linear-attention/legacy/training/configs/deepspeed_sencore.yaml" |
| with open(output_file, "w") as f: |
| yaml.dump(config, f, default_flow_style=False) |
| |
| |
|
|
|
|
| print(f"\nConfiguration saved to: {output_file}") |
| |
| |
| print("\nFile content:") |
| with open(output_file, "r") as f: |
| print(f.read()) |
| |
| |
|
|
| input_file = '/mnt/jfzn/msj/flash-linear-attention/legacy/training/ssh_config/config' |
| output_file = '/mnt/jfzn/msj/flash-linear-attention/legacy/training/hostfile.txt' |
|
|
| hostnames = [] |
|
|
| with open(input_file, "r") as f: |
| for line in f: |
| line = line.strip() |
| if line.startswith("Hostname"): |
| |
| hostname = line.split(None, 1)[1] |
| hostnames.append(hostname) |
|
|
| |
| with open(output_file, "w") as f: |
| for host in hostnames: |
| f.write(host+ " slots=8\n") |
|
|
| print(f"提取了 {len(hostnames)} 个 hostname,已写入 {output_file}") |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return 0 |
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |