Update modeling_xvla.py
Browse files- modeling_xvla.py +2 -9
modeling_xvla.py
CHANGED
|
@@ -55,23 +55,20 @@ class XVLA(PreTrainedModel):
|
|
| 55 |
# Core settings
|
| 56 |
self.num_actions: int = config.num_actions
|
| 57 |
self.use_proprio: bool = config.use_proprio
|
| 58 |
-
|
| 59 |
# Action space (dimensions + hooks)
|
| 60 |
self.action_space = build_action_space(config.action_mode.lower())
|
| 61 |
dim_action = self.action_space.dim_action
|
| 62 |
dim_proprio = getattr(self.action_space, "dim_proprio", dim_action)
|
| 63 |
|
| 64 |
# Florence2 backbone (encoder only)
|
| 65 |
-
self.vlm = Florence2ForConditionalGeneration(config.florence_config)
|
| 66 |
if hasattr(self.vlm, "language_model"):
|
| 67 |
lm = self.vlm.language_model
|
| 68 |
if hasattr(lm, "model") and hasattr(lm.model, "decoder"):
|
| 69 |
del lm.model.decoder
|
| 70 |
if hasattr(lm, "lm_head"):
|
| 71 |
del lm.lm_head
|
| 72 |
-
# ⚠️ VERY IMPORTANT: disable Florence2's tie_weights hooks to avoid decoder access
|
| 73 |
-
if hasattr(self.vlm, "tie_weights"):
|
| 74 |
-
self.vlm.tie_weights = lambda *a, **k: None
|
| 75 |
|
| 76 |
projection_dim = getattr(self.vlm.config, "projection_dim", None)
|
| 77 |
if projection_dim is None:
|
|
@@ -96,10 +93,6 @@ class XVLA(PreTrainedModel):
|
|
| 96 |
# Deferred FastAPI app
|
| 97 |
self.app: FastAPI | None = None
|
| 98 |
|
| 99 |
-
def tie_weights(self):
|
| 100 |
-
"""Disable automatic weight tying (Florence is encoder-only)."""
|
| 101 |
-
return
|
| 102 |
-
|
| 103 |
# ============================= Florence2 encoder =============================
|
| 104 |
def forward_vlm(
|
| 105 |
self,
|
|
|
|
| 55 |
# Core settings
|
| 56 |
self.num_actions: int = config.num_actions
|
| 57 |
self.use_proprio: bool = config.use_proprio
|
| 58 |
+
self.action_mode: str = config.action_mode.lower()
|
| 59 |
# Action space (dimensions + hooks)
|
| 60 |
self.action_space = build_action_space(config.action_mode.lower())
|
| 61 |
dim_action = self.action_space.dim_action
|
| 62 |
dim_proprio = getattr(self.action_space, "dim_proprio", dim_action)
|
| 63 |
|
| 64 |
# Florence2 backbone (encoder only)
|
| 65 |
+
self.vlm = Florence2ForConditionalGeneration(config.florence_config).to(torch.float32)
|
| 66 |
if hasattr(self.vlm, "language_model"):
|
| 67 |
lm = self.vlm.language_model
|
| 68 |
if hasattr(lm, "model") and hasattr(lm.model, "decoder"):
|
| 69 |
del lm.model.decoder
|
| 70 |
if hasattr(lm, "lm_head"):
|
| 71 |
del lm.lm_head
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
projection_dim = getattr(self.vlm.config, "projection_dim", None)
|
| 74 |
if projection_dim is None:
|
|
|
|
| 93 |
# Deferred FastAPI app
|
| 94 |
self.app: FastAPI | None = None
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
# ============================= Florence2 encoder =============================
|
| 97 |
def forward_vlm(
|
| 98 |
self,
|