Spaces:
Sleeping
Sleeping
File size: 1,237 Bytes
f77f8b3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | import ray
import torch
import torch.nn as nn
@ray.remote
class HeadNode:
def __init__(self, num_parts):
self.num_parts = num_parts
self.model_parts = [None] * num_parts
def receive_weights(self, part_id, weights):
print(f"📩 Received weights from part {part_id}")
self.model_parts[part_id] = weights
if all(self.model_parts):
print("🧠 All parts received, combining full model...")
self.combine_model()
def combine_model(self):
# Full architecture (must match model split)
full_model = nn.Sequential(
nn.Linear(10, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
full_state_dict = full_model.state_dict()
part_keys = list(full_state_dict.keys())
# Flatten part weights into full model
i = 0
for part in self.model_parts:
for key in part.keys():
full_state_dict[part_keys[i]] = part[key]
i += 1
full_model.load_state_dict(full_state_dict)
torch.save(full_model.state_dict(), "final_model.pt")
print("✅ Full model saved as final_model.pt")
|