Spaces:
Sleeping
Sleeping
| import ray | |
| import torch | |
| import torch.nn as nn | |
| class HeadNode: | |
| def __init__(self, num_parts): | |
| self.num_parts = num_parts | |
| self.model_parts = [None] * num_parts | |
| def receive_weights(self, part_id, weights): | |
| print(f"📩 Received weights from part {part_id}") | |
| self.model_parts[part_id] = weights | |
| if all(self.model_parts): | |
| print("🧠 All parts received, combining full model...") | |
| self.combine_model() | |
| def combine_model(self): | |
| # Full architecture (must match model split) | |
| full_model = nn.Sequential( | |
| nn.Linear(10, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, 1) | |
| ) | |
| full_state_dict = full_model.state_dict() | |
| part_keys = list(full_state_dict.keys()) | |
| # Flatten part weights into full model | |
| i = 0 | |
| for part in self.model_parts: | |
| for key in part.keys(): | |
| full_state_dict[part_keys[i]] = part[key] | |
| i += 1 | |
| full_model.load_state_dict(full_state_dict) | |
| torch.save(full_model.state_dict(), "final_model.pt") | |
| print("✅ Full model saved as final_model.pt") | |