File size: 540 Bytes
cd67fd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import ray
from head_node import HeadNode
from worker_node import WorkerNode

ray.init()  # Use address="auto" for multi-machine

# Read model split files (node_0.py, node_1.py, etc.)
model_codes = []
for i in range(3):
    with open(f"model_stage_files/node_{i}.py", "r") as f:
        model_codes.append(f.read())

# Start head node
head = HeadNode.remote(num_parts=3)

# Start worker nodes
workers = []
for i in range(3):
    worker = WorkerNode.remote(i, model_codes[i], head)
    workers.append(worker)
    worker.train_model.remote()