aw1app commited on
Commit
adf1b6d
·
1 Parent(s): ed14d35

Final fix: Exact checkpoint dimensions (action=6, not 7)

Browse files
Files changed (2) hide show
  1. patched_factory.py +16 -37
  2. policy_head.py +14 -32
patched_factory.py CHANGED
@@ -1,4 +1,4 @@
1
- """Factory with correct dimensions"""
2
  import sys
3
  import torch
4
  import torch.nn as nn
@@ -9,42 +9,31 @@ from huggingface_hub import hf_hub_download
9
  from policy_head import LSTMPolicyHead
10
 
11
  class RoboFlamingoWithPolicy(nn.Module):
12
- """Wraps OpenFlamingo + LSTM Policy Head"""
13
  def __init__(self, base_model, policy_head):
14
  super().__init__()
15
  self.base_model = base_model
16
  self.policy_head = policy_head
17
-
18
  self.vision_encoder = base_model.vision_encoder
19
  self.lang_encoder = base_model.lang_encoder
20
 
21
  def forward(self, vision_x, lang_x, attention_mask=None):
22
- # Get embeddings with hidden states
23
  output = self.base_model(
24
  vision_x=vision_x,
25
  lang_x=lang_x,
26
  attention_mask=attention_mask
27
  )
28
 
29
- # Get hidden states if available
30
  if hasattr(output, 'hidden_states') and output.hidden_states is not None:
31
  embeddings = output.hidden_states[-1]
32
  else:
33
- # Fallback: use logits (not ideal)
34
  embeddings = output.logits
35
 
36
- # Apply policy head
37
  actions, gripper, _ = self.policy_head(embeddings)
38
 
39
- return {
40
- 'actions': actions,
41
- 'gripper': gripper
42
- }
43
 
44
  def create_model_and_transforms(checkpoint_path=None):
45
- """Load RoboFlamingo"""
46
-
47
- print("📦 Creating base OpenFlamingo...")
48
  base_model, image_processor, tokenizer = create_base(
49
  clip_vision_encoder_path="ViT-L-14",
50
  clip_vision_encoder_pretrained="openai",
@@ -53,66 +42,56 @@ def create_model_and_transforms(checkpoint_path=None):
53
  cross_attn_every_n_layers=4,
54
  )
55
 
56
- print("✅ Base created")
57
-
58
- # Enable hidden states
59
  if hasattr(base_model.lang_encoder, 'config'):
60
  base_model.lang_encoder.config.output_hidden_states = True
61
 
62
- # Create policy head with CORRECT dimensions from checkpoint
63
- print("🔨 Creating policy head (4-layer LSTM, hidden=1024)...")
64
  policy_head = LSTMPolicyHead(
65
  input_dim=2048,
66
  hidden_dim=1024,
67
- num_layers=4,
68
- action_dim=7
69
  )
70
 
71
  model = RoboFlamingoWithPolicy(base_model, policy_head)
72
- print("✅ Policy head attached")
73
 
74
  if checkpoint_path:
75
- print("📥 Downloading checkpoint...")
76
  ckpt_file = hf_hub_download(
77
  repo_id="robovlms/RoboFlamingo",
78
  filename="checkpoint_gripper_post_hist_1_aug_10_4_traj_cons_ws_12_mpt_3b_4.pth",
79
  repo_type="model"
80
  )
81
 
82
- print("📥 Loading...")
83
  checkpoint = torch.load(ckpt_file, map_location='cpu')
84
  state_dict = checkpoint.get('model_state_dict', checkpoint)
85
 
86
- # Map keys
87
  new_state_dict = {}
88
 
89
  for key, value in state_dict.items():
90
- # Map policy head
91
  if 'action_head.rnn' in key:
92
  new_key = key.replace('module.action_head.rnn', 'policy_head.lstm')
93
  new_state_dict[new_key] = value
94
  elif 'action_head.actions.mlp' in key:
95
- # Map actions MLP layers
96
  new_key = key.replace('module.action_head.actions.mlp', 'policy_head.action_head')
97
  new_state_dict[new_key] = value
98
  elif 'action_head.gripper.mlp' in key:
99
- # Map gripper MLP layers
100
  new_key = key.replace('module.action_head.gripper.mlp', 'policy_head.gripper_head')
101
  new_state_dict[new_key] = value
 
 
 
 
 
 
 
102
  else:
103
- # Base model keys
104
  new_key = key.replace('module.', 'base_model.')
