Khelendramee commited on
Commit
f77f8b3
·
verified ·
1 Parent(s): bdba16a

Create head.py

Browse files
Files changed (1) hide show
  1. head.py +41 -0
head.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ray
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ @ray.remote
6
+ class HeadNode:
7
+ def __init__(self, num_parts):
8
+ self.num_parts = num_parts
9
+ self.model_parts = [None] * num_parts
10
+
11
+ def receive_weights(self, part_id, weights):
12
+ print(f"📩 Received weights from part {part_id}")
13
+ self.model_parts[part_id] = weights
14
+
15
+ if all(self.model_parts):
16
+ print("🧠 All parts received, combining full model...")
17
+ self.combine_model()
18
+
19
+ def combine_model(self):
20
+ # Full architecture (must match model split)
21
+ full_model = nn.Sequential(
22
+ nn.Linear(10, 64),
23
+ nn.ReLU(),
24
+ nn.Linear(64, 64),
25
+ nn.ReLU(),
26
+ nn.Linear(64, 1)
27
+ )
28
+
29
+ full_state_dict = full_model.state_dict()
30
+ part_keys = list(full_state_dict.keys())
31
+
32
+ # Flatten part weights into full model
33
+ i = 0
34
+ for part in self.model_parts:
35
+ for key in part.keys():
36
+ full_state_dict[part_keys[i]] = part[key]
37
+ i += 1
38
+
39
+ full_model.load_state_dict(full_state_dict)
40
+ torch.save(full_model.state_dict(), "final_model.pt")
41
+ print("✅ Full model saved as final_model.pt")