Spaces:
Sleeping
Sleeping
| import ray | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| class WorkerNode: | |
| def __init__(self, part_id, model_code, head_node): | |
| self.part_id = part_id | |
| self.model_code = model_code | |
| self.head_node = head_node | |
| def load_model(self): | |
| local_vars = {} | |
| exec(self.model_code, globals(), local_vars) | |
| self.model = local_vars['get_model']() | |
| def train_model(self): | |
| self.load_model() | |
| X = torch.randn(100, 10) # Dummy input | |
| y = torch.randn(100, 1) # Dummy output | |
| criterion = nn.MSELoss() | |
| optimizer = optim.SGD(self.model.parameters(), lr=0.01) | |
| for epoch in range(5): | |
| optimizer.zero_grad() | |
| output = self.model(X) | |
| loss = criterion(output, y) | |
| loss.backward() | |
| optimizer.step() | |
| print(f"✅ Worker {self.part_id} training done.") | |
| # Send trained weights (not gradients) to head node | |
| weights = self.model.state_dict() | |
| ray.get(self.head_node.receive_weights.remote(self.part_id, weights)) | |