Spaces:
Runtime error
Runtime error
Force output_hidden_states=True in lang_encoder forward call
Browse files- patched_factory.py +28 -49
patched_factory.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""Factory with
|
| 2 |
import sys
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
|
@@ -17,37 +17,35 @@ class RoboFlamingoWithPolicy(nn.Module):
|
|
| 17 |
self.lang_encoder = base_model.lang_encoder
|
| 18 |
|
| 19 |
def forward(self, vision_x, lang_x, attention_mask=None):
|
| 20 |
-
#
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
)
|
| 26 |
|
| 27 |
-
#
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
embeddings =
|
| 34 |
-
print(f"
|
| 35 |
-
elif hasattr(output, 'last_hidden_state'):
|
| 36 |
-
embeddings = output.last_hidden_state
|
| 37 |
-
print(f" Using last_hidden_state: {embeddings.shape}")
|
| 38 |
else:
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
print(f" ⚠️ No hidden_states! Output type: {type(output)}")
|
| 42 |
-
print(f" Output has: {[k for k in dir(output) if not k.startswith('_')]}")
|
| 43 |
-
|
| 44 |
-
# Try to get from logits by taking last layer
|
| 45 |
-
if hasattr(output, 'logits'):
|
| 46 |
-
# Logits are (batch, seq, vocab_size=50281)
|
| 47 |
-
# We need (batch, seq, hidden=2048)
|
| 48 |
-
# This won't work - we need proper hidden states
|
| 49 |
-
print(f" ❌ Only have logits: {output.logits.shape}")
|
| 50 |
-
raise RuntimeError("Model not outputting hidden_states! Need to configure model properly.")
|
| 51 |
|
| 52 |
# Apply policy head
|
| 53 |
actions, gripper, _ = self.policy_head(embeddings)
|
|
@@ -55,7 +53,7 @@ class RoboFlamingoWithPolicy(nn.Module):
|
|
| 55 |
return {'actions': actions, 'gripper': gripper}
|
| 56 |
|
| 57 |
def create_model_and_transforms(checkpoint_path=None):
|
| 58 |
-
print("📦 Creating base
|
| 59 |
base_model, image_processor, tokenizer = create_base(
|
| 60 |
clip_vision_encoder_path="ViT-L-14",
|
| 61 |
clip_vision_encoder_pretrained="openai",
|
|
@@ -64,25 +62,6 @@ def create_model_and_transforms(checkpoint_path=None):
|
|
| 64 |
cross_attn_every_n_layers=4,
|
| 65 |
)
|
| 66 |
|
| 67 |
-
# CRITICAL: Enable hidden states output
|
| 68 |
-
print("🔧 Enabling hidden states output...")
|
| 69 |
-
|
| 70 |
-
# Try multiple ways to enable hidden states
|
| 71 |
-
if hasattr(base_model, 'lang_encoder'):
|
| 72 |
-
if hasattr(base_model.lang_encoder, 'config'):
|
| 73 |
-
base_model.lang_encoder.config.output_hidden_states = True
|
| 74 |
-
print(" ✅ Set via lang_encoder.config")
|
| 75 |
-
|
| 76 |
-
if hasattr(base_model.lang_encoder, 'transformer'):
|
| 77 |
-
if hasattr(base_model.lang_encoder.transformer, 'config'):
|
| 78 |
-
base_model.lang_encoder.transformer.config.output_hidden_states = True
|
| 79 |
-
print(" ✅ Set via transformer.config")
|
| 80 |
-
|
| 81 |
-
# Also try setting on the model itself
|
| 82 |
-
if hasattr(base_model, 'config'):
|
| 83 |
-
base_model.config.output_hidden_states = True
|
| 84 |
-
print(" ✅ Set on base_model.config")
|
| 85 |
-
|
| 86 |
print("🔨 Creating policy head...")
|
| 87 |
policy_head = LSTMPolicyHead(
|
| 88 |
input_dim=2048,
|
|
|
|
| 1 |
+
"""Factory with forced hidden states"""
|
| 2 |
import sys
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
|
|
|
| 17 |
self.lang_encoder = base_model.lang_encoder
|
| 18 |
|
| 19 |
def forward(self, vision_x, lang_x, attention_mask=None):
|
| 20 |
+
# Get the internal model
|
| 21 |
+
# OpenFlamingo wraps the language model, we need to call it with output_hidden_states
|
| 22 |
+
|
| 23 |
+
# The base_model is Flamingo, which has lang_encoder
|
| 24 |
+
# We need to get embeddings from the language encoder
|
|
|
|
| 25 |
|
| 26 |
+
# First, process vision
|
| 27 |
+
if vision_x is not None:
|
| 28 |
+
# Vision encoder
|
| 29 |
+
vision_features = self.base_model._encode_vision_x(vision_x=vision_x)
|
| 30 |
+
else:
|
| 31 |
+
vision_features = None
|
| 32 |
+
|
| 33 |
+
# Now call language model with output_hidden_states=True
|
| 34 |
+
# The lang_encoder should support this parameter
|
| 35 |
+
lang_output = self.base_model.lang_encoder(
|
| 36 |
+
input_ids=lang_x,
|
| 37 |
+
attention_mask=attention_mask,
|
| 38 |
+
output_hidden_states=True, # FORCE hidden states output!
|
| 39 |
+
return_dict=True
|
| 40 |
+
)
|
| 41 |
|
| 42 |
+
# Now we should have hidden states
|
| 43 |
+
if hasattr(lang_output, 'hidden_states') and lang_output.hidden_states is not None:
|
| 44 |
+
embeddings = lang_output.hidden_states[-1]
|
| 45 |
+
print(f" ✅ Got hidden states: {embeddings.shape}")
|
|
|
|
|
|
|
|
|
|
| 46 |
else:
|
| 47 |
+
print(f" ❌ Still no hidden states!")
|
| 48 |
+
raise RuntimeError("Cannot get hidden states from language model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
# Apply policy head
|
| 51 |
actions, gripper, _ = self.policy_head(embeddings)
|
|
|
|
| 53 |
return {'actions': actions, 'gripper': gripper}
|
| 54 |
|
| 55 |
def create_model_and_transforms(checkpoint_path=None):
|
| 56 |
+
print("📦 Creating base...")
|
| 57 |
base_model, image_processor, tokenizer = create_base(
|
| 58 |
clip_vision_encoder_path="ViT-L-14",
|
| 59 |
clip_vision_encoder_pretrained="openai",
|
|
|
|
| 62 |
cross_attn_every_n_layers=4,
|
| 63 |
)
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
print("🔨 Creating policy head...")
|
| 66 |
policy_head = LSTMPolicyHead(
|
| 67 |
input_dim=2048,
|