stocker / head_node.py
Khelendramee's picture
Rename head.py to head_node.py
f0f2888 verified
import ray
import torch
import torch.nn as nn
@ray.remote
class HeadNode:
def __init__(self, num_parts):
self.num_parts = num_parts
self.model_parts = [None] * num_parts
def receive_weights(self, part_id, weights):
print(f"📩 Received weights from part {part_id}")
self.model_parts[part_id] = weights
if all(self.model_parts):
print("🧠 All parts received, combining full model...")
self.combine_model()
def combine_model(self):
# Full architecture (must match model split)
full_model = nn.Sequential(
nn.Linear(10, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
full_state_dict = full_model.state_dict()
part_keys = list(full_state_dict.keys())
# Flatten part weights into full model
i = 0
for part in self.model_parts:
for key in part.keys():
full_state_dict[part_keys[i]] = part[key]
i += 1
full_model.load_state_dict(full_state_dict)
torch.save(full_model.state_dict(), "final_model.pt")
print("✅ Full model saved as final_model.pt")