ar27111994 commited on
Commit
295a4bf
·
verified ·
1 Parent(s): a15cf53

Upload lewm_model.py

Browse files
Files changed (1) hide show
  1. lewm_model.py +4 -2
lewm_model.py CHANGED
@@ -302,6 +302,8 @@ class LeWorldModel(nn.Module):
302
  self.action_encoder = action_encoder
303
  self.projector = projector or nn.Identity()
304
  self.pred_proj = pred_proj or nn.Identity()
 
 
305
 
306
  def encode(self, pixels: torch.Tensor) -> torch.Tensor:
307
  """
@@ -359,8 +361,8 @@ class LeWorldModel(nn.Module):
359
  pred_loss = (pred_emb[:, :-1] - emb[:, 1:history_size]).pow(2).mean()
360
 
361
  # SIGReg on step-wise embeddings (transpose to (T, B, D))
362
- sigreg = SIGReg()
363
- sigreg_loss = sigreg(emb.transpose(0, 1))
364
 
365
  loss = pred_loss + sigreg_weight * sigreg_loss
366
  return {
 
302
  self.action_encoder = action_encoder
303
  self.projector = projector or nn.Identity()
304
  self.pred_proj = pred_proj or nn.Identity()
305
+ # SIGReg registered as a submodule so model.to(device) moves its buffers
306
+ self.sigreg = SIGReg()
307
 
308
  def encode(self, pixels: torch.Tensor) -> torch.Tensor:
309
  """
 
361
  pred_loss = (pred_emb[:, :-1] - emb[:, 1:history_size]).pow(2).mean()
362
 
363
  # SIGReg on step-wise embeddings (transpose to (T, B, D))
364
+ # self.sigreg is a registered submodule so it follows model.to(device)
365
+ sigreg_loss = self.sigreg(emb.transpose(0, 1))
366
 
367
  loss = pred_loss + sigreg_weight * sigreg_loss
368
  return {