105
  new_state_dict[new_key] = value
106
 
107
- # Load (strict=False to ignore size mismatches for vocab)
108
  missing, unexpected = model.load_state_dict(new_state_dict, strict=False)
109
 
110
- print(f"✅ Loaded (missing: {len(missing)}, unexpected: {len(unexpected)})")
111
-
112
- # Show any remaining mismatches
113
- if len(missing) > 0:
114
- print(f" Missing keys: {list(missing)[:3]}")
115
- if len(unexpected) > 0:
116
- print(f" Unexpected keys: {list(unexpected)[:3]}")
117
 
118
  return model, image_processor, tokenizer
 
1
+ """Factory - load checkpoint with exact dimensions"""
2
  import sys
3
  import torch
4
  import torch.nn as nn
 
9
  from policy_head import LSTMPolicyHead
10
 
11
  class RoboFlamingoWithPolicy(nn.Module):
 
12
  def __init__(self, base_model, policy_head):
13
  super().__init__()
14
  self.base_model = base_model
15
  self.policy_head = policy_head
 
16
  self.vision_encoder = base_model.vision_encoder
17
  self.lang_encoder = base_model.lang_encoder
18
 
19
  def forward(self, vision_x, lang_x, attention_mask=None):
 
20
  output = self.base_model(
21
  vision_x=vision_x,
22
  lang_x=lang_x,
23
  attention_mask=attention_mask
24
  )
25
 
 
26
  if hasattr(output, 'hidden_states') and output.hidden_states is not None:
27
  embeddings = output.hidden_states[-1]
28
  else:
 
29
  embeddings = output.logits
30
 
 
31
  actions, gripper, _ = self.policy_head(embeddings)
32
 
33
+ return {'actions': actions, 'gripper': gripper}
 
 
 
34
 
35
  def create_model_and_transforms(checkpoint_path=None):
36
+ print("📦 Creating base...")
 
 
37
  base_model, image_processor, tokenizer = create_base(
38
  clip_vision_encoder_path="ViT-L-14",
39
  clip_vision_encoder_pretrained="openai",
 
42
  cross_attn_every_n_layers=4,
43
  )
44
 
 
 
 
45
  if hasattr(base_model.lang_encoder, 'config'):
46
  base_model.lang_encoder.config.output_hidden_states = True
47
 
48
+ print("🔨 Creating policy head...")
 
49
  policy_head = LSTMPolicyHead(
50
  input_dim=2048,
51
  hidden_dim=1024,
52
+ num_layers=4
 
53
  )
54
 
55
  model = RoboFlamingoWithPolicy(base_model, policy_head)
56
+ print("✅ Model ready")
57
 
58
  if checkpoint_path:
59
+ print("📥 Loading checkpoint...")
60
  ckpt_file = hf_hub_download(
61
  repo_id="robovlms/RoboFlamingo",
62
  filename="checkpoint_gripper_post_hist_1_aug_10_4_traj_cons_ws_12_mpt_3b_4.pth",
63
  repo_type="model"
64
  )
65
 
 
66
  checkpoint = torch.load(ckpt_file, map_location='cpu')
67
  state_dict = checkpoint.get('model_state_dict', checkpoint)
68
 
 
69
  new_state_dict = {}
70
 
71
  for key, value in state_dict.items():
 
72
  if 'action_head.rnn' in key:
73
  new_key = key.replace('module.action_head.rnn', 'policy_head.lstm')
74
  new_state_dict[new_key] = value
75
  elif 'action_head.actions.mlp' in key:
 
76
  new_key = key.replace('module.action_head.actions.mlp', 'policy_head.action_head')
77
  new_state_dict[new_key] = value
78
  elif 'action_head.gripper.mlp' in key:
 
79
  new_key = key.replace('module.action_head.gripper.mlp', 'policy_head.gripper_head')
80
  new_state_dict[new_key] = value
81
+ elif 'transformer.wte.weight' in key:
82
+ # Handle vocab size mismatch (50280 -> 50281)
83
+ # Pad with zeros for the extra token
84
+ if value.shape[0] == 50280:
85
+ value = torch.cat([value, torch.zeros(1, value.shape[1])], dim=0)
86
+ new_key = key.replace('module.', 'base_model.')
87
+ new_state_dict[new_key] = value
88
  else:
 
89
  new_key = key.replace('module.', 'base_model.')
90
  new_state_dict[new_key] = value
91
 
 
92
  missing, unexpected = model.load_state_dict(new_state_dict, strict=False)
