stocker / ray_node.py
Khelendramee's picture
Rename ray.py to ray_node.py
c58ad97 verified
raw
history blame contribute delete
540 Bytes
import ray
from head_node import HeadNode
from worker_node import WorkerNode
ray.init() # Use address="auto" for multi-machine
# Read model split files (node_0.py, node_1.py, etc.)
model_codes = []
for i in range(3):
with open(f"model_stage_files/node_{i}.py", "r") as f:
model_codes.append(f.read())
# Start head node
head = HeadNode.remote(num_parts=3)
# Start worker nodes
workers = []
for i in range(3):
worker = WorkerNode.remote(i, model_codes[i], head)
workers.append(worker)
worker.train_model.remote()