Spaces:
Runtime error
Runtime error
Deploy with src only (no large eval data)
Browse files- README.md +4 -20
- app.py +91 -202
- open_flamingo/__init__.py +2 -0
- open_flamingo/src/__init__.py +0 -0
- open_flamingo/src/factory.py +141 -0
- open_flamingo/src/flamingo.py +338 -0
- open_flamingo/src/flamingo_lm.py +191 -0
- open_flamingo/src/helpers.py +279 -0
- open_flamingo/src/utils.py +48 -0
README.md
CHANGED
|
@@ -9,26 +9,10 @@ app_file: app.py
|
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
-
# RoboFlamingo Demo 🤖
|
| 13 |
|
| 14 |
-
|
| 15 |
|
| 16 |
-
|
| 17 |
-
- 📸 Upload your own robot images
|
| 18 |
-
- 💬 Enter natural language instructions
|
| 19 |
-
- 🎯 Get real model predictions
|
| 20 |
-
- 📊 3D trajectory visualization
|
| 21 |
|
| 22 |
-
|
| 23 |
-
1. Upload third-person view image
|
| 24 |
-
2. Upload gripper view image
|
| 25 |
-
3. Enter instruction (e.g., "Pick up the red block")
|
| 26 |
-
4. Click "Predict Actions"
|
| 27 |
-
|
| 28 |
-
## Requirements
|
| 29 |
-
⚠️ **Requires GPU**: Enable T4 GPU in Space settings for real model.
|
| 30 |
-
Without GPU, runs in simulation mode.
|
| 31 |
-
|
| 32 |
-
## Resources
|
| 33 |
-
- [Paper](https://arxiv.org/abs/2311.01378)
|
| 34 |
-
- [Code](https://github.com/RoboFlamingo/RoboFlamingo)
|
|
|
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# RoboFlamingo Interactive Demo 🤖
|
| 13 |
|
| 14 |
+
Upload images and get real predictions!
|
| 15 |
|
| 16 |
+
⚠️ Enable T4 GPU in Settings for real model.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
[Paper](https://arxiv.org/abs/2311.01378) | [Code](https://github.com/RoboFlamingo/RoboFlamingo)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -1,8 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
-
RoboFlamingo Interactive Demo - Real Model
|
| 3 |
-
Upload your images and get real predictions!
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
import gradio as gr
|
| 7 |
import torch
|
| 8 |
import numpy as np
|
|
@@ -11,31 +7,18 @@ import matplotlib.pyplot as plt
|
|
| 11 |
from io import BytesIO
|
| 12 |
import sys
|
| 13 |
|
| 14 |
-
# Add OpenFlamingo to path
|
| 15 |
sys.path.insert(0, '/home/user/app/open_flamingo/src')
|
| 16 |
|
| 17 |
-
print("
|
| 18 |
-
print("🚀 INITIALIZING ROBOFLAMINGO")
|
| 19 |
-
print("=" * 70)
|
| 20 |
-
|
| 21 |
-
# ============================================================================
|
| 22 |
-
# LOAD MODEL
|
| 23 |
-
# ============================================================================
|
| 24 |
|
| 25 |
MODEL_LOADED = False
|
| 26 |
-
model = None
|
| 27 |
-
image_processor = None
|
| 28 |
-
tokenizer = None
|
| 29 |
|
| 30 |
try:
|
| 31 |
-
print("📦 Importing
|
| 32 |
from factory import create_model_and_transforms
|
| 33 |
|
| 34 |
-
print("✅ Import successful!")
|
| 35 |
-
print("🔧 Loading model (2-3 minutes on first run)...")
|
| 36 |
-
|
| 37 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 38 |
-
print(f"
|
| 39 |
|
| 40 |
model, image_processor, tokenizer = create_model_and_transforms(
|
| 41 |
clip_vision_encoder_path="ViT-L-14",
|
|
@@ -46,228 +29,134 @@ try:
|
|
| 46 |
decoder_type='lstm',
|
| 47 |
)
|
| 48 |
|
| 49 |
-
model.to(device)
|
| 50 |
-
model.eval()
|
| 51 |
MODEL_LOADED = True
|
| 52 |
-
|
| 53 |
-
print("=" * 70)
|
| 54 |
-
print("✅ REAL MODEL LOADED!")
|
| 55 |
-
print("=" * 70)
|
| 56 |
-
|
| 57 |
except Exception as e:
|
| 58 |
-
print(f"
|
| 59 |
-
print("⚠️ SIMULATION MODE")
|
| 60 |
import traceback
|
| 61 |
traceback.print_exc()
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
# ============================================================================
|
| 66 |
-
|
| 67 |
-
def create_trajectory_plot(actions):
|
| 68 |
-
fig = plt.figure(figsize=(10, 8))
|
| 69 |
ax = fig.add_subplot(111, projection='3d')
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
z =
|
| 74 |
-
|
| 75 |
-
ax.
|
| 76 |
-
ax.
|
| 77 |
-
ax.
|
| 78 |
-
|
| 79 |
-
ax.set_xlabel('X (m)')
|
| 80 |
-
ax.set_ylabel('Y (m)')
|
| 81 |
-
ax.set_zlabel('Z (m)')
|
| 82 |
-
ax.set_title('Predicted Trajectory')
|
| 83 |
-
ax.legend()
|
| 84 |
-
ax.grid(True)
|
| 85 |
-
|
| 86 |
buf = BytesIO()
|
| 87 |
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
|
| 88 |
-
buf.seek(0)
|
| 89 |
-
plt.close()
|
| 90 |
return Image.open(buf)
|
| 91 |
|
| 92 |
-
def
|
| 93 |
-
fig, ax = plt.subplots(figsize=(12,
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
ax.text(i, 0.5, label, ha='center', va='center', fontweight='bold')
|
| 100 |
-
|
| 101 |
-
ax.set_xlabel('Timestep')
|
| 102 |
-
ax.set_title('Gripper Commands')
|
| 103 |
-
ax.set_ylim(0, 1.2)
|
| 104 |
-
ax.grid(True, alpha=0.3)
|
| 105 |
-
|
| 106 |
buf = BytesIO()
|
| 107 |
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
|
| 108 |
-
buf.seek(0)
|
| 109 |
-
plt.close()
|
| 110 |
return Image.open(buf)
|
| 111 |
|
| 112 |
-
def
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
table += f"| {a['timestep']} | {a['delta_x']:.3f} | {a['delta_y']:.3f} | {a['delta_z']:.3f} | "
|
| 116 |
-
table += f"({a['qw']:.2f},{a['qx']:.2f},{a['qy']:.2f},{a['qz']:.2f}) |\n"
|
| 117 |
-
return table
|
| 118 |
-
|
| 119 |
-
# ============================================================================
|
| 120 |
-
# PREDICTION
|
| 121 |
-
# ============================================================================
|
| 122 |
-
|
| 123 |
-
def simulate(instruction):
|
| 124 |
-
"""Fallback simulation."""
|
| 125 |
-
seed = sum(ord(c) for c in instruction) % 100
|
| 126 |
-
np.random.seed(seed)
|
| 127 |
-
|
| 128 |
-
actions = []
|
| 129 |
for t in range(12):
|
| 130 |
-
p = t
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
if not instruction or not instruction.strip():
|
| 144 |
-
return None, None, "", "❌ Enter instruction!"
|
| 145 |
-
if third_img is None:
|
| 146 |
-
return None, None, "", "❌ Upload third-person view!"
|
| 147 |
-
if grip_img is None:
|
| 148 |
-
return None, None, "", "❌ Upload gripper view!"
|
| 149 |
|
| 150 |
try:
|
| 151 |
-
if isinstance(
|
| 152 |
-
|
| 153 |
-
if isinstance(
|
| 154 |
-
|
| 155 |
|
| 156 |
if not MODEL_LOADED:
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
status = f"⚠️ SIMULATION\n\n{instruction}\n\nEnable GPU for real model."
|
| 160 |
else:
|
| 161 |
-
print(f"🚀 Real inference: {instruction}")
|
| 162 |
-
|
| 163 |
with torch.no_grad():
|
| 164 |
-
t1 = image_processor(
|
| 165 |
-
t2 = image_processor(
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
truncation=True, max_length=512).to(device)
|
| 170 |
-
|
| 171 |
-
outputs = model(vision_x=vision_x, lang_x=tokens['input_ids'],
|
| 172 |
-
attention_mask=tokens.get('attention_mask'))
|
| 173 |
|
| 174 |
-
if isinstance(
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
elif isinstance(
|
| 178 |
-
|
| 179 |
-
|
| 180 |
else:
|
| 181 |
-
|
| 182 |
-
|
| 183 |
|
| 184 |
-
if
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
for t,
|
| 188 |
-
if len(
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
})
|
| 195 |
-
|
| 196 |
-
if grip is not None:
|
| 197 |
-
grip_np = grip[0].cpu().numpy()
|
| 198 |
-
gripper = [int(g>0.5) if np.isscalar(g) else int(g[0]>0.5) for g in grip_np]
|
| 199 |
-
else:
|
| 200 |
-
gripper = [0]*len(actions)
|
| 201 |
-
|
| 202 |
-
status = f"✅ REAL MODEL\n\n{instruction}\n\n{device}\n{len(actions)} timesteps"
|
| 203 |
-
print(f"✅ Success! {len(actions)} timesteps")
|
| 204 |
else:
|
| 205 |
-
|
| 206 |
-
status = f"⚠️ Unexpected output\n{
|
| 207 |
|
| 208 |
-
traj =
|
| 209 |
-
|
| 210 |
-
table =
|
| 211 |
-
|
| 212 |
-
|
| 213 |
|
|
|
|
| 214 |
except Exception as e:
|
| 215 |
-
print(f"
|
| 216 |
import traceback
|
| 217 |
traceback.print_exc()
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
traj = create_trajectory_plot(actions)
|
| 221 |
-
grip_plot = create_gripper_timeline(gripper)
|
| 222 |
-
table = format_table(actions)
|
| 223 |
-
|
| 224 |
-
return traj, grip_plot, table, f"❌ Error: {str(e)}"
|
| 225 |
-
|
| 226 |
-
# ============================================================================
|
| 227 |
-
# UI
|
| 228 |
-
# ============================================================================
|
| 229 |
|
| 230 |
-
mode = "🟢 REAL
|
| 231 |
|
| 232 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 233 |
-
gr.Markdown(f""
|
| 234 |
-
# 🤖 RoboFlamingo Demo - {mode}
|
| 235 |
-
|
| 236 |
-
### Vision-Language Foundation Models as Effective Robot Imitators
|
| 237 |
-
|
| 238 |
-
{'✅ Real model loaded!' if MODEL_LOADED else '⚠️ Simulation mode - Enable GPU in settings'}
|
| 239 |
-
|
| 240 |
-
**Upload your images and enter instructions:**
|
| 241 |
-
""")
|
| 242 |
|
| 243 |
with gr.Row():
|
| 244 |
with gr.Column():
|
| 245 |
-
gr.
|
| 246 |
-
instruction = gr.Textbox(label="Instruction",
|
| 247 |
-
placeholder="Pick up the red block", lines=3)
|
| 248 |
with gr.Row():
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
btn = gr.Button("🚀 Predict
|
| 252 |
-
|
| 253 |
-
|
| 254 |
with gr.Column():
|
| 255 |
-
gr.Markdown("### 📊 Predictions")
|
| 256 |
traj = gr.Image(label="Trajectory", type="pil")
|
| 257 |
grip = gr.Image(label="Gripper", type="pil")
|
| 258 |
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
btn.click(predict, [instruction, third, gripper], [traj, grip, table, status])
|
| 263 |
-
|
| 264 |
-
gr.Markdown(f"""
|
| 265 |
-
---
|
| 266 |
-
**Status:** {mode} | **Device:** {'GPU' if torch.cuda.is_available() else 'CPU'}
|
| 267 |
-
|
| 268 |
-
[Paper](https://arxiv.org/abs/2311.01378) | [Code](https://github.com/RoboFlamingo/RoboFlamingo)
|
| 269 |
|
| 270 |
-
{
|
| 271 |
-
""")
|
| 272 |
|
| 273 |
demo.launch()
|
|
|
|
| 1 |
+
"""RoboFlamingo Interactive Demo"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
|
|
|
| 7 |
from io import BytesIO
|
| 8 |
import sys
|
| 9 |
|
|
|
|
| 10 |
sys.path.insert(0, '/home/user/app/open_flamingo/src')
|
| 11 |
|
| 12 |
+
print("🚀 Initializing RoboFlamingo")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
MODEL_LOADED = False
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
try:
|
| 17 |
+
print("📦 Importing...")
|
| 18 |
from factory import create_model_and_transforms
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
+
print(f"Device: {device}")
|
| 22 |
|
| 23 |
model, image_processor, tokenizer = create_model_and_transforms(
|
| 24 |
clip_vision_encoder_path="ViT-L-14",
|
|
|
|
| 29 |
decoder_type='lstm',
|
| 30 |
)
|
| 31 |
|
| 32 |
+
model.to(device).eval()
|
|
|
|
| 33 |
MODEL_LOADED = True
|
| 34 |
+
print("✅ Model loaded!")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
except Exception as e:
|
| 36 |
+
print(f"⚠️ Model failed: {e}")
|
|
|
|
| 37 |
import traceback
|
| 38 |
traceback.print_exc()
|
| 39 |
|
| 40 |
+
def plot_traj(acts):
|
| 41 |
+
fig = plt.figure(figsize=(10,8))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
ax = fig.add_subplot(111, projection='3d')
|
| 43 |
+
x = np.cumsum([a['delta_x'] for a in acts])
|
| 44 |
+
y = np.cumsum([a['delta_y'] for a in acts])
|
| 45 |
+
z = np.cumsum([a['delta_z'] for a in acts])
|
| 46 |
+
ax.plot(x, y, z, 'b-', lw=2, marker='o', ms=6)
|
| 47 |
+
ax.scatter(x[0], y[0], z[0], c='green', s=100, label='Start')
|
| 48 |
+
ax.scatter(x[-1], y[-1], z[-1], c='red', s=100, label='End')
|
| 49 |
+
ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
|
| 50 |
+
ax.legend(); ax.grid()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
buf = BytesIO()
|
| 52 |
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
|
| 53 |
+
buf.seek(0); plt.close()
|
|
|
|
| 54 |
return Image.open(buf)
|
| 55 |
|
| 56 |
+
def plot_grip(grip):
|
| 57 |
+
fig, ax = plt.subplots(figsize=(12,3))
|
| 58 |
+
cols = ['green' if g==0 else 'red' for g in grip]
|
| 59 |
+
ax.bar(range(len(grip)), [1]*len(grip), color=cols, alpha=0.7, ec='black')
|
| 60 |
+
for i, g in enumerate(grip):
|
| 61 |
+
ax.text(i, 0.5, 'OPEN' if g==0 else 'CLOSE', ha='center', va='center', weight='bold')
|
| 62 |
+
ax.set_xlabel('Timestep'); ax.set_ylim(0,1.2); ax.grid(alpha=0.3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
buf = BytesIO()
|
| 64 |
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
|
| 65 |
+
buf.seek(0); plt.close()
|
|
|
|
| 66 |
return Image.open(buf)
|
| 67 |
|
| 68 |
+
def simulate(inst):
|
| 69 |
+
np.random.seed(sum(ord(c) for c in inst) % 100)
|
| 70 |
+
acts = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
for t in range(12):
|
| 72 |
+
p = t/12
|
| 73 |
+
acts.append({'timestep': t, 'delta_x': (0.05+np.random.randn()*0.01)*p,
|
| 74 |
+
'delta_y': (0.02+np.random.randn()*0.01)*p,
|
| 75 |
+
'delta_z': (-0.03+np.random.randn()*0.01)*(1-p),
|
| 76 |
+
'qw': 0.99, 'qx': 0.01, 'qy': 0.01, 'qz': 0.01})
|
| 77 |
+
return acts, [0]*6+[1]*6
|
| 78 |
+
|
| 79 |
+
def predict(inst, img1, img2):
|
| 80 |
+
if not inst or not inst.strip():
|
| 81 |
+
return None, None, "", "❌ Enter instruction"
|
| 82 |
+
if img1 is None or img2 is None:
|
| 83 |
+
return None, None, "", "❌ Upload both images"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
try:
|
| 86 |
+
if isinstance(img1, np.ndarray):
|
| 87 |
+
img1 = Image.fromarray(img1)
|
| 88 |
+
if isinstance(img2, np.ndarray):
|
| 89 |
+
img2 = Image.fromarray(img2)
|
| 90 |
|
| 91 |
if not MODEL_LOADED:
|
| 92 |
+
acts, grip = simulate(inst)
|
| 93 |
+
status = f"⚠️ SIMULATION\n{inst}\nEnable GPU for real model"
|
|
|
|
| 94 |
else:
|
|
|
|
|
|
|
| 95 |
with torch.no_grad():
|
| 96 |
+
t1 = image_processor(img1).unsqueeze(0).to(device)
|
| 97 |
+
t2 = image_processor(img2).unsqueeze(0).to(device)
|
| 98 |
+
vis = torch.stack([t1, t2], dim=1)
|
| 99 |
+
tok = tokenizer(inst, return_tensors="pt", padding=True, truncation=True).to(device)
|
| 100 |
+
out = model(vision_x=vis, lang_x=tok['input_ids'], attention_mask=tok.get('attention_mask'))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
if isinstance(out, dict):
|
| 103 |
+
a = out.get('actions', out.get('action'))
|
| 104 |
+
g = out.get('gripper')
|
| 105 |
+
elif isinstance(out, tuple):
|
| 106 |
+
a = out[0]
|
| 107 |
+
g = out[1] if len(out)>1 else None
|
| 108 |
else:
|
| 109 |
+
a = out
|
| 110 |
+
g = None
|
| 111 |
|
| 112 |
+
if a is not None:
|
| 113 |
+
anp = a[0].cpu().numpy()
|
| 114 |
+
acts = []
|
| 115 |
+
for t, ac in enumerate(anp):
|
| 116 |
+
if len(ac)<7: ac = np.pad(ac, (0,7-len(ac)))
|
| 117 |
+
acts.append({'timestep': t, 'delta_x': float(ac[0]), 'delta_y': float(ac[1]),
|
| 118 |
+
'delta_z': float(ac[2]), 'qw': float(ac[3]), 'qx': float(ac[4]),
|
| 119 |
+
'qy': float(ac[5]), 'qz': float(ac[6])})
|
| 120 |
+
grip = [int(x>0.5) if np.isscalar(x) else int(x[0]>0.5) for x in (g[0].cpu().numpy() if g is not None else [0]*len(acts))]
|
| 121 |
+
status = f"✅ REAL MODEL\n{inst}\n{device}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
else:
|
| 123 |
+
acts, grip = simulate(inst)
|
| 124 |
+
status = f"⚠️ Unexpected output\n{inst}"
|
| 125 |
|
| 126 |
+
traj = plot_traj(acts)
|
| 127 |
+
gp = plot_grip(grip)
|
| 128 |
+
table = "| T | Δx | Δy | Δz |\n|--|--|--|--|\n"
|
| 129 |
+
for a in acts:
|
| 130 |
+
table += f"| {a['timestep']} | {a['delta_x']:.3f} | {a['delta_y']:.3f} | {a['delta_z']:.3f} |\n"
|
| 131 |
|
| 132 |
+
return traj, gp, table, status
|
| 133 |
except Exception as e:
|
| 134 |
+
print(f"Error: {e}")
|
| 135 |
import traceback
|
| 136 |
traceback.print_exc()
|
| 137 |
+
acts, grip = simulate(inst)
|
| 138 |
+
return plot_traj(acts), plot_grip(grip), "", f"❌ {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
+
mode = "🟢 REAL" if MODEL_LOADED else "🟡 SIM"
|
| 141 |
|
| 142 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 143 |
+
gr.Markdown(f"# 🤖 RoboFlamingo - {mode}\n{'Real model loaded!' if MODEL_LOADED else 'Enable GPU for real model'}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
with gr.Row():
|
| 146 |
with gr.Column():
|
| 147 |
+
inst = gr.Textbox(label="Instruction", placeholder="Pick up the red block", lines=3)
|
|
|
|
|
|
|
| 148 |
with gr.Row():
|
| 149 |
+
img1 = gr.Image(label="Third-Person", type="pil", height=250)
|
| 150 |
+
img2 = gr.Image(label="Gripper", type="pil", height=250)
|
| 151 |
+
btn = gr.Button("🚀 Predict", variant="primary", size="lg")
|
| 152 |
+
st = gr.Textbox(label="Status", lines=4, interactive=False)
|
|
|
|
| 153 |
with gr.Column():
|
|
|
|
| 154 |
traj = gr.Image(label="Trajectory", type="pil")
|
| 155 |
grip = gr.Image(label="Gripper", type="pil")
|
| 156 |
|
| 157 |
+
tab = gr.Markdown()
|
| 158 |
+
btn.click(predict, [inst, img1, img2], [traj, grip, tab, st])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
+
gr.Markdown(f"**Status:** {mode} | [Paper](https://arxiv.org/abs/2311.01378)")
|
|
|
|
| 161 |
|
| 162 |
demo.launch()
|
open_flamingo/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .src.flamingo import Flamingo
|
| 2 |
+
from .src.factory import create_model_and_transforms
|
open_flamingo/src/__init__.py
ADDED
|
File without changes
|
open_flamingo/src/factory.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 4 |
+
import open_clip
|
| 5 |
+
|
| 6 |
+
from .flamingo import Flamingo
|
| 7 |
+
from .flamingo_lm import FlamingoLMMixin
|
| 8 |
+
from .utils import extend_instance
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_model_and_transforms(
|
| 12 |
+
clip_vision_encoder_path: str,
|
| 13 |
+
clip_vision_encoder_pretrained: str,
|
| 14 |
+
lang_encoder_path: str,
|
| 15 |
+
tokenizer_path: str,
|
| 16 |
+
cross_attn_every_n_layers: int = 1,
|
| 17 |
+
use_local_files: bool = False,
|
| 18 |
+
decoder_layers_attr_name: str = None,
|
| 19 |
+
freeze_lm_embeddings: bool = False,
|
| 20 |
+
cache_dir: Optional[str] = None,
|
| 21 |
+
**flamingo_kwargs,
|
| 22 |
+
):
|
| 23 |
+
"""
|
| 24 |
+
Initialize a Flamingo model from a pretrained vision encoder and language encoder.
|
| 25 |
+
Appends special tokens to the tokenizer and freezes backbones.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
|
| 29 |
+
clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
|
| 30 |
+
lang_encoder_path (str): path to pretrained language encoder
|
| 31 |
+
tokenizer_path (str): path to pretrained tokenizer
|
| 32 |
+
cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
|
| 33 |
+
use_local_files (bool, optional): whether to use local files. Defaults to False.
|
| 34 |
+
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
|
| 35 |
+
freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver.
|
| 36 |
+
cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights.
|
| 37 |
+
Returns:
|
| 38 |
+
Flamingo: Flamingo model from pretrained vision and language encoders
|
| 39 |
+
Image processor: Pipeline to preprocess input images
|
| 40 |
+
Tokenizer: A tokenizer for the language model
|
| 41 |
+
"""
|
| 42 |
+
vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
|
| 43 |
+
clip_vision_encoder_path,
|
| 44 |
+
pretrained=clip_vision_encoder_pretrained,
|
| 45 |
+
cache_dir=cache_dir,
|
| 46 |
+
)
|
| 47 |
+
# set the vision encoder to output the visual features
|
| 48 |
+
vision_encoder.visual.output_tokens = True
|
| 49 |
+
|
| 50 |
+
text_tokenizer = AutoTokenizer.from_pretrained(
|
| 51 |
+
tokenizer_path,
|
| 52 |
+
local_files_only=use_local_files,
|
| 53 |
+
trust_remote_code=True,
|
| 54 |
+
cache_dir=cache_dir,
|
| 55 |
+
)
|
| 56 |
+
# add Flamingo special tokens to the tokenizer
|
| 57 |
+
text_tokenizer.add_special_tokens(
|
| 58 |
+
{"additional_special_tokens": ["<|endofchunk|>", "<image>", "<action>"]}
|
| 59 |
+
)
|
| 60 |
+
if text_tokenizer.pad_token is None:
|
| 61 |
+
# Issue: GPT models don't have a pad token, which we use to
|
| 62 |
+
# modify labels for the loss.
|
| 63 |
+
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
| 64 |
+
|
| 65 |
+
lang_encoder = AutoModelForCausalLM.from_pretrained(
|
| 66 |
+
lang_encoder_path,
|
| 67 |
+
local_files_only=use_local_files,
|
| 68 |
+
trust_remote_code=True,
|
| 69 |
+
cache_dir=cache_dir,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# hacks for MPT-1B, which doesn't have a get_input_embeddings method
|
| 73 |
+
if "mpt-1b-redpajama-200b" in lang_encoder_path:
|
| 74 |
+
|
| 75 |
+
class EmbeddingFnMixin:
|
| 76 |
+
def get_input_embeddings(self):
|
| 77 |
+
return self.transformer.wte
|
| 78 |
+
|
| 79 |
+
def set_input_embeddings(self, new_embeddings):
|
| 80 |
+
self.transformer.wte = new_embeddings
|
| 81 |
+
|
| 82 |
+
extend_instance(lang_encoder, EmbeddingFnMixin)
|
| 83 |
+
|
| 84 |
+
# convert LM to FlamingoLM
|
| 85 |
+
extend_instance(lang_encoder, FlamingoLMMixin)
|
| 86 |
+
|
| 87 |
+
if decoder_layers_attr_name is None:
|
| 88 |
+
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
|
| 89 |
+
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
|
| 90 |
+
lang_encoder.resize_token_embeddings(len(text_tokenizer))
|
| 91 |
+
|
| 92 |
+
model = Flamingo(
|
| 93 |
+
vision_encoder,
|
| 94 |
+
lang_encoder,
|
| 95 |
+
text_tokenizer.encode("<|endofchunk|>")[-1],
|
| 96 |
+
text_tokenizer.encode("<image>")[-1],
|
| 97 |
+
vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][
|
| 98 |
+
"width"
|
| 99 |
+
],
|
| 100 |
+
cross_attn_every_n_layers=cross_attn_every_n_layers,
|
| 101 |
+
**flamingo_kwargs,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Freeze all parameters
|
| 105 |
+
model.requires_grad_(False)
|
| 106 |
+
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
|
| 107 |
+
|
| 108 |
+
# Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings
|
| 109 |
+
model.perceiver.requires_grad_(True)
|
| 110 |
+
model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
|
| 111 |
+
if not freeze_lm_embeddings:
|
| 112 |
+
model.lang_encoder.get_input_embeddings().requires_grad_(True)
|
| 113 |
+
# TODO: investigate also training the output embeddings when untied
|
| 114 |
+
|
| 115 |
+
print(
|
| 116 |
+
f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return model, image_processor, text_tokenizer
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _infer_decoder_layers_attr_name(model):
|
| 123 |
+
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
|
| 124 |
+
if k.lower() in model.__class__.__name__.lower():
|
| 125 |
+
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
|
| 126 |
+
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
|
| 133 |
+
"opt": "model.decoder.layers",
|
| 134 |
+
"gptj": "transformer.h",
|
| 135 |
+
"gpt-j": "transformer.h",
|
| 136 |
+
"pythia": "gpt_neox.layers",
|
| 137 |
+
"llama": "model.layers",
|
| 138 |
+
"gptneoxforcausallm": "gpt_neox.layers",
|
| 139 |
+
"mpt": "transformer.blocks",
|
| 140 |
+
"mosaicgpt": "transformer.blocks",
|
| 141 |
+
}
|
open_flamingo/src/flamingo.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
from torch import nn
|
| 4 |
+
from .helpers import PerceiverResampler
|
| 5 |
+
from torch.distributed.fsdp.wrap import (
|
| 6 |
+
enable_wrap,
|
| 7 |
+
wrap,
|
| 8 |
+
)
|
| 9 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 10 |
+
from torch.distributed.fsdp import (
|
| 11 |
+
FullyShardedDataParallel as FSDP,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
from .utils import apply_with_stopping_condition
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Flamingo(nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
vision_encoder: nn.Module,
|
| 21 |
+
lang_encoder: nn.Module,
|
| 22 |
+
eoc_token_id: int,
|
| 23 |
+
media_token_id: int,
|
| 24 |
+
vis_dim: int,
|
| 25 |
+
cross_attn_every_n_layers: int = 1,
|
| 26 |
+
gradient_checkpointing: bool = False,
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
Args:
|
| 30 |
+
vision_encoder (nn.Module): HF CLIPModel
|
| 31 |
+
lang_encoder (nn.Module): HF causal language model
|
| 32 |
+
eoc_token_id (int): Token id for <|endofchunk|>
|
| 33 |
+
media_token_id (int): Token id for <image>
|
| 34 |
+
vis_dim (int): Dimension of the visual features.
|
| 35 |
+
Visual features are projected to match this shape along the last dimension.
|
| 36 |
+
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
|
| 37 |
+
"""
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.eoc_token_id = eoc_token_id
|
| 40 |
+
self.media_token_id = media_token_id
|
| 41 |
+
self.vis_dim = vis_dim
|
| 42 |
+
if hasattr(lang_encoder.config, "d_model"):
|
| 43 |
+
self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
|
| 44 |
+
else:
|
| 45 |
+
self.lang_dim = lang_encoder.config.hidden_size
|
| 46 |
+
|
| 47 |
+
self.vision_encoder = vision_encoder.visual
|
| 48 |
+
self.perceiver = PerceiverResampler(dim=self.vis_dim)
|
| 49 |
+
self.lang_encoder = lang_encoder
|
| 50 |
+
self.lang_encoder.init_flamingo(
|
| 51 |
+
media_token_id=media_token_id,
|
| 52 |
+
lang_hidden_size=self.lang_dim,
|
| 53 |
+
vis_hidden_size=self.vis_dim,
|
| 54 |
+
cross_attn_every_n_layers=cross_attn_every_n_layers,
|
| 55 |
+
gradient_checkpointing=gradient_checkpointing,
|
| 56 |
+
)
|
| 57 |
+
self._use_gradient_checkpointing = gradient_checkpointing
|
| 58 |
+
self.perceiver._use_gradient_checkpointing = gradient_checkpointing
|
| 59 |
+
|
| 60 |
+
def forward(
|
| 61 |
+
self,
|
| 62 |
+
vision_x: torch.Tensor,
|
| 63 |
+
lang_x: torch.Tensor,
|
| 64 |
+
attention_mask: torch.Tensor = None,
|
| 65 |
+
labels: torch.Tensor = None,
|
| 66 |
+
clear_conditioned_layers: bool = True,
|
| 67 |
+
past_key_values=None,
|
| 68 |
+
use_cache: bool = False,
|
| 69 |
+
):
|
| 70 |
+
"""
|
| 71 |
+
Forward pass of Flamingo.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
vision_x (torch.Tensor): Vision input
|
| 75 |
+
shape (B, T_img, F, C, H, W) with F=1
|
| 76 |
+
lang_x (torch.Tensor): Language input ids
|
| 77 |
+
shape (B, T_txt)
|
| 78 |
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
| 79 |
+
labels (torch.Tensor, optional): Labels. Defaults to None.
|
| 80 |
+
clear_conditioned_layers: if True, clear the conditioned layers
|
| 81 |
+
once the foward pass is completed. Set this to false if the
|
| 82 |
+
same set of images will be reused in another subsequent
|
| 83 |
+
forward pass.
|
| 84 |
+
past_key_values: pre-computed values to pass to language model.
|
| 85 |
+
See past_key_values documentation in Hugging Face
|
| 86 |
+
CausalLM models.
|
| 87 |
+
use_cache: whether to use cached key values. See use_cache
|
| 88 |
+
documentation in Hugging Face CausalLM models.
|
| 89 |
+
"""
|
| 90 |
+
assert (
|
| 91 |
+
self.lang_encoder.initialized_flamingo
|
| 92 |
+
), "Flamingo layers are not initialized. Please call `init_flamingo` first."
|
| 93 |
+
|
| 94 |
+
assert (
|
| 95 |
+
self.lang_encoder._use_cached_vision_x or vision_x is not None
|
| 96 |
+
), "Must provide either vision_x or have precached media using cache_media()."
|
| 97 |
+
|
| 98 |
+
if self.lang_encoder._use_cached_vision_x:
|
| 99 |
+
# Case: use cached; vision_x should be cached and other
|
| 100 |
+
# vision-related inputs should not be provided.
|
| 101 |
+
assert (
|
| 102 |
+
vision_x is None
|
| 103 |
+
), "Expect vision_x to be None when media has been cached using cache_media(). Try uncache_media() first."
|
| 104 |
+
assert self.lang_encoder.is_conditioned()
|
| 105 |
+
|
| 106 |
+
else:
|
| 107 |
+
# Case: do not use caching (i.e. this is a standard forward pass);
|
| 108 |
+
self._encode_vision_x(vision_x=vision_x)
|
| 109 |
+
self._condition_media_locations(input_ids=lang_x)
|
| 110 |
+
|
| 111 |
+
output = self.lang_encoder(
|
| 112 |
+
input_ids=lang_x,
|
| 113 |
+
attention_mask=attention_mask,
|
| 114 |
+
labels=labels,
|
| 115 |
+
past_key_values=past_key_values,
|
| 116 |
+
use_cache=use_cache,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
if clear_conditioned_layers:
|
| 120 |
+
self.lang_encoder.clear_conditioned_layers()
|
| 121 |
+
|
| 122 |
+
return output
|
| 123 |
+
|
| 124 |
+
def generate(
|
| 125 |
+
self,
|
| 126 |
+
vision_x: torch.Tensor,
|
| 127 |
+
lang_x: torch.Tensor,
|
| 128 |
+
attention_mask: torch.Tensor = None,
|
| 129 |
+
**kwargs,
|
| 130 |
+
):
|
| 131 |
+
"""
|
| 132 |
+
Generate text conditioned on vision and language inputs.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
vision_x (torch.Tensor): Vision input
|
| 136 |
+
shape (B, T_img, F, C, H, W)
|
| 137 |
+
images in the same chunk are collated along T_img, and frames are collated along F
|
| 138 |
+
currently only F=1 is supported (single-frame videos)
|
| 139 |
+
lang_x (torch.Tensor): Language input
|
| 140 |
+
shape (B, T_txt)
|
| 141 |
+
**kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs:
|
| 142 |
+
max_length (int, optional): Maximum length of the output. Defaults to None.
|
| 143 |
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
| 144 |
+
num_beams (int, optional): Number of beams. Defaults to 1.
|
| 145 |
+
max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
|
| 146 |
+
temperature (float, optional): Temperature. Defaults to 1.0.
|
| 147 |
+
top_k (int, optional): Top k. Defaults to 50.
|
| 148 |
+
top_p (float, optional): Top p. Defaults to 1.0.
|
| 149 |
+
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
|
| 150 |
+
length_penalty (float, optional): Length penalty. Defaults to 1.0.
|
| 151 |
+
num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
|
| 152 |
+
do_sample (bool, optional): Do sample. Defaults to False.
|
| 153 |
+
early_stopping (bool, optional): Early stopping. Defaults to False.
|
| 154 |
+
Returns:
|
| 155 |
+
torch.Tensor: lang_x with generated tokens appended to it
|
| 156 |
+
"""
|
| 157 |
+
num_beams = kwargs.pop("num_beams", 1)
|
| 158 |
+
if num_beams > 1:
|
| 159 |
+
vision_x = vision_x.repeat_interleave(num_beams, dim=0)
|
| 160 |
+
|
| 161 |
+
self.lang_encoder._use_cached_vision_x = True
|
| 162 |
+
self._encode_vision_x(vision_x=vision_x)
|
| 163 |
+
|
| 164 |
+
eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id)
|
| 165 |
+
output = self.lang_encoder.generate(
|
| 166 |
+
input_ids=lang_x,
|
| 167 |
+
attention_mask=attention_mask,
|
| 168 |
+
eos_token_id=eos_token_id,
|
| 169 |
+
num_beams=num_beams,
|
| 170 |
+
**kwargs,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
self.lang_encoder.clear_conditioned_layers()
|
| 174 |
+
self.lang_encoder._use_cached_vision_x = False
|
| 175 |
+
return output
|
| 176 |
+
|
| 177 |
+
def _encode_vision_x(self, vision_x: torch.Tensor):
|
| 178 |
+
"""
|
| 179 |
+
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
|
| 180 |
+
Args:
|
| 181 |
+
vision_x (torch.Tensor): Vision input
|
| 182 |
+
shape (B, T_img, F, C, H, W)
|
| 183 |
+
Images in the same chunk are collated along T_img, and frames are collated along F
|
| 184 |
+
Currently only F=1 is supported (single-frame videos)
|
| 185 |
+
|
| 186 |
+
rearrange code based on https://github.com/dhansmair/flamingo-mini
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
|
| 190 |
+
b, T, F = vision_x.shape[:3]
|
| 191 |
+
assert F == 1, "Only single frame supported"
|
| 192 |
+
|
| 193 |
+
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
vision_x = self.vision_encoder(vision_x)[1]
|
| 196 |
+
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
|
| 197 |
+
vision_x = self.perceiver(vision_x)
|
| 198 |
+
|
| 199 |
+
for layer in self.lang_encoder._get_decoder_layers():
|
| 200 |
+
layer.condition_vis_x(vision_x)
|
| 201 |
+
|
| 202 |
+
def wrap_fsdp(self, wrapper_kwargs, device_id):
|
| 203 |
+
"""
|
| 204 |
+
Manually wraps submodules for FSDP and move other parameters to device_id.
|
| 205 |
+
|
| 206 |
+
Why manually wrap?
|
| 207 |
+
- all parameters within the FSDP wrapper must have the same requires_grad.
|
| 208 |
+
We have a mix of frozen and unfrozen parameters.
|
| 209 |
+
- model.vision_encoder.visual needs to be individually wrapped or encode_vision_x errors
|
| 210 |
+
See: https://github.com/pytorch/pytorch/issues/82461#issuecomment-1269136344
|
| 211 |
+
|
| 212 |
+
The rough wrapping structure is:
|
| 213 |
+
- FlamingoModel
|
| 214 |
+
- FSDP(FSDP(vision_encoder))
|
| 215 |
+
- FSDP(FSDP(perceiver))
|
| 216 |
+
- lang_encoder
|
| 217 |
+
- FSDP(FSDP(input_embeddings))
|
| 218 |
+
- FlamingoLayers
|
| 219 |
+
- FSDP(FSDP(gated_cross_attn_layer))
|
| 220 |
+
- FSDP(FSDP(decoder_layer))
|
| 221 |
+
- FSDP(FSDP(output_embeddings))
|
| 222 |
+
- other parameters
|
| 223 |
+
|
| 224 |
+
Known issues:
|
| 225 |
+
- Our FSDP strategy is not compatible with tied embeddings. If the LM embeddings are tied,
|
| 226 |
+
train with DDP or set the --freeze_lm_embeddings flag to true.
|
| 227 |
+
- With FSDP + gradient ckpting, one can increase the batch size with seemingly no upper bound.
|
| 228 |
+
Although the training curves look okay, we found that downstream performance dramatically
|
| 229 |
+
degrades if the batch size is unreasonably large (e.g., 100 MMC4 batch size for OPT-125M).
|
| 230 |
+
|
| 231 |
+
FAQs about our FSDP wrapping strategy:
|
| 232 |
+
Why double wrap?
|
| 233 |
+
As of torch==2.0.1, FSDP's _post_forward_hook and _post_backward_hook
|
| 234 |
+
only free gathered parameters if the module is NOT FSDP root.
|
| 235 |
+
|
| 236 |
+
Why unfreeze the decoder_layers?
|
| 237 |
+
See https://github.com/pytorch/pytorch/issues/95805
|
| 238 |
+
As of torch==2.0.1, FSDP's _post_backward_hook is only registed if the flat param
|
| 239 |
+
requires_grad=True. We need the postback to fire to avoid OOM.
|
| 240 |
+
To effectively freeze the decoder layers, we exclude them from the optimizer.
|
| 241 |
+
|
| 242 |
+
What is assumed to be frozen v. unfrozen?
|
| 243 |
+
We assume that the model is being trained under normal Flamingo settings
|
| 244 |
+
with these lines being called in factory.py:
|
| 245 |
+
```
|
| 246 |
+
# Freeze all parameters
|
| 247 |
+
model.requires_grad_(False)
|
| 248 |
+
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
|
| 249 |
+
|
| 250 |
+
# Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings
|
| 251 |
+
model.perceiver.requires_grad_(True)
|
| 252 |
+
model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
|
| 253 |
+
[optional] model.lang_encoder.get_input_embeddings().requires_grad_(True)
|
| 254 |
+
```
|
| 255 |
+
"""
|
| 256 |
+
# unfreeze the decoder layers
|
| 257 |
+
for block in self.lang_encoder.old_decoder_blocks:
|
| 258 |
+
block.requires_grad_(True)
|
| 259 |
+
|
| 260 |
+
# wrap in FSDP
|
| 261 |
+
with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs):
|
| 262 |
+
self.perceiver = wrap(wrap(self.perceiver))
|
| 263 |
+
self.lang_encoder.old_decoder_blocks = nn.ModuleList(
|
| 264 |
+
wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks
|
| 265 |
+
)
|
| 266 |
+
self.lang_encoder.gated_cross_attn_layers = nn.ModuleList(
|
| 267 |
+
wrap(wrap(layer)) if layer is not None else None
|
| 268 |
+
for layer in self.lang_encoder.gated_cross_attn_layers
|
| 269 |
+
)
|
| 270 |
+
self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing)
|
| 271 |
+
self.lang_encoder.set_input_embeddings(
|
| 272 |
+
wrap(wrap(self.lang_encoder.get_input_embeddings()))
|
| 273 |
+
)
|
| 274 |
+
self.lang_encoder.set_output_embeddings(
|
| 275 |
+
wrap(wrap(self.lang_encoder.get_output_embeddings()))
|
| 276 |
+
)
|
| 277 |
+
self.vision_encoder = wrap(wrap(self.vision_encoder)) # frozen
|
| 278 |
+
|
| 279 |
+
# manually move non-FSDP managed parameters to device_id
|
| 280 |
+
# these are all in lang_encoder
|
| 281 |
+
apply_with_stopping_condition(
|
| 282 |
+
module=self.lang_encoder,
|
| 283 |
+
apply_fn=lambda m: m.to(device_id),
|
| 284 |
+
apply_condition=lambda m: len(list(m.children())) == 0,
|
| 285 |
+
stopping_condition=lambda m: isinstance(m, FSDP),
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# exclude the original decoder layers from the optimizer
|
| 289 |
+
for block in self.lang_encoder.old_decoder_blocks:
|
| 290 |
+
for p in block.parameters():
|
| 291 |
+
p.exclude_from_optimizer = True
|
| 292 |
+
|
| 293 |
+
# set up clip_grad_norm_ function
|
| 294 |
+
def clip_grad_norm_(max_norm):
|
| 295 |
+
self.perceiver.clip_grad_norm_(max_norm)
|
| 296 |
+
for layer in self.lang_encoder.gated_cross_attn_layers:
|
| 297 |
+
if layer is not None:
|
| 298 |
+
layer.clip_grad_norm_(max_norm)
|
| 299 |
+
self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm)
|
| 300 |
+
|
| 301 |
+
self.clip_grad_norm_ = clip_grad_norm_
|
| 302 |
+
|
| 303 |
+
def _condition_media_locations(self, input_ids: torch.Tensor):
|
| 304 |
+
"""
|
| 305 |
+
Compute the media token locations from lang_x and condition the language model on these.
|
| 306 |
+
Args:
|
| 307 |
+
input_ids (torch.Tensor): Language input
|
| 308 |
+
shape (B, T_txt)
|
| 309 |
+
"""
|
| 310 |
+
media_locations = input_ids == self.media_token_id
|
| 311 |
+
|
| 312 |
+
for layer in self.lang_encoder._get_decoder_layers():
|
| 313 |
+
layer.condition_media_locations(media_locations)
|
| 314 |
+
|
| 315 |
+
def cache_media(self, input_ids: torch.Tensor, vision_x: torch.Tensor):
|
| 316 |
+
"""
|
| 317 |
+
Pre-cache a prompt/sequence of images / text for log-likelihood evaluations.
|
| 318 |
+
All subsequent calls to forward() will generate attending to the LAST
|
| 319 |
+
image in vision_x.
|
| 320 |
+
This is not meant to be used to cache things for generate().
|
| 321 |
+
Args:
|
| 322 |
+
input_ids (torch.Tensor): Language input
|
| 323 |
+
shape (B, T_txt)
|
| 324 |
+
vision_x (torch.Tensor): Vision input
|
| 325 |
+
shape (B, T_img, F, C, H, W)
|
| 326 |
+
Images in the same chunk are collated along T_img, and frames are collated along F
|
| 327 |
+
Currently only F=1 is supported (single-frame videos)
|
| 328 |
+
"""
|
| 329 |
+
self._encode_vision_x(vision_x=vision_x)
|
| 330 |
+
self._condition_media_locations(input_ids=input_ids)
|
| 331 |
+
self.lang_encoder._use_cached_vision_x = True
|
| 332 |
+
|
| 333 |
+
def uncache_media(self):
|
| 334 |
+
"""
|
| 335 |
+
Clear all conditioning.
|
| 336 |
+
"""
|
| 337 |
+
self.lang_encoder.clear_conditioned_layers()
|
| 338 |
+
self.lang_encoder._use_cached_vision_x = False
|
open_flamingo/src/flamingo_lm.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from .helpers import GatedCrossAttentionBlock
|
| 3 |
+
from .utils import getattr_recursive, setattr_recursive
|
| 4 |
+
import copy
|
| 5 |
+
|
| 6 |
+
class FlamingoLayer(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
FlamingoLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False, residual=False
|
| 13 |
+
):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.gated_cross_attn_layer = gated_cross_attn_layer
|
| 16 |
+
self.decoder_layer = decoder_layer
|
| 17 |
+
self.vis_x = None
|
| 18 |
+
self.media_locations = None
|
| 19 |
+
self.residual = residual
|
| 20 |
+
|
| 21 |
+
if self.gated_cross_attn_layer is not None:
|
| 22 |
+
self.gated_cross_attn_layer._use_gradient_checkpointing = (
|
| 23 |
+
gradient_checkpointing
|
| 24 |
+
)
|
| 25 |
+
self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing
|
| 26 |
+
|
| 27 |
+
def clone_parameters(self):
|
| 28 |
+
self.res_layer = copy.deepcopy(self.gated_cross_attn_layer)
|
| 29 |
+
if self.res_layer is not None:
|
| 30 |
+
self.res_layer.requires_grad_(False)
|
| 31 |
+
|
| 32 |
+
def is_conditioned(self) -> bool:
|
| 33 |
+
"""Check whether the layer is conditioned."""
|
| 34 |
+
return self.vis_x is not None and self.media_locations is not None
|
| 35 |
+
|
| 36 |
+
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
|
| 37 |
+
def condition_vis_x(self, vis_x):
|
| 38 |
+
self.vis_x = vis_x
|
| 39 |
+
|
| 40 |
+
def condition_media_locations(self, media_locations):
|
| 41 |
+
self.media_locations = media_locations
|
| 42 |
+
|
| 43 |
+
def condition_use_cached_media(self, use_cached_media):
|
| 44 |
+
self.use_cached_media = use_cached_media
|
| 45 |
+
|
| 46 |
+
def forward(
|
| 47 |
+
self,
|
| 48 |
+
lang_x,
|
| 49 |
+
attention_mask=None,
|
| 50 |
+
**decoder_layer_kwargs,
|
| 51 |
+
):
|
| 52 |
+
# Cross attention
|
| 53 |
+
if self.gated_cross_attn_layer is not None:
|
| 54 |
+
if self.vis_x is None:
|
| 55 |
+
raise ValueError("vis_x must be conditioned before forward pass")
|
| 56 |
+
|
| 57 |
+
if self.media_locations is None:
|
| 58 |
+
raise ValueError(
|
| 59 |
+
"media_locations must be conditioned before forward pass"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
lang_x = self.gated_cross_attn_layer(
|
| 63 |
+
lang_x,
|
| 64 |
+
self.vis_x,
|
| 65 |
+
media_locations=self.media_locations,
|
| 66 |
+
use_cached_media=self.use_cached_media,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Residual
|
| 70 |
+
if self.residual and self.res_layer is not None:
|
| 71 |
+
lang_x_res = self.res_layer(
|
| 72 |
+
lang_x,
|
| 73 |
+
self.vis_x,
|
| 74 |
+
media_locations=self.media_locations,
|
| 75 |
+
attend_previous=self.attend_previous,
|
| 76 |
+
)
|
| 77 |
+
lang_x = (lang_x + lang_x_res) / 2.0
|
| 78 |
+
|
| 79 |
+
# Normal decoder layer
|
| 80 |
+
lang_x = self.decoder_layer(
|
| 81 |
+
lang_x, attention_mask=attention_mask, **decoder_layer_kwargs
|
| 82 |
+
)
|
| 83 |
+
return lang_x
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class FlamingoLMMixin(nn.Module):
|
| 87 |
+
"""
|
| 88 |
+
Mixin to add cross-attention layers to a language model.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
|
| 92 |
+
self.decoder_layers_attr_name = decoder_layers_attr_name
|
| 93 |
+
|
| 94 |
+
def _get_decoder_layers(self):
|
| 95 |
+
return getattr_recursive(self, self.decoder_layers_attr_name)
|
| 96 |
+
|
| 97 |
+
def _set_decoder_layers(self, value):
|
| 98 |
+
setattr_recursive(self, self.decoder_layers_attr_name, value)
|
| 99 |
+
|
| 100 |
+
def init_flamingo(
|
| 101 |
+
self,
|
| 102 |
+
media_token_id,
|
| 103 |
+
lang_hidden_size,
|
| 104 |
+
vis_hidden_size,
|
| 105 |
+
cross_attn_every_n_layers,
|
| 106 |
+
gradient_checkpointing,
|
| 107 |
+
residual=False,
|
| 108 |
+
):
|
| 109 |
+
"""
|
| 110 |
+
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
|
| 111 |
+
"""
|
| 112 |
+
print('-'*100)
|
| 113 |
+
print(self.decoder_layers_attr_name)
|
| 114 |
+
self.old_decoder_blocks = self._get_decoder_layers()
|
| 115 |
+
self.gated_cross_attn_layers = nn.ModuleList(
|
| 116 |
+
[
|
| 117 |
+
GatedCrossAttentionBlock(
|
| 118 |
+
dim=lang_hidden_size, dim_visual=vis_hidden_size
|
| 119 |
+
)
|
| 120 |
+
if (layer_idx + 1) % cross_attn_every_n_layers == 0
|
| 121 |
+
else None
|
| 122 |
+
for layer_idx, _ in enumerate(self._get_decoder_layers())
|
| 123 |
+
]
|
| 124 |
+
)
|
| 125 |
+
self.init_flamingo_layers(gradient_checkpointing, residual=residual)
|
| 126 |
+
self.media_token_id = media_token_id
|
| 127 |
+
self.initialized_flamingo = True
|
| 128 |
+
self._use_cached_vision_x = False
|
| 129 |
+
|
| 130 |
+
def init_flamingo_layers(self, gradient_checkpointing, residual=False):
|
| 131 |
+
"""
|
| 132 |
+
Re initializes the FlamingoLayers.
|
| 133 |
+
Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks
|
| 134 |
+
"""
|
| 135 |
+
self._set_decoder_layers(
|
| 136 |
+
nn.ModuleList(
|
| 137 |
+
[
|
| 138 |
+
FlamingoLayer(
|
| 139 |
+
gated_cross_attn_layer, decoder_layer, gradient_checkpointing, residual=residual
|
| 140 |
+
)
|
| 141 |
+
for gated_cross_attn_layer, decoder_layer in zip(
|
| 142 |
+
self.gated_cross_attn_layers, self.old_decoder_blocks
|
| 143 |
+
)
|
| 144 |
+
]
|
| 145 |
+
)
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def forward(self, input_ids, attention_mask, **kwargs):
|
| 149 |
+
"""Condition the Flamingo layers on the media locations before forward()"""
|
| 150 |
+
if not self.initialized_flamingo:
|
| 151 |
+
raise ValueError(
|
| 152 |
+
"Flamingo layers are not initialized. Please call `init_flamingo` first."
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
media_locations = input_ids == self.media_token_id
|
| 156 |
+
|
| 157 |
+
# if there are media already cached and we're generating and there are no media tokens in the input,
|
| 158 |
+
# we'll assume that ALL input tokens should attend to the last previous media that is cached.
|
| 159 |
+
# this is especially important for HF generate() compatibility, since generate() calls forward()
|
| 160 |
+
# repeatedly one token at a time (with no media tokens).
|
| 161 |
+
# without this check, the model would not attend to any images when generating (after the first token)
|
| 162 |
+
use_cached_media_locations = (
|
| 163 |
+
self._use_cached_vision_x
|
| 164 |
+
and self.is_conditioned()
|
| 165 |
+
and not media_locations.any()
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
for layer in self._get_decoder_layers():
|
| 169 |
+
if not use_cached_media_locations:
|
| 170 |
+
layer.condition_media_locations(media_locations)
|
| 171 |
+
layer.condition_use_cached_media(use_cached_media_locations)
|
| 172 |
+
|
| 173 |
+
# package arguments for the other parent's forward. since we don't know the order of the arguments,
|
| 174 |
+
# make them all kwargs
|
| 175 |
+
kwargs["input_ids"] = input_ids
|
| 176 |
+
kwargs["attention_mask"] = attention_mask
|
| 177 |
+
return super().forward(**kwargs) # Call the other parent's forward method
|
| 178 |
+
|
| 179 |
+
def is_conditioned(self) -> bool:
|
| 180 |
+
"""Check whether all decoder layers are already conditioned."""
|
| 181 |
+
return all(l.is_conditioned() for l in self._get_decoder_layers())
|
| 182 |
+
|
| 183 |
+
def clone_parameters(self):
|
| 184 |
+
for layer in self._get_decoder_layers():
|
| 185 |
+
layer.clone_parameters()
|
| 186 |
+
|
| 187 |
+
def clear_conditioned_layers(self):
|
| 188 |
+
for layer in self._get_decoder_layers():
|
| 189 |
+
layer.condition_vis_x(None)
|
| 190 |
+
layer.condition_media_locations(None)
|
| 191 |
+
layer.condition_use_cached_media(None)
|
open_flamingo/src/helpers.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on: https://github.com/lucidrains/flamingo-pytorch
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from einops import rearrange, repeat
|
| 7 |
+
from einops_exts import rearrange_many
|
| 8 |
+
from torch import einsum, nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def exists(val):
|
| 12 |
+
return val is not None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def FeedForward(dim, mult=4):
|
| 16 |
+
inner_dim = int(dim * mult)
|
| 17 |
+
return nn.Sequential(
|
| 18 |
+
nn.LayerNorm(dim),
|
| 19 |
+
nn.Linear(dim, inner_dim, bias=False),
|
| 20 |
+
nn.GELU(),
|
| 21 |
+
nn.Linear(inner_dim, dim, bias=False),
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PerceiverAttention(nn.Module):
|
| 26 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.scale = dim_head**-0.5
|
| 29 |
+
self.heads = heads
|
| 30 |
+
inner_dim = dim_head * heads
|
| 31 |
+
|
| 32 |
+
self.norm_media = nn.LayerNorm(dim)
|
| 33 |
+
self.norm_latents = nn.LayerNorm(dim)
|
| 34 |
+
|
| 35 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 36 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
| 37 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
| 38 |
+
|
| 39 |
+
def forward(self, x, latents):
|
| 40 |
+
"""
|
| 41 |
+
Args:
|
| 42 |
+
x (torch.Tensor): image features
|
| 43 |
+
shape (b, T, n1, D)
|
| 44 |
+
latent (torch.Tensor): latent features
|
| 45 |
+
shape (b, T, n2, D)
|
| 46 |
+
"""
|
| 47 |
+
x = self.norm_media(x)
|
| 48 |
+
latents = self.norm_latents(latents)
|
| 49 |
+
|
| 50 |
+
h = self.heads
|
| 51 |
+
|
| 52 |
+
q = self.to_q(latents)
|
| 53 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
| 54 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
| 55 |
+
q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
|
| 56 |
+
q = q * self.scale
|
| 57 |
+
|
| 58 |
+
# attention
|
| 59 |
+
sim = einsum("... i d, ... j d -> ... i j", q, k)
|
| 60 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
| 61 |
+
attn = sim.softmax(dim=-1)
|
| 62 |
+
|
| 63 |
+
out = einsum("... i j, ... j d -> ... i d", attn, v)
|
| 64 |
+
out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
|
| 65 |
+
return self.to_out(out)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class PerceiverResampler(nn.Module):
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
*,
|
| 72 |
+
dim,
|
| 73 |
+
depth=6,
|
| 74 |
+
dim_head=64,
|
| 75 |
+
heads=8,
|
| 76 |
+
num_latents=64,
|
| 77 |
+
max_num_media=None,
|
| 78 |
+
max_num_frames=None,
|
| 79 |
+
ff_mult=4,
|
| 80 |
+
):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
| 83 |
+
self.frame_embs = (
|
| 84 |
+
nn.Parameter(torch.randn(max_num_frames, dim))
|
| 85 |
+
if exists(max_num_frames)
|
| 86 |
+
else None
|
| 87 |
+
)
|
| 88 |
+
self.media_time_embs = (
|
| 89 |
+
nn.Parameter(torch.randn(max_num_media, 1, dim))
|
| 90 |
+
if exists(max_num_media)
|
| 91 |
+
else None
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
self.layers = nn.ModuleList([])
|
| 95 |
+
for _ in range(depth):
|
| 96 |
+
self.layers.append(
|
| 97 |
+
nn.ModuleList(
|
| 98 |
+
[
|
| 99 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
| 100 |
+
FeedForward(dim=dim, mult=ff_mult),
|
| 101 |
+
]
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
self.norm = nn.LayerNorm(dim)
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
"""
|
| 109 |
+
Args:
|
| 110 |
+
x (torch.Tensor): image features
|
| 111 |
+
shape (b, T, F, v, D)
|
| 112 |
+
Returns:
|
| 113 |
+
shape (b, T, n, D) where n is self.num_latents
|
| 114 |
+
"""
|
| 115 |
+
b, T, F, v = x.shape[:4]
|
| 116 |
+
|
| 117 |
+
# frame and media time embeddings
|
| 118 |
+
if exists(self.frame_embs):
|
| 119 |
+
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
|
| 120 |
+
x = x + frame_embs
|
| 121 |
+
x = rearrange(
|
| 122 |
+
x, "b T F v d -> b T (F v) d"
|
| 123 |
+
) # flatten the frame and spatial dimensions
|
| 124 |
+
if exists(self.media_time_embs):
|
| 125 |
+
x = x + self.media_time_embs[:T]
|
| 126 |
+
|
| 127 |
+
# blocks
|
| 128 |
+
latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
|
| 129 |
+
for attn, ff in self.layers:
|
| 130 |
+
latents = attn(x, latents) + latents
|
| 131 |
+
latents = ff(latents) + latents
|
| 132 |
+
return self.norm(latents)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# gated cross attention
|
| 136 |
+
class MaskedCrossAttention(nn.Module):
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
*,
|
| 140 |
+
dim,
|
| 141 |
+
dim_visual,
|
| 142 |
+
dim_head=64,
|
| 143 |
+
heads=8,
|
| 144 |
+
only_attend_immediate_media=True,
|
| 145 |
+
):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.scale = dim_head**-0.5
|
| 148 |
+
self.heads = heads
|
| 149 |
+
inner_dim = dim_head * heads
|
| 150 |
+
|
| 151 |
+
self.norm = nn.LayerNorm(dim)
|
| 152 |
+
|
| 153 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 154 |
+
self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
|
| 155 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
| 156 |
+
|
| 157 |
+
# whether for text to only attend to immediate preceding image, or all previous images
|
| 158 |
+
self.only_attend_immediate_media = only_attend_immediate_media
|
| 159 |
+
|
| 160 |
+
def forward(self, x, media, media_locations=None, use_cached_media=False):
|
| 161 |
+
"""
|
| 162 |
+
Args:
|
| 163 |
+
x (torch.Tensor): text features
|
| 164 |
+
shape (B, T_txt, D_txt)
|
| 165 |
+
media (torch.Tensor): image features
|
| 166 |
+
shape (B, T_img, n, D_img) where n is the dim of the latents
|
| 167 |
+
media_locations: boolean mask identifying the media tokens in x
|
| 168 |
+
shape (B, T_txt)
|
| 169 |
+
use_cached_media: bool
|
| 170 |
+
If true, treat all of x as if they occur after the last media
|
| 171 |
+
registered in media_locations. T_txt does not need to exactly
|
| 172 |
+
equal media_locations.shape[1] in this case
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
if not use_cached_media:
|
| 176 |
+
assert (
|
| 177 |
+
media_locations.shape[1] == x.shape[1]
|
| 178 |
+
), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}"
|
| 179 |
+
|
| 180 |
+
T_txt = x.shape[1]
|
| 181 |
+
_, T_img, n = media.shape[:3]
|
| 182 |
+
h = self.heads
|
| 183 |
+
|
| 184 |
+
x = self.norm(x)
|
| 185 |
+
|
| 186 |
+
q = self.to_q(x)
|
| 187 |
+
media = rearrange(media, "b t n d -> b (t n) d")
|
| 188 |
+
|
| 189 |
+
k, v = self.to_kv(media).chunk(2, dim=-1)
|
| 190 |
+
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
|
| 191 |
+
|
| 192 |
+
q = q * self.scale
|
| 193 |
+
|
| 194 |
+
sim = einsum("... i d, ... j d -> ... i j", q, k)
|
| 195 |
+
|
| 196 |
+
if exists(media_locations):
|
| 197 |
+
media_time = torch.arange(T_img, device=x.device) + 1
|
| 198 |
+
|
| 199 |
+
if use_cached_media:
|
| 200 |
+
# text time is set to the last cached media location
|
| 201 |
+
text_time = repeat(
|
| 202 |
+
torch.count_nonzero(media_locations, dim=1),
|
| 203 |
+
"b -> b i",
|
| 204 |
+
i=T_txt,
|
| 205 |
+
)
|
| 206 |
+
else:
|
| 207 |
+
# at each boolean of True, increment the time counter (relative to media time)
|
| 208 |
+
text_time = media_locations.cumsum(dim=-1)
|
| 209 |
+
|
| 210 |
+
# text time must equal media time if only attending to most immediate image
|
| 211 |
+
# otherwise, as long as text time is greater than media time (if attending to all previous images / media)
|
| 212 |
+
mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
|
| 213 |
+
|
| 214 |
+
text_to_media_mask = mask_op(
|
| 215 |
+
rearrange(text_time, "b i -> b 1 i 1"),
|
| 216 |
+
repeat(media_time, "j -> 1 1 1 (j n)", n=n),
|
| 217 |
+
)
|
| 218 |
+
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
|
| 219 |
+
|
| 220 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
| 221 |
+
attn = sim.softmax(dim=-1)
|
| 222 |
+
|
| 223 |
+
if exists(media_locations) and self.only_attend_immediate_media:
|
| 224 |
+
# any text without a preceding media needs to have attention zeroed out
|
| 225 |
+
text_without_media_mask = text_time == 0
|
| 226 |
+
text_without_media_mask = rearrange(
|
| 227 |
+
text_without_media_mask, "b i -> b 1 i 1"
|
| 228 |
+
)
|
| 229 |
+
attn = attn.masked_fill(text_without_media_mask, 0.0)
|
| 230 |
+
|
| 231 |
+
out = einsum("... i j, ... j d -> ... i d", attn, v)
|
| 232 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 233 |
+
return self.to_out(out)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class GatedCrossAttentionBlock(nn.Module):
|
| 237 |
+
def __init__(
|
| 238 |
+
self,
|
| 239 |
+
*,
|
| 240 |
+
dim,
|
| 241 |
+
dim_visual,
|
| 242 |
+
dim_head=64,
|
| 243 |
+
heads=8,
|
| 244 |
+
ff_mult=4,
|
| 245 |
+
only_attend_immediate_media=True,
|
| 246 |
+
):
|
| 247 |
+
super().__init__()
|
| 248 |
+
self.attn = MaskedCrossAttention(
|
| 249 |
+
dim=dim,
|
| 250 |
+
dim_visual=dim_visual,
|
| 251 |
+
dim_head=dim_head,
|
| 252 |
+
heads=heads,
|
| 253 |
+
only_attend_immediate_media=only_attend_immediate_media,
|
| 254 |
+
)
|
| 255 |
+
self.attn_gate = nn.Parameter(torch.tensor([0.0]))
|
| 256 |
+
|
| 257 |
+
self.ff = FeedForward(dim, mult=ff_mult)
|
| 258 |
+
self.ff_gate = nn.Parameter(torch.tensor([0.0]))
|
| 259 |
+
|
| 260 |
+
def forward(
|
| 261 |
+
self,
|
| 262 |
+
x,
|
| 263 |
+
media,
|
| 264 |
+
media_locations=None,
|
| 265 |
+
use_cached_media=False,
|
| 266 |
+
):
|
| 267 |
+
x = (
|
| 268 |
+
self.attn(
|
| 269 |
+
x,
|
| 270 |
+
media,
|
| 271 |
+
media_locations=media_locations,
|
| 272 |
+
use_cached_media=use_cached_media,
|
| 273 |
+
)
|
| 274 |
+
* self.attn_gate.tanh()
|
| 275 |
+
+ x
|
| 276 |
+
)
|
| 277 |
+
x = self.ff(x) * self.ff_gate.tanh() + x
|
| 278 |
+
|
| 279 |
+
return x
|
open_flamingo/src/utils.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def extend_instance(obj, mixin):
|
| 2 |
+
"""Apply mixins to a class instance after creation"""
|
| 3 |
+
base_cls = obj.__class__
|
| 4 |
+
base_cls_name = obj.__class__.__name__
|
| 5 |
+
obj.__class__ = type(
|
| 6 |
+
base_cls_name, (mixin, base_cls), {}
|
| 7 |
+
) # mixin needs to go first for our forward() logic to work
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def getattr_recursive(obj, att):
|
| 11 |
+
"""
|
| 12 |
+
Return nested attribute of obj
|
| 13 |
+
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
|
| 14 |
+
"""
|
| 15 |
+
if att == "":
|
| 16 |
+
return obj
|
| 17 |
+
i = att.find(".")
|
| 18 |
+
if i < 0:
|
| 19 |
+
return getattr(obj, att)
|
| 20 |
+
else:
|
| 21 |
+
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def setattr_recursive(obj, att, val):
|
| 25 |
+
"""
|
| 26 |
+
Set nested attribute of obj
|
| 27 |
+
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
|
| 28 |
+
"""
|
| 29 |
+
if "." in att:
|
| 30 |
+
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
|
| 31 |
+
setattr(obj, att.split(".")[-1], val)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def apply_with_stopping_condition(
|
| 35 |
+
module, apply_fn, apply_condition=None, stopping_condition=None, **other_args
|
| 36 |
+
):
|
| 37 |
+
if stopping_condition(module):
|
| 38 |
+
return
|
| 39 |
+
if apply_condition(module):
|
| 40 |
+
apply_fn(module, **other_args)
|
| 41 |
+
for child in module.children():
|
| 42 |
+
apply_with_stopping_condition(
|
| 43 |
+
child,
|
| 44 |
+
apply_fn,
|
| 45 |
+
apply_condition=apply_condition,
|
| 46 |
+
stopping_condition=stopping_condition,
|
| 47 |
+
**other_args
|
| 48 |
+
)
|