import ray import torch import torch.nn as nn import torch.optim as optim @ray.remote class WorkerNode: def __init__(self, part_id, model_code, head_node): self.part_id = part_id self.model_code = model_code self.head_node = head_node def load_model(self): local_vars = {} exec(self.model_code, globals(), local_vars) self.model = local_vars['get_model']() def train_model(self): self.load_model() X = torch.randn(100, 10) # Dummy input y = torch.randn(100, 1) # Dummy output criterion = nn.MSELoss() optimizer = optim.SGD(self.model.parameters(), lr=0.01) for epoch in range(5): optimizer.zero_grad() output = self.model(X) loss = criterion(output, y) loss.backward() optimizer.step() print(f"✅ Worker {self.part_id} training done.") # Send trained weights (not gradients) to head node weights = self.model.state_dict() ray.get(self.head_node.receive_weights.remote(self.part_id, weights))