Spaces:
Sleeping
Sleeping
File size: 1,103 Bytes
b4822f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 | import ray
import torch
import torch.nn as nn
import torch.optim as optim
@ray.remote
class WorkerNode:
def __init__(self, part_id, model_code, head_node):
self.part_id = part_id
self.model_code = model_code
self.head_node = head_node
def load_model(self):
local_vars = {}
exec(self.model_code, globals(), local_vars)
self.model = local_vars['get_model']()
def train_model(self):
self.load_model()
X = torch.randn(100, 10) # Dummy input
y = torch.randn(100, 1) # Dummy output
criterion = nn.MSELoss()
optimizer = optim.SGD(self.model.parameters(), lr=0.01)
for epoch in range(5):
optimizer.zero_grad()
output = self.model(X)
loss = criterion(output, y)
loss.backward()
optimizer.step()
print(f"✅ Worker {self.part_id} training done.")
# Send trained weights (not gradients) to head node
weights = self.model.state_dict()
ray.get(self.head_node.receive_weights.remote(self.part_id, weights))
|