Leacb4 commited on
Commit
fc0f58c
·
verified ·
1 Parent(s): 512487a

Upload hierarchy_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. hierarchy_model.py +13 -1
hierarchy_model.py CHANGED
@@ -282,7 +282,19 @@ class HierarchyEncoder(nn.Module):
282
 
283
  def forward(self, hierarchy_indices):
284
  # hierarchy_indices: (B,) - batch of hierarchy indices
285
- emb = self.embedding(hierarchy_indices)
 
 
 
 
 
 
 
 
 
 
 
 
286
  return self.projection(emb)
287
 
288
  class HierarchyClassifierHead(nn.Module):
 
282
 
283
  def forward(self, hierarchy_indices):
284
  # hierarchy_indices: (B,) - batch of hierarchy indices
285
+ # Workaround for MPS: embedding layers don't work well with MPS, so do lookup on CPU
286
+ device = next(self.parameters()).device
287
+ if device.type == 'mps':
288
+ # Move indices to CPU for embedding lookup
289
+ indices_cpu = hierarchy_indices.cpu()
290
+ # Use functional embedding with explicit weight handling for MPS compatibility
291
+ emb_weight = self.embedding.weight.cpu()
292
+ emb = F.embedding(indices_cpu, emb_weight)
293
+ # Move result back to model device (MPS) - ensure it's contiguous
294
+ emb = emb.contiguous().to(device)
295
+ else:
296
+ emb = self.embedding(hierarchy_indices)
297
+ # Ensure emb is on the same device as projection before calling it
298
  return self.projection(emb)
299
 
300
  class HierarchyClassifierHead(nn.Module):