Khelendramee commited on
Commit
cd67fd0
·
verified ·
1 Parent(s): 0f7de26

Create ray.py

Browse files
Files changed (1) hide show
  1. ray.py +21 -0
ray.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ray
2
+ from head_node import HeadNode
3
+ from worker_node import WorkerNode
4
+
5
+ ray.init() # Use address="auto" for multi-machine
6
+
7
+ # Read model split files (node_0.py, node_1.py, etc.)
8
+ model_codes = []
9
+ for i in range(3):
10
+ with open(f"model_stage_files/node_{i}.py", "r") as f:
11
+ model_codes.append(f.read())
12
+
13
+ # Start head node
14
+ head = HeadNode.remote(num_parts=3)
15
+
16
+ # Start worker nodes
17
+ workers = []
18
+ for i in range(3):
19
+ worker = WorkerNode.remote(i, model_codes[i], head)
20
+ workers.append(worker)
21
+ worker.train_model.remote()