93
 
94
+ print(f"✅ Checkpoint loaded!")
95
+ print(f" Missing: {len(missing)}, Unexpected: {len(unexpected)}")
 
 
 
 
 
96
 
97
  return model, image_processor, tokenizer
policy_head.py CHANGED
@@ -1,21 +1,13 @@
1
- """LSTM Policy Head with correct dimensions from checkpoint"""
2
  import torch
3
  import torch.nn as nn
4
 
5
  class LSTMPolicyHead(nn.Module):
6
- """
7
- LSTM-based policy head from RoboFlamingo checkpoint.
8
- Dimensions extracted from checkpoint weights.
9
- """
10
- def __init__(self, input_dim=2048, hidden_dim=1024, num_layers=4, action_dim=7):
11
  super().__init__()
12
 
13
- self.input_dim = input_dim
14
- self.hidden_dim = hidden_dim
15
- self.num_layers = num_layers
16
- self.action_dim = action_dim
17
-
18
- # LSTM with 4 layers, hidden_dim=1024
19
  self.lstm = nn.LSTM(
20
  input_size=input_dim,
21
  hidden_size=hidden_dim,
@@ -23,40 +15,30 @@ class LSTMPolicyHead(nn.Module):
23
  batch_first=True
24
  )
25
 
26
- # Action MLP (4 layers based on checkpoint)
27
  self.action_head = nn.Sequential(
28
- nn.Linear(hidden_dim, 512),
29
  nn.ReLU(),
30
- nn.Linear(512, 256),
31
  nn.ReLU(),
32
- nn.Linear(256, 128),
33
  nn.ReLU(),
34
- nn.Linear(128, action_dim)
35
  )
36
 
37
- # Gripper MLP (4 layers based on checkpoint)
38
  self.gripper_head = nn.Sequential(
39
- nn.Linear(hidden_dim, 512),
40
  nn.ReLU(),
41
- nn.Linear(512, 256),
42
  nn.ReLU(),
43
- nn.Linear(256, 128),
44
  nn.ReLU(),
45
- nn.Linear(128, 1),
46
  nn.Sigmoid()
47
  )
48
 
49
  def forward(self, x, hidden=None):
50
- """
51
- Args:
52
- x: (batch, seq_len, input_dim)
53
- hidden: tuple of (h_0, c_0)
54
-
55
- Returns:
56
- actions: (batch, seq_len, action_dim)
57
- gripper: (batch, seq_len, 1)
58
- hidden: tuple of (h_n, c_n)
59
- """
60
  # LSTM
61
  lstm_out, hidden = self.lstm(x, hidden)
62
 
 
1
+ """LSTM Policy Head - EXACT checkpoint dimensions"""
2
  import torch
3
  import torch.nn as nn
4
 
5
  class LSTMPolicyHead(nn.Module):
6
+ """Exact architecture from RoboFlamingo checkpoint"""
7
+ def __init__(self, input_dim=2048, hidden_dim=1024, num_layers=4):
 
 
 
8
  super().__init__()
9
 
10
+ # LSTM: 4 layers, hidden=1024
 
 
 
 
 
11
  self.lstm = nn.LSTM(
12
  input_size=input_dim,
13
  hidden_size=hidden_dim,
 
15
  batch_first=True
16
  )
17
 
18
+ # Action MLP: 1024 -> 1024 -> 512 -> 256 -> 6
19
  self.action_head = nn.Sequential(
20
+ nn.Linear(1024, 1024),
21
  nn.ReLU(),
22
+ nn.Linear(1024, 512),
23
  nn.ReLU(),
24
+ nn.Linear(512, 256),
25
  nn.ReLU(),
26
+ nn.Linear(256, 6) # 6 outputs (position + rotation, no gripper here)
27
  )
28
 
29
+ # Gripper MLP: 1024 -> 1024 -> 512 -> 256 -> 1
30
  self.gripper_head = nn.Sequential(
31
+ nn.Linear(1024, 1024),
32
  nn.ReLU(),
33
+ nn.Linear(1024, 512),
34
  nn.ReLU(),
35
+ nn.Linear(512, 256),
36
  nn.ReLU(),
37
+ nn.Linear(256, 1),
38
  nn.Sigmoid()
39
  )
40
 
41
  def forward(self, x, hidden=None):
 
 
 
 
 
 
 
 
 
 
42
  # LSTM
43
  lstm_out, hidden = self.lstm(x, hidden)
44