2toINF commited on
Commit
bb9981b
·
verified ·
1 Parent(s): d14e104

Update modeling_xvla.py

Browse files
Files changed (1) hide show
  1. modeling_xvla.py +2 -9
modeling_xvla.py CHANGED
@@ -55,23 +55,20 @@ class XVLA(PreTrainedModel):
55
  # Core settings
56
  self.num_actions: int = config.num_actions
57
  self.use_proprio: bool = config.use_proprio
58
-
59
  # Action space (dimensions + hooks)
60
  self.action_space = build_action_space(config.action_mode.lower())
61
  dim_action = self.action_space.dim_action
62
  dim_proprio = getattr(self.action_space, "dim_proprio", dim_action)
63
 
64
  # Florence2 backbone (encoder only)
65
- self.vlm = Florence2ForConditionalGeneration(config.florence_config)
66
  if hasattr(self.vlm, "language_model"):
67
  lm = self.vlm.language_model
68
  if hasattr(lm, "model") and hasattr(lm.model, "decoder"):
69
  del lm.model.decoder
70
  if hasattr(lm, "lm_head"):
71
  del lm.lm_head
72
- # ⚠️ VERY IMPORTANT: disable Florence2's tie_weights hooks to avoid decoder access
73
- if hasattr(self.vlm, "tie_weights"):
74
- self.vlm.tie_weights = lambda *a, **k: None
75
 
76
  projection_dim = getattr(self.vlm.config, "projection_dim", None)
77
  if projection_dim is None:
@@ -96,10 +93,6 @@ class XVLA(PreTrainedModel):
96
  # Deferred FastAPI app
97
  self.app: FastAPI | None = None
98
 
99
- def tie_weights(self):
100
- """Disable automatic weight tying (Florence is encoder-only)."""
101
- return
102
-
103
  # ============================= Florence2 encoder =============================
104
  def forward_vlm(
105
  self,
 
55
  # Core settings
56
  self.num_actions: int = config.num_actions
57
  self.use_proprio: bool = config.use_proprio
58
+ self.action_mode: str = config.action_mode.lower()
59
  # Action space (dimensions + hooks)
60
  self.action_space = build_action_space(config.action_mode.lower())
61
  dim_action = self.action_space.dim_action
62
  dim_proprio = getattr(self.action_space, "dim_proprio", dim_action)
63
 
64
  # Florence2 backbone (encoder only)
65
+ self.vlm = Florence2ForConditionalGeneration(config.florence_config).to(torch.float32)
66
  if hasattr(self.vlm, "language_model"):
67
  lm = self.vlm.language_model
68
  if hasattr(lm, "model") and hasattr(lm.model, "decoder"):
69
  del lm.model.decoder
70
  if hasattr(lm, "lm_head"):
71
  del lm.lm_head
 
 
 
72
 
73
  projection_dim = getattr(self.vlm.config, "projection_dim", None)
74
  if projection_dim is None:
 
93
  # Deferred FastAPI app
94
  self.app: FastAPI | None = None
95
 
 
 
 
 
96
  # ============================= Florence2 encoder =============================
97
  def forward_vlm(
98
  self,