import ray import torch import torch.nn as nn @ray.remote class HeadNode: def __init__(self, num_parts): self.num_parts = num_parts self.model_parts = [None] * num_parts def receive_weights(self, part_id, weights): print(f"📩 Received weights from part {part_id}") self.model_parts[part_id] = weights if all(self.model_parts): print("🧠 All parts received, combining full model...") self.combine_model() def combine_model(self): # Full architecture (must match model split) full_model = nn.Sequential( nn.Linear(10, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 1) ) full_state_dict = full_model.state_dict() part_keys = list(full_state_dict.keys()) # Flatten part weights into full model i = 0 for part in self.model_parts: for key in part.keys(): full_state_dict[part_keys[i]] = part[key] i += 1 full_model.load_state_dict(full_state_dict) torch.save(full_model.state_dict(), "final_model.pt") print("✅ Full model saved as final_model.pt")