Upload lewm_model.py
Browse files- lewm_model.py +4 -2
lewm_model.py
CHANGED
|
@@ -302,6 +302,8 @@ class LeWorldModel(nn.Module):
|
|
| 302 |
self.action_encoder = action_encoder
|
| 303 |
self.projector = projector or nn.Identity()
|
| 304 |
self.pred_proj = pred_proj or nn.Identity()
|
|
|
|
|
|
|
| 305 |
|
| 306 |
def encode(self, pixels: torch.Tensor) -> torch.Tensor:
|
| 307 |
"""
|
|
@@ -359,8 +361,8 @@ class LeWorldModel(nn.Module):
|
|
| 359 |
pred_loss = (pred_emb[:, :-1] - emb[:, 1:history_size]).pow(2).mean()
|
| 360 |
|
| 361 |
# SIGReg on step-wise embeddings (transpose to (T, B, D))
|
| 362 |
-
sigreg
|
| 363 |
-
sigreg_loss = sigreg(emb.transpose(0, 1))
|
| 364 |
|
| 365 |
loss = pred_loss + sigreg_weight * sigreg_loss
|
| 366 |
return {
|
|
|
|
| 302 |
self.action_encoder = action_encoder
|
| 303 |
self.projector = projector or nn.Identity()
|
| 304 |
self.pred_proj = pred_proj or nn.Identity()
|
| 305 |
+
# SIGReg registered as a submodule so model.to(device) moves its buffers
|
| 306 |
+
self.sigreg = SIGReg()
|
| 307 |
|
| 308 |
def encode(self, pixels: torch.Tensor) -> torch.Tensor:
|
| 309 |
"""
|
|
|
|
| 361 |
pred_loss = (pred_emb[:, :-1] - emb[:, 1:history_size]).pow(2).mean()
|
| 362 |
|
| 363 |
# SIGReg on step-wise embeddings (transpose to (T, B, D))
|
| 364 |
+
# self.sigreg is a registered submodule so it follows model.to(device)
|
| 365 |
+
sigreg_loss = self.sigreg(emb.transpose(0, 1))
|
| 366 |
|
| 367 |
loss = pred_loss + sigreg_weight * sigreg_loss
|
| 368 |
return {
|