Khelendramee commited on
Commit
b4822f9
·
verified ·
1 Parent(s): f0f2888

Create worker_node.py

Browse files
Files changed (1) hide show
  1. worker_node.py +37 -0
worker_node.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ray
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+
6
+ @ray.remote
7
+ class WorkerNode:
8
+ def __init__(self, part_id, model_code, head_node):
9
+ self.part_id = part_id
10
+ self.model_code = model_code
11
+ self.head_node = head_node
12
+
13
+ def load_model(self):
14
+ local_vars = {}
15
+ exec(self.model_code, globals(), local_vars)
16
+ self.model = local_vars['get_model']()
17
+
18
+ def train_model(self):
19
+ self.load_model()
20
+ X = torch.randn(100, 10) # Dummy input
21
+ y = torch.randn(100, 1) # Dummy output
22
+
23
+ criterion = nn.MSELoss()
24
+ optimizer = optim.SGD(self.model.parameters(), lr=0.01)
25
+
26
+ for epoch in range(5):
27
+ optimizer.zero_grad()
28
+ output = self.model(X)
29
+ loss = criterion(output, y)
30
+ loss.backward()
31
+ optimizer.step()
32
+
33
+ print(f"✅ Worker {self.part_id} training done.")
34
+
35
+ # Send trained weights (not gradients) to head node
36
+ weights = self.model.state_dict()
37
+ ray.get(self.head_node.receive_weights.remote(self.part_id, weights))