aw1app commited on
Commit
7cc7ee4
·
1 Parent(s): 44a0f72

Force output_hidden_states=True in lang_encoder forward call

Browse files
Files changed (1) hide show
  1. patched_factory.py +28 -49
patched_factory.py CHANGED
@@ -1,4 +1,4 @@
1
- """Factory with proper hidden state extraction"""
2
  import sys
3
  import torch
4
  import torch.nn as nn
@@ -17,37 +17,35 @@ class RoboFlamingoWithPolicy(nn.Module):
17
  self.lang_encoder = base_model.lang_encoder
18
 
19
  def forward(self, vision_x, lang_x, attention_mask=None):
20
- # Call base model
21
- output = self.base_model(
22
- vision_x=vision_x,
23
- lang_x=lang_x,
24
- attention_mask=attention_mask
25
- )
26
 
27
- # CRITICAL: We need hidden states, not logits!
28
- # hidden_states should be enabled via config
29
- embeddings = None
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- if hasattr(output, 'hidden_states') and output.hidden_states is not None and len(output.hidden_states) > 0:
32
- # Use last layer hidden states
33
- embeddings = output.hidden_states[-1]
34
- print(f" Using hidden_states: {embeddings.shape}")
35
- elif hasattr(output, 'last_hidden_state'):
36
- embeddings = output.last_hidden_state
37
- print(f" Using last_hidden_state: {embeddings.shape}")
38
  else:
39
- # Fallback: logits have wrong dimension
40
- # We need to access the language model's actual hidden states
41
- print(f" ⚠️ No hidden_states! Output type: {type(output)}")
42
- print(f" Output has: {[k for k in dir(output) if not k.startswith('_')]}")
43
-
44
- # Try to get from logits by taking last layer
45
- if hasattr(output, 'logits'):
46
- # Logits are (batch, seq, vocab_size=50281)
47
- # We need (batch, seq, hidden=2048)
48
- # This won't work - we need proper hidden states
49
- print(f" ❌ Only have logits: {output.logits.shape}")
50
- raise RuntimeError("Model not outputting hidden_states! Need to configure model properly.")
51
 
52
  # Apply policy head
53
  actions, gripper, _ = self.policy_head(embeddings)
@@ -55,7 +53,7 @@ class RoboFlamingoWithPolicy(nn.Module):
55
  return {'actions': actions, 'gripper': gripper}
56
 
57
  def create_model_and_transforms(checkpoint_path=None):
58
- print("📦 Creating base OpenFlamingo...")
59
  base_model, image_processor, tokenizer = create_base(
60
  clip_vision_encoder_path="ViT-L-14",
61
  clip_vision_encoder_pretrained="openai",
@@ -64,25 +62,6 @@ def create_model_and_transforms(checkpoint_path=None):
64
  cross_attn_every_n_layers=4,
65
  )
66
 
67
- # CRITICAL: Enable hidden states output
68
- print("🔧 Enabling hidden states output...")
69
-
70
- # Try multiple ways to enable hidden states
71
- if hasattr(base_model, 'lang_encoder'):
72
- if hasattr(base_model.lang_encoder, 'config'):
73
- base_model.lang_encoder.config.output_hidden_states = True
74
- print(" ✅ Set via lang_encoder.config")
75
-
76
- if hasattr(base_model.lang_encoder, 'transformer'):
77
- if hasattr(base_model.lang_encoder.transformer, 'config'):
78
- base_model.lang_encoder.transformer.config.output_hidden_states = True
79
- print(" ✅ Set via transformer.config")
80
-
81
- # Also try setting on the model itself
82
- if hasattr(base_model, 'config'):
83
- base_model.config.output_hidden_states = True
84
- print(" ✅ Set on base_model.config")
85
-
86
  print("🔨 Creating policy head...")
87
  policy_head = LSTMPolicyHead(
88
  input_dim=2048,
 
1
+ """Factory with forced hidden states"""
2
  import sys
3
  import torch
4
  import torch.nn as nn
 
17
  self.lang_encoder = base_model.lang_encoder
18
 
19
  def forward(self, vision_x, lang_x, attention_mask=None):
20
+ # Get the internal model
21
+ # OpenFlamingo wraps the language model, we need to call it with output_hidden_states
22
+
23
+ # The base_model is Flamingo, which has lang_encoder
24
+ # We need to get embeddings from the language encoder
 
25
 
26
+ # First, process vision
27
+ if vision_x is not None:
28
+ # Vision encoder
29
+ vision_features = self.base_model._encode_vision_x(vision_x=vision_x)
30
+ else:
31
+ vision_features = None
32
+
33
+ # Now call language model with output_hidden_states=True
34
+ # The lang_encoder should support this parameter
35
+ lang_output = self.base_model.lang_encoder(
36
+ input_ids=lang_x,
37
+ attention_mask=attention_mask,
38
+ output_hidden_states=True, # FORCE hidden states output!
39
+ return_dict=True
40
+ )
41
 
42
+ # Now we should have hidden states
43
+ if hasattr(lang_output, 'hidden_states') and lang_output.hidden_states is not None:
44
+ embeddings = lang_output.hidden_states[-1]
45
+ print(f" Got hidden states: {embeddings.shape}")
 
 
 
46
  else:
47
+ print(f" ❌ Still no hidden states!")
48
+ raise RuntimeError("Cannot get hidden states from language model")
 
 
 
 
 
 
 
 
 
 
49
 
50
  # Apply policy head
51
  actions, gripper, _ = self.policy_head(embeddings)
 
53
  return {'actions': actions, 'gripper': gripper}
54
 
55
  def create_model_and_transforms(checkpoint_path=None):
56
+ print("📦 Creating base...")
57
  base_model, image_processor, tokenizer = create_base(
58
  clip_vision_encoder_path="ViT-L-14",
59
  clip_vision_encoder_pretrained="openai",
 
62
  cross_attn_every_n_layers=4,
63
  )
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  print("🔨 Creating policy head...")
66
  policy_head = LSTMPolicyHead(
67
  input_dim=2048,