Upload hierarchy_model.py with huggingface_hub
Browse files- hierarchy_model.py +13 -1
hierarchy_model.py
CHANGED
|
@@ -282,7 +282,19 @@ class HierarchyEncoder(nn.Module):
|
|
| 282 |
|
| 283 |
def forward(self, hierarchy_indices):
|
| 284 |
# hierarchy_indices: (B,) - batch of hierarchy indices
|
| 285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
return self.projection(emb)
|
| 287 |
|
| 288 |
class HierarchyClassifierHead(nn.Module):
|
|
|
|
| 282 |
|
| 283 |
def forward(self, hierarchy_indices):
|
| 284 |
# hierarchy_indices: (B,) - batch of hierarchy indices
|
| 285 |
+
# Workaround for MPS: embedding layers don't work well with MPS, so do lookup on CPU
|
| 286 |
+
device = next(self.parameters()).device
|
| 287 |
+
if device.type == 'mps':
|
| 288 |
+
# Move indices to CPU for embedding lookup
|
| 289 |
+
indices_cpu = hierarchy_indices.cpu()
|
| 290 |
+
# Use functional embedding with explicit weight handling for MPS compatibility
|
| 291 |
+
emb_weight = self.embedding.weight.cpu()
|
| 292 |
+
emb = F.embedding(indices_cpu, emb_weight)
|
| 293 |
+
# Move result back to model device (MPS) - ensure it's contiguous
|
| 294 |
+
emb = emb.contiguous().to(device)
|
| 295 |
+
else:
|
| 296 |
+
emb = self.embedding(hierarchy_indices)
|
| 297 |
+
# Ensure emb is on the same device as projection before calling it
|
| 298 |
return self.projection(emb)
|
| 299 |
|
| 300 |
class HierarchyClassifierHead(nn.Module):
|