aw1app commited on
Commit
ff95b80
·
1 Parent(s): 15c4e01

Deploy with src only (no large eval data)

Browse files
README.md CHANGED
@@ -9,26 +9,10 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- # RoboFlamingo Demo 🤖
13
 
14
- Interactive demo with REAL model inference!
15
 
16
- ## Features
17
- - 📸 Upload your own robot images
18
- - 💬 Enter natural language instructions
19
- - 🎯 Get real model predictions
20
- - 📊 3D trajectory visualization
21
 
22
- ## Usage
23
- 1. Upload third-person view image
24
- 2. Upload gripper view image
25
- 3. Enter instruction (e.g., "Pick up the red block")
26
- 4. Click "Predict Actions"
27
-
28
- ## Requirements
29
- ⚠️ **Requires GPU**: Enable T4 GPU in Space settings for real model.
30
- Without GPU, runs in simulation mode.
31
-
32
- ## Resources
33
- - [Paper](https://arxiv.org/abs/2311.01378)
34
- - [Code](https://github.com/RoboFlamingo/RoboFlamingo)
 
9
  pinned: false
10
  ---
11
 
12
+ # RoboFlamingo Interactive Demo 🤖
13
 
14
+ Upload images and get real predictions!
15
 
16
+ ⚠️ Enable T4 GPU in Settings for real model.
 
 
 
 
17
 
18
+ [Paper](https://arxiv.org/abs/2311.01378) | [Code](https://github.com/RoboFlamingo/RoboFlamingo)
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,8 +1,4 @@
1
- """
2
- RoboFlamingo Interactive Demo - Real Model
3
- Upload your images and get real predictions!
4
- """
5
-
6
  import gradio as gr
7
  import torch
8
  import numpy as np
@@ -11,31 +7,18 @@ import matplotlib.pyplot as plt
11
  from io import BytesIO
12
  import sys
13
 
14
- # Add OpenFlamingo to path
15
  sys.path.insert(0, '/home/user/app/open_flamingo/src')
16
 
17
- print("=" * 70)
18
- print("🚀 INITIALIZING ROBOFLAMINGO")
19
- print("=" * 70)
20
-
21
- # ============================================================================
22
- # LOAD MODEL
23
- # ============================================================================
24
 
25
  MODEL_LOADED = False
26
- model = None
27
- image_processor = None
28
- tokenizer = None
29
 
30
  try:
31
- print("📦 Importing OpenFlamingo...")
32
  from factory import create_model_and_transforms
33
 
34
- print("✅ Import successful!")
35
- print("🔧 Loading model (2-3 minutes on first run)...")
36
-
37
  device = "cuda" if torch.cuda.is_available() else "cpu"
38
- print(f"📍 Device: {device}")
39
 
40
  model, image_processor, tokenizer = create_model_and_transforms(
41
  clip_vision_encoder_path="ViT-L-14",
@@ -46,228 +29,134 @@ try:
46
  decoder_type='lstm',
47
  )
48
 
49
- model.to(device)
50
- model.eval()
51
  MODEL_LOADED = True
52
-
53
- print("=" * 70)
54
- print("✅ REAL MODEL LOADED!")
55
- print("=" * 70)
56
-
57
  except Exception as e:
58
- print(f" Model failed: {e}")
59
- print("⚠️ SIMULATION MODE")
60
  import traceback
61
  traceback.print_exc()
62
 
63
- # ============================================================================
64
- # VISUALIZATION
65
- # ============================================================================
66
-
67
- def create_trajectory_plot(actions):
68
- fig = plt.figure(figsize=(10, 8))
69
  ax = fig.add_subplot(111, projection='3d')
70
-
71
- x = np.cumsum([a['delta_x'] for a in actions])
72
- y = np.cumsum([a['delta_y'] for a in actions])
73
- z = np.cumsum([a['delta_z'] for a in actions])
74
-
75
- ax.plot(x, y, z, 'b-', linewidth=2, marker='o', markersize=6)
76
- ax.scatter([x[0]], [y[0]], [z[0]], c='green', s=100, label='Start')
77
- ax.scatter([x[-1]], [y[-1]], [z[-1]], c='red', s=100, label='End')
78
-
79
- ax.set_xlabel('X (m)')
80
- ax.set_ylabel('Y (m)')
81
- ax.set_zlabel('Z (m)')
82
- ax.set_title('Predicted Trajectory')
83
- ax.legend()
84
- ax.grid(True)
85
-
86
  buf = BytesIO()
87
  plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
88
- buf.seek(0)
89
- plt.close()
90
  return Image.open(buf)
91
 
92
- def create_gripper_timeline(gripper):
93
- fig, ax = plt.subplots(figsize=(12, 3))
94
- colors = ['green' if g==0 else 'red' for g in gripper]
95
- labels = ['OPEN' if g==0 else 'CLOSE' for g in gripper]
96
-
97
- ax.bar(range(len(gripper)), [1]*len(gripper), color=colors, alpha=0.7, edgecolor='black')
98
- for i, label in enumerate(labels):
99
- ax.text(i, 0.5, label, ha='center', va='center', fontweight='bold')
100
-
101
- ax.set_xlabel('Timestep')
102
- ax.set_title('Gripper Commands')
103
- ax.set_ylim(0, 1.2)
104
- ax.grid(True, alpha=0.3)
105
-
106
  buf = BytesIO()
107
  plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
108
- buf.seek(0)
109
- plt.close()
110
  return Image.open(buf)
111
 
112
- def format_table(actions):
113
- table = "| T | Δx | Δy | Δz | Quat |\n|--|--|--|--|--|\n"
114
- for a in actions:
115
- table += f"| {a['timestep']} | {a['delta_x']:.3f} | {a['delta_y']:.3f} | {a['delta_z']:.3f} | "
116
- table += f"({a['qw']:.2f},{a['qx']:.2f},{a['qy']:.2f},{a['qz']:.2f}) |\n"
117
- return table
118
-
119
- # ============================================================================
120
- # PREDICTION
121
- # ============================================================================
122
-
123
- def simulate(instruction):
124
- """Fallback simulation."""
125
- seed = sum(ord(c) for c in instruction) % 100
126
- np.random.seed(seed)
127
-
128
- actions = []
129
  for t in range(12):
130
- p = t / 12
131
- actions.append({
132
- 'timestep': t,
133
- 'delta_x': (0.05 + np.random.randn()*0.01) * p,
134
- 'delta_y': (0.02 + np.random.randn()*0.01) * p,
135
- 'delta_z': (-0.03 + np.random.randn()*0.01) * (1-p),
136
- 'qw': 0.99, 'qx': 0.01, 'qy': 0.01, 'qz': 0.01
137
- })
138
- return actions, [0]*6 + [1]*6
139
-
140
- def predict(instruction, third_img, grip_img):
141
- """Main prediction function."""
142
-
143
- if not instruction or not instruction.strip():
144
- return None, None, "", "❌ Enter instruction!"
145
- if third_img is None:
146
- return None, None, "", "❌ Upload third-person view!"
147
- if grip_img is None:
148
- return None, None, "", "❌ Upload gripper view!"
149
 
150
  try:
151
- if isinstance(third_img, np.ndarray):
152
- third_img = Image.fromarray(third_img)
153
- if isinstance(grip_img, np.ndarray):
154
- grip_img = Image.fromarray(grip_img)
155
 
156
  if not MODEL_LOADED:
157
- print("⚠️ Using simulation")
158
- actions, gripper = simulate(instruction)
159
- status = f"⚠️ SIMULATION\n\n{instruction}\n\nEnable GPU for real model."
160
  else:
161
- print(f"🚀 Real inference: {instruction}")
162
-
163
  with torch.no_grad():
164
- t1 = image_processor(third_img).unsqueeze(0).to(device)
165
- t2 = image_processor(grip_img).unsqueeze(0).to(device)
166
- vision_x = torch.stack([t1, t2], dim=1)
167
-
168
- tokens = tokenizer(instruction, return_tensors="pt", padding=True,
169
- truncation=True, max_length=512).to(device)
170
-
171
- outputs = model(vision_x=vision_x, lang_x=tokens['input_ids'],
172
- attention_mask=tokens.get('attention_mask'))
173
 
174
- if isinstance(outputs, dict):
175
- acts = outputs.get('actions', outputs.get('action'))
176
- grip = outputs.get('gripper', outputs.get('gripper_command'))
177
- elif isinstance(outputs, tuple):
178
- acts = outputs[0]
179
- grip = outputs[1] if len(outputs) > 1 else None
180
  else:
181
- acts = outputs
182
- grip = None
183
 
184
- if acts is not None:
185
- acts_np = acts[0].cpu().numpy()
186
- actions = []
187
- for t, a in enumerate(acts_np):
188
- if len(a) < 7:
189
- a = np.pad(a, (0, 7-len(a)))
190
- actions.append({
191
- 'timestep': t, 'delta_x': float(a[0]), 'delta_y': float(a[1]),
192
- 'delta_z': float(a[2]), 'qw': float(a[3]), 'qx': float(a[4]),
193
- 'qy': float(a[5]), 'qz': float(a[6])
194
- })
195
-
196
- if grip is not None:
197
- grip_np = grip[0].cpu().numpy()
198
- gripper = [int(g>0.5) if np.isscalar(g) else int(g[0]>0.5) for g in grip_np]
199
- else:
200
- gripper = [0]*len(actions)
201
-
202
- status = f"✅ REAL MODEL\n\n{instruction}\n\n{device}\n{len(actions)} timesteps"
203
- print(f"✅ Success! {len(actions)} timesteps")
204
  else:
205
- actions, gripper = simulate(instruction)
206
- status = f"⚠️ Unexpected output\n{instruction}"
207
 
208
- traj = create_trajectory_plot(actions)
209
- grip_plot = create_gripper_timeline(gripper)
210
- table = format_table(actions)
211
-
212
- return traj, grip_plot, table, status
213
 
 
214
  except Exception as e:
215
- print(f"Error: {e}")
216
  import traceback
217
  traceback.print_exc()
218
-
219
- actions, gripper = simulate(instruction)
220
- traj = create_trajectory_plot(actions)
221
- grip_plot = create_gripper_timeline(gripper)
222
- table = format_table(actions)
223
-
224
- return traj, grip_plot, table, f"❌ Error: {str(e)}"
225
-
226
- # ============================================================================
227
- # UI
228
- # ============================================================================
229
 
230
- mode = "🟢 REAL MODEL" if MODEL_LOADED else "🟡 SIMULATION"
231
 
232
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
233
- gr.Markdown(f"""
234
- # 🤖 RoboFlamingo Demo - {mode}
235
-
236
- ### Vision-Language Foundation Models as Effective Robot Imitators
237
-
238
- {'✅ Real model loaded!' if MODEL_LOADED else '⚠️ Simulation mode - Enable GPU in settings'}
239
-
240
- **Upload your images and enter instructions:**
241
- """)
242
 
243
  with gr.Row():
244
  with gr.Column():
245
- gr.Markdown("### 📥 Inputs")
246
- instruction = gr.Textbox(label="Instruction",
247
- placeholder="Pick up the red block", lines=3)
248
  with gr.Row():
249
- third = gr.Image(label="Third-Person View", type="pil", height=250)
250
- gripper = gr.Image(label="Gripper View", type="pil", height=250)
251
- btn = gr.Button("🚀 Predict Actions", variant="primary", size="lg")
252
- status = gr.Textbox(label="Status", lines=5, interactive=False)
253
-
254
  with gr.Column():
255
- gr.Markdown("### 📊 Predictions")
256
  traj = gr.Image(label="Trajectory", type="pil")
257
  grip = gr.Image(label="Gripper", type="pil")
258
 
259
- with gr.Row():
260
- table = gr.Markdown()
261
-
262
- btn.click(predict, [instruction, third, gripper], [traj, grip, table, status])
263
-
264
- gr.Markdown(f"""
265
- ---
266
- **Status:** {mode} | **Device:** {'GPU' if torch.cuda.is_available() else 'CPU'}
267
-
268
- [Paper](https://arxiv.org/abs/2311.01378) | [Code](https://github.com/RoboFlamingo/RoboFlamingo)
269
 
270
- {'⚠️ Enable T4 GPU in Settings → Hardware for real model' if not MODEL_LOADED else ''}
271
- """)
272
 
273
  demo.launch()
 
1
+ """RoboFlamingo Interactive Demo"""
 
 
 
 
2
  import gradio as gr
3
  import torch
4
  import numpy as np
 
7
  from io import BytesIO
8
  import sys
9
 
 
10
  sys.path.insert(0, '/home/user/app/open_flamingo/src')
11
 
12
+ print("🚀 Initializing RoboFlamingo")
 
 
 
 
 
 
13
 
14
  MODEL_LOADED = False
 
 
 
15
 
16
  try:
17
+ print("📦 Importing...")
18
  from factory import create_model_and_transforms
19
 
 
 
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ print(f"Device: {device}")
22
 
23
  model, image_processor, tokenizer = create_model_and_transforms(
24
  clip_vision_encoder_path="ViT-L-14",
 
29
  decoder_type='lstm',
30
  )
31
 
32
+ model.to(device).eval()
 
33
  MODEL_LOADED = True
34
+ print("✅ Model loaded!")
 
 
 
 
35
  except Exception as e:
36
+ print(f"⚠️ Model failed: {e}")
 
37
  import traceback
38
  traceback.print_exc()
39
 
40
+ def plot_traj(acts):
41
+ fig = plt.figure(figsize=(10,8))
 
 
 
 
42
  ax = fig.add_subplot(111, projection='3d')
43
+ x = np.cumsum([a['delta_x'] for a in acts])
44
+ y = np.cumsum([a['delta_y'] for a in acts])
45
+ z = np.cumsum([a['delta_z'] for a in acts])
46
+ ax.plot(x, y, z, 'b-', lw=2, marker='o', ms=6)
47
+ ax.scatter(x[0], y[0], z[0], c='green', s=100, label='Start')
48
+ ax.scatter(x[-1], y[-1], z[-1], c='red', s=100, label='End')
49
+ ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
50
+ ax.legend(); ax.grid()
 
 
 
 
 
 
 
 
51
  buf = BytesIO()
52
  plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
53
+ buf.seek(0); plt.close()
 
54
  return Image.open(buf)
55
 
56
+ def plot_grip(grip):
57
+ fig, ax = plt.subplots(figsize=(12,3))
58
+ cols = ['green' if g==0 else 'red' for g in grip]
59
+ ax.bar(range(len(grip)), [1]*len(grip), color=cols, alpha=0.7, ec='black')
60
+ for i, g in enumerate(grip):
61
+ ax.text(i, 0.5, 'OPEN' if g==0 else 'CLOSE', ha='center', va='center', weight='bold')
62
+ ax.set_xlabel('Timestep'); ax.set_ylim(0,1.2); ax.grid(alpha=0.3)
 
 
 
 
 
 
 
63
  buf = BytesIO()
64
  plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
65
+ buf.seek(0); plt.close()
 
66
  return Image.open(buf)
67
 
68
+ def simulate(inst):
69
+ np.random.seed(sum(ord(c) for c in inst) % 100)
70
+ acts = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  for t in range(12):
72
+ p = t/12
73
+ acts.append({'timestep': t, 'delta_x': (0.05+np.random.randn()*0.01)*p,
74
+ 'delta_y': (0.02+np.random.randn()*0.01)*p,
75
+ 'delta_z': (-0.03+np.random.randn()*0.01)*(1-p),
76
+ 'qw': 0.99, 'qx': 0.01, 'qy': 0.01, 'qz': 0.01})
77
+ return acts, [0]*6+[1]*6
78
+
79
+ def predict(inst, img1, img2):
80
+ if not inst or not inst.strip():
81
+ return None, None, "", "❌ Enter instruction"
82
+ if img1 is None or img2 is None:
83
+ return None, None, "", " Upload both images"
 
 
 
 
 
 
 
84
 
85
  try:
86
+ if isinstance(img1, np.ndarray):
87
+ img1 = Image.fromarray(img1)
88
+ if isinstance(img2, np.ndarray):
89
+ img2 = Image.fromarray(img2)
90
 
91
  if not MODEL_LOADED:
92
+ acts, grip = simulate(inst)
93
+ status = f"⚠️ SIMULATION\n{inst}\nEnable GPU for real model"
 
94
  else:
 
 
95
  with torch.no_grad():
96
+ t1 = image_processor(img1).unsqueeze(0).to(device)
97
+ t2 = image_processor(img2).unsqueeze(0).to(device)
98
+ vis = torch.stack([t1, t2], dim=1)
99
+ tok = tokenizer(inst, return_tensors="pt", padding=True, truncation=True).to(device)
100
+ out = model(vision_x=vis, lang_x=tok['input_ids'], attention_mask=tok.get('attention_mask'))
 
 
 
 
101
 
102
+ if isinstance(out, dict):
103
+ a = out.get('actions', out.get('action'))
104
+ g = out.get('gripper')
105
+ elif isinstance(out, tuple):
106
+ a = out[0]
107
+ g = out[1] if len(out)>1 else None
108
  else:
109
+ a = out
110
+ g = None
111
 
112
+ if a is not None:
113
+ anp = a[0].cpu().numpy()
114
+ acts = []
115
+ for t, ac in enumerate(anp):
116
+ if len(ac)<7: ac = np.pad(ac, (0,7-len(ac)))
117
+ acts.append({'timestep': t, 'delta_x': float(ac[0]), 'delta_y': float(ac[1]),
118
+ 'delta_z': float(ac[2]), 'qw': float(ac[3]), 'qx': float(ac[4]),
119
+ 'qy': float(ac[5]), 'qz': float(ac[6])})
120
+ grip = [int(x>0.5) if np.isscalar(x) else int(x[0]>0.5) for x in (g[0].cpu().numpy() if g is not None else [0]*len(acts))]
121
+ status = f"✅ REAL MODEL\n{inst}\n{device}"
 
 
 
 
 
 
 
 
 
 
122
  else:
123
+ acts, grip = simulate(inst)
124
+ status = f"⚠️ Unexpected output\n{inst}"
125
 
126
+ traj = plot_traj(acts)
127
+ gp = plot_grip(grip)
128
+ table = "| T | Δx | Δy | Δz |\n|--|--|--|--|\n"
129
+ for a in acts:
130
+ table += f"| {a['timestep']} | {a['delta_x']:.3f} | {a['delta_y']:.3f} | {a['delta_z']:.3f} |\n"
131
 
132
+ return traj, gp, table, status
133
  except Exception as e:
134
+ print(f"Error: {e}")
135
  import traceback
136
  traceback.print_exc()
137
+ acts, grip = simulate(inst)
138
+ return plot_traj(acts), plot_grip(grip), "", f"❌ {str(e)}"
 
 
 
 
 
 
 
 
 
139
 
140
+ mode = "🟢 REAL" if MODEL_LOADED else "🟡 SIM"
141
 
142
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
143
+ gr.Markdown(f"# 🤖 RoboFlamingo - {mode}\n{'Real model loaded!' if MODEL_LOADED else 'Enable GPU for real model'}")
 
 
 
 
 
 
 
 
144
 
145
  with gr.Row():
146
  with gr.Column():
147
+ inst = gr.Textbox(label="Instruction", placeholder="Pick up the red block", lines=3)
 
 
148
  with gr.Row():
149
+ img1 = gr.Image(label="Third-Person", type="pil", height=250)
150
+ img2 = gr.Image(label="Gripper", type="pil", height=250)
151
+ btn = gr.Button("🚀 Predict", variant="primary", size="lg")
152
+ st = gr.Textbox(label="Status", lines=4, interactive=False)
 
153
  with gr.Column():
 
154
  traj = gr.Image(label="Trajectory", type="pil")
155
  grip = gr.Image(label="Gripper", type="pil")
156
 
157
+ tab = gr.Markdown()
158
+ btn.click(predict, [inst, img1, img2], [traj, grip, tab, st])
 
 
 
 
 
 
 
 
159
 
160
+ gr.Markdown(f"**Status:** {mode} | [Paper](https://arxiv.org/abs/2311.01378)")
 
161
 
162
  demo.launch()
open_flamingo/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .src.flamingo import Flamingo
2
+ from .src.factory import create_model_and_transforms
open_flamingo/src/__init__.py ADDED
File without changes
open_flamingo/src/factory.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import open_clip
5
+
6
+ from .flamingo import Flamingo
7
+ from .flamingo_lm import FlamingoLMMixin
8
+ from .utils import extend_instance
9
+
10
+
11
+ def create_model_and_transforms(
12
+ clip_vision_encoder_path: str,
13
+ clip_vision_encoder_pretrained: str,
14
+ lang_encoder_path: str,
15
+ tokenizer_path: str,
16
+ cross_attn_every_n_layers: int = 1,
17
+ use_local_files: bool = False,
18
+ decoder_layers_attr_name: str = None,
19
+ freeze_lm_embeddings: bool = False,
20
+ cache_dir: Optional[str] = None,
21
+ **flamingo_kwargs,
22
+ ):
23
+ """
24
+ Initialize a Flamingo model from a pretrained vision encoder and language encoder.
25
+ Appends special tokens to the tokenizer and freezes backbones.
26
+
27
+ Args:
28
+ clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
29
+ clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
30
+ lang_encoder_path (str): path to pretrained language encoder
31
+ tokenizer_path (str): path to pretrained tokenizer
32
+ cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
33
+ use_local_files (bool, optional): whether to use local files. Defaults to False.
34
+ decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
35
+ freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver.
36
+ cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights.
37
+ Returns:
38
+ Flamingo: Flamingo model from pretrained vision and language encoders
39
+ Image processor: Pipeline to preprocess input images
40
+ Tokenizer: A tokenizer for the language model
41
+ """
42
+ vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
43
+ clip_vision_encoder_path,
44
+ pretrained=clip_vision_encoder_pretrained,
45
+ cache_dir=cache_dir,
46
+ )
47
+ # set the vision encoder to output the visual features
48
+ vision_encoder.visual.output_tokens = True
49
+
50
+ text_tokenizer = AutoTokenizer.from_pretrained(
51
+ tokenizer_path,
52
+ local_files_only=use_local_files,
53
+ trust_remote_code=True,
54
+ cache_dir=cache_dir,
55
+ )
56
+ # add Flamingo special tokens to the tokenizer
57
+ text_tokenizer.add_special_tokens(
58
+ {"additional_special_tokens": ["<|endofchunk|>", "<image>", "<action>"]}
59
+ )
60
+ if text_tokenizer.pad_token is None:
61
+ # Issue: GPT models don't have a pad token, which we use to
62
+ # modify labels for the loss.
63
+ text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
64
+
65
+ lang_encoder = AutoModelForCausalLM.from_pretrained(
66
+ lang_encoder_path,
67
+ local_files_only=use_local_files,
68
+ trust_remote_code=True,
69
+ cache_dir=cache_dir,
70
+ )
71
+
72
+ # hacks for MPT-1B, which doesn't have a get_input_embeddings method
73
+ if "mpt-1b-redpajama-200b" in lang_encoder_path:
74
+
75
+ class EmbeddingFnMixin:
76
+ def get_input_embeddings(self):
77
+ return self.transformer.wte
78
+
79
+ def set_input_embeddings(self, new_embeddings):
80
+ self.transformer.wte = new_embeddings
81
+
82
+ extend_instance(lang_encoder, EmbeddingFnMixin)
83
+
84
+ # convert LM to FlamingoLM
85
+ extend_instance(lang_encoder, FlamingoLMMixin)
86
+
87
+ if decoder_layers_attr_name is None:
88
+ decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
89
+ lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
90
+ lang_encoder.resize_token_embeddings(len(text_tokenizer))
91
+
92
+ model = Flamingo(
93
+ vision_encoder,
94
+ lang_encoder,
95
+ text_tokenizer.encode("<|endofchunk|>")[-1],
96
+ text_tokenizer.encode("<image>")[-1],
97
+ vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][
98
+ "width"
99
+ ],
100
+ cross_attn_every_n_layers=cross_attn_every_n_layers,
101
+ **flamingo_kwargs,
102
+ )
103
+
104
+ # Freeze all parameters
105
+ model.requires_grad_(False)
106
+ assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
107
+
108
+ # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings
109
+ model.perceiver.requires_grad_(True)
110
+ model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
111
+ if not freeze_lm_embeddings:
112
+ model.lang_encoder.get_input_embeddings().requires_grad_(True)
113
+ # TODO: investigate also training the output embeddings when untied
114
+
115
+ print(
116
+ f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
117
+ )
118
+
119
+ return model, image_processor, text_tokenizer
120
+
121
+
122
+ def _infer_decoder_layers_attr_name(model):
123
+ for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
124
+ if k.lower() in model.__class__.__name__.lower():
125
+ return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
126
+
127
+ raise ValueError(
128
+ f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
129
+ )
130
+
131
+
132
+ __KNOWN_DECODER_LAYERS_ATTR_NAMES = {
133
+ "opt": "model.decoder.layers",
134
+ "gptj": "transformer.h",
135
+ "gpt-j": "transformer.h",
136
+ "pythia": "gpt_neox.layers",
137
+ "llama": "model.layers",
138
+ "gptneoxforcausallm": "gpt_neox.layers",
139
+ "mpt": "transformer.blocks",
140
+ "mosaicgpt": "transformer.blocks",
141
+ }
open_flamingo/src/flamingo.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import nn
4
+ from .helpers import PerceiverResampler
5
+ from torch.distributed.fsdp.wrap import (
6
+ enable_wrap,
7
+ wrap,
8
+ )
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+ from torch.distributed.fsdp import (
11
+ FullyShardedDataParallel as FSDP,
12
+ )
13
+
14
+ from .utils import apply_with_stopping_condition
15
+
16
+
17
+ class Flamingo(nn.Module):
18
+ def __init__(
19
+ self,
20
+ vision_encoder: nn.Module,
21
+ lang_encoder: nn.Module,
22
+ eoc_token_id: int,
23
+ media_token_id: int,
24
+ vis_dim: int,
25
+ cross_attn_every_n_layers: int = 1,
26
+ gradient_checkpointing: bool = False,
27
+ ):
28
+ """
29
+ Args:
30
+ vision_encoder (nn.Module): HF CLIPModel
31
+ lang_encoder (nn.Module): HF causal language model
32
+ eoc_token_id (int): Token id for <|endofchunk|>
33
+ media_token_id (int): Token id for <image>
34
+ vis_dim (int): Dimension of the visual features.
35
+ Visual features are projected to match this shape along the last dimension.
36
+ cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
37
+ """
38
+ super().__init__()
39
+ self.eoc_token_id = eoc_token_id
40
+ self.media_token_id = media_token_id
41
+ self.vis_dim = vis_dim
42
+ if hasattr(lang_encoder.config, "d_model"):
43
+ self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
44
+ else:
45
+ self.lang_dim = lang_encoder.config.hidden_size
46
+
47
+ self.vision_encoder = vision_encoder.visual
48
+ self.perceiver = PerceiverResampler(dim=self.vis_dim)
49
+ self.lang_encoder = lang_encoder
50
+ self.lang_encoder.init_flamingo(
51
+ media_token_id=media_token_id,
52
+ lang_hidden_size=self.lang_dim,
53
+ vis_hidden_size=self.vis_dim,
54
+ cross_attn_every_n_layers=cross_attn_every_n_layers,
55
+ gradient_checkpointing=gradient_checkpointing,
56
+ )
57
+ self._use_gradient_checkpointing = gradient_checkpointing
58
+ self.perceiver._use_gradient_checkpointing = gradient_checkpointing
59
+
60
+ def forward(
61
+ self,
62
+ vision_x: torch.Tensor,
63
+ lang_x: torch.Tensor,
64
+ attention_mask: torch.Tensor = None,
65
+ labels: torch.Tensor = None,
66
+ clear_conditioned_layers: bool = True,
67
+ past_key_values=None,
68
+ use_cache: bool = False,
69
+ ):
70
+ """
71
+ Forward pass of Flamingo.
72
+
73
+ Args:
74
+ vision_x (torch.Tensor): Vision input
75
+ shape (B, T_img, F, C, H, W) with F=1
76
+ lang_x (torch.Tensor): Language input ids
77
+ shape (B, T_txt)
78
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
79
+ labels (torch.Tensor, optional): Labels. Defaults to None.
80
+ clear_conditioned_layers: if True, clear the conditioned layers
81
+ once the foward pass is completed. Set this to false if the
82
+ same set of images will be reused in another subsequent
83
+ forward pass.
84
+ past_key_values: pre-computed values to pass to language model.
85
+ See past_key_values documentation in Hugging Face
86
+ CausalLM models.
87
+ use_cache: whether to use cached key values. See use_cache
88
+ documentation in Hugging Face CausalLM models.
89
+ """
90
+ assert (
91
+ self.lang_encoder.initialized_flamingo
92
+ ), "Flamingo layers are not initialized. Please call `init_flamingo` first."
93
+
94
+ assert (
95
+ self.lang_encoder._use_cached_vision_x or vision_x is not None
96
+ ), "Must provide either vision_x or have precached media using cache_media()."
97
+
98
+ if self.lang_encoder._use_cached_vision_x:
99
+ # Case: use cached; vision_x should be cached and other
100
+ # vision-related inputs should not be provided.
101
+ assert (
102
+ vision_x is None
103
+ ), "Expect vision_x to be None when media has been cached using cache_media(). Try uncache_media() first."
104
+ assert self.lang_encoder.is_conditioned()
105
+
106
+ else:
107
+ # Case: do not use caching (i.e. this is a standard forward pass);
108
+ self._encode_vision_x(vision_x=vision_x)
109
+ self._condition_media_locations(input_ids=lang_x)
110
+
111
+ output = self.lang_encoder(
112
+ input_ids=lang_x,
113
+ attention_mask=attention_mask,
114
+ labels=labels,
115
+ past_key_values=past_key_values,
116
+ use_cache=use_cache,
117
+ )
118
+
119
+ if clear_conditioned_layers:
120
+ self.lang_encoder.clear_conditioned_layers()
121
+
122
+ return output
123
+
124
+ def generate(
125
+ self,
126
+ vision_x: torch.Tensor,
127
+ lang_x: torch.Tensor,
128
+ attention_mask: torch.Tensor = None,
129
+ **kwargs,
130
+ ):
131
+ """
132
+ Generate text conditioned on vision and language inputs.
133
+
134
+ Args:
135
+ vision_x (torch.Tensor): Vision input
136
+ shape (B, T_img, F, C, H, W)
137
+ images in the same chunk are collated along T_img, and frames are collated along F
138
+ currently only F=1 is supported (single-frame videos)
139
+ lang_x (torch.Tensor): Language input
140
+ shape (B, T_txt)
141
+ **kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs:
142
+ max_length (int, optional): Maximum length of the output. Defaults to None.
143
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
144
+ num_beams (int, optional): Number of beams. Defaults to 1.
145
+ max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
146
+ temperature (float, optional): Temperature. Defaults to 1.0.
147
+ top_k (int, optional): Top k. Defaults to 50.
148
+ top_p (float, optional): Top p. Defaults to 1.0.
149
+ no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
150
+ length_penalty (float, optional): Length penalty. Defaults to 1.0.
151
+ num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
152
+ do_sample (bool, optional): Do sample. Defaults to False.
153
+ early_stopping (bool, optional): Early stopping. Defaults to False.
154
+ Returns:
155
+ torch.Tensor: lang_x with generated tokens appended to it
156
+ """
157
+ num_beams = kwargs.pop("num_beams", 1)
158
+ if num_beams > 1:
159
+ vision_x = vision_x.repeat_interleave(num_beams, dim=0)
160
+
161
+ self.lang_encoder._use_cached_vision_x = True
162
+ self._encode_vision_x(vision_x=vision_x)
163
+
164
+ eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id)
165
+ output = self.lang_encoder.generate(
166
+ input_ids=lang_x,
167
+ attention_mask=attention_mask,
168
+ eos_token_id=eos_token_id,
169
+ num_beams=num_beams,
170
+ **kwargs,
171
+ )
172
+
173
+ self.lang_encoder.clear_conditioned_layers()
174
+ self.lang_encoder._use_cached_vision_x = False
175
+ return output
176
+
177
+ def _encode_vision_x(self, vision_x: torch.Tensor):
178
+ """
179
+ Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
180
+ Args:
181
+ vision_x (torch.Tensor): Vision input
182
+ shape (B, T_img, F, C, H, W)
183
+ Images in the same chunk are collated along T_img, and frames are collated along F
184
+ Currently only F=1 is supported (single-frame videos)
185
+
186
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
187
+ """
188
+
189
+ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
190
+ b, T, F = vision_x.shape[:3]
191
+ assert F == 1, "Only single frame supported"
192
+
193
+ vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
194
+ with torch.no_grad():
195
+ vision_x = self.vision_encoder(vision_x)[1]
196
+ vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
197
+ vision_x = self.perceiver(vision_x)
198
+
199
+ for layer in self.lang_encoder._get_decoder_layers():
200
+ layer.condition_vis_x(vision_x)
201
+
202
+ def wrap_fsdp(self, wrapper_kwargs, device_id):
203
+ """
204
+ Manually wraps submodules for FSDP and move other parameters to device_id.
205
+
206
+ Why manually wrap?
207
+ - all parameters within the FSDP wrapper must have the same requires_grad.
208
+ We have a mix of frozen and unfrozen parameters.
209
+ - model.vision_encoder.visual needs to be individually wrapped or encode_vision_x errors
210
+ See: https://github.com/pytorch/pytorch/issues/82461#issuecomment-1269136344
211
+
212
+ The rough wrapping structure is:
213
+ - FlamingoModel
214
+ - FSDP(FSDP(vision_encoder))
215
+ - FSDP(FSDP(perceiver))
216
+ - lang_encoder
217
+ - FSDP(FSDP(input_embeddings))
218
+ - FlamingoLayers
219
+ - FSDP(FSDP(gated_cross_attn_layer))
220
+ - FSDP(FSDP(decoder_layer))
221
+ - FSDP(FSDP(output_embeddings))
222
+ - other parameters
223
+
224
+ Known issues:
225
+ - Our FSDP strategy is not compatible with tied embeddings. If the LM embeddings are tied,
226
+ train with DDP or set the --freeze_lm_embeddings flag to true.
227
+ - With FSDP + gradient ckpting, one can increase the batch size with seemingly no upper bound.
228
+ Although the training curves look okay, we found that downstream performance dramatically
229
+ degrades if the batch size is unreasonably large (e.g., 100 MMC4 batch size for OPT-125M).
230
+
231
+ FAQs about our FSDP wrapping strategy:
232
+ Why double wrap?
233
+ As of torch==2.0.1, FSDP's _post_forward_hook and _post_backward_hook
234
+ only free gathered parameters if the module is NOT FSDP root.
235
+
236
+ Why unfreeze the decoder_layers?
237
+ See https://github.com/pytorch/pytorch/issues/95805
238
+ As of torch==2.0.1, FSDP's _post_backward_hook is only registed if the flat param
239
+ requires_grad=True. We need the postback to fire to avoid OOM.
240
+ To effectively freeze the decoder layers, we exclude them from the optimizer.
241
+
242
+ What is assumed to be frozen v. unfrozen?
243
+ We assume that the model is being trained under normal Flamingo settings
244
+ with these lines being called in factory.py:
245
+ ```
246
+ # Freeze all parameters
247
+ model.requires_grad_(False)
248
+ assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
249
+
250
+ # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings
251
+ model.perceiver.requires_grad_(True)
252
+ model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
253
+ [optional] model.lang_encoder.get_input_embeddings().requires_grad_(True)
254
+ ```
255
+ """
256
+ # unfreeze the decoder layers
257
+ for block in self.lang_encoder.old_decoder_blocks:
258
+ block.requires_grad_(True)
259
+
260
+ # wrap in FSDP
261
+ with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs):
262
+ self.perceiver = wrap(wrap(self.perceiver))
263
+ self.lang_encoder.old_decoder_blocks = nn.ModuleList(
264
+ wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks
265
+ )
266
+ self.lang_encoder.gated_cross_attn_layers = nn.ModuleList(
267
+ wrap(wrap(layer)) if layer is not None else None
268
+ for layer in self.lang_encoder.gated_cross_attn_layers
269
+ )
270
+ self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing)
271
+ self.lang_encoder.set_input_embeddings(
272
+ wrap(wrap(self.lang_encoder.get_input_embeddings()))
273
+ )
274
+ self.lang_encoder.set_output_embeddings(
275
+ wrap(wrap(self.lang_encoder.get_output_embeddings()))
276
+ )
277
+ self.vision_encoder = wrap(wrap(self.vision_encoder)) # frozen
278
+
279
+ # manually move non-FSDP managed parameters to device_id
280
+ # these are all in lang_encoder
281
+ apply_with_stopping_condition(
282
+ module=self.lang_encoder,
283
+ apply_fn=lambda m: m.to(device_id),
284
+ apply_condition=lambda m: len(list(m.children())) == 0,
285
+ stopping_condition=lambda m: isinstance(m, FSDP),
286
+ )
287
+
288
+ # exclude the original decoder layers from the optimizer
289
+ for block in self.lang_encoder.old_decoder_blocks:
290
+ for p in block.parameters():
291
+ p.exclude_from_optimizer = True
292
+
293
+ # set up clip_grad_norm_ function
294
+ def clip_grad_norm_(max_norm):
295
+ self.perceiver.clip_grad_norm_(max_norm)
296
+ for layer in self.lang_encoder.gated_cross_attn_layers:
297
+ if layer is not None:
298
+ layer.clip_grad_norm_(max_norm)
299
+ self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm)
300
+
301
+ self.clip_grad_norm_ = clip_grad_norm_
302
+
303
+ def _condition_media_locations(self, input_ids: torch.Tensor):
304
+ """
305
+ Compute the media token locations from lang_x and condition the language model on these.
306
+ Args:
307
+ input_ids (torch.Tensor): Language input
308
+ shape (B, T_txt)
309
+ """
310
+ media_locations = input_ids == self.media_token_id
311
+
312
+ for layer in self.lang_encoder._get_decoder_layers():
313
+ layer.condition_media_locations(media_locations)
314
+
315
+ def cache_media(self, input_ids: torch.Tensor, vision_x: torch.Tensor):
316
+ """
317
+ Pre-cache a prompt/sequence of images / text for log-likelihood evaluations.
318
+ All subsequent calls to forward() will generate attending to the LAST
319
+ image in vision_x.
320
+ This is not meant to be used to cache things for generate().
321
+ Args:
322
+ input_ids (torch.Tensor): Language input
323
+ shape (B, T_txt)
324
+ vision_x (torch.Tensor): Vision input
325
+ shape (B, T_img, F, C, H, W)
326
+ Images in the same chunk are collated along T_img, and frames are collated along F
327
+ Currently only F=1 is supported (single-frame videos)
328
+ """
329
+ self._encode_vision_x(vision_x=vision_x)
330
+ self._condition_media_locations(input_ids=input_ids)
331
+ self.lang_encoder._use_cached_vision_x = True
332
+
333
+ def uncache_media(self):
334
+ """
335
+ Clear all conditioning.
336
+ """
337
+ self.lang_encoder.clear_conditioned_layers()
338
+ self.lang_encoder._use_cached_vision_x = False
open_flamingo/src/flamingo_lm.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .helpers import GatedCrossAttentionBlock
3
+ from .utils import getattr_recursive, setattr_recursive
4
+ import copy
5
+
6
+ class FlamingoLayer(nn.Module):
7
+ """
8
+ FlamingoLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer.
9
+ """
10
+
11
+ def __init__(
12
+ self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False, residual=False
13
+ ):
14
+ super().__init__()
15
+ self.gated_cross_attn_layer = gated_cross_attn_layer
16
+ self.decoder_layer = decoder_layer
17
+ self.vis_x = None
18
+ self.media_locations = None
19
+ self.residual = residual
20
+
21
+ if self.gated_cross_attn_layer is not None:
22
+ self.gated_cross_attn_layer._use_gradient_checkpointing = (
23
+ gradient_checkpointing
24
+ )
25
+ self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing
26
+
27
+ def clone_parameters(self):
28
+ self.res_layer = copy.deepcopy(self.gated_cross_attn_layer)
29
+ if self.res_layer is not None:
30
+ self.res_layer.requires_grad_(False)
31
+
32
+ def is_conditioned(self) -> bool:
33
+ """Check whether the layer is conditioned."""
34
+ return self.vis_x is not None and self.media_locations is not None
35
+
36
+ # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
37
+ def condition_vis_x(self, vis_x):
38
+ self.vis_x = vis_x
39
+
40
+ def condition_media_locations(self, media_locations):
41
+ self.media_locations = media_locations
42
+
43
+ def condition_use_cached_media(self, use_cached_media):
44
+ self.use_cached_media = use_cached_media
45
+
46
+ def forward(
47
+ self,
48
+ lang_x,
49
+ attention_mask=None,
50
+ **decoder_layer_kwargs,
51
+ ):
52
+ # Cross attention
53
+ if self.gated_cross_attn_layer is not None:
54
+ if self.vis_x is None:
55
+ raise ValueError("vis_x must be conditioned before forward pass")
56
+
57
+ if self.media_locations is None:
58
+ raise ValueError(
59
+ "media_locations must be conditioned before forward pass"
60
+ )
61
+
62
+ lang_x = self.gated_cross_attn_layer(
63
+ lang_x,
64
+ self.vis_x,
65
+ media_locations=self.media_locations,
66
+ use_cached_media=self.use_cached_media,
67
+ )
68
+
69
+ # Residual
70
+ if self.residual and self.res_layer is not None:
71
+ lang_x_res = self.res_layer(
72
+ lang_x,
73
+ self.vis_x,
74
+ media_locations=self.media_locations,
75
+ attend_previous=self.attend_previous,
76
+ )
77
+ lang_x = (lang_x + lang_x_res) / 2.0
78
+
79
+ # Normal decoder layer
80
+ lang_x = self.decoder_layer(
81
+ lang_x, attention_mask=attention_mask, **decoder_layer_kwargs
82
+ )
83
+ return lang_x
84
+
85
+
86
+ class FlamingoLMMixin(nn.Module):
87
+ """
88
+ Mixin to add cross-attention layers to a language model.
89
+ """
90
+
91
+ def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
92
+ self.decoder_layers_attr_name = decoder_layers_attr_name
93
+
94
+ def _get_decoder_layers(self):
95
+ return getattr_recursive(self, self.decoder_layers_attr_name)
96
+
97
+ def _set_decoder_layers(self, value):
98
+ setattr_recursive(self, self.decoder_layers_attr_name, value)
99
+
100
+ def init_flamingo(
101
+ self,
102
+ media_token_id,
103
+ lang_hidden_size,
104
+ vis_hidden_size,
105
+ cross_attn_every_n_layers,
106
+ gradient_checkpointing,
107
+ residual=False,
108
+ ):
109
+ """
110
+ Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
111
+ """
112
+ print('-'*100)
113
+ print(self.decoder_layers_attr_name)
114
+ self.old_decoder_blocks = self._get_decoder_layers()
115
+ self.gated_cross_attn_layers = nn.ModuleList(
116
+ [
117
+ GatedCrossAttentionBlock(
118
+ dim=lang_hidden_size, dim_visual=vis_hidden_size
119
+ )
120
+ if (layer_idx + 1) % cross_attn_every_n_layers == 0
121
+ else None
122
+ for layer_idx, _ in enumerate(self._get_decoder_layers())
123
+ ]
124
+ )
125
+ self.init_flamingo_layers(gradient_checkpointing, residual=residual)
126
+ self.media_token_id = media_token_id
127
+ self.initialized_flamingo = True
128
+ self._use_cached_vision_x = False
129
+
130
+ def init_flamingo_layers(self, gradient_checkpointing, residual=False):
131
+ """
132
+ Re initializes the FlamingoLayers.
133
+ Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks
134
+ """
135
+ self._set_decoder_layers(
136
+ nn.ModuleList(
137
+ [
138
+ FlamingoLayer(
139
+ gated_cross_attn_layer, decoder_layer, gradient_checkpointing, residual=residual
140
+ )
141
+ for gated_cross_attn_layer, decoder_layer in zip(
142
+ self.gated_cross_attn_layers, self.old_decoder_blocks
143
+ )
144
+ ]
145
+ )
146
+ )
147
+
148
+ def forward(self, input_ids, attention_mask, **kwargs):
149
+ """Condition the Flamingo layers on the media locations before forward()"""
150
+ if not self.initialized_flamingo:
151
+ raise ValueError(
152
+ "Flamingo layers are not initialized. Please call `init_flamingo` first."
153
+ )
154
+
155
+ media_locations = input_ids == self.media_token_id
156
+
157
+ # if there are media already cached and we're generating and there are no media tokens in the input,
158
+ # we'll assume that ALL input tokens should attend to the last previous media that is cached.
159
+ # this is especially important for HF generate() compatibility, since generate() calls forward()
160
+ # repeatedly one token at a time (with no media tokens).
161
+ # without this check, the model would not attend to any images when generating (after the first token)
162
+ use_cached_media_locations = (
163
+ self._use_cached_vision_x
164
+ and self.is_conditioned()
165
+ and not media_locations.any()
166
+ )
167
+
168
+ for layer in self._get_decoder_layers():
169
+ if not use_cached_media_locations:
170
+ layer.condition_media_locations(media_locations)
171
+ layer.condition_use_cached_media(use_cached_media_locations)
172
+
173
+ # package arguments for the other parent's forward. since we don't know the order of the arguments,
174
+ # make them all kwargs
175
+ kwargs["input_ids"] = input_ids
176
+ kwargs["attention_mask"] = attention_mask
177
+ return super().forward(**kwargs) # Call the other parent's forward method
178
+
179
+ def is_conditioned(self) -> bool:
180
+ """Check whether all decoder layers are already conditioned."""
181
+ return all(l.is_conditioned() for l in self._get_decoder_layers())
182
+
183
+ def clone_parameters(self):
184
+ for layer in self._get_decoder_layers():
185
+ layer.clone_parameters()
186
+
187
+ def clear_conditioned_layers(self):
188
+ for layer in self._get_decoder_layers():
189
+ layer.condition_vis_x(None)
190
+ layer.condition_media_locations(None)
191
+ layer.condition_use_cached_media(None)
open_flamingo/src/helpers.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on: https://github.com/lucidrains/flamingo-pytorch
3
+ """
4
+
5
+ import torch
6
+ from einops import rearrange, repeat
7
+ from einops_exts import rearrange_many
8
+ from torch import einsum, nn
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def FeedForward(dim, mult=4):
16
+ inner_dim = int(dim * mult)
17
+ return nn.Sequential(
18
+ nn.LayerNorm(dim),
19
+ nn.Linear(dim, inner_dim, bias=False),
20
+ nn.GELU(),
21
+ nn.Linear(inner_dim, dim, bias=False),
22
+ )
23
+
24
+
25
+ class PerceiverAttention(nn.Module):
26
+ def __init__(self, *, dim, dim_head=64, heads=8):
27
+ super().__init__()
28
+ self.scale = dim_head**-0.5
29
+ self.heads = heads
30
+ inner_dim = dim_head * heads
31
+
32
+ self.norm_media = nn.LayerNorm(dim)
33
+ self.norm_latents = nn.LayerNorm(dim)
34
+
35
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
36
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
37
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
38
+
39
+ def forward(self, x, latents):
40
+ """
41
+ Args:
42
+ x (torch.Tensor): image features
43
+ shape (b, T, n1, D)
44
+ latent (torch.Tensor): latent features
45
+ shape (b, T, n2, D)
46
+ """
47
+ x = self.norm_media(x)
48
+ latents = self.norm_latents(latents)
49
+
50
+ h = self.heads
51
+
52
+ q = self.to_q(latents)
53
+ kv_input = torch.cat((x, latents), dim=-2)
54
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
55
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
56
+ q = q * self.scale
57
+
58
+ # attention
59
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
60
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
61
+ attn = sim.softmax(dim=-1)
62
+
63
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
64
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
65
+ return self.to_out(out)
66
+
67
+
68
+ class PerceiverResampler(nn.Module):
69
+ def __init__(
70
+ self,
71
+ *,
72
+ dim,
73
+ depth=6,
74
+ dim_head=64,
75
+ heads=8,
76
+ num_latents=64,
77
+ max_num_media=None,
78
+ max_num_frames=None,
79
+ ff_mult=4,
80
+ ):
81
+ super().__init__()
82
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
83
+ self.frame_embs = (
84
+ nn.Parameter(torch.randn(max_num_frames, dim))
85
+ if exists(max_num_frames)
86
+ else None
87
+ )
88
+ self.media_time_embs = (
89
+ nn.Parameter(torch.randn(max_num_media, 1, dim))
90
+ if exists(max_num_media)
91
+ else None
92
+ )
93
+
94
+ self.layers = nn.ModuleList([])
95
+ for _ in range(depth):
96
+ self.layers.append(
97
+ nn.ModuleList(
98
+ [
99
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
100
+ FeedForward(dim=dim, mult=ff_mult),
101
+ ]
102
+ )
103
+ )
104
+
105
+ self.norm = nn.LayerNorm(dim)
106
+
107
+ def forward(self, x):
108
+ """
109
+ Args:
110
+ x (torch.Tensor): image features
111
+ shape (b, T, F, v, D)
112
+ Returns:
113
+ shape (b, T, n, D) where n is self.num_latents
114
+ """
115
+ b, T, F, v = x.shape[:4]
116
+
117
+ # frame and media time embeddings
118
+ if exists(self.frame_embs):
119
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
120
+ x = x + frame_embs
121
+ x = rearrange(
122
+ x, "b T F v d -> b T (F v) d"
123
+ ) # flatten the frame and spatial dimensions
124
+ if exists(self.media_time_embs):
125
+ x = x + self.media_time_embs[:T]
126
+
127
+ # blocks
128
+ latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
129
+ for attn, ff in self.layers:
130
+ latents = attn(x, latents) + latents
131
+ latents = ff(latents) + latents
132
+ return self.norm(latents)
133
+
134
+
135
+ # gated cross attention
136
+ class MaskedCrossAttention(nn.Module):
137
+ def __init__(
138
+ self,
139
+ *,
140
+ dim,
141
+ dim_visual,
142
+ dim_head=64,
143
+ heads=8,
144
+ only_attend_immediate_media=True,
145
+ ):
146
+ super().__init__()
147
+ self.scale = dim_head**-0.5
148
+ self.heads = heads
149
+ inner_dim = dim_head * heads
150
+
151
+ self.norm = nn.LayerNorm(dim)
152
+
153
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
154
+ self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
155
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
156
+
157
+ # whether for text to only attend to immediate preceding image, or all previous images
158
+ self.only_attend_immediate_media = only_attend_immediate_media
159
+
160
+ def forward(self, x, media, media_locations=None, use_cached_media=False):
161
+ """
162
+ Args:
163
+ x (torch.Tensor): text features
164
+ shape (B, T_txt, D_txt)
165
+ media (torch.Tensor): image features
166
+ shape (B, T_img, n, D_img) where n is the dim of the latents
167
+ media_locations: boolean mask identifying the media tokens in x
168
+ shape (B, T_txt)
169
+ use_cached_media: bool
170
+ If true, treat all of x as if they occur after the last media
171
+ registered in media_locations. T_txt does not need to exactly
172
+ equal media_locations.shape[1] in this case
173
+ """
174
+
175
+ if not use_cached_media:
176
+ assert (
177
+ media_locations.shape[1] == x.shape[1]
178
+ ), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}"
179
+
180
+ T_txt = x.shape[1]
181
+ _, T_img, n = media.shape[:3]
182
+ h = self.heads
183
+
184
+ x = self.norm(x)
185
+
186
+ q = self.to_q(x)
187
+ media = rearrange(media, "b t n d -> b (t n) d")
188
+
189
+ k, v = self.to_kv(media).chunk(2, dim=-1)
190
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
191
+
192
+ q = q * self.scale
193
+
194
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
195
+
196
+ if exists(media_locations):
197
+ media_time = torch.arange(T_img, device=x.device) + 1
198
+
199
+ if use_cached_media:
200
+ # text time is set to the last cached media location
201
+ text_time = repeat(
202
+ torch.count_nonzero(media_locations, dim=1),
203
+ "b -> b i",
204
+ i=T_txt,
205
+ )
206
+ else:
207
+ # at each boolean of True, increment the time counter (relative to media time)
208
+ text_time = media_locations.cumsum(dim=-1)
209
+
210
+ # text time must equal media time if only attending to most immediate image
211
+ # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
212
+ mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
213
+
214
+ text_to_media_mask = mask_op(
215
+ rearrange(text_time, "b i -> b 1 i 1"),
216
+ repeat(media_time, "j -> 1 1 1 (j n)", n=n),
217
+ )
218
+ sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
219
+
220
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
221
+ attn = sim.softmax(dim=-1)
222
+
223
+ if exists(media_locations) and self.only_attend_immediate_media:
224
+ # any text without a preceding media needs to have attention zeroed out
225
+ text_without_media_mask = text_time == 0
226
+ text_without_media_mask = rearrange(
227
+ text_without_media_mask, "b i -> b 1 i 1"
228
+ )
229
+ attn = attn.masked_fill(text_without_media_mask, 0.0)
230
+
231
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
232
+ out = rearrange(out, "b h n d -> b n (h d)")
233
+ return self.to_out(out)
234
+
235
+
236
+ class GatedCrossAttentionBlock(nn.Module):
237
+ def __init__(
238
+ self,
239
+ *,
240
+ dim,
241
+ dim_visual,
242
+ dim_head=64,
243
+ heads=8,
244
+ ff_mult=4,
245
+ only_attend_immediate_media=True,
246
+ ):
247
+ super().__init__()
248
+ self.attn = MaskedCrossAttention(
249
+ dim=dim,
250
+ dim_visual=dim_visual,
251
+ dim_head=dim_head,
252
+ heads=heads,
253
+ only_attend_immediate_media=only_attend_immediate_media,
254
+ )
255
+ self.attn_gate = nn.Parameter(torch.tensor([0.0]))
256
+
257
+ self.ff = FeedForward(dim, mult=ff_mult)
258
+ self.ff_gate = nn.Parameter(torch.tensor([0.0]))
259
+
260
+ def forward(
261
+ self,
262
+ x,
263
+ media,
264
+ media_locations=None,
265
+ use_cached_media=False,
266
+ ):
267
+ x = (
268
+ self.attn(
269
+ x,
270
+ media,
271
+ media_locations=media_locations,
272
+ use_cached_media=use_cached_media,
273
+ )
274
+ * self.attn_gate.tanh()
275
+ + x
276
+ )
277
+ x = self.ff(x) * self.ff_gate.tanh() + x
278
+
279
+ return x
open_flamingo/src/utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def extend_instance(obj, mixin):
2
+ """Apply mixins to a class instance after creation"""
3
+ base_cls = obj.__class__
4
+ base_cls_name = obj.__class__.__name__
5
+ obj.__class__ = type(
6
+ base_cls_name, (mixin, base_cls), {}
7
+ ) # mixin needs to go first for our forward() logic to work
8
+
9
+
10
+ def getattr_recursive(obj, att):
11
+ """
12
+ Return nested attribute of obj
13
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
14
+ """
15
+ if att == "":
16
+ return obj
17
+ i = att.find(".")
18
+ if i < 0:
19
+ return getattr(obj, att)
20
+ else:
21
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
22
+
23
+
24
+ def setattr_recursive(obj, att, val):
25
+ """
26
+ Set nested attribute of obj
27
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
28
+ """
29
+ if "." in att:
30
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
31
+ setattr(obj, att.split(".")[-1], val)
32
+
33
+
34
+ def apply_with_stopping_condition(
35
+ module, apply_fn, apply_condition=None, stopping_condition=None, **other_args
36
+ ):
37
+ if stopping_condition(module):
38
+ return
39
+ if apply_condition(module):
40
+ apply_fn(module, **other_args)
41
+ for child in module.children():
42
+ apply_with_stopping_condition(
43
+ child,
44
+ apply_fn,
45
+ apply_condition=apply_condition,
46
+ stopping_condition=stopping_condition,
47
+ **other_args
48
+ )