Spaces:
Runtime error
Runtime error
Deploy real interactive model
Browse files- README.md +4 -35
- app.py +47 -116
- open_flamingo/eval/README.md +68 -0
- open_flamingo/eval/__init__.py +1 -0
- open_flamingo/eval/classification_utils.py +1008 -0
- open_flamingo/eval/coco_metric.py +22 -0
- open_flamingo/eval/data/textvqa/train_questions_vqa_format.json +0 -0
- open_flamingo/eval/data/textvqa/val_annotations_vqa_format.json +0 -0
- open_flamingo/eval/data/textvqa/val_questions_vqa_format.json +0 -0
- open_flamingo/eval/data/vizwiz/test_questions_vqa_format.json +0 -0
- open_flamingo/eval/data/vizwiz/train_questions_vqa_format.json +0 -0
- open_flamingo/eval/data/vizwiz/val_annotations_vqa_format.json +0 -0
- open_flamingo/eval/data/vizwiz/val_questions_vqa_format.json +0 -0
- open_flamingo/eval/eval_datasets.py +157 -0
- open_flamingo/eval/eval_model.py +89 -0
- open_flamingo/eval/evaluate.py +1301 -0
- open_flamingo/eval/models/blip.py +117 -0
- open_flamingo/eval/models/open_flamingo.py +334 -0
- open_flamingo/eval/ok_vqa_utils.py +215 -0
- open_flamingo/eval/rices.py +95 -0
- open_flamingo/eval/utils.py +124 -0
- open_flamingo/eval/vqa_metric.py +560 -0
- open_flamingo/scripts/cache_rices_features.py +370 -0
- open_flamingo/scripts/convert_mmc4_to_wds.py +85 -0
- open_flamingo/scripts/fill_vqa_testdev_results.py +142 -0
- open_flamingo/train/README.md +65 -0
- open_flamingo/train/__init__.py +1 -0
- open_flamingo/train/data.py +492 -0
- open_flamingo/train/data_utils.py +234 -0
- open_flamingo/train/distributed.py +132 -0
- open_flamingo/train/train.py +484 -0
- open_flamingo/train/train_utils.py +377 -0
- requirements.txt +1 -6
README.md
CHANGED
|
@@ -9,41 +9,10 @@ app_file: app.py
|
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
-
#
|
| 13 |
|
| 14 |
-
|
| 15 |
|
| 16 |
-
|
| 17 |
|
| 18 |
-
|
| 19 |
-
- ✅ LSTM Policy Head for action prediction
|
| 20 |
-
- ✅ Trained on CALVIN robot manipulation dataset
|
| 21 |
-
- ✅ Outputs 7-DOF robot actions (position + rotation + gripper)
|
| 22 |
-
- ✅ State-of-the-art performance: 4.09 average task length
|
| 23 |
-
|
| 24 |
-
## Features
|
| 25 |
-
- 📸 Upload robot camera views (third-person + gripper)
|
| 26 |
-
- 💬 Natural language instructions
|
| 27 |
-
- 🎯 Real trained model predictions
|
| 28 |
-
- 📊 7-DOF trajectory visualization
|
| 29 |
-
- 🤏 Gripper command timeline
|
| 30 |
-
|
| 31 |
-
## Model Details
|
| 32 |
-
- **Base**: OpenFlamingo (vision-language model)
|
| 33 |
-
- **Policy Head**: LSTM with MLP
|
| 34 |
-
- **Training**: CALVIN dataset (34 manipulation tasks)
|
| 35 |
-
- **Parameters**: 386M trainable
|
| 36 |
-
- **Performance**: SOTA on CALVIN benchmark
|
| 37 |
-
|
| 38 |
-
## Requirements
|
| 39 |
-
⚠️ **Enable T4 GPU** in Space Settings → Hardware
|
| 40 |
-
|
| 41 |
-
## Citation
|
| 42 |
-
```bibtex
|
| 43 |
-
@article{li2023vision,
|
| 44 |
-
title={Vision-Language Foundation Models as Effective Robot Imitators},
|
| 45 |
-
author={Li, Xinghang and others},
|
| 46 |
-
journal={arXiv preprint arXiv:2311.01378},
|
| 47 |
-
year={2023}
|
| 48 |
-
}
|
| 49 |
-
```
|
|
|
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# RoboFlamingo Interactive Demo 🤖
|
| 13 |
|
| 14 |
+
Upload images and get real predictions!
|
| 15 |
|
| 16 |
+
⚠️ Enable T4 GPU in Settings for real model.
|
| 17 |
|
| 18 |
+
[Paper](https://arxiv.org/abs/2311.01378) | [Code](https://github.com/RoboFlamingo/RoboFlamingo)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""RoboFlamingo
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
|
@@ -7,105 +7,33 @@ import matplotlib.pyplot as plt
|
|
| 7 |
from io import BytesIO
|
| 8 |
import sys
|
| 9 |
|
| 10 |
-
sys.path.insert(0, '/home/user/app')
|
| 11 |
|
| 12 |
-
print("
|
| 13 |
|
| 14 |
MODEL_LOADED = False
|
| 15 |
-
device = "cpu"
|
| 16 |
|
| 17 |
try:
|
| 18 |
-
|
| 19 |
-
from
|
| 20 |
-
|
| 21 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
print(f"Device: {device}")
|
| 23 |
-
|
| 24 |
-
print("📦 Creating model structure...")
|
| 25 |
-
|
| 26 |
-
# Create model
|
| 27 |
model, image_processor, tokenizer = create_model_and_transforms(
|
| 28 |
clip_vision_encoder_path="ViT-L-14",
|
| 29 |
clip_vision_encoder_pretrained="openai",
|
| 30 |
lang_encoder_path="mosaicml/mpt-1b-redpajama-200b",
|
| 31 |
tokenizer_path="mosaicml/mpt-1b-redpajama-200b",
|
| 32 |
cross_attn_every_n_layers=4,
|
|
|
|
| 33 |
)
|
| 34 |
-
|
| 35 |
-
print("✅ Model structure created")
|
| 36 |
-
|
| 37 |
-
# Download checkpoint from HuggingFace
|
| 38 |
-
print("📥 Downloading trained checkpoint from robovlms/RoboFlamingo...")
|
| 39 |
-
|
| 40 |
-
try:
|
| 41 |
-
ckpt_path = hf_hub_download(
|
| 42 |
-
repo_id="robovlms/RoboFlamingo",
|
| 43 |
-
filename="checkpoint.pth",
|
| 44 |
-
repo_type="model"
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
print(f"✅ Downloaded: {ckpt_path}")
|
| 48 |
-
|
| 49 |
-
# Load checkpoint
|
| 50 |
-
print("📥 Loading checkpoint weights...")
|
| 51 |
-
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
| 52 |
-
|
| 53 |
-
# Try different checkpoint formats
|
| 54 |
-
if 'model_state_dict' in checkpoint:
|
| 55 |
-
state_dict = checkpoint['model_state_dict']
|
| 56 |
-
elif 'state_dict' in checkpoint:
|
| 57 |
-
state_dict = checkpoint['state_dict']
|
| 58 |
-
else:
|
| 59 |
-
state_dict = checkpoint
|
| 60 |
-
|
| 61 |
-
# Load with strict=False to handle any mismatches
|
| 62 |
-
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
| 63 |
-
|
| 64 |
-
if len(missing) > 0:
|
| 65 |
-
print(f"⚠️ Missing keys: {len(missing)}")
|
| 66 |
-
if len(unexpected) > 0:
|
| 67 |
-
print(f"⚠️ Unexpected keys: {len(unexpected)}")
|
| 68 |
-
|
| 69 |
-
print("✅ LOADED TRAINED CHECKPOINT!")
|
| 70 |
-
|
| 71 |
-
except Exception as e:
|
| 72 |
-
print(f"⚠️ Checkpoint download/load failed: {e}")
|
| 73 |
-
print(" Trying alternative checkpoint...")
|
| 74 |
-
|
| 75 |
-
try:
|
| 76 |
-
# Try the other HF repo
|
| 77 |
-
ckpt_path = hf_hub_download(
|
| 78 |
-
repo_id="hywslxh/RoboFlamingo-MPT",
|
| 79 |
-
filename="checkpoint.pth",
|
| 80 |
-
repo_type="model"
|
| 81 |
-
)
|
| 82 |
-
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
| 83 |
-
|
| 84 |
-
if 'model_state_dict' in checkpoint:
|
| 85 |
-
state_dict = checkpoint['model_state_dict']
|
| 86 |
-
elif 'state_dict' in checkpoint:
|
| 87 |
-
state_dict = checkpoint['state_dict']
|
| 88 |
-
else:
|
| 89 |
-
state_dict = checkpoint
|
| 90 |
-
|
| 91 |
-
model.load_state_dict(state_dict, strict=False)
|
| 92 |
-
print("✅ Loaded from hywslxh/RoboFlamingo-MPT")
|
| 93 |
-
|
| 94 |
-
except Exception as e2:
|
| 95 |
-
print(f"⚠️ Alternative also failed: {e2}")
|
| 96 |
-
print(" Using model without checkpoint")
|
| 97 |
-
|
| 98 |
model.to(device).eval()
|
| 99 |
MODEL_LOADED = True
|
| 100 |
-
|
| 101 |
-
print("=" * 70)
|
| 102 |
-
print("✅ MODEL READY!")
|
| 103 |
-
print("=" * 70)
|
| 104 |
-
|
| 105 |
except Exception as e:
|
| 106 |
-
print(f"⚠️
|
| 107 |
-
import traceback
|
| 108 |
-
traceback.print_exc()
|
| 109 |
|
| 110 |
def plot_traj(acts):
|
| 111 |
fig = plt.figure(figsize=(10,8))
|
|
@@ -117,7 +45,7 @@ def plot_traj(acts):
|
|
| 117 |
ax.scatter(x[0], y[0], z[0], c='green', s=100, label='Start')
|
| 118 |
ax.scatter(x[-1], y[-1], z[-1], c='red', s=100, label='End')
|
| 119 |
ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
|
| 120 |
-
ax.
|
| 121 |
buf = BytesIO()
|
| 122 |
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
|
| 123 |
buf.seek(0); plt.close()
|
|
@@ -141,7 +69,7 @@ def simulate(inst):
|
|
| 141 |
for t in range(12):
|
| 142 |
p = t/12
|
| 143 |
acts.append({'timestep': t, 'delta_x': (0.05+np.random.randn()*0.01)*p,
|
| 144 |
-
'delta_y': (0.02+np.random.randn()*0.01)*p,
|
| 145 |
'delta_z': (-0.03+np.random.randn()*0.01)*(1-p),
|
| 146 |
'qw': 0.99, 'qx': 0.01, 'qy': 0.01, 'qz': 0.01})
|
| 147 |
return acts, [0]*6+[1]*6
|
|
@@ -150,78 +78,81 @@ def predict(inst, img1, img2):
|
|
| 150 |
if not inst or not inst.strip():
|
| 151 |
return None, None, "", "❌ Enter instruction"
|
| 152 |
if img1 is None or img2 is None:
|
| 153 |
-
return None, None, "", "❌ Upload images"
|
| 154 |
-
|
| 155 |
try:
|
| 156 |
if isinstance(img1, np.ndarray):
|
| 157 |
img1 = Image.fromarray(img1)
|
| 158 |
if isinstance(img2, np.ndarray):
|
| 159 |
img2 = Image.fromarray(img2)
|
| 160 |
-
|
| 161 |
if not MODEL_LOADED:
|
| 162 |
acts, grip = simulate(inst)
|
| 163 |
-
status = f"⚠️
|
| 164 |
else:
|
| 165 |
-
print(f"🤖 {inst}")
|
| 166 |
with torch.no_grad():
|
| 167 |
t1 = image_processor(img1).unsqueeze(0).to(device)
|
| 168 |
t2 = image_processor(img2).unsqueeze(0).to(device)
|
| 169 |
-
vis = torch.stack([t1, t2], dim=1)
|
| 170 |
-
tok = tokenizer(inst, return_tensors="pt", padding=True,
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
out = model(vision_x=vis, lang_x=tok['input_ids'])
|
| 174 |
-
|
| 175 |
if isinstance(out, dict):
|
| 176 |
-
a = out.get('actions')
|
| 177 |
g = out.get('gripper')
|
| 178 |
elif isinstance(out, tuple):
|
| 179 |
a = out[0]
|
| 180 |
g = out[1] if len(out)>1 else None
|
| 181 |
else:
|
| 182 |
-
a = out
|
| 183 |
-
|
|
|
|
| 184 |
if a is not None:
|
| 185 |
anp = a[0].cpu().numpy()
|
| 186 |
acts = []
|
| 187 |
for t, ac in enumerate(anp):
|
| 188 |
if len(ac)<7: ac = np.pad(ac, (0,7-len(ac)))
|
| 189 |
-
acts.append({'timestep': t, 'delta_x': float(ac[0]),
|
| 190 |
-
'
|
| 191 |
-
'qw': float(ac[3]), 'qx': float(ac[4]),
|
| 192 |
'qy': float(ac[5]), 'qz': float(ac[6])})
|
| 193 |
-
grip = [int(x>0.5) if np.isscalar(x) else int(x[0]>0.5)
|
| 194 |
-
|
| 195 |
-
status = f"✅ MODEL\n{inst}\n{device}"
|
| 196 |
else:
|
| 197 |
acts, grip = simulate(inst)
|
| 198 |
-
status = "⚠️ Unexpected"
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
except Exception as e:
|
| 203 |
-
print(f"
|
| 204 |
-
import traceback; traceback.print_exc()
|
| 205 |
acts, grip = simulate(inst)
|
| 206 |
return plot_traj(acts), plot_grip(grip), "", f"❌ {str(e)}"
|
| 207 |
|
| 208 |
-
mode = "🟢
|
| 209 |
|
| 210 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 211 |
-
gr.Markdown(f"# 🤖 RoboFlamingo - {mode}")
|
|
|
|
| 212 |
with gr.Row():
|
| 213 |
with gr.Column():
|
| 214 |
-
inst = gr.Textbox(label="Instruction", placeholder="Pick up red block", lines=3)
|
| 215 |
with gr.Row():
|
| 216 |
img1 = gr.Image(label="Third-Person", type="pil", height=250)
|
| 217 |
img2 = gr.Image(label="Gripper", type="pil", height=250)
|
| 218 |
-
btn = gr.Button("
|
| 219 |
st = gr.Textbox(label="Status", lines=4, interactive=False)
|
| 220 |
with gr.Column():
|
| 221 |
traj = gr.Image(label="Trajectory", type="pil")
|
| 222 |
grip = gr.Image(label="Gripper", type="pil")
|
|
|
|
| 223 |
tab = gr.Markdown()
|
| 224 |
btn.click(predict, [inst, img1, img2], [traj, grip, tab, st])
|
| 225 |
-
|
|
|
|
| 226 |
|
| 227 |
demo.launch()
|
|
|
|
| 1 |
+
"""RoboFlamingo Interactive Demo"""
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
|
|
|
| 7 |
from io import BytesIO
|
| 8 |
import sys
|
| 9 |
|
| 10 |
+
sys.path.insert(0, '/home/user/app/open_flamingo/src')
|
| 11 |
|
| 12 |
+
print("🚀 Initializing RoboFlamingo")
|
| 13 |
|
| 14 |
MODEL_LOADED = False
|
|
|
|
| 15 |
|
| 16 |
try:
|
| 17 |
+
print("📦 Importing...")
|
| 18 |
+
from factory import create_model_and_transforms
|
| 19 |
+
|
| 20 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
print(f"Device: {device}")
|
| 22 |
+
|
|
|
|
|
|
|
|
|
|
| 23 |
model, image_processor, tokenizer = create_model_and_transforms(
|
| 24 |
clip_vision_encoder_path="ViT-L-14",
|
| 25 |
clip_vision_encoder_pretrained="openai",
|
| 26 |
lang_encoder_path="mosaicml/mpt-1b-redpajama-200b",
|
| 27 |
tokenizer_path="mosaicml/mpt-1b-redpajama-200b",
|
| 28 |
cross_attn_every_n_layers=4,
|
| 29 |
+
decoder_type='lstm',
|
| 30 |
)
|
| 31 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
model.to(device).eval()
|
| 33 |
MODEL_LOADED = True
|
| 34 |
+
print("✅ Model loaded!")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
except Exception as e:
|
| 36 |
+
print(f"⚠️ Model failed: {e}")
|
|
|
|
|
|
|
| 37 |
|
| 38 |
def plot_traj(acts):
|
| 39 |
fig = plt.figure(figsize=(10,8))
|
|
|
|
| 45 |
ax.scatter(x[0], y[0], z[0], c='green', s=100, label='Start')
|
| 46 |
ax.scatter(x[-1], y[-1], z[-1], c='red', s=100, label='End')
|
| 47 |
ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
|
| 48 |
+
ax.legend(); ax.grid()
|
| 49 |
buf = BytesIO()
|
| 50 |
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
|
| 51 |
buf.seek(0); plt.close()
|
|
|
|
| 69 |
for t in range(12):
|
| 70 |
p = t/12
|
| 71 |
acts.append({'timestep': t, 'delta_x': (0.05+np.random.randn()*0.01)*p,
|
| 72 |
+
'delta_y': (0.02+np.random.randn()*0.01)*p,
|
| 73 |
'delta_z': (-0.03+np.random.randn()*0.01)*(1-p),
|
| 74 |
'qw': 0.99, 'qx': 0.01, 'qy': 0.01, 'qz': 0.01})
|
| 75 |
return acts, [0]*6+[1]*6
|
|
|
|
| 78 |
if not inst or not inst.strip():
|
| 79 |
return None, None, "", "❌ Enter instruction"
|
| 80 |
if img1 is None or img2 is None:
|
| 81 |
+
return None, None, "", "❌ Upload both images"
|
| 82 |
+
|
| 83 |
try:
|
| 84 |
if isinstance(img1, np.ndarray):
|
| 85 |
img1 = Image.fromarray(img1)
|
| 86 |
if isinstance(img2, np.ndarray):
|
| 87 |
img2 = Image.fromarray(img2)
|
| 88 |
+
|
| 89 |
if not MODEL_LOADED:
|
| 90 |
acts, grip = simulate(inst)
|
| 91 |
+
status = f"⚠️ SIMULATION\n{inst}\nEnable GPU for real model"
|
| 92 |
else:
|
|
|
|
| 93 |
with torch.no_grad():
|
| 94 |
t1 = image_processor(img1).unsqueeze(0).to(device)
|
| 95 |
t2 = image_processor(img2).unsqueeze(0).to(device)
|
| 96 |
+
vis = torch.stack([t1, t2], dim=1)
|
| 97 |
+
tok = tokenizer(inst, return_tensors="pt", padding=True, truncation=True).to(device)
|
| 98 |
+
out = model(vision_x=vis, lang_x=tok['input_ids'], attention_mask=tok.get('attention_mask'))
|
| 99 |
+
|
|
|
|
|
|
|
| 100 |
if isinstance(out, dict):
|
| 101 |
+
a = out.get('actions', out.get('action'))
|
| 102 |
g = out.get('gripper')
|
| 103 |
elif isinstance(out, tuple):
|
| 104 |
a = out[0]
|
| 105 |
g = out[1] if len(out)>1 else None
|
| 106 |
else:
|
| 107 |
+
a = out
|
| 108 |
+
g = None
|
| 109 |
+
|
| 110 |
if a is not None:
|
| 111 |
anp = a[0].cpu().numpy()
|
| 112 |
acts = []
|
| 113 |
for t, ac in enumerate(anp):
|
| 114 |
if len(ac)<7: ac = np.pad(ac, (0,7-len(ac)))
|
| 115 |
+
acts.append({'timestep': t, 'delta_x': float(ac[0]), 'delta_y': float(ac[1]),
|
| 116 |
+
'delta_z': float(ac[2]), 'qw': float(ac[3]), 'qx': float(ac[4]),
|
|
|
|
| 117 |
'qy': float(ac[5]), 'qz': float(ac[6])})
|
| 118 |
+
grip = [int(x>0.5) if np.isscalar(x) else int(x[0]>0.5) for x in (g[0].cpu().numpy() if g is not None else [0]*len(acts))]
|
| 119 |
+
status = f"✅ REAL MODEL\n{inst}\n{device}"
|
|
|
|
| 120 |
else:
|
| 121 |
acts, grip = simulate(inst)
|
| 122 |
+
status = f"⚠️ Unexpected output\n{inst}"
|
| 123 |
+
|
| 124 |
+
traj = plot_traj(acts)
|
| 125 |
+
gp = plot_grip(grip)
|
| 126 |
+
table = "| T | Δx | Δy | Δz |\n|--|--|--|--|\n"
|
| 127 |
+
for a in acts:
|
| 128 |
+
table += f"| {a['timestep']} | {a['delta_x']:.3f} | {a['delta_y']:.3f} | {a['delta_z']:.3f} |\n"
|
| 129 |
+
|
| 130 |
+
return traj, gp, table, status
|
| 131 |
except Exception as e:
|
| 132 |
+
print(f"Error: {e}")
|
|
|
|
| 133 |
acts, grip = simulate(inst)
|
| 134 |
return plot_traj(acts), plot_grip(grip), "", f"❌ {str(e)}"
|
| 135 |
|
| 136 |
+
mode = "🟢 REAL" if MODEL_LOADED else "🟡 SIM"
|
| 137 |
|
| 138 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 139 |
+
gr.Markdown(f"# 🤖 RoboFlamingo - {mode}\n{'Real model loaded!' if MODEL_LOADED else 'Enable GPU for real model'}")
|
| 140 |
+
|
| 141 |
with gr.Row():
|
| 142 |
with gr.Column():
|
| 143 |
+
inst = gr.Textbox(label="Instruction", placeholder="Pick up the red block", lines=3)
|
| 144 |
with gr.Row():
|
| 145 |
img1 = gr.Image(label="Third-Person", type="pil", height=250)
|
| 146 |
img2 = gr.Image(label="Gripper", type="pil", height=250)
|
| 147 |
+
btn = gr.Button("🚀 Predict", variant="primary", size="lg")
|
| 148 |
st = gr.Textbox(label="Status", lines=4, interactive=False)
|
| 149 |
with gr.Column():
|
| 150 |
traj = gr.Image(label="Trajectory", type="pil")
|
| 151 |
grip = gr.Image(label="Gripper", type="pil")
|
| 152 |
+
|
| 153 |
tab = gr.Markdown()
|
| 154 |
btn.click(predict, [inst, img1, img2], [traj, grip, tab, st])
|
| 155 |
+
|
| 156 |
+
gr.Markdown(f"**Status:** {mode} | [Paper](https://arxiv.org/abs/2311.01378)")
|
| 157 |
|
| 158 |
demo.launch()
|
open_flamingo/eval/README.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OpenFlamingo Evaluation Suite
|
| 2 |
+
|
| 3 |
+
This is the evaluation module of OpenFlamingo. It contains a set of utilities for evaluating multimodal models on various benchmarking datasets.
|
| 4 |
+
|
| 5 |
+
*This module is a work in progress! We will be updating this README as it develops. In the meantime, if you notice an issue, please file a Bug Report or Feature Request [here](https://github.com/mlfoundations/open_flamingo/issues/new/choose).*
|
| 6 |
+
|
| 7 |
+
## Supported datasets
|
| 8 |
+
|
| 9 |
+
|Dataset|Task|Metric|Evaluation method|
|
| 10 |
+
|-------|----|------|-----------------|
|
| 11 |
+
|[COCO](https://arxiv.org/abs/1405.0312)|Captioning|CIDEr|Generation|
|
| 12 |
+
|[Flickr-30K](https://aclanthology.org/Q14-1006/)|Captioning|CIDEr|Generation|
|
| 13 |
+
|[VQAv2](https://arxiv.org/abs/1612.00837v3)|VQA|VQA accuracy|Generation|
|
| 14 |
+
|[OK-VQA](https://arxiv.org/abs/1906.00067)|VQA|VQA accuracy|Generation|
|
| 15 |
+
|[TextVQA](https://arxiv.org/abs/1904.08920)|VQA|VQA accuracy|Generation|
|
| 16 |
+
|[VizWiz](https://arxiv.org/abs/1802.08218)|VQA|VQA accuracy|Generation|
|
| 17 |
+
|[Hateful Memes](https://arxiv.org/abs/2005.04790)|Classification|ROC AUC|Logprobs|
|
| 18 |
+
|[ImageNet](https://arxiv.org/abs/1409.0575)|Classification|Top-1 accuracy|Logprobs|
|
| 19 |
+
|
| 20 |
+
When evaluating a model using `num_shots` shots, we sample the exemplars from the training split. Performance is evaluated on a disjoint test split, subsampled to `--num_samples` examples (or using the full test split if `--num_samples=-1`).
|
| 21 |
+
|
| 22 |
+
## Sample scripts
|
| 23 |
+
Our codebase uses DistributedDataParallel to parallelize evaluation by default, so please make sure to set the `MASTER_ADDR` and `MASTER_PORT` environment variables or use `torchrun`. We provide a sample Slurm evaluation script in `open_flamingo/open_flamingo/scripts/run_eval.sh`.
|
| 24 |
+
|
| 25 |
+
We also support evaluating at a lower precision using the `--precision` flag. We find minimal difference between evaluating at full precision vs. amp_bf16.
|
| 26 |
+
|
| 27 |
+
To evaluate one of our pretrained checkpoints, we suggest first downloading a local copy of the weights, as follows:
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
# grab model checkpoint from huggingface hub
|
| 31 |
+
from huggingface_hub import hf_hub_download
|
| 32 |
+
HF_TOKEN="<your-hf-token-here>"
|
| 33 |
+
|
| 34 |
+
checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt")
|
| 35 |
+
checkpoint_path= hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b",
|
| 36 |
+
"checkpoint.pt",
|
| 37 |
+
local_dir="openflamingo/OpenFlamingo-3B-vitl-mpt1b",
|
| 38 |
+
cache_dir="openflamingo/OpenFlamingo-3B-vitl-mpt1b",
|
| 39 |
+
local_dir_use_symlinks=False,
|
| 40 |
+
token=HF_TOKEN)
|
| 41 |
+
print(checkpoint_path)
|
| 42 |
+
## openflamingo/OpenFlamingo-3B-vitl-mpt1b/checkpoint.pt
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
This should place the OpenFlamingo model at the expected location in the evaluation script.
|
| 46 |
+
|
| 47 |
+
For TextVQA and VizWiz we expect annotations to be formatted differently than the original datasets. We provide the custom annotations in `open_flamingo/open_flamingo/eval/data/`. We have also uploaded all the annotation files in a [huggingface dataset](https://huggingface.co/datasets/openflamingo/eval_benchmark/tree/main) for easy access.
|
| 48 |
+
|
| 49 |
+
# Evaluating using RICES (Retrieval-based In-Context Example Selection)
|
| 50 |
+
|
| 51 |
+
We provide the option to evaluate using RICES, which is a method for selecting exemplars from the training set based on image similarity. This method was used in DeepMind's implementation for evaluating on ImageNet, but can be used for any dataset in our evaluation suite.
|
| 52 |
+
|
| 53 |
+
To use RICES, you must first create features for a benchmark's training set. We provide a script for doing so in `open_flamingo/open_flamingo/scripts/cache_rices_features.py`. This script will extract image features for a given dataset using a given CLIP model checkpoint. For example, to extract features for the COCO training set, you can run:
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
python cache_rices_features.py \
|
| 57 |
+
--vision_encoder_path ViT-L-14 \
|
| 58 |
+
--vision_encoder_pretrained openai \
|
| 59 |
+
--batch_size 128 \
|
| 60 |
+
--eval_coco \
|
| 61 |
+
--coco_train_image_dir_path /path/to/coco/train2014 \
|
| 62 |
+
--coco_val_image_dir_path /path/to/coco/val2014 \
|
| 63 |
+
--coco_karpathy_json_path /path/to/coco/dataset_coco.json \
|
| 64 |
+
--coco_annotations_json_path /path/to/coco/annotations/captions_train2014.json \
|
| 65 |
+
--output_dir /path/to/coco/features
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
This will create a directory at `/path/to/coco/features` containing a file named `coco.pkl` with the extracted features. You can then use this directory to evaluate using RICES by passing the `--rices` flag to the evaluation script, specifying the path to the features directory using the `--cached_demonstration_features` flag, and specifying the vision encoder to use for RICES using the `--rices_vision_encoder_path` and `--rices_vision_encoder_pretrained` flags.
|
open_flamingo/eval/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
open_flamingo/eval/classification_utils.py
ADDED
|
@@ -0,0 +1,1008 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# classnames via https://github.com/mlfoundations/wise-ft/blob/master/src/datasets/imagenet_classnames.py#L1
|
| 2 |
+
IMAGENET_CLASSNAMES = [
|
| 3 |
+
"tench",
|
| 4 |
+
"goldfish",
|
| 5 |
+
"great white shark",
|
| 6 |
+
"tiger shark",
|
| 7 |
+
"hammerhead shark",
|
| 8 |
+
"electric ray",
|
| 9 |
+
"stingray",
|
| 10 |
+
"rooster",
|
| 11 |
+
"hen",
|
| 12 |
+
"ostrich",
|
| 13 |
+
"brambling",
|
| 14 |
+
"goldfinch",
|
| 15 |
+
"house finch",
|
| 16 |
+
"junco",
|
| 17 |
+
"indigo bunting",
|
| 18 |
+
"American robin",
|
| 19 |
+
"bulbul",
|
| 20 |
+
"jay",
|
| 21 |
+
"magpie",
|
| 22 |
+
"chickadee",
|
| 23 |
+
"American dipper",
|
| 24 |
+
"kite (bird of prey)",
|
| 25 |
+
"bald eagle",
|
| 26 |
+
"vulture",
|
| 27 |
+
"great grey owl",
|
| 28 |
+
"fire salamander",
|
| 29 |
+
"smooth newt",
|
| 30 |
+
"newt",
|
| 31 |
+
"spotted salamander",
|
| 32 |
+
"axolotl",
|
| 33 |
+
"American bullfrog",
|
| 34 |
+
"tree frog",
|
| 35 |
+
"tailed frog",
|
| 36 |
+
"loggerhead sea turtle",
|
| 37 |
+
"leatherback sea turtle",
|
| 38 |
+
"mud turtle",
|
| 39 |
+
"terrapin",
|
| 40 |
+
"box turtle",
|
| 41 |
+
"banded gecko",
|
| 42 |
+
"green iguana",
|
| 43 |
+
"Carolina anole",
|
| 44 |
+
"desert grassland whiptail lizard",
|
| 45 |
+
"agama",
|
| 46 |
+
"frilled-necked lizard",
|
| 47 |
+
"alligator lizard",
|
| 48 |
+
"Gila monster",
|
| 49 |
+
"European green lizard",
|
| 50 |
+
"chameleon",
|
| 51 |
+
"Komodo dragon",
|
| 52 |
+
"Nile crocodile",
|
| 53 |
+
"American alligator",
|
| 54 |
+
"triceratops",
|
| 55 |
+
"worm snake",
|
| 56 |
+
"ring-necked snake",
|
| 57 |
+
"eastern hog-nosed snake",
|
| 58 |
+
"smooth green snake",
|
| 59 |
+
"kingsnake",
|
| 60 |
+
"garter snake",
|
| 61 |
+
"water snake",
|
| 62 |
+
"vine snake",
|
| 63 |
+
"night snake",
|
| 64 |
+
"boa constrictor",
|
| 65 |
+
"African rock python",
|
| 66 |
+
"Indian cobra",
|
| 67 |
+
"green mamba",
|
| 68 |
+
"sea snake",
|
| 69 |
+
"Saharan horned viper",
|
| 70 |
+
"eastern diamondback rattlesnake",
|
| 71 |
+
"sidewinder rattlesnake",
|
| 72 |
+
"trilobite",
|
| 73 |
+
"harvestman",
|
| 74 |
+
"scorpion",
|
| 75 |
+
"yellow garden spider",
|
| 76 |
+
"barn spider",
|
| 77 |
+
"European garden spider",
|
| 78 |
+
"southern black widow",
|
| 79 |
+
"tarantula",
|
| 80 |
+
"wolf spider",
|
| 81 |
+
"tick",
|
| 82 |
+
"centipede",
|
| 83 |
+
"black grouse",
|
| 84 |
+
"ptarmigan",
|
| 85 |
+
"ruffed grouse",
|
| 86 |
+
"prairie grouse",
|
| 87 |
+
"peafowl",
|
| 88 |
+
"quail",
|
| 89 |
+
"partridge",
|
| 90 |
+
"african grey parrot",
|
| 91 |
+
"macaw",
|
| 92 |
+
"sulphur-crested cockatoo",
|
| 93 |
+
"lorikeet",
|
| 94 |
+
"coucal",
|
| 95 |
+
"bee eater",
|
| 96 |
+
"hornbill",
|
| 97 |
+
"hummingbird",
|
| 98 |
+
"jacamar",
|
| 99 |
+
"toucan",
|
| 100 |
+
"duck",
|
| 101 |
+
"red-breasted merganser",
|
| 102 |
+
"goose",
|
| 103 |
+
"black swan",
|
| 104 |
+
"tusker",
|
| 105 |
+
"echidna",
|
| 106 |
+
"platypus",
|
| 107 |
+
"wallaby",
|
| 108 |
+
"koala",
|
| 109 |
+
"wombat",
|
| 110 |
+
"jellyfish",
|
| 111 |
+
"sea anemone",
|
| 112 |
+
"brain coral",
|
| 113 |
+
"flatworm",
|
| 114 |
+
"nematode",
|
| 115 |
+
"conch",
|
| 116 |
+
"snail",
|
| 117 |
+
"slug",
|
| 118 |
+
"sea slug",
|
| 119 |
+
"chiton",
|
| 120 |
+
"chambered nautilus",
|
| 121 |
+
"Dungeness crab",
|
| 122 |
+
"rock crab",
|
| 123 |
+
"fiddler crab",
|
| 124 |
+
"red king crab",
|
| 125 |
+
"American lobster",
|
| 126 |
+
"spiny lobster",
|
| 127 |
+
"crayfish",
|
| 128 |
+
"hermit crab",
|
| 129 |
+
"isopod",
|
| 130 |
+
"white stork",
|
| 131 |
+
"black stork",
|
| 132 |
+
"spoonbill",
|
| 133 |
+
"flamingo",
|
| 134 |
+
"little blue heron",
|
| 135 |
+
"great egret",
|
| 136 |
+
"bittern bird",
|
| 137 |
+
"crane bird",
|
| 138 |
+
"limpkin",
|
| 139 |
+
"common gallinule",
|
| 140 |
+
"American coot",
|
| 141 |
+
"bustard",
|
| 142 |
+
"ruddy turnstone",
|
| 143 |
+
"dunlin",
|
| 144 |
+
"common redshank",
|
| 145 |
+
"dowitcher",
|
| 146 |
+
"oystercatcher",
|
| 147 |
+
"pelican",
|
| 148 |
+
"king penguin",
|
| 149 |
+
"albatross",
|
| 150 |
+
"grey whale",
|
| 151 |
+
"killer whale",
|
| 152 |
+
"dugong",
|
| 153 |
+
"sea lion",
|
| 154 |
+
"Chihuahua",
|
| 155 |
+
"Japanese Chin",
|
| 156 |
+
"Maltese",
|
| 157 |
+
"Pekingese",
|
| 158 |
+
"Shih Tzu",
|
| 159 |
+
"King Charles Spaniel",
|
| 160 |
+
"Papillon",
|
| 161 |
+
"toy terrier",
|
| 162 |
+
"Rhodesian Ridgeback",
|
| 163 |
+
"Afghan Hound",
|
| 164 |
+
"Basset Hound",
|
| 165 |
+
"Beagle",
|
| 166 |
+
"Bloodhound",
|
| 167 |
+
"Bluetick Coonhound",
|
| 168 |
+
"Black and Tan Coonhound",
|
| 169 |
+
"Treeing Walker Coonhound",
|
| 170 |
+
"English foxhound",
|
| 171 |
+
"Redbone Coonhound",
|
| 172 |
+
"borzoi",
|
| 173 |
+
"Irish Wolfhound",
|
| 174 |
+
"Italian Greyhound",
|
| 175 |
+
"Whippet",
|
| 176 |
+
"Ibizan Hound",
|
| 177 |
+
"Norwegian Elkhound",
|
| 178 |
+
"Otterhound",
|
| 179 |
+
"Saluki",
|
| 180 |
+
"Scottish Deerhound",
|
| 181 |
+
"Weimaraner",
|
| 182 |
+
"Staffordshire Bull Terrier",
|
| 183 |
+
"American Staffordshire Terrier",
|
| 184 |
+
"Bedlington Terrier",
|
| 185 |
+
"Border Terrier",
|
| 186 |
+
"Kerry Blue Terrier",
|
| 187 |
+
"Irish Terrier",
|
| 188 |
+
"Norfolk Terrier",
|
| 189 |
+
"Norwich Terrier",
|
| 190 |
+
"Yorkshire Terrier",
|
| 191 |
+
"Wire Fox Terrier",
|
| 192 |
+
"Lakeland Terrier",
|
| 193 |
+
"Sealyham Terrier",
|
| 194 |
+
"Airedale Terrier",
|
| 195 |
+
"Cairn Terrier",
|
| 196 |
+
"Australian Terrier",
|
| 197 |
+
"Dandie Dinmont Terrier",
|
| 198 |
+
"Boston Terrier",
|
| 199 |
+
"Miniature Schnauzer",
|
| 200 |
+
"Giant Schnauzer",
|
| 201 |
+
"Standard Schnauzer",
|
| 202 |
+
"Scottish Terrier",
|
| 203 |
+
"Tibetan Terrier",
|
| 204 |
+
"Australian Silky Terrier",
|
| 205 |
+
"Soft-coated Wheaten Terrier",
|
| 206 |
+
"West Highland White Terrier",
|
| 207 |
+
"Lhasa Apso",
|
| 208 |
+
"Flat-Coated Retriever",
|
| 209 |
+
"Curly-coated Retriever",
|
| 210 |
+
"Golden Retriever",
|
| 211 |
+
"Labrador Retriever",
|
| 212 |
+
"Chesapeake Bay Retriever",
|
| 213 |
+
"German Shorthaired Pointer",
|
| 214 |
+
"Vizsla",
|
| 215 |
+
"English Setter",
|
| 216 |
+
"Irish Setter",
|
| 217 |
+
"Gordon Setter",
|
| 218 |
+
"Brittany dog",
|
| 219 |
+
"Clumber Spaniel",
|
| 220 |
+
"English Springer Spaniel",
|
| 221 |
+
"Welsh Springer Spaniel",
|
| 222 |
+
"Cocker Spaniel",
|
| 223 |
+
"Sussex Spaniel",
|
| 224 |
+
"Irish Water Spaniel",
|
| 225 |
+
"Kuvasz",
|
| 226 |
+
"Schipperke",
|
| 227 |
+
"Groenendael dog",
|
| 228 |
+
"Malinois",
|
| 229 |
+
"Briard",
|
| 230 |
+
"Australian Kelpie",
|
| 231 |
+
"Komondor",
|
| 232 |
+
"Old English Sheepdog",
|
| 233 |
+
"Shetland Sheepdog",
|
| 234 |
+
"collie",
|
| 235 |
+
"Border Collie",
|
| 236 |
+
"Bouvier des Flandres dog",
|
| 237 |
+
"Rottweiler",
|
| 238 |
+
"German Shepherd Dog",
|
| 239 |
+
"Dobermann",
|
| 240 |
+
"Miniature Pinscher",
|
| 241 |
+
"Greater Swiss Mountain Dog",
|
| 242 |
+
"Bernese Mountain Dog",
|
| 243 |
+
"Appenzeller Sennenhund",
|
| 244 |
+
"Entlebucher Sennenhund",
|
| 245 |
+
"Boxer",
|
| 246 |
+
"Bullmastiff",
|
| 247 |
+
"Tibetan Mastiff",
|
| 248 |
+
"French Bulldog",
|
| 249 |
+
"Great Dane",
|
| 250 |
+
"St. Bernard",
|
| 251 |
+
"husky",
|
| 252 |
+
"Alaskan Malamute",
|
| 253 |
+
"Siberian Husky",
|
| 254 |
+
"Dalmatian",
|
| 255 |
+
"Affenpinscher",
|
| 256 |
+
"Basenji",
|
| 257 |
+
"pug",
|
| 258 |
+
"Leonberger",
|
| 259 |
+
"Newfoundland dog",
|
| 260 |
+
"Great Pyrenees dog",
|
| 261 |
+
"Samoyed",
|
| 262 |
+
"Pomeranian",
|
| 263 |
+
"Chow Chow",
|
| 264 |
+
"Keeshond",
|
| 265 |
+
"brussels griffon",
|
| 266 |
+
"Pembroke Welsh Corgi",
|
| 267 |
+
"Cardigan Welsh Corgi",
|
| 268 |
+
"Toy Poodle",
|
| 269 |
+
"Miniature Poodle",
|
| 270 |
+
"Standard Poodle",
|
| 271 |
+
"Mexican hairless dog (xoloitzcuintli)",
|
| 272 |
+
"grey wolf",
|
| 273 |
+
"Alaskan tundra wolf",
|
| 274 |
+
"red wolf or maned wolf",
|
| 275 |
+
"coyote",
|
| 276 |
+
"dingo",
|
| 277 |
+
"dhole",
|
| 278 |
+
"African wild dog",
|
| 279 |
+
"hyena",
|
| 280 |
+
"red fox",
|
| 281 |
+
"kit fox",
|
| 282 |
+
"Arctic fox",
|
| 283 |
+
"grey fox",
|
| 284 |
+
"tabby cat",
|
| 285 |
+
"tiger cat",
|
| 286 |
+
"Persian cat",
|
| 287 |
+
"Siamese cat",
|
| 288 |
+
"Egyptian Mau",
|
| 289 |
+
"cougar",
|
| 290 |
+
"lynx",
|
| 291 |
+
"leopard",
|
| 292 |
+
"snow leopard",
|
| 293 |
+
"jaguar",
|
| 294 |
+
"lion",
|
| 295 |
+
"tiger",
|
| 296 |
+
"cheetah",
|
| 297 |
+
"brown bear",
|
| 298 |
+
"American black bear",
|
| 299 |
+
"polar bear",
|
| 300 |
+
"sloth bear",
|
| 301 |
+
"mongoose",
|
| 302 |
+
"meerkat",
|
| 303 |
+
"tiger beetle",
|
| 304 |
+
"ladybug",
|
| 305 |
+
"ground beetle",
|
| 306 |
+
"longhorn beetle",
|
| 307 |
+
"leaf beetle",
|
| 308 |
+
"dung beetle",
|
| 309 |
+
"rhinoceros beetle",
|
| 310 |
+
"weevil",
|
| 311 |
+
"fly",
|
| 312 |
+
"bee",
|
| 313 |
+
"ant",
|
| 314 |
+
"grasshopper",
|
| 315 |
+
"cricket insect",
|
| 316 |
+
"stick insect",
|
| 317 |
+
"cockroach",
|
| 318 |
+
"praying mantis",
|
| 319 |
+
"cicada",
|
| 320 |
+
"leafhopper",
|
| 321 |
+
"lacewing",
|
| 322 |
+
"dragonfly",
|
| 323 |
+
"damselfly",
|
| 324 |
+
"red admiral butterfly",
|
| 325 |
+
"ringlet butterfly",
|
| 326 |
+
"monarch butterfly",
|
| 327 |
+
"small white butterfly",
|
| 328 |
+
"sulphur butterfly",
|
| 329 |
+
"gossamer-winged butterfly",
|
| 330 |
+
"starfish",
|
| 331 |
+
"sea urchin",
|
| 332 |
+
"sea cucumber",
|
| 333 |
+
"cottontail rabbit",
|
| 334 |
+
"hare",
|
| 335 |
+
"Angora rabbit",
|
| 336 |
+
"hamster",
|
| 337 |
+
"porcupine",
|
| 338 |
+
"fox squirrel",
|
| 339 |
+
"marmot",
|
| 340 |
+
"beaver",
|
| 341 |
+
"guinea pig",
|
| 342 |
+
"common sorrel horse",
|
| 343 |
+
"zebra",
|
| 344 |
+
"pig",
|
| 345 |
+
"wild boar",
|
| 346 |
+
"warthog",
|
| 347 |
+
"hippopotamus",
|
| 348 |
+
"ox",
|
| 349 |
+
"water buffalo",
|
| 350 |
+
"bison",
|
| 351 |
+
"ram (adult male sheep)",
|
| 352 |
+
"bighorn sheep",
|
| 353 |
+
"Alpine ibex",
|
| 354 |
+
"hartebeest",
|
| 355 |
+
"impala (antelope)",
|
| 356 |
+
"gazelle",
|
| 357 |
+
"arabian camel",
|
| 358 |
+
"llama",
|
| 359 |
+
"weasel",
|
| 360 |
+
"mink",
|
| 361 |
+
"European polecat",
|
| 362 |
+
"black-footed ferret",
|
| 363 |
+
"otter",
|
| 364 |
+
"skunk",
|
| 365 |
+
"badger",
|
| 366 |
+
"armadillo",
|
| 367 |
+
"three-toed sloth",
|
| 368 |
+
"orangutan",
|
| 369 |
+
"gorilla",
|
| 370 |
+
"chimpanzee",
|
| 371 |
+
"gibbon",
|
| 372 |
+
"siamang",
|
| 373 |
+
"guenon",
|
| 374 |
+
"patas monkey",
|
| 375 |
+
"baboon",
|
| 376 |
+
"macaque",
|
| 377 |
+
"langur",
|
| 378 |
+
"black-and-white colobus",
|
| 379 |
+
"proboscis monkey",
|
| 380 |
+
"marmoset",
|
| 381 |
+
"white-headed capuchin",
|
| 382 |
+
"howler monkey",
|
| 383 |
+
"titi monkey",
|
| 384 |
+
"Geoffroy's spider monkey",
|
| 385 |
+
"common squirrel monkey",
|
| 386 |
+
"ring-tailed lemur",
|
| 387 |
+
"indri",
|
| 388 |
+
"Asian elephant",
|
| 389 |
+
"African bush elephant",
|
| 390 |
+
"red panda",
|
| 391 |
+
"giant panda",
|
| 392 |
+
"snoek fish",
|
| 393 |
+
"eel",
|
| 394 |
+
"silver salmon",
|
| 395 |
+
"rock beauty fish",
|
| 396 |
+
"clownfish",
|
| 397 |
+
"sturgeon",
|
| 398 |
+
"gar fish",
|
| 399 |
+
"lionfish",
|
| 400 |
+
"pufferfish",
|
| 401 |
+
"abacus",
|
| 402 |
+
"abaya",
|
| 403 |
+
"academic gown",
|
| 404 |
+
"accordion",
|
| 405 |
+
"acoustic guitar",
|
| 406 |
+
"aircraft carrier",
|
| 407 |
+
"airliner",
|
| 408 |
+
"airship",
|
| 409 |
+
"altar",
|
| 410 |
+
"ambulance",
|
| 411 |
+
"amphibious vehicle",
|
| 412 |
+
"analog clock",
|
| 413 |
+
"apiary",
|
| 414 |
+
"apron",
|
| 415 |
+
"trash can",
|
| 416 |
+
"assault rifle",
|
| 417 |
+
"backpack",
|
| 418 |
+
"bakery",
|
| 419 |
+
"balance beam",
|
| 420 |
+
"balloon",
|
| 421 |
+
"ballpoint pen",
|
| 422 |
+
"Band-Aid",
|
| 423 |
+
"banjo",
|
| 424 |
+
"baluster / handrail",
|
| 425 |
+
"barbell",
|
| 426 |
+
"barber chair",
|
| 427 |
+
"barbershop",
|
| 428 |
+
"barn",
|
| 429 |
+
"barometer",
|
| 430 |
+
"barrel",
|
| 431 |
+
"wheelbarrow",
|
| 432 |
+
"baseball",
|
| 433 |
+
"basketball",
|
| 434 |
+
"bassinet",
|
| 435 |
+
"bassoon",
|
| 436 |
+
"swimming cap",
|
| 437 |
+
"bath towel",
|
| 438 |
+
"bathtub",
|
| 439 |
+
"station wagon",
|
| 440 |
+
"lighthouse",
|
| 441 |
+
"beaker",
|
| 442 |
+
"military hat (bearskin or shako)",
|
| 443 |
+
"beer bottle",
|
| 444 |
+
"beer glass",
|
| 445 |
+
"bell tower",
|
| 446 |
+
"baby bib",
|
| 447 |
+
"tandem bicycle",
|
| 448 |
+
"bikini",
|
| 449 |
+
"ring binder",
|
| 450 |
+
"binoculars",
|
| 451 |
+
"birdhouse",
|
| 452 |
+
"boathouse",
|
| 453 |
+
"bobsleigh",
|
| 454 |
+
"bolo tie",
|
| 455 |
+
"poke bonnet",
|
| 456 |
+
"bookcase",
|
| 457 |
+
"bookstore",
|
| 458 |
+
"bottle cap",
|
| 459 |
+
"hunting bow",
|
| 460 |
+
"bow tie",
|
| 461 |
+
"brass memorial plaque",
|
| 462 |
+
"bra",
|
| 463 |
+
"breakwater",
|
| 464 |
+
"breastplate",
|
| 465 |
+
"broom",
|
| 466 |
+
"bucket",
|
| 467 |
+
"buckle",
|
| 468 |
+
"bulletproof vest",
|
| 469 |
+
"high-speed train",
|
| 470 |
+
"butcher shop",
|
| 471 |
+
"taxicab",
|
| 472 |
+
"cauldron",
|
| 473 |
+
"candle",
|
| 474 |
+
"cannon",
|
| 475 |
+
"canoe",
|
| 476 |
+
"can opener",
|
| 477 |
+
"cardigan",
|
| 478 |
+
"car mirror",
|
| 479 |
+
"carousel",
|
| 480 |
+
"tool kit",
|
| 481 |
+
"cardboard box / carton",
|
| 482 |
+
"car wheel",
|
| 483 |
+
"automated teller machine",
|
| 484 |
+
"cassette",
|
| 485 |
+
"cassette player",
|
| 486 |
+
"castle",
|
| 487 |
+
"catamaran",
|
| 488 |
+
"CD player",
|
| 489 |
+
"cello",
|
| 490 |
+
"mobile phone",
|
| 491 |
+
"chain",
|
| 492 |
+
"chain-link fence",
|
| 493 |
+
"chain mail",
|
| 494 |
+
"chainsaw",
|
| 495 |
+
"storage chest",
|
| 496 |
+
"chiffonier",
|
| 497 |
+
"bell or wind chime",
|
| 498 |
+
"china cabinet",
|
| 499 |
+
"Christmas stocking",
|
| 500 |
+
"church",
|
| 501 |
+
"movie theater",
|
| 502 |
+
"cleaver",
|
| 503 |
+
"cliff dwelling",
|
| 504 |
+
"cloak",
|
| 505 |
+
"clogs",
|
| 506 |
+
"cocktail shaker",
|
| 507 |
+
"coffee mug",
|
| 508 |
+
"coffeemaker",
|
| 509 |
+
"spiral or coil",
|
| 510 |
+
"combination lock",
|
| 511 |
+
"computer keyboard",
|
| 512 |
+
"candy store",
|
| 513 |
+
"container ship",
|
| 514 |
+
"convertible",
|
| 515 |
+
"corkscrew",
|
| 516 |
+
"cornet",
|
| 517 |
+
"cowboy boot",
|
| 518 |
+
"cowboy hat",
|
| 519 |
+
"cradle",
|
| 520 |
+
"construction crane",
|
| 521 |
+
"crash helmet",
|
| 522 |
+
"crate",
|
| 523 |
+
"infant bed",
|
| 524 |
+
"Crock Pot",
|
| 525 |
+
"croquet ball",
|
| 526 |
+
"crutch",
|
| 527 |
+
"cuirass",
|
| 528 |
+
"dam",
|
| 529 |
+
"desk",
|
| 530 |
+
"desktop computer",
|
| 531 |
+
"rotary dial telephone",
|
| 532 |
+
"diaper",
|
| 533 |
+
"digital clock",
|
| 534 |
+
"digital watch",
|
| 535 |
+
"dining table",
|
| 536 |
+
"dishcloth",
|
| 537 |
+
"dishwasher",
|
| 538 |
+
"disc brake",
|
| 539 |
+
"dock",
|
| 540 |
+
"dog sled",
|
| 541 |
+
"dome",
|
| 542 |
+
"doormat",
|
| 543 |
+
"drilling rig",
|
| 544 |
+
"drum",
|
| 545 |
+
"drumstick",
|
| 546 |
+
"dumbbell",
|
| 547 |
+
"Dutch oven",
|
| 548 |
+
"electric fan",
|
| 549 |
+
"electric guitar",
|
| 550 |
+
"electric locomotive",
|
| 551 |
+
"entertainment center",
|
| 552 |
+
"envelope",
|
| 553 |
+
"espresso machine",
|
| 554 |
+
"face powder",
|
| 555 |
+
"feather boa",
|
| 556 |
+
"filing cabinet",
|
| 557 |
+
"fireboat",
|
| 558 |
+
"fire truck",
|
| 559 |
+
"fire screen",
|
| 560 |
+
"flagpole",
|
| 561 |
+
"flute",
|
| 562 |
+
"folding chair",
|
| 563 |
+
"football helmet",
|
| 564 |
+
"forklift",
|
| 565 |
+
"fountain",
|
| 566 |
+
"fountain pen",
|
| 567 |
+
"four-poster bed",
|
| 568 |
+
"freight car",
|
| 569 |
+
"French horn",
|
| 570 |
+
"frying pan",
|
| 571 |
+
"fur coat",
|
| 572 |
+
"garbage truck",
|
| 573 |
+
"gas mask or respirator",
|
| 574 |
+
"gas pump",
|
| 575 |
+
"goblet",
|
| 576 |
+
"go-kart",
|
| 577 |
+
"golf ball",
|
| 578 |
+
"golf cart",
|
| 579 |
+
"gondola",
|
| 580 |
+
"gong",
|
| 581 |
+
"gown",
|
| 582 |
+
"grand piano",
|
| 583 |
+
"greenhouse",
|
| 584 |
+
"radiator grille",
|
| 585 |
+
"grocery store",
|
| 586 |
+
"guillotine",
|
| 587 |
+
"hair clip",
|
| 588 |
+
"hair spray",
|
| 589 |
+
"half-track",
|
| 590 |
+
"hammer",
|
| 591 |
+
"hamper",
|
| 592 |
+
"hair dryer",
|
| 593 |
+
"hand-held computer",
|
| 594 |
+
"handkerchief",
|
| 595 |
+
"hard disk drive",
|
| 596 |
+
"harmonica",
|
| 597 |
+
"harp",
|
| 598 |
+
"combine harvester",
|
| 599 |
+
"hatchet",
|
| 600 |
+
"holster",
|
| 601 |
+
"home theater",
|
| 602 |
+
"honeycomb",
|
| 603 |
+
"hook",
|
| 604 |
+
"hoop skirt",
|
| 605 |
+
"gymnastic horizontal bar",
|
| 606 |
+
"horse-drawn vehicle",
|
| 607 |
+
"hourglass",
|
| 608 |
+
"iPod",
|
| 609 |
+
"clothes iron",
|
| 610 |
+
"carved pumpkin",
|
| 611 |
+
"jeans",
|
| 612 |
+
"jeep",
|
| 613 |
+
"T-shirt",
|
| 614 |
+
"jigsaw puzzle",
|
| 615 |
+
"rickshaw",
|
| 616 |
+
"joystick",
|
| 617 |
+
"kimono",
|
| 618 |
+
"knee pad",
|
| 619 |
+
"knot",
|
| 620 |
+
"lab coat",
|
| 621 |
+
"ladle",
|
| 622 |
+
"lampshade",
|
| 623 |
+
"laptop computer",
|
| 624 |
+
"lawn mower",
|
| 625 |
+
"lens cap",
|
| 626 |
+
"letter opener",
|
| 627 |
+
"library",
|
| 628 |
+
"lifeboat",
|
| 629 |
+
"lighter",
|
| 630 |
+
"limousine",
|
| 631 |
+
"ocean liner",
|
| 632 |
+
"lipstick",
|
| 633 |
+
"slip-on shoe",
|
| 634 |
+
"lotion",
|
| 635 |
+
"music speaker",
|
| 636 |
+
"loupe magnifying glass",
|
| 637 |
+
"sawmill",
|
| 638 |
+
"magnetic compass",
|
| 639 |
+
"messenger bag",
|
| 640 |
+
"mailbox",
|
| 641 |
+
"tights",
|
| 642 |
+
"one-piece bathing suit",
|
| 643 |
+
"manhole cover",
|
| 644 |
+
"maraca",
|
| 645 |
+
"marimba",
|
| 646 |
+
"mask",
|
| 647 |
+
"matchstick",
|
| 648 |
+
"maypole",
|
| 649 |
+
"maze",
|
| 650 |
+
"measuring cup",
|
| 651 |
+
"medicine cabinet",
|
| 652 |
+
"megalith",
|
| 653 |
+
"microphone",
|
| 654 |
+
"microwave oven",
|
| 655 |
+
"military uniform",
|
| 656 |
+
"milk can",
|
| 657 |
+
"minibus",
|
| 658 |
+
"miniskirt",
|
| 659 |
+
"minivan",
|
| 660 |
+
"missile",
|
| 661 |
+
"mitten",
|
| 662 |
+
"mixing bowl",
|
| 663 |
+
"mobile home",
|
| 664 |
+
"ford model t",
|
| 665 |
+
"modem",
|
| 666 |
+
"monastery",
|
| 667 |
+
"monitor",
|
| 668 |
+
"moped",
|
| 669 |
+
"mortar and pestle",
|
| 670 |
+
"graduation cap",
|
| 671 |
+
"mosque",
|
| 672 |
+
"mosquito net",
|
| 673 |
+
"vespa",
|
| 674 |
+
"mountain bike",
|
| 675 |
+
"tent",
|
| 676 |
+
"computer mouse",
|
| 677 |
+
"mousetrap",
|
| 678 |
+
"moving van",
|
| 679 |
+
"muzzle",
|
| 680 |
+
"metal nail",
|
| 681 |
+
"neck brace",
|
| 682 |
+
"necklace",
|
| 683 |
+
"baby pacifier",
|
| 684 |
+
"notebook computer",
|
| 685 |
+
"obelisk",
|
| 686 |
+
"oboe",
|
| 687 |
+
"ocarina",
|
| 688 |
+
"odometer",
|
| 689 |
+
"oil filter",
|
| 690 |
+
"pipe organ",
|
| 691 |
+
"oscilloscope",
|
| 692 |
+
"overskirt",
|
| 693 |
+
"bullock cart",
|
| 694 |
+
"oxygen mask",
|
| 695 |
+
"product packet / packaging",
|
| 696 |
+
"paddle",
|
| 697 |
+
"paddle wheel",
|
| 698 |
+
"padlock",
|
| 699 |
+
"paintbrush",
|
| 700 |
+
"pajamas",
|
| 701 |
+
"palace",
|
| 702 |
+
"pan flute",
|
| 703 |
+
"paper towel",
|
| 704 |
+
"parachute",
|
| 705 |
+
"parallel bars",
|
| 706 |
+
"park bench",
|
| 707 |
+
"parking meter",
|
| 708 |
+
"railroad car",
|
| 709 |
+
"patio",
|
| 710 |
+
"payphone",
|
| 711 |
+
"pedestal",
|
| 712 |
+
"pencil case",
|
| 713 |
+
"pencil sharpener",
|
| 714 |
+
"perfume",
|
| 715 |
+
"Petri dish",
|
| 716 |
+
"photocopier",
|
| 717 |
+
"plectrum",
|
| 718 |
+
"Pickelhaube",
|
| 719 |
+
"picket fence",
|
| 720 |
+
"pickup truck",
|
| 721 |
+
"pier",
|
| 722 |
+
"piggy bank",
|
| 723 |
+
"pill bottle",
|
| 724 |
+
"pillow",
|
| 725 |
+
"ping-pong ball",
|
| 726 |
+
"pinwheel",
|
| 727 |
+
"pirate ship",
|
| 728 |
+
"drink pitcher",
|
| 729 |
+
"block plane",
|
| 730 |
+
"planetarium",
|
| 731 |
+
"plastic bag",
|
| 732 |
+
"plate rack",
|
| 733 |
+
"farm plow",
|
| 734 |
+
"plunger",
|
| 735 |
+
"Polaroid camera",
|
| 736 |
+
"pole",
|
| 737 |
+
"police van",
|
| 738 |
+
"poncho",
|
| 739 |
+
"pool table",
|
| 740 |
+
"soda bottle",
|
| 741 |
+
"plant pot",
|
| 742 |
+
"potter's wheel",
|
| 743 |
+
"power drill",
|
| 744 |
+
"prayer rug",
|
| 745 |
+
"printer",
|
| 746 |
+
"prison",
|
| 747 |
+
"missile",
|
| 748 |
+
"projector",
|
| 749 |
+
"hockey puck",
|
| 750 |
+
"punching bag",
|
| 751 |
+
"purse",
|
| 752 |
+
"quill",
|
| 753 |
+
"quilt",
|
| 754 |
+
"race car",
|
| 755 |
+
"racket",
|
| 756 |
+
"radiator",
|
| 757 |
+
"radio",
|
| 758 |
+
"radio telescope",
|
| 759 |
+
"rain barrel",
|
| 760 |
+
"recreational vehicle",
|
| 761 |
+
"fishing casting reel",
|
| 762 |
+
"reflex camera",
|
| 763 |
+
"refrigerator",
|
| 764 |
+
"remote control",
|
| 765 |
+
"restaurant",
|
| 766 |
+
"revolver",
|
| 767 |
+
"rifle",
|
| 768 |
+
"rocking chair",
|
| 769 |
+
"rotisserie",
|
| 770 |
+
"eraser",
|
| 771 |
+
"rugby ball",
|
| 772 |
+
"ruler measuring stick",
|
| 773 |
+
"sneaker",
|
| 774 |
+
"safe",
|
| 775 |
+
"safety pin",
|
| 776 |
+
"salt shaker",
|
| 777 |
+
"sandal",
|
| 778 |
+
"sarong",
|
| 779 |
+
"saxophone",
|
| 780 |
+
"scabbard",
|
| 781 |
+
"weighing scale",
|
| 782 |
+
"school bus",
|
| 783 |
+
"schooner",
|
| 784 |
+
"scoreboard",
|
| 785 |
+
"CRT monitor",
|
| 786 |
+
"screw",
|
| 787 |
+
"screwdriver",
|
| 788 |
+
"seat belt",
|
| 789 |
+
"sewing machine",
|
| 790 |
+
"shield",
|
| 791 |
+
"shoe store",
|
| 792 |
+
"shoji screen / room divider",
|
| 793 |
+
"shopping basket",
|
| 794 |
+
"shopping cart",
|
| 795 |
+
"shovel",
|
| 796 |
+
"shower cap",
|
| 797 |
+
"shower curtain",
|
| 798 |
+
"ski",
|
| 799 |
+
"balaclava ski mask",
|
| 800 |
+
"sleeping bag",
|
| 801 |
+
"slide rule",
|
| 802 |
+
"sliding door",
|
| 803 |
+
"slot machine",
|
| 804 |
+
"snorkel",
|
| 805 |
+
"snowmobile",
|
| 806 |
+
"snowplow",
|
| 807 |
+
"soap dispenser",
|
| 808 |
+
"soccer ball",
|
| 809 |
+
"sock",
|
| 810 |
+
"solar thermal collector",
|
| 811 |
+
"sombrero",
|
| 812 |
+
"soup bowl",
|
| 813 |
+
"keyboard space bar",
|
| 814 |
+
"space heater",
|
| 815 |
+
"space shuttle",
|
| 816 |
+
"spatula",
|
| 817 |
+
"motorboat",
|
| 818 |
+
"spider web",
|
| 819 |
+
"spindle",
|
| 820 |
+
"sports car",
|
| 821 |
+
"spotlight",
|
| 822 |
+
"stage",
|
| 823 |
+
"steam locomotive",
|
| 824 |
+
"through arch bridge",
|
| 825 |
+
"steel drum",
|
| 826 |
+
"stethoscope",
|
| 827 |
+
"scarf",
|
| 828 |
+
"stone wall",
|
| 829 |
+
"stopwatch",
|
| 830 |
+
"stove",
|
| 831 |
+
"strainer",
|
| 832 |
+
"tram",
|
| 833 |
+
"stretcher",
|
| 834 |
+
"couch",
|
| 835 |
+
"stupa",
|
| 836 |
+
"submarine",
|
| 837 |
+
"suit",
|
| 838 |
+
"sundial",
|
| 839 |
+
"sunglasses",
|
| 840 |
+
"sunglasses",
|
| 841 |
+
"sunscreen",
|
| 842 |
+
"suspension bridge",
|
| 843 |
+
"mop",
|
| 844 |
+
"sweatshirt",
|
| 845 |
+
"swim trunks / shorts",
|
| 846 |
+
"swing",
|
| 847 |
+
"electrical switch",
|
| 848 |
+
"syringe",
|
| 849 |
+
"table lamp",
|
| 850 |
+
"tank",
|
| 851 |
+
"tape player",
|
| 852 |
+
"teapot",
|
| 853 |
+
"teddy bear",
|
| 854 |
+
"television",
|
| 855 |
+
"tennis ball",
|
| 856 |
+
"thatched roof",
|
| 857 |
+
"front curtain",
|
| 858 |
+
"thimble",
|
| 859 |
+
"threshing machine",
|
| 860 |
+
"throne",
|
| 861 |
+
"tile roof",
|
| 862 |
+
"toaster",
|
| 863 |
+
"tobacco shop",
|
| 864 |
+
"toilet seat",
|
| 865 |
+
"torch",
|
| 866 |
+
"totem pole",
|
| 867 |
+
"tow truck",
|
| 868 |
+
"toy store",
|
| 869 |
+
"tractor",
|
| 870 |
+
"semi-trailer truck",
|
| 871 |
+
"tray",
|
| 872 |
+
"trench coat",
|
| 873 |
+
"tricycle",
|
| 874 |
+
"trimaran",
|
| 875 |
+
"tripod",
|
| 876 |
+
"triumphal arch",
|
| 877 |
+
"trolleybus",
|
| 878 |
+
"trombone",
|
| 879 |
+
"hot tub",
|
| 880 |
+
"turnstile",
|
| 881 |
+
"typewriter keyboard",
|
| 882 |
+
"umbrella",
|
| 883 |
+
"unicycle",
|
| 884 |
+
"upright piano",
|
| 885 |
+
"vacuum cleaner",
|
| 886 |
+
"vase",
|
| 887 |
+
"vaulted or arched ceiling",
|
| 888 |
+
"velvet fabric",
|
| 889 |
+
"vending machine",
|
| 890 |
+
"vestment",
|
| 891 |
+
"viaduct",
|
| 892 |
+
"violin",
|
| 893 |
+
"volleyball",
|
| 894 |
+
"waffle iron",
|
| 895 |
+
"wall clock",
|
| 896 |
+
"wallet",
|
| 897 |
+
"wardrobe",
|
| 898 |
+
"military aircraft",
|
| 899 |
+
"sink",
|
| 900 |
+
"washing machine",
|
| 901 |
+
"water bottle",
|
| 902 |
+
"water jug",
|
| 903 |
+
"water tower",
|
| 904 |
+
"whiskey jug",
|
| 905 |
+
"whistle",
|
| 906 |
+
"hair wig",
|
| 907 |
+
"window screen",
|
| 908 |
+
"window shade",
|
| 909 |
+
"Windsor tie",
|
| 910 |
+
"wine bottle",
|
| 911 |
+
"airplane wing",
|
| 912 |
+
"wok",
|
| 913 |
+
"wooden spoon",
|
| 914 |
+
"wool",
|
| 915 |
+
"split-rail fence",
|
| 916 |
+
"shipwreck",
|
| 917 |
+
"sailboat",
|
| 918 |
+
"yurt",
|
| 919 |
+
"website",
|
| 920 |
+
"comic book",
|
| 921 |
+
"crossword",
|
| 922 |
+
"traffic or street sign",
|
| 923 |
+
"traffic light",
|
| 924 |
+
"dust jacket",
|
| 925 |
+
"menu",
|
| 926 |
+
"plate",
|
| 927 |
+
"guacamole",
|
| 928 |
+
"consomme",
|
| 929 |
+
"hot pot",
|
| 930 |
+
"trifle",
|
| 931 |
+
"ice cream",
|
| 932 |
+
"popsicle",
|
| 933 |
+
"baguette",
|
| 934 |
+
"bagel",
|
| 935 |
+
"pretzel",
|
| 936 |
+
"cheeseburger",
|
| 937 |
+
"hot dog",
|
| 938 |
+
"mashed potatoes",
|
| 939 |
+
"cabbage",
|
| 940 |
+
"broccoli",
|
| 941 |
+
"cauliflower",
|
| 942 |
+
"zucchini",
|
| 943 |
+
"spaghetti squash",
|
| 944 |
+
"acorn squash",
|
| 945 |
+
"butternut squash",
|
| 946 |
+
"cucumber",
|
| 947 |
+
"artichoke",
|
| 948 |
+
"bell pepper",
|
| 949 |
+
"cardoon",
|
| 950 |
+
"mushroom",
|
| 951 |
+
"Granny Smith apple",
|
| 952 |
+
"strawberry",
|
| 953 |
+
"orange",
|
| 954 |
+
"lemon",
|
| 955 |
+
"fig",
|
| 956 |
+
"pineapple",
|
| 957 |
+
"banana",
|
| 958 |
+
"jackfruit",
|
| 959 |
+
"cherimoya (custard apple)",
|
| 960 |
+
"pomegranate",
|
| 961 |
+
"hay",
|
| 962 |
+
"carbonara",
|
| 963 |
+
"chocolate syrup",
|
| 964 |
+
"dough",
|
| 965 |
+
"meatloaf",
|
| 966 |
+
"pizza",
|
| 967 |
+
"pot pie",
|
| 968 |
+
"burrito",
|
| 969 |
+
"red wine",
|
| 970 |
+
"espresso",
|
| 971 |
+
"tea cup",
|
| 972 |
+
"eggnog",
|
| 973 |
+
"mountain",
|
| 974 |
+
"bubble",
|
| 975 |
+
"cliff",
|
| 976 |
+
"coral reef",
|
| 977 |
+
"geyser",
|
| 978 |
+
"lakeshore",
|
| 979 |
+
"promontory",
|
| 980 |
+
"sandbar",
|
| 981 |
+
"beach",
|
| 982 |
+
"valley",
|
| 983 |
+
"volcano",
|
| 984 |
+
"baseball player",
|
| 985 |
+
"bridegroom",
|
| 986 |
+
"scuba diver",
|
| 987 |
+
"rapeseed",
|
| 988 |
+
"daisy",
|
| 989 |
+
"yellow lady's slipper",
|
| 990 |
+
"corn",
|
| 991 |
+
"acorn",
|
| 992 |
+
"rose hip",
|
| 993 |
+
"horse chestnut seed",
|
| 994 |
+
"coral fungus",
|
| 995 |
+
"agaric",
|
| 996 |
+
"gyromitra",
|
| 997 |
+
"stinkhorn mushroom",
|
| 998 |
+
"earth star fungus",
|
| 999 |
+
"hen of the woods mushroom",
|
| 1000 |
+
"bolete",
|
| 1001 |
+
"corn cob",
|
| 1002 |
+
"toilet paper",
|
| 1003 |
+
]
|
| 1004 |
+
|
| 1005 |
+
HM_CLASSNAMES = [
|
| 1006 |
+
"no",
|
| 1007 |
+
"yes",
|
| 1008 |
+
]
|
open_flamingo/eval/coco_metric.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pycocoevalcap.eval import COCOEvalCap
|
| 2 |
+
from pycocotools.coco import COCO
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def compute_cider(
|
| 6 |
+
result_path,
|
| 7 |
+
annotations_path,
|
| 8 |
+
):
|
| 9 |
+
# create coco object and coco_result object
|
| 10 |
+
coco = COCO(annotations_path)
|
| 11 |
+
coco_result = coco.loadRes(result_path)
|
| 12 |
+
|
| 13 |
+
# create coco_eval object by taking coco and coco_result
|
| 14 |
+
coco_eval = COCOEvalCap(coco, coco_result)
|
| 15 |
+
coco_eval.params["image_id"] = coco_result.getImgIds()
|
| 16 |
+
coco_eval.evaluate()
|
| 17 |
+
|
| 18 |
+
return coco_eval.eval
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def postprocess_captioning_generation(predictions):
|
| 22 |
+
return predictions.split("Output", 1)[0]
|
open_flamingo/eval/data/textvqa/train_questions_vqa_format.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
open_flamingo/eval/data/textvqa/val_annotations_vqa_format.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
open_flamingo/eval/data/textvqa/val_questions_vqa_format.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
open_flamingo/eval/data/vizwiz/test_questions_vqa_format.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
open_flamingo/eval/data/vizwiz/train_questions_vqa_format.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
open_flamingo/eval/data/vizwiz/val_annotations_vqa_format.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
open_flamingo/eval/data/vizwiz/val_questions_vqa_format.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
open_flamingo/eval/eval_datasets.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
from torchvision.datasets import ImageFolder
|
| 7 |
+
|
| 8 |
+
from open_flamingo.eval.classification_utils import IMAGENET_CLASSNAMES
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CaptionDataset(Dataset):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
image_train_dir_path,
|
| 15 |
+
annotations_path,
|
| 16 |
+
is_train,
|
| 17 |
+
dataset_name,
|
| 18 |
+
image_val_dir_path=None,
|
| 19 |
+
):
|
| 20 |
+
self.image_train_dir_path = image_train_dir_path
|
| 21 |
+
self.image_val_dir_path = image_val_dir_path
|
| 22 |
+
self.annotations = []
|
| 23 |
+
self.is_train = is_train
|
| 24 |
+
self.dataset_name = dataset_name
|
| 25 |
+
|
| 26 |
+
full_annotations = json.load(open(annotations_path))["images"]
|
| 27 |
+
|
| 28 |
+
for i in range(len(full_annotations)):
|
| 29 |
+
if self.is_train and full_annotations[i]["split"] != "train":
|
| 30 |
+
continue
|
| 31 |
+
elif not self.is_train and full_annotations[i]["split"] != "test":
|
| 32 |
+
continue
|
| 33 |
+
|
| 34 |
+
self.annotations.append(full_annotations[i])
|
| 35 |
+
|
| 36 |
+
def __len__(self):
|
| 37 |
+
return len(self.annotations)
|
| 38 |
+
|
| 39 |
+
def __getitem__(self, idx):
|
| 40 |
+
if self.dataset_name == "coco":
|
| 41 |
+
image = Image.open(
|
| 42 |
+
os.path.join(
|
| 43 |
+
self.image_train_dir_path, self.annotations[idx]["filename"]
|
| 44 |
+
)
|
| 45 |
+
if self.annotations[idx]["filepath"] == "train2014"
|
| 46 |
+
else os.path.join(
|
| 47 |
+
self.image_val_dir_path, self.annotations[idx]["filename"]
|
| 48 |
+
)
|
| 49 |
+
)
|
| 50 |
+
elif self.dataset_name == "flickr":
|
| 51 |
+
image = Image.open(
|
| 52 |
+
os.path.join(
|
| 53 |
+
self.image_train_dir_path, self.annotations[idx]["filename"]
|
| 54 |
+
)
|
| 55 |
+
)
|
| 56 |
+
image.load()
|
| 57 |
+
caption = self.annotations[idx]["sentences"][0]["raw"]
|
| 58 |
+
return {
|
| 59 |
+
"image": image,
|
| 60 |
+
"caption": caption,
|
| 61 |
+
"image_id": self.annotations[idx]["cocoid"]
|
| 62 |
+
if self.dataset_name == "coco"
|
| 63 |
+
else self.annotations[idx]["filename"].split(".")[0],
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class VQADataset(Dataset):
|
| 68 |
+
def __init__(
|
| 69 |
+
self, image_dir_path, question_path, annotations_path, is_train, dataset_name
|
| 70 |
+
):
|
| 71 |
+
self.questions = json.load(open(question_path, "r"))["questions"]
|
| 72 |
+
if annotations_path is not None:
|
| 73 |
+
self.answers = json.load(open(annotations_path, "r"))["annotations"]
|
| 74 |
+
else:
|
| 75 |
+
self.answers = None
|
| 76 |
+
self.image_dir_path = image_dir_path
|
| 77 |
+
self.is_train = is_train
|
| 78 |
+
self.dataset_name = dataset_name
|
| 79 |
+
if self.dataset_name in {"vqav2", "ok_vqa"}:
|
| 80 |
+
self.img_coco_split = self.image_dir_path.strip("/").split("/")[-1]
|
| 81 |
+
assert self.img_coco_split in {"train2014", "val2014", "test2015"}
|
| 82 |
+
|
| 83 |
+
def __len__(self):
|
| 84 |
+
return len(self.questions)
|
| 85 |
+
|
| 86 |
+
def get_img_path(self, question):
|
| 87 |
+
if self.dataset_name in {"vqav2", "ok_vqa"}:
|
| 88 |
+
return os.path.join(
|
| 89 |
+
self.image_dir_path,
|
| 90 |
+
f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg"
|
| 91 |
+
if self.is_train
|
| 92 |
+
else f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg",
|
| 93 |
+
)
|
| 94 |
+
elif self.dataset_name == "vizwiz":
|
| 95 |
+
return os.path.join(self.image_dir_path, question["image_id"])
|
| 96 |
+
elif self.dataset_name == "textvqa":
|
| 97 |
+
return os.path.join(self.image_dir_path, f"{question['image_id']}.jpg")
|
| 98 |
+
else:
|
| 99 |
+
raise Exception(f"Unknown VQA dataset {self.dataset_name}")
|
| 100 |
+
|
| 101 |
+
def __getitem__(self, idx):
|
| 102 |
+
question = self.questions[idx]
|
| 103 |
+
img_path = self.get_img_path(question)
|
| 104 |
+
image = Image.open(img_path)
|
| 105 |
+
image.load()
|
| 106 |
+
results = {
|
| 107 |
+
"image": image,
|
| 108 |
+
"question": question["question"],
|
| 109 |
+
"question_id": question["question_id"],
|
| 110 |
+
}
|
| 111 |
+
if self.answers is not None:
|
| 112 |
+
answers = self.answers[idx]
|
| 113 |
+
results["answers"] = [a["answer"] for a in answers["answers"]]
|
| 114 |
+
return results
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class ImageNetDataset(ImageFolder):
|
| 118 |
+
"""Class to represent the ImageNet1k dataset."""
|
| 119 |
+
|
| 120 |
+
def __init__(self, root, **kwargs):
|
| 121 |
+
super().__init__(root=root, **kwargs)
|
| 122 |
+
self.class_id_to_name = dict(
|
| 123 |
+
zip(range(len(IMAGENET_CLASSNAMES)), IMAGENET_CLASSNAMES)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def __getitem__(self, idx):
|
| 127 |
+
sample, target = super().__getitem__(idx)
|
| 128 |
+
target_label = self.class_id_to_name[target]
|
| 129 |
+
return {
|
| 130 |
+
"id": idx,
|
| 131 |
+
"image": sample,
|
| 132 |
+
"class_id": target, # numeric ID of the ImageNet class
|
| 133 |
+
"class_name": target_label, # human-readable name of ImageNet class
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class HatefulMemesDataset(Dataset):
|
| 138 |
+
def __init__(self, image_dir_path, annotations_path):
|
| 139 |
+
self.image_dir_path = image_dir_path
|
| 140 |
+
with open(annotations_path, "r") as f:
|
| 141 |
+
self.annotations = [json.loads(line) for line in f]
|
| 142 |
+
|
| 143 |
+
def __len__(self):
|
| 144 |
+
return len(self.annotations)
|
| 145 |
+
|
| 146 |
+
def __getitem__(self, idx):
|
| 147 |
+
annotation = self.annotations[idx]
|
| 148 |
+
img_path = os.path.join(self.image_dir_path, annotation["img"].split("/")[-1])
|
| 149 |
+
image = Image.open(img_path)
|
| 150 |
+
image.load()
|
| 151 |
+
return {
|
| 152 |
+
"id": annotation["id"],
|
| 153 |
+
"image": image,
|
| 154 |
+
"ocr": annotation["text"],
|
| 155 |
+
"class_name": "yes" if annotation["label"] == 1 else "no",
|
| 156 |
+
"class_id": annotation["label"],
|
| 157 |
+
}
|
open_flamingo/eval/eval_model.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import argparse
|
| 3 |
+
from typing import List
|
| 4 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BaseEvalModel(abc.ABC):
|
| 9 |
+
"""Base class encapsulating functionality needed to evaluate a model."""
|
| 10 |
+
|
| 11 |
+
def __init__(self, args: List[str]):
|
| 12 |
+
"""Initialize model.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
args: arguments to model. These should be parsed, or if the model
|
| 16 |
+
has no applicable arguments, an error should be thrown if `args`
|
| 17 |
+
is non-empty.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def init_distributed(self):
|
| 21 |
+
"""Wrap model as DDP."""
|
| 22 |
+
self.model = DDP(self.model, device_ids=[self.device])
|
| 23 |
+
|
| 24 |
+
def set_device(self, device):
|
| 25 |
+
"""Set device for model."""
|
| 26 |
+
self.device = device
|
| 27 |
+
self.model = self.model.to(device)
|
| 28 |
+
|
| 29 |
+
def get_outputs(
|
| 30 |
+
self,
|
| 31 |
+
batch_text: List[str],
|
| 32 |
+
batch_images: List[List[Image.Image]],
|
| 33 |
+
min_generation_length: int,
|
| 34 |
+
max_generation_length: int,
|
| 35 |
+
num_beams: int,
|
| 36 |
+
length_penalty: float,
|
| 37 |
+
) -> List[str]:
|
| 38 |
+
"""Get outputs for a batch of images and text.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
batch_text: list of text strings, with the text "<image>" in place
|
| 42 |
+
of any images to be included.
|
| 43 |
+
batch_images: images to provide to model. Should be a list of lists,
|
| 44 |
+
where each list contains the images for a single example.
|
| 45 |
+
max_generation_length: maximum length of the generated caption.
|
| 46 |
+
Defaults to 10.
|
| 47 |
+
num_beams: number of beams to use for beam search. Defaults to 3.
|
| 48 |
+
length_penalty: length penalty for beam search. Defaults to -2.0.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
List of decoded output strings.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def vqa_prompt(self, question, answer=None) -> str:
|
| 55 |
+
"""Get the prompt to use for VQA evaluation. If the answer is not provided, it should be left blank to be generated by the model.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
The prompt to use for VQA.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def caption_prompt(self, caption=None) -> str:
|
| 62 |
+
"""Get the prompt to use for caption evaluation. If the caption is not provided, it should be left blank to be generated by the model.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
The prompt to use for captioning.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def get_rank_classifications(
|
| 69 |
+
self,
|
| 70 |
+
batch_text: List[str],
|
| 71 |
+
batch_images: List[List[Image.Image]],
|
| 72 |
+
all_class_names: List[str],
|
| 73 |
+
use_cache: bool,
|
| 74 |
+
normalize_length: bool,
|
| 75 |
+
):
|
| 76 |
+
"""
|
| 77 |
+
Returns a (B, |all_class_names|) tensor containing the logprobs for each class name.
|
| 78 |
+
Args:
|
| 79 |
+
batch_text: list of text strings, with the text "<image>" in place
|
| 80 |
+
of any images to be included.
|
| 81 |
+
batch_images: images to provide to model. Should be a list of lists,
|
| 82 |
+
where each list contains the images for a single example.
|
| 83 |
+
all_class_names: list of all class names.
|
| 84 |
+
use_cache: whether to cache the context to speed up evaluations.
|
| 85 |
+
normalize_length: whether to normalize logprobs by the length of the
|
| 86 |
+
class name
|
| 87 |
+
Returns:
|
| 88 |
+
(B, |all_class_names|) tensor containing the logprobs for each class name.
|
| 89 |
+
"""
|
open_flamingo/eval/evaluate.py
ADDED
|
@@ -0,0 +1,1301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import importlib
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import uuid
|
| 6 |
+
import random
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from sklearn.metrics import roc_auc_score
|
| 12 |
+
import utils
|
| 13 |
+
import math
|
| 14 |
+
|
| 15 |
+
from coco_metric import compute_cider, postprocess_captioning_generation
|
| 16 |
+
from eval_datasets import (
|
| 17 |
+
CaptionDataset,
|
| 18 |
+
VQADataset,
|
| 19 |
+
ImageNetDataset,
|
| 20 |
+
HatefulMemesDataset,
|
| 21 |
+
)
|
| 22 |
+
from rices import RICES
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
from classification_utils import (
|
| 27 |
+
IMAGENET_CLASSNAMES,
|
| 28 |
+
HM_CLASSNAMES,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
from eval_model import BaseEvalModel
|
| 32 |
+
|
| 33 |
+
from ok_vqa_utils import postprocess_ok_vqa_generation
|
| 34 |
+
from open_flamingo.src.flamingo import Flamingo
|
| 35 |
+
from vqa_metric import compute_vqa_accuracy, postprocess_vqa_generation
|
| 36 |
+
|
| 37 |
+
from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
|
| 38 |
+
|
| 39 |
+
parser = argparse.ArgumentParser()
|
| 40 |
+
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--model",
|
| 43 |
+
type=str,
|
| 44 |
+
help="Model name. Currently only `OpenFlamingo` is supported.",
|
| 45 |
+
default="open_flamingo",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--results_file", type=str, default=None, help="JSON file to save results"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Trial arguments
|
| 52 |
+
parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--num_trials",
|
| 55 |
+
type=int,
|
| 56 |
+
default=1,
|
| 57 |
+
help="Number of trials to run for each shot using different demonstrations",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--trial_seeds",
|
| 61 |
+
nargs="+",
|
| 62 |
+
type=int,
|
| 63 |
+
default=[42],
|
| 64 |
+
help="Seeds to use for each trial for picking demonstrations and eval sets",
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--num_samples",
|
| 68 |
+
type=int,
|
| 69 |
+
default=-1,
|
| 70 |
+
help="Number of samples to evaluate on. -1 for all samples.",
|
| 71 |
+
)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--query_set_size", type=int, default=2048, help="Size of demonstration query set"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
| 77 |
+
|
| 78 |
+
parser.add_argument(
|
| 79 |
+
"--no_caching_for_classification",
|
| 80 |
+
action="store_true",
|
| 81 |
+
help="Whether to skip using key-value caching for classification evals, which usually speeds it up.",
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--classification_prompt_ensembling",
|
| 85 |
+
action="store_true",
|
| 86 |
+
help="Whether to use prompt ensembling (average log-likelihoods over permutations of in-context examples)",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--rices",
|
| 90 |
+
action="store_true",
|
| 91 |
+
help="Whether to use RICES for evaluation. If False, uses random demonstrations.",
|
| 92 |
+
)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--rices_vision_encoder_path",
|
| 95 |
+
default="ViT-L-14",
|
| 96 |
+
type=str,
|
| 97 |
+
help="CLIP vision encoder to use for RICES if cached_demonstration_features is None.",
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--rices_vision_encoder_pretrained",
|
| 101 |
+
default="openai",
|
| 102 |
+
type=str,
|
| 103 |
+
help="CLIP vision encoder to use for RICES if cached_demonstration_features is None.",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--cached_demonstration_features",
|
| 107 |
+
default=None,
|
| 108 |
+
help="Directory where rices features for all choices of in-context examples are stored as a pkl file with the dataset name. If None, features are re-computed by script.",
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Per-dataset evaluation flags
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--eval_coco",
|
| 114 |
+
action="store_true",
|
| 115 |
+
default=False,
|
| 116 |
+
help="Whether to evaluate on COCO.",
|
| 117 |
+
)
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--eval_vqav2",
|
| 120 |
+
action="store_true",
|
| 121 |
+
default=False,
|
| 122 |
+
help="Whether to evaluate on VQAV2.",
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--eval_ok_vqa",
|
| 126 |
+
action="store_true",
|
| 127 |
+
default=False,
|
| 128 |
+
help="Whether to evaluate on OK-VQA.",
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--eval_vizwiz",
|
| 132 |
+
action="store_true",
|
| 133 |
+
default=False,
|
| 134 |
+
help="Whether to evaluate on VizWiz.",
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--eval_textvqa",
|
| 138 |
+
action="store_true",
|
| 139 |
+
default=False,
|
| 140 |
+
help="Whether to evaluate on TextVQA.",
|
| 141 |
+
)
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--eval_imagenet",
|
| 144 |
+
action="store_true",
|
| 145 |
+
default=False,
|
| 146 |
+
help="Whether to evaluate on ImageNet.",
|
| 147 |
+
)
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
"--eval_flickr30",
|
| 150 |
+
action="store_true",
|
| 151 |
+
default=False,
|
| 152 |
+
help="Whether to evaluate on Flickr30.",
|
| 153 |
+
)
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"--eval_hateful_memes",
|
| 156 |
+
action="store_true",
|
| 157 |
+
default=False,
|
| 158 |
+
help="Whether to evaluate on Hateful Memes.",
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Dataset arguments
|
| 162 |
+
|
| 163 |
+
## Flickr30 Dataset
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--flickr_image_dir_path",
|
| 166 |
+
type=str,
|
| 167 |
+
help="Path to the flickr30/flickr30k_images directory.",
|
| 168 |
+
default=None,
|
| 169 |
+
)
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--flickr_karpathy_json_path",
|
| 172 |
+
type=str,
|
| 173 |
+
help="Path to the dataset_flickr30k.json file.",
|
| 174 |
+
default=None,
|
| 175 |
+
)
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--flickr_annotations_json_path",
|
| 178 |
+
type=str,
|
| 179 |
+
help="Path to the dataset_flickr30k_coco_style.json file.",
|
| 180 |
+
)
|
| 181 |
+
## COCO Dataset
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"--coco_train_image_dir_path",
|
| 184 |
+
type=str,
|
| 185 |
+
default=None,
|
| 186 |
+
)
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--coco_val_image_dir_path",
|
| 189 |
+
type=str,
|
| 190 |
+
default=None,
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--coco_karpathy_json_path",
|
| 194 |
+
type=str,
|
| 195 |
+
default=None,
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--coco_annotations_json_path",
|
| 199 |
+
type=str,
|
| 200 |
+
default=None,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
## VQAV2 Dataset
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--vqav2_train_image_dir_path",
|
| 206 |
+
type=str,
|
| 207 |
+
default=None,
|
| 208 |
+
)
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--vqav2_train_questions_json_path",
|
| 211 |
+
type=str,
|
| 212 |
+
default=None,
|
| 213 |
+
)
|
| 214 |
+
parser.add_argument(
|
| 215 |
+
"--vqav2_train_annotations_json_path",
|
| 216 |
+
type=str,
|
| 217 |
+
default=None,
|
| 218 |
+
)
|
| 219 |
+
parser.add_argument(
|
| 220 |
+
"--vqav2_test_image_dir_path",
|
| 221 |
+
type=str,
|
| 222 |
+
default=None,
|
| 223 |
+
)
|
| 224 |
+
parser.add_argument(
|
| 225 |
+
"--vqav2_test_questions_json_path",
|
| 226 |
+
type=str,
|
| 227 |
+
default=None,
|
| 228 |
+
)
|
| 229 |
+
parser.add_argument(
|
| 230 |
+
"--vqav2_test_annotations_json_path",
|
| 231 |
+
type=str,
|
| 232 |
+
default=None,
|
| 233 |
+
)
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--vqav2_final_test_questions_json_path",
|
| 236 |
+
type=str,
|
| 237 |
+
help="Path to the v2_OpenEnded_mscoco_test2015_questions.json file containing all test questions. This is required to format the predictions for EvalAI.",
|
| 238 |
+
default=None,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
## OK-VQA Dataset
|
| 242 |
+
parser.add_argument(
|
| 243 |
+
"--ok_vqa_train_image_dir_path",
|
| 244 |
+
type=str,
|
| 245 |
+
help="Path to the vqav2/train2014 directory.",
|
| 246 |
+
default=None,
|
| 247 |
+
)
|
| 248 |
+
parser.add_argument(
|
| 249 |
+
"--ok_vqa_train_questions_json_path",
|
| 250 |
+
type=str,
|
| 251 |
+
help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
|
| 252 |
+
default=None,
|
| 253 |
+
)
|
| 254 |
+
parser.add_argument(
|
| 255 |
+
"--ok_vqa_train_annotations_json_path",
|
| 256 |
+
type=str,
|
| 257 |
+
help="Path to the v2_mscoco_train2014_annotations.json file.",
|
| 258 |
+
default=None,
|
| 259 |
+
)
|
| 260 |
+
parser.add_argument(
|
| 261 |
+
"--ok_vqa_test_image_dir_path",
|
| 262 |
+
type=str,
|
| 263 |
+
help="Path to the vqav2/val2014 directory.",
|
| 264 |
+
default=None,
|
| 265 |
+
)
|
| 266 |
+
parser.add_argument(
|
| 267 |
+
"--ok_vqa_test_questions_json_path",
|
| 268 |
+
type=str,
|
| 269 |
+
help="Path to the v2_OpenEnded_mscoco_val2014_questions.json file.",
|
| 270 |
+
default=None,
|
| 271 |
+
)
|
| 272 |
+
parser.add_argument(
|
| 273 |
+
"--ok_vqa_test_annotations_json_path",
|
| 274 |
+
type=str,
|
| 275 |
+
help="Path to the v2_mscoco_val2014_annotations.json file.",
|
| 276 |
+
default=None,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
## VizWiz Dataset
|
| 280 |
+
parser.add_argument(
|
| 281 |
+
"--vizwiz_train_image_dir_path",
|
| 282 |
+
type=str,
|
| 283 |
+
help="Path to the vizwiz train images directory.",
|
| 284 |
+
default=None,
|
| 285 |
+
)
|
| 286 |
+
parser.add_argument(
|
| 287 |
+
"--vizwiz_test_image_dir_path",
|
| 288 |
+
type=str,
|
| 289 |
+
help="Path to the vizwiz test images directory.",
|
| 290 |
+
default=None,
|
| 291 |
+
)
|
| 292 |
+
parser.add_argument(
|
| 293 |
+
"--vizwiz_train_questions_json_path",
|
| 294 |
+
type=str,
|
| 295 |
+
help="Path to the vizwiz questions json file.",
|
| 296 |
+
default=None,
|
| 297 |
+
)
|
| 298 |
+
parser.add_argument(
|
| 299 |
+
"--vizwiz_train_annotations_json_path",
|
| 300 |
+
type=str,
|
| 301 |
+
help="Path to the vizwiz annotations json file.",
|
| 302 |
+
default=None,
|
| 303 |
+
)
|
| 304 |
+
parser.add_argument(
|
| 305 |
+
"--vizwiz_test_questions_json_path",
|
| 306 |
+
type=str,
|
| 307 |
+
help="Path to the vizwiz questions json file.",
|
| 308 |
+
default=None,
|
| 309 |
+
)
|
| 310 |
+
parser.add_argument(
|
| 311 |
+
"--vizwiz_test_annotations_json_path",
|
| 312 |
+
type=str,
|
| 313 |
+
help="Path to the vizwiz annotations json file.",
|
| 314 |
+
default=None,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
# TextVQA Dataset
|
| 318 |
+
parser.add_argument(
|
| 319 |
+
"--textvqa_image_dir_path",
|
| 320 |
+
type=str,
|
| 321 |
+
help="Path to the textvqa images directory.",
|
| 322 |
+
default=None,
|
| 323 |
+
)
|
| 324 |
+
parser.add_argument(
|
| 325 |
+
"--textvqa_train_questions_json_path",
|
| 326 |
+
type=str,
|
| 327 |
+
help="Path to the textvqa questions json file.",
|
| 328 |
+
default=None,
|
| 329 |
+
)
|
| 330 |
+
parser.add_argument(
|
| 331 |
+
"--textvqa_train_annotations_json_path",
|
| 332 |
+
type=str,
|
| 333 |
+
help="Path to the textvqa annotations json file.",
|
| 334 |
+
default=None,
|
| 335 |
+
)
|
| 336 |
+
parser.add_argument(
|
| 337 |
+
"--textvqa_test_questions_json_path",
|
| 338 |
+
type=str,
|
| 339 |
+
help="Path to the textvqa questions json file.",
|
| 340 |
+
default=None,
|
| 341 |
+
)
|
| 342 |
+
parser.add_argument(
|
| 343 |
+
"--textvqa_test_annotations_json_path",
|
| 344 |
+
type=str,
|
| 345 |
+
help="Path to the textvqa annotations json file.",
|
| 346 |
+
default=None,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
## Imagenet dataset
|
| 350 |
+
parser.add_argument("--imagenet_root", type=str, default="/tmp")
|
| 351 |
+
|
| 352 |
+
## Hateful Memes dataset
|
| 353 |
+
parser.add_argument(
|
| 354 |
+
"--hateful_memes_image_dir_path",
|
| 355 |
+
type=str,
|
| 356 |
+
default=None,
|
| 357 |
+
)
|
| 358 |
+
parser.add_argument(
|
| 359 |
+
"--hateful_memes_train_annotations_json_path",
|
| 360 |
+
type=str,
|
| 361 |
+
default=None,
|
| 362 |
+
)
|
| 363 |
+
parser.add_argument(
|
| 364 |
+
"--hateful_memes_test_annotations_json_path",
|
| 365 |
+
type=str,
|
| 366 |
+
default=None,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Distributed evaluation
|
| 370 |
+
parser.add_argument(
|
| 371 |
+
"--dist-url",
|
| 372 |
+
default="env://",
|
| 373 |
+
type=str,
|
| 374 |
+
help="url used to set up distributed training",
|
| 375 |
+
)
|
| 376 |
+
parser.add_argument(
|
| 377 |
+
"--dist-backend", default="nccl", type=str, help="distributed backend"
|
| 378 |
+
)
|
| 379 |
+
parser.add_argument(
|
| 380 |
+
"--horovod",
|
| 381 |
+
default=False,
|
| 382 |
+
action="store_true",
|
| 383 |
+
help="Use horovod for distributed training.",
|
| 384 |
+
)
|
| 385 |
+
parser.add_argument(
|
| 386 |
+
"--no-set-device-rank",
|
| 387 |
+
default=False,
|
| 388 |
+
action="store_true",
|
| 389 |
+
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def main():
|
| 394 |
+
args, leftovers = parser.parse_known_args()
|
| 395 |
+
module = importlib.import_module(f"open_flamingo.eval.models.{args.model}")
|
| 396 |
+
|
| 397 |
+
model_args = {
|
| 398 |
+
leftovers[i].lstrip("-"): leftovers[i + 1] for i in range(0, len(leftovers), 2)
|
| 399 |
+
}
|
| 400 |
+
eval_model = module.EvalModel(model_args)
|
| 401 |
+
|
| 402 |
+
# set up distributed evaluation
|
| 403 |
+
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
| 404 |
+
device_id = init_distributed_device(args)
|
| 405 |
+
eval_model.set_device(device_id)
|
| 406 |
+
eval_model.init_distributed()
|
| 407 |
+
|
| 408 |
+
if args.model != "open_flamingo" and args.shots != [0]:
|
| 409 |
+
raise ValueError("Only 0 shot eval is supported for non-open_flamingo models")
|
| 410 |
+
|
| 411 |
+
if len(args.trial_seeds) != args.num_trials:
|
| 412 |
+
raise ValueError("Number of trial seeds must be == number of trials.")
|
| 413 |
+
|
| 414 |
+
results = defaultdict(list)
|
| 415 |
+
|
| 416 |
+
if args.eval_flickr30:
|
| 417 |
+
print("Evaluating on Flickr30k...")
|
| 418 |
+
|
| 419 |
+
# load cached demonstration features for RICES
|
| 420 |
+
if args.cached_demonstration_features is not None:
|
| 421 |
+
cached_features = torch.load(
|
| 422 |
+
f"{args.cached_demonstration_features}/flickr30.pkl", map_location="cpu"
|
| 423 |
+
)
|
| 424 |
+
else:
|
| 425 |
+
cached_features = None
|
| 426 |
+
|
| 427 |
+
for shot in args.shots:
|
| 428 |
+
scores = []
|
| 429 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
| 430 |
+
cider_score = evaluate_captioning(
|
| 431 |
+
args,
|
| 432 |
+
eval_model=eval_model,
|
| 433 |
+
num_shots=shot,
|
| 434 |
+
seed=seed,
|
| 435 |
+
dataset_name="flickr",
|
| 436 |
+
cached_features=cached_features,
|
| 437 |
+
)
|
| 438 |
+
if args.rank == 0:
|
| 439 |
+
print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}")
|
| 440 |
+
scores.append(cider_score)
|
| 441 |
+
|
| 442 |
+
if args.rank == 0:
|
| 443 |
+
print(f"Shots {shot} Mean CIDEr score: {np.nanmean(scores)}")
|
| 444 |
+
results["flickr30"].append(
|
| 445 |
+
{
|
| 446 |
+
"shots": shot,
|
| 447 |
+
"trials": scores,
|
| 448 |
+
"mean": np.nanmean(scores),
|
| 449 |
+
"stddev": np.nanstd(scores),
|
| 450 |
+
}
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
if args.eval_coco:
|
| 454 |
+
print("Evaluating on COCO...")
|
| 455 |
+
|
| 456 |
+
# load cached demonstration features for RICES
|
| 457 |
+
if args.cached_demonstration_features is not None:
|
| 458 |
+
cached_features = torch.load(
|
| 459 |
+
f"{args.cached_demonstration_features}/coco.pkl", map_location="cpu"
|
| 460 |
+
)
|
| 461 |
+
else:
|
| 462 |
+
cached_features = None
|
| 463 |
+
|
| 464 |
+
for shot in args.shots:
|
| 465 |
+
scores = []
|
| 466 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
| 467 |
+
cider_score = evaluate_captioning(
|
| 468 |
+
args,
|
| 469 |
+
eval_model=eval_model,
|
| 470 |
+
num_shots=shot,
|
| 471 |
+
seed=seed,
|
| 472 |
+
dataset_name="coco",
|
| 473 |
+
cached_features=cached_features,
|
| 474 |
+
)
|
| 475 |
+
if args.rank == 0:
|
| 476 |
+
print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}")
|
| 477 |
+
scores.append(cider_score)
|
| 478 |
+
|
| 479 |
+
if args.rank == 0:
|
| 480 |
+
print(f"Shots {shot} Mean CIDEr score: {np.nanmean(scores)}")
|
| 481 |
+
results["coco"].append(
|
| 482 |
+
{
|
| 483 |
+
"shots": shot,
|
| 484 |
+
"trials": scores,
|
| 485 |
+
"mean": np.nanmean(scores),
|
| 486 |
+
"stddev": np.nanstd(scores),
|
| 487 |
+
}
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
if args.eval_ok_vqa:
|
| 491 |
+
print("Evaluating on OK-VQA...")
|
| 492 |
+
|
| 493 |
+
# load cached demonstration features for RICES
|
| 494 |
+
if args.cached_demonstration_features is not None:
|
| 495 |
+
cached_features = torch.load(
|
| 496 |
+
f"{args.cached_demonstration_features}/ok_vqa.pkl", map_location="cpu"
|
| 497 |
+
)
|
| 498 |
+
else:
|
| 499 |
+
cached_features = None
|
| 500 |
+
|
| 501 |
+
for shot in args.shots:
|
| 502 |
+
scores = []
|
| 503 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
| 504 |
+
ok_vqa_score = evaluate_vqa(
|
| 505 |
+
args=args,
|
| 506 |
+
eval_model=eval_model,
|
| 507 |
+
num_shots=shot,
|
| 508 |
+
seed=seed,
|
| 509 |
+
dataset_name="ok_vqa",
|
| 510 |
+
cached_features=cached_features,
|
| 511 |
+
)
|
| 512 |
+
if args.rank == 0:
|
| 513 |
+
print(f"Shots {shot} Trial {trial} OK-VQA score: {ok_vqa_score}")
|
| 514 |
+
scores.append(ok_vqa_score)
|
| 515 |
+
|
| 516 |
+
if args.rank == 0:
|
| 517 |
+
print(f"Shots {shot} Mean OK-VQA score: {np.nanmean(scores)}")
|
| 518 |
+
results["ok_vqa"].append(
|
| 519 |
+
{
|
| 520 |
+
"shots": shot,
|
| 521 |
+
"trials": scores,
|
| 522 |
+
"mean": np.nanmean(scores),
|
| 523 |
+
"stddev": np.nanstd(scores),
|
| 524 |
+
}
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
if args.eval_vqav2:
|
| 528 |
+
print("Evaluating on VQAv2...")
|
| 529 |
+
|
| 530 |
+
# load cached demonstration features for RICES
|
| 531 |
+
if args.cached_demonstration_features is not None:
|
| 532 |
+
cached_features = torch.load(
|
| 533 |
+
f"{args.cached_demonstration_features}/vqav2.pkl", map_location="cpu"
|
| 534 |
+
)
|
| 535 |
+
else:
|
| 536 |
+
cached_features = None
|
| 537 |
+
|
| 538 |
+
for shot in args.shots:
|
| 539 |
+
scores = []
|
| 540 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
| 541 |
+
vqa_score = evaluate_vqa(
|
| 542 |
+
args=args,
|
| 543 |
+
eval_model=eval_model,
|
| 544 |
+
num_shots=shot,
|
| 545 |
+
seed=seed,
|
| 546 |
+
dataset_name="vqav2",
|
| 547 |
+
cached_features=cached_features,
|
| 548 |
+
)
|
| 549 |
+
if args.rank == 0 and vqa_score is not None:
|
| 550 |
+
print(f"Shots {shot} Trial {trial} VQA score: {vqa_score}")
|
| 551 |
+
scores.append(vqa_score)
|
| 552 |
+
|
| 553 |
+
if args.rank == 0 and len(scores) > 0:
|
| 554 |
+
print(f"Shots {shot} Mean VQA score: {np.nanmean(scores)}")
|
| 555 |
+
results["vqav2"].append(
|
| 556 |
+
{
|
| 557 |
+
"shots": shot,
|
| 558 |
+
"trials": scores,
|
| 559 |
+
"mean": np.nanmean(scores),
|
| 560 |
+
"stddev": np.nanstd(scores),
|
| 561 |
+
}
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
if args.eval_vizwiz:
|
| 565 |
+
print("Evaluating on VizWiz...")
|
| 566 |
+
|
| 567 |
+
# load cached demonstration features for RICES
|
| 568 |
+
if args.cached_demonstration_features is not None:
|
| 569 |
+
cached_features = torch.load(
|
| 570 |
+
f"{args.cached_demonstration_features}/vizwiz.pkl", map_location="cpu"
|
| 571 |
+
)
|
| 572 |
+
else:
|
| 573 |
+
cached_features = None
|
| 574 |
+
|
| 575 |
+
for shot in args.shots:
|
| 576 |
+
scores = []
|
| 577 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
| 578 |
+
vizwiz_score = evaluate_vqa(
|
| 579 |
+
args=args,
|
| 580 |
+
eval_model=eval_model,
|
| 581 |
+
num_shots=shot,
|
| 582 |
+
seed=seed,
|
| 583 |
+
dataset_name="vizwiz",
|
| 584 |
+
cached_features=cached_features,
|
| 585 |
+
)
|
| 586 |
+
if args.rank == 0 and vizwiz_score is not None:
|
| 587 |
+
print(f"Shots {shot} Trial {trial} VizWiz score: {vizwiz_score}")
|
| 588 |
+
scores.append(vizwiz_score)
|
| 589 |
+
|
| 590 |
+
if args.rank == 0 and len(scores) > 0:
|
| 591 |
+
print(f"Shots {shot} Mean VizWiz score: {np.nanmean(scores)}")
|
| 592 |
+
results["vizwiz"].append(
|
| 593 |
+
{
|
| 594 |
+
"shots": shot,
|
| 595 |
+
"trials": scores,
|
| 596 |
+
"mean": np.nanmean(scores),
|
| 597 |
+
"stddev": np.nanstd(scores),
|
| 598 |
+
}
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
if args.eval_textvqa:
|
| 602 |
+
print("Evaluating on TextVQA...")
|
| 603 |
+
|
| 604 |
+
# load cached demonstration features for RICES
|
| 605 |
+
if args.cached_demonstration_features is not None:
|
| 606 |
+
cached_features = torch.load(
|
| 607 |
+
f"{args.cached_demonstration_features}/textvqa.pkl", map_location="cpu"
|
| 608 |
+
)
|
| 609 |
+
else:
|
| 610 |
+
cached_features = None
|
| 611 |
+
|
| 612 |
+
for shot in args.shots:
|
| 613 |
+
scores = []
|
| 614 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
| 615 |
+
textvqa_score = evaluate_vqa(
|
| 616 |
+
args=args,
|
| 617 |
+
eval_model=eval_model,
|
| 618 |
+
num_shots=shot,
|
| 619 |
+
seed=seed,
|
| 620 |
+
dataset_name="textvqa",
|
| 621 |
+
max_generation_length=10,
|
| 622 |
+
cached_features=cached_features,
|
| 623 |
+
)
|
| 624 |
+
if args.rank == 0:
|
| 625 |
+
print(f"Shots {shot} Trial {trial} TextVQA score: {textvqa_score}")
|
| 626 |
+
scores.append(textvqa_score)
|
| 627 |
+
|
| 628 |
+
if args.rank == 0:
|
| 629 |
+
print(f"Shots {shot} Mean TextVQA score: {np.nanmean(scores)}")
|
| 630 |
+
results["textvqa"].append(
|
| 631 |
+
{
|
| 632 |
+
"shots": shot,
|
| 633 |
+
"trials": scores,
|
| 634 |
+
"mean": np.nanmean(scores),
|
| 635 |
+
"stddev": np.nanstd(scores),
|
| 636 |
+
}
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
if args.eval_imagenet:
|
| 640 |
+
print("Evaluating on ImageNet...")
|
| 641 |
+
|
| 642 |
+
# load cached demonstration features for RICES
|
| 643 |
+
if args.cached_demonstration_features is not None:
|
| 644 |
+
cached_features = torch.load(
|
| 645 |
+
f"{args.cached_demonstration_features}/imagenet.pkl", map_location="cpu"
|
| 646 |
+
)
|
| 647 |
+
else:
|
| 648 |
+
cached_features = None
|
| 649 |
+
|
| 650 |
+
for shot in args.shots:
|
| 651 |
+
scores = []
|
| 652 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
| 653 |
+
imagenet_score = evaluate_classification(
|
| 654 |
+
args,
|
| 655 |
+
eval_model=eval_model,
|
| 656 |
+
num_shots=shot,
|
| 657 |
+
seed=seed,
|
| 658 |
+
no_kv_caching=args.no_caching_for_classification,
|
| 659 |
+
dataset_name="imagenet",
|
| 660 |
+
cached_features=cached_features,
|
| 661 |
+
use_prompt_ensembling=args.classification_prompt_ensembling,
|
| 662 |
+
)
|
| 663 |
+
if args.rank == 0:
|
| 664 |
+
print(
|
| 665 |
+
f"Shots {shot} Trial {trial} "
|
| 666 |
+
f"ImageNet score: {imagenet_score}"
|
| 667 |
+
)
|
| 668 |
+
scores.append(imagenet_score)
|
| 669 |
+
|
| 670 |
+
if args.rank == 0:
|
| 671 |
+
print(f"Shots {shot} Mean ImageNet score: {np.nanmean(scores)}")
|
| 672 |
+
results["imagenet"].append(
|
| 673 |
+
{
|
| 674 |
+
"shots": shot,
|
| 675 |
+
"trials": scores,
|
| 676 |
+
"mean": np.nanmean(scores),
|
| 677 |
+
"stddev": np.nanstd(scores),
|
| 678 |
+
}
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
if args.eval_hateful_memes:
|
| 682 |
+
print("Evaluating on Hateful Memes...")
|
| 683 |
+
|
| 684 |
+
# load cached demonstration features for RICES
|
| 685 |
+
if args.cached_demonstration_features is not None:
|
| 686 |
+
cached_features = torch.load(
|
| 687 |
+
f"{args.cached_demonstration_features}/hateful_memes.pkl",
|
| 688 |
+
map_location="cpu",
|
| 689 |
+
)
|
| 690 |
+
else:
|
| 691 |
+
cached_features = None
|
| 692 |
+
|
| 693 |
+
for shot in args.shots:
|
| 694 |
+
scores = []
|
| 695 |
+
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
|
| 696 |
+
hateful_memes_score = evaluate_classification(
|
| 697 |
+
args,
|
| 698 |
+
eval_model=eval_model,
|
| 699 |
+
num_shots=shot,
|
| 700 |
+
seed=seed,
|
| 701 |
+
no_kv_caching=args.no_caching_for_classification,
|
| 702 |
+
dataset_name="hateful_memes",
|
| 703 |
+
cached_features=cached_features,
|
| 704 |
+
)
|
| 705 |
+
if args.rank == 0:
|
| 706 |
+
print(
|
| 707 |
+
f"Shots {shot} Trial {trial} "
|
| 708 |
+
f"Hateful Memes score: {hateful_memes_score}"
|
| 709 |
+
)
|
| 710 |
+
scores.append(hateful_memes_score)
|
| 711 |
+
|
| 712 |
+
if args.rank == 0:
|
| 713 |
+
print(f"Shots {shot} Mean Hateful Memes score: {np.nanmean(scores)}")
|
| 714 |
+
results["hateful_memes"].append(
|
| 715 |
+
{
|
| 716 |
+
"shots": shot,
|
| 717 |
+
"trials": scores,
|
| 718 |
+
"mean": np.nanmean(scores),
|
| 719 |
+
"stddev": np.nanstd(scores),
|
| 720 |
+
}
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
if args.rank == 0 and args.results_file is not None:
|
| 724 |
+
with open(args.results_file, "w") as f:
|
| 725 |
+
json.dump(results, f)
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def evaluate_captioning(
|
| 729 |
+
args: argparse.Namespace,
|
| 730 |
+
eval_model: BaseEvalModel,
|
| 731 |
+
seed: int = 42,
|
| 732 |
+
min_generation_length: int = 0,
|
| 733 |
+
max_generation_length: int = 20,
|
| 734 |
+
num_beams: int = 3,
|
| 735 |
+
length_penalty: float = 0.0,
|
| 736 |
+
num_shots: int = 8,
|
| 737 |
+
dataset_name: str = "coco",
|
| 738 |
+
cached_features=None,
|
| 739 |
+
):
|
| 740 |
+
"""Evaluate a model on COCO dataset.
|
| 741 |
+
|
| 742 |
+
Args:
|
| 743 |
+
args (argparse.Namespace): arguments
|
| 744 |
+
eval_model (BaseEvalModel): model to evaluate
|
| 745 |
+
seed (int, optional): seed for random number generator. Defaults to 42.
|
| 746 |
+
max_generation_length (int, optional): maximum length of the generated caption. Defaults to 20.
|
| 747 |
+
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
|
| 748 |
+
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
|
| 749 |
+
num_shots (int, optional): number of in-context samples to use. Defaults to 8.
|
| 750 |
+
dataset_name (str, optional): dataset to evaluate on. Can be "coco" or "flickr". Defaults to "coco".
|
| 751 |
+
cached_features (tensor, optional): cached demonstration features for RICES. Defaults to None.
|
| 752 |
+
Returns:
|
| 753 |
+
float: CIDEr score
|
| 754 |
+
|
| 755 |
+
"""
|
| 756 |
+
|
| 757 |
+
if dataset_name == "coco":
|
| 758 |
+
image_train_dir_path = args.coco_train_image_dir_path
|
| 759 |
+
image_val_dir_path = args.coco_val_image_dir_path
|
| 760 |
+
annotations_path = args.coco_karpathy_json_path
|
| 761 |
+
elif dataset_name == "flickr":
|
| 762 |
+
image_train_dir_path = (
|
| 763 |
+
args.flickr_image_dir_path
|
| 764 |
+
) # Note: calling this "train" for consistency with COCO but Flickr only has one split for images
|
| 765 |
+
image_val_dir_path = None
|
| 766 |
+
annotations_path = args.flickr_karpathy_json_path
|
| 767 |
+
else:
|
| 768 |
+
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
| 769 |
+
|
| 770 |
+
train_dataset = CaptionDataset(
|
| 771 |
+
image_train_dir_path=image_train_dir_path,
|
| 772 |
+
image_val_dir_path=image_val_dir_path,
|
| 773 |
+
annotations_path=annotations_path,
|
| 774 |
+
is_train=True,
|
| 775 |
+
dataset_name=dataset_name if dataset_name != "nocaps" else "coco",
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
test_dataset = CaptionDataset(
|
| 779 |
+
image_train_dir_path=image_train_dir_path,
|
| 780 |
+
image_val_dir_path=image_val_dir_path,
|
| 781 |
+
annotations_path=annotations_path,
|
| 782 |
+
is_train=False,
|
| 783 |
+
dataset_name=dataset_name,
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model)
|
| 787 |
+
|
| 788 |
+
np.random.seed(seed)
|
| 789 |
+
test_dataloader = utils.prepare_eval_samples(
|
| 790 |
+
test_dataset,
|
| 791 |
+
args.num_samples if args.num_samples > 0 else len(test_dataset),
|
| 792 |
+
args.batch_size,
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
if args.rices:
|
| 796 |
+
rices_dataset = RICES(
|
| 797 |
+
train_dataset,
|
| 798 |
+
eval_model.device,
|
| 799 |
+
args.batch_size,
|
| 800 |
+
cached_features=cached_features,
|
| 801 |
+
vision_encoder_path=args.rices_vision_encoder_path,
|
| 802 |
+
vision_encoder_pretrained=args.rices_vision_encoder_pretrained,
|
| 803 |
+
)
|
| 804 |
+
else:
|
| 805 |
+
# subset of the training set to sample context images from
|
| 806 |
+
query_set = utils.get_query_set(train_dataset, args.query_set_size)
|
| 807 |
+
|
| 808 |
+
utils.random_seed(seed, args.rank)
|
| 809 |
+
predictions = defaultdict()
|
| 810 |
+
for batch in tqdm(
|
| 811 |
+
test_dataloader,
|
| 812 |
+
desc=f"Running inference {dataset_name.upper()}",
|
| 813 |
+
disable=args.rank != 0,
|
| 814 |
+
):
|
| 815 |
+
if args.rices:
|
| 816 |
+
batch_demo_samples = rices_dataset.find(batch["image"], effective_num_shots)
|
| 817 |
+
else:
|
| 818 |
+
batch_demo_samples = utils.sample_batch_demos_from_query_set(
|
| 819 |
+
query_set, effective_num_shots, len(batch["image"])
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
batch_images, batch_text = [], []
|
| 823 |
+
for i in range(len(batch["image"])):
|
| 824 |
+
if num_shots > 0:
|
| 825 |
+
context_images = [x["image"] for x in batch_demo_samples[i]]
|
| 826 |
+
else:
|
| 827 |
+
context_images = []
|
| 828 |
+
batch_images.append(context_images + [batch["image"][i]])
|
| 829 |
+
|
| 830 |
+
context_text = "".join(
|
| 831 |
+
[
|
| 832 |
+
eval_model.get_caption_prompt(caption=x["caption"].strip()) + "\n"
|
| 833 |
+
for x in batch_demo_samples[i]
|
| 834 |
+
]
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
# Keep the text but remove the image tags for the zero-shot case
|
| 838 |
+
if num_shots == 0:
|
| 839 |
+
context_text = context_text.replace("<image>", "")
|
| 840 |
+
|
| 841 |
+
batch_text.append(context_text + eval_model.get_caption_prompt())
|
| 842 |
+
|
| 843 |
+
outputs = eval_model.get_outputs(
|
| 844 |
+
batch_images=batch_images,
|
| 845 |
+
batch_text=batch_text,
|
| 846 |
+
min_generation_length=min_generation_length,
|
| 847 |
+
max_generation_length=max_generation_length,
|
| 848 |
+
num_beams=num_beams,
|
| 849 |
+
length_penalty=length_penalty,
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
new_predictions = [
|
| 853 |
+
postprocess_captioning_generation(out).replace('"', "") for out in outputs
|
| 854 |
+
]
|
| 855 |
+
|
| 856 |
+
for i, sample_id in enumerate(batch["image_id"]):
|
| 857 |
+
predictions[sample_id] = {
|
| 858 |
+
"caption": new_predictions[i],
|
| 859 |
+
}
|
| 860 |
+
|
| 861 |
+
# all gather
|
| 862 |
+
all_predictions = [None for _ in range(args.world_size)]
|
| 863 |
+
torch.distributed.all_gather_object(all_predictions, predictions) # list of dicts
|
| 864 |
+
|
| 865 |
+
if args.rank != 0:
|
| 866 |
+
return None
|
| 867 |
+
|
| 868 |
+
all_predictions = {
|
| 869 |
+
k: v for d in all_predictions for k, v in d.items()
|
| 870 |
+
} # merge dicts
|
| 871 |
+
|
| 872 |
+
# save the predictions to a temporary file
|
| 873 |
+
results_path = f"{dataset_name}results_{uuid.uuid4()}.json"
|
| 874 |
+
|
| 875 |
+
with open(results_path, "w") as f:
|
| 876 |
+
f.write(
|
| 877 |
+
json.dumps(
|
| 878 |
+
[
|
| 879 |
+
{"image_id": k, "caption": all_predictions[k]["caption"]}
|
| 880 |
+
for k in all_predictions
|
| 881 |
+
],
|
| 882 |
+
indent=4,
|
| 883 |
+
)
|
| 884 |
+
)
|
| 885 |
+
|
| 886 |
+
metrics = compute_cider(
|
| 887 |
+
result_path=results_path,
|
| 888 |
+
annotations_path=args.coco_annotations_json_path
|
| 889 |
+
if dataset_name == "coco"
|
| 890 |
+
else args.flickr_annotations_json_path,
|
| 891 |
+
)
|
| 892 |
+
|
| 893 |
+
# delete the temporary file
|
| 894 |
+
os.remove(results_path)
|
| 895 |
+
|
| 896 |
+
return metrics["CIDEr"] * 100.0
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
def evaluate_vqa(
|
| 900 |
+
args: argparse.Namespace,
|
| 901 |
+
eval_model: BaseEvalModel,
|
| 902 |
+
seed: int = 42,
|
| 903 |
+
min_generation_length: int = 0,
|
| 904 |
+
max_generation_length: int = 5,
|
| 905 |
+
num_beams: int = 3,
|
| 906 |
+
length_penalty: float = 0.0,
|
| 907 |
+
num_shots: int = 8,
|
| 908 |
+
dataset_name: str = "vqav2",
|
| 909 |
+
cached_features=None,
|
| 910 |
+
):
|
| 911 |
+
"""
|
| 912 |
+
Evaluate a model on VQA datasets. Currently supports VQA v2.0, OK-VQA, VizWiz and TextVQA.
|
| 913 |
+
|
| 914 |
+
Args:
|
| 915 |
+
args (argparse.Namespace): arguments
|
| 916 |
+
eval_model (BaseEvalModel): model to evaluate
|
| 917 |
+
seed (int, optional): random seed. Defaults to 42.
|
| 918 |
+
max_generation_length (int, optional): max generation length. Defaults to 5.
|
| 919 |
+
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
|
| 920 |
+
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
|
| 921 |
+
num_shots (int, optional): number of shots to use. Defaults to 8.
|
| 922 |
+
dataset_name (string): type of vqa dataset: currently supports vqav2, ok_vqa. Defaults to vqav2.
|
| 923 |
+
cached_features (tensor, optional): cached demonstration features for RICES. Defaults to None.
|
| 924 |
+
Returns:
|
| 925 |
+
float: accuracy score
|
| 926 |
+
"""
|
| 927 |
+
|
| 928 |
+
if dataset_name == "ok_vqa":
|
| 929 |
+
train_image_dir_path = args.ok_vqa_train_image_dir_path
|
| 930 |
+
train_questions_json_path = args.ok_vqa_train_questions_json_path
|
| 931 |
+
train_annotations_json_path = args.ok_vqa_train_annotations_json_path
|
| 932 |
+
test_image_dir_path = args.ok_vqa_test_image_dir_path
|
| 933 |
+
test_questions_json_path = args.ok_vqa_test_questions_json_path
|
| 934 |
+
test_annotations_json_path = args.ok_vqa_test_annotations_json_path
|
| 935 |
+
elif dataset_name == "vqav2":
|
| 936 |
+
train_image_dir_path = args.vqav2_train_image_dir_path
|
| 937 |
+
train_questions_json_path = args.vqav2_train_questions_json_path
|
| 938 |
+
train_annotations_json_path = args.vqav2_train_annotations_json_path
|
| 939 |
+
test_image_dir_path = args.vqav2_test_image_dir_path
|
| 940 |
+
test_questions_json_path = args.vqav2_test_questions_json_path
|
| 941 |
+
test_annotations_json_path = args.vqav2_test_annotations_json_path
|
| 942 |
+
elif dataset_name == "vizwiz":
|
| 943 |
+
train_image_dir_path = args.vizwiz_train_image_dir_path
|
| 944 |
+
train_questions_json_path = args.vizwiz_train_questions_json_path
|
| 945 |
+
train_annotations_json_path = args.vizwiz_train_annotations_json_path
|
| 946 |
+
test_image_dir_path = args.vizwiz_test_image_dir_path
|
| 947 |
+
test_questions_json_path = args.vizwiz_test_questions_json_path
|
| 948 |
+
test_annotations_json_path = args.vizwiz_test_annotations_json_path
|
| 949 |
+
elif dataset_name == "textvqa":
|
| 950 |
+
train_image_dir_path = args.textvqa_image_dir_path
|
| 951 |
+
train_questions_json_path = args.textvqa_train_questions_json_path
|
| 952 |
+
train_annotations_json_path = args.textvqa_train_annotations_json_path
|
| 953 |
+
test_image_dir_path = args.textvqa_image_dir_path
|
| 954 |
+
test_questions_json_path = args.textvqa_test_questions_json_path
|
| 955 |
+
test_annotations_json_path = args.textvqa_test_annotations_json_path
|
| 956 |
+
else:
|
| 957 |
+
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
| 958 |
+
|
| 959 |
+
train_dataset = VQADataset(
|
| 960 |
+
image_dir_path=train_image_dir_path,
|
| 961 |
+
question_path=train_questions_json_path,
|
| 962 |
+
annotations_path=train_annotations_json_path,
|
| 963 |
+
is_train=True,
|
| 964 |
+
dataset_name=dataset_name,
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
test_dataset = VQADataset(
|
| 968 |
+
image_dir_path=test_image_dir_path,
|
| 969 |
+
question_path=test_questions_json_path,
|
| 970 |
+
annotations_path=test_annotations_json_path,
|
| 971 |
+
is_train=False,
|
| 972 |
+
dataset_name=dataset_name,
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model)
|
| 976 |
+
|
| 977 |
+
np.random.seed(seed)
|
| 978 |
+
test_dataloader = utils.prepare_eval_samples(
|
| 979 |
+
test_dataset,
|
| 980 |
+
args.num_samples if args.num_samples > 0 else len(test_dataset),
|
| 981 |
+
args.batch_size,
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
if args.rices:
|
| 985 |
+
rices_dataset = RICES(
|
| 986 |
+
train_dataset,
|
| 987 |
+
eval_model.device,
|
| 988 |
+
args.batch_size,
|
| 989 |
+
cached_features=cached_features,
|
| 990 |
+
vision_encoder_path=args.rices_vision_encoder_path,
|
| 991 |
+
vision_encoder_pretrained=args.rices_vision_encoder_pretrained,
|
| 992 |
+
)
|
| 993 |
+
else:
|
| 994 |
+
query_set = utils.get_query_set(train_dataset, args.query_set_size)
|
| 995 |
+
|
| 996 |
+
utils.random_seed(seed, args.rank)
|
| 997 |
+
predictions = []
|
| 998 |
+
for batch in tqdm(
|
| 999 |
+
test_dataloader,
|
| 1000 |
+
desc=f"Running inference {dataset_name}",
|
| 1001 |
+
disable=args.rank != 0,
|
| 1002 |
+
):
|
| 1003 |
+
if args.rices:
|
| 1004 |
+
batch_demo_samples = rices_dataset.find(batch["image"], effective_num_shots)
|
| 1005 |
+
else:
|
| 1006 |
+
batch_demo_samples = utils.sample_batch_demos_from_query_set(
|
| 1007 |
+
query_set, effective_num_shots, len(batch["image"])
|
| 1008 |
+
)
|
| 1009 |
+
|
| 1010 |
+
batch_images, batch_text = [], []
|
| 1011 |
+
for i in range(len(batch["image"])):
|
| 1012 |
+
if num_shots > 0:
|
| 1013 |
+
context_images = [x["image"] for x in batch_demo_samples[i]]
|
| 1014 |
+
else:
|
| 1015 |
+
context_images = []
|
| 1016 |
+
batch_images.append(context_images + [batch["image"][i]])
|
| 1017 |
+
|
| 1018 |
+
context_text = "".join(
|
| 1019 |
+
[
|
| 1020 |
+
eval_model.get_vqa_prompt(
|
| 1021 |
+
question=x["question"], answer=x["answers"][0]
|
| 1022 |
+
)
|
| 1023 |
+
+ "\n"
|
| 1024 |
+
for x in batch_demo_samples[i]
|
| 1025 |
+
]
|
| 1026 |
+
)
|
| 1027 |
+
|
| 1028 |
+
# Keep the text but remove the image tags for the zero-shot case
|
| 1029 |
+
if num_shots == 0:
|
| 1030 |
+
context_text = context_text.replace("<image>", "")
|
| 1031 |
+
|
| 1032 |
+
batch_text.append(
|
| 1033 |
+
context_text + eval_model.get_vqa_prompt(question=batch["question"][i])
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
outputs = eval_model.get_outputs(
|
| 1037 |
+
batch_images=batch_images,
|
| 1038 |
+
batch_text=batch_text,
|
| 1039 |
+
min_generation_length=min_generation_length,
|
| 1040 |
+
max_generation_length=max_generation_length,
|
| 1041 |
+
num_beams=num_beams,
|
| 1042 |
+
length_penalty=length_penalty,
|
| 1043 |
+
)
|
| 1044 |
+
|
| 1045 |
+
process_function = (
|
| 1046 |
+
postprocess_ok_vqa_generation
|
| 1047 |
+
if dataset_name == "ok_vqa"
|
| 1048 |
+
else postprocess_vqa_generation
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
new_predictions = map(process_function, outputs)
|
| 1052 |
+
|
| 1053 |
+
for new_prediction, sample_id in zip(new_predictions, batch["question_id"]):
|
| 1054 |
+
predictions.append({"answer": new_prediction, "question_id": sample_id})
|
| 1055 |
+
|
| 1056 |
+
# all gather
|
| 1057 |
+
all_predictions = [None for _ in range(args.world_size)]
|
| 1058 |
+
torch.distributed.all_gather_object(all_predictions, predictions) # list of lists
|
| 1059 |
+
|
| 1060 |
+
if args.rank != 0:
|
| 1061 |
+
return None
|
| 1062 |
+
|
| 1063 |
+
all_predictions = [
|
| 1064 |
+
item for sublist in all_predictions for item in sublist
|
| 1065 |
+
] # flatten
|
| 1066 |
+
|
| 1067 |
+
# save the predictions to a temporary file
|
| 1068 |
+
random_uuid = str(uuid.uuid4())
|
| 1069 |
+
with open(f"{dataset_name}results_{random_uuid}.json", "w") as f:
|
| 1070 |
+
f.write(json.dumps(all_predictions, indent=4))
|
| 1071 |
+
|
| 1072 |
+
if test_annotations_json_path is not None:
|
| 1073 |
+
acc = compute_vqa_accuracy(
|
| 1074 |
+
f"{dataset_name}results_{random_uuid}.json",
|
| 1075 |
+
test_questions_json_path,
|
| 1076 |
+
test_annotations_json_path,
|
| 1077 |
+
)
|
| 1078 |
+
# delete the temporary file
|
| 1079 |
+
os.remove(f"{dataset_name}results_{random_uuid}.json")
|
| 1080 |
+
|
| 1081 |
+
else:
|
| 1082 |
+
print("No annotations provided, skipping accuracy computation.")
|
| 1083 |
+
acc = None
|
| 1084 |
+
if dataset_name == "vqav2":
|
| 1085 |
+
from open_flamingo.scripts.fill_vqa_testdev_results import (
|
| 1086 |
+
fill_vqav2_test_json,
|
| 1087 |
+
)
|
| 1088 |
+
|
| 1089 |
+
fill_fn = fill_vqav2_test_json
|
| 1090 |
+
elif dataset_name == "vizwiz":
|
| 1091 |
+
from open_flamingo.scripts.fill_vqa_testdev_results import (
|
| 1092 |
+
fill_vizwiz_test_json,
|
| 1093 |
+
)
|
| 1094 |
+
|
| 1095 |
+
fill_fn = fill_vizwiz_test_json
|
| 1096 |
+
else:
|
| 1097 |
+
print(
|
| 1098 |
+
"Temporary file saved to ", f"{dataset_name}results_{random_uuid}.json"
|
| 1099 |
+
)
|
| 1100 |
+
return
|
| 1101 |
+
|
| 1102 |
+
fill_fn(
|
| 1103 |
+
f"{dataset_name}results_{random_uuid}.json",
|
| 1104 |
+
f"{dataset_name}-testdev_{eval_model.lm_name}_{num_shots}_{'rices' if args.rices else 'random'}_{seed}.json",
|
| 1105 |
+
args.vqav2_final_test_questions_json_path
|
| 1106 |
+
if dataset_name == "vqav2"
|
| 1107 |
+
else args.vizwiz_test_questions_json_path,
|
| 1108 |
+
)
|
| 1109 |
+
print(
|
| 1110 |
+
"Test-dev results saved to ",
|
| 1111 |
+
f"{dataset_name}-testdev_{eval_model.lm_name}_{num_shots}_{'rices' if args.rices else 'random'}_{seed}.json",
|
| 1112 |
+
)
|
| 1113 |
+
os.remove(f"{dataset_name}results_{random_uuid}.json")
|
| 1114 |
+
|
| 1115 |
+
return acc
|
| 1116 |
+
|
| 1117 |
+
|
| 1118 |
+
def evaluate_classification(
|
| 1119 |
+
args: argparse.Namespace,
|
| 1120 |
+
eval_model,
|
| 1121 |
+
seed: int = 42,
|
| 1122 |
+
num_shots: int = 8,
|
| 1123 |
+
dataset_name: str = "imagenet",
|
| 1124 |
+
cached_features=None,
|
| 1125 |
+
no_kv_caching=False,
|
| 1126 |
+
use_prompt_ensembling: bool = False,
|
| 1127 |
+
):
|
| 1128 |
+
"""
|
| 1129 |
+
Evaluate a model on classification dataset.
|
| 1130 |
+
|
| 1131 |
+
Args:
|
| 1132 |
+
eval_model (BaseEvalModel): model to evaluate
|
| 1133 |
+
seed (int, optional): random seed. Defaults to 42.
|
| 1134 |
+
num_shots (int, optional): number of shots to use. Defaults to 8.
|
| 1135 |
+
no_kv_caching (bool): whether to disable key-value caching
|
| 1136 |
+
dataset_name (str, optional): dataset name. Defaults to "imagenet".
|
| 1137 |
+
cached_features (tensor, optional): cached demonstration features for RICES. Defaults to None.
|
| 1138 |
+
|
| 1139 |
+
Returns:
|
| 1140 |
+
float: accuracy score
|
| 1141 |
+
"""
|
| 1142 |
+
if args.model != "open_flamingo":
|
| 1143 |
+
raise NotImplementedError(
|
| 1144 |
+
"evaluate_classification is currently only supported for OpenFlamingo"
|
| 1145 |
+
)
|
| 1146 |
+
|
| 1147 |
+
if dataset_name == "imagenet":
|
| 1148 |
+
train_dataset = ImageNetDataset(os.path.join(args.imagenet_root, "train"))
|
| 1149 |
+
test_dataset = ImageNetDataset(os.path.join(args.imagenet_root, "val"))
|
| 1150 |
+
prompt_fn = lambda x: eval_model.get_imagenet_prompt(label=x["class_name"])
|
| 1151 |
+
all_class_names = IMAGENET_CLASSNAMES
|
| 1152 |
+
k = 5
|
| 1153 |
+
elif dataset_name == "hateful_memes":
|
| 1154 |
+
train_dataset = HatefulMemesDataset(
|
| 1155 |
+
args.hateful_memes_image_dir_path,
|
| 1156 |
+
args.hateful_memes_train_annotations_json_path,
|
| 1157 |
+
)
|
| 1158 |
+
test_dataset = HatefulMemesDataset(
|
| 1159 |
+
args.hateful_memes_image_dir_path,
|
| 1160 |
+
args.hateful_memes_test_annotations_json_path,
|
| 1161 |
+
)
|
| 1162 |
+
prompt_fn = lambda x: eval_model.get_hateful_memes_prompt(
|
| 1163 |
+
text=x["ocr"], label=x["class_name"]
|
| 1164 |
+
)
|
| 1165 |
+
all_class_names = HM_CLASSNAMES
|
| 1166 |
+
k = 1
|
| 1167 |
+
else:
|
| 1168 |
+
raise ValueError(f"Unsupported dataset {dataset_name}")
|
| 1169 |
+
|
| 1170 |
+
class_id_to_name = dict(zip(range(len(all_class_names)), all_class_names))
|
| 1171 |
+
|
| 1172 |
+
effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model)
|
| 1173 |
+
|
| 1174 |
+
np.random.seed(seed)
|
| 1175 |
+
test_dataloader = utils.prepare_eval_samples(
|
| 1176 |
+
test_dataset,
|
| 1177 |
+
args.num_samples if args.num_samples > 0 else len(test_dataset),
|
| 1178 |
+
args.batch_size,
|
| 1179 |
+
)
|
| 1180 |
+
|
| 1181 |
+
if args.rices:
|
| 1182 |
+
rices_dataset = RICES(
|
| 1183 |
+
train_dataset,
|
| 1184 |
+
eval_model.device,
|
| 1185 |
+
args.batch_size,
|
| 1186 |
+
cached_features=cached_features,
|
| 1187 |
+
vision_encoder_path=args.rices_vision_encoder_path,
|
| 1188 |
+
vision_encoder_pretrained=args.rices_vision_encoder_pretrained,
|
| 1189 |
+
)
|
| 1190 |
+
else:
|
| 1191 |
+
# subset of the training set to sample context images from
|
| 1192 |
+
query_set = utils.get_query_set(train_dataset, args.query_set_size)
|
| 1193 |
+
|
| 1194 |
+
utils.random_seed(seed, args.rank)
|
| 1195 |
+
predictions = []
|
| 1196 |
+
for batch_idx, batch in tqdm(
|
| 1197 |
+
enumerate(test_dataloader),
|
| 1198 |
+
desc=f"Running inference {dataset_name}",
|
| 1199 |
+
disable=args.rank != 0,
|
| 1200 |
+
):
|
| 1201 |
+
if args.rices:
|
| 1202 |
+
batch_demo_samples = rices_dataset.find(batch["image"], effective_num_shots)
|
| 1203 |
+
else:
|
| 1204 |
+
batch_demo_samples = utils.sample_batch_demos_from_query_set(
|
| 1205 |
+
query_set, effective_num_shots, len(batch["image"])
|
| 1206 |
+
)
|
| 1207 |
+
|
| 1208 |
+
# set up prompt ensembling
|
| 1209 |
+
num_permutations = (
|
| 1210 |
+
min(6, math.factorial(effective_num_shots)) if use_prompt_ensembling else 1
|
| 1211 |
+
)
|
| 1212 |
+
logprobs = []
|
| 1213 |
+
for _ in range(num_permutations):
|
| 1214 |
+
batch_images, batch_text = [], []
|
| 1215 |
+
for i in range(len(batch["image"])):
|
| 1216 |
+
if use_prompt_ensembling:
|
| 1217 |
+
random.shuffle(batch_demo_samples[i])
|
| 1218 |
+
|
| 1219 |
+
if effective_num_shots > 0:
|
| 1220 |
+
context_images = [x["image"] for x in batch_demo_samples[i]]
|
| 1221 |
+
else:
|
| 1222 |
+
context_images = []
|
| 1223 |
+
batch_images.append(context_images + [batch["image"][i]])
|
| 1224 |
+
|
| 1225 |
+
context_text = "".join([prompt_fn(x) for x in batch_demo_samples[i]])
|
| 1226 |
+
|
| 1227 |
+
# Keep the text but remove the image tags for the zero-shot case
|
| 1228 |
+
if num_shots == 0:
|
| 1229 |
+
context_text = context_text.replace("<image>", "")
|
| 1230 |
+
|
| 1231 |
+
batch_text.append(
|
| 1232 |
+
context_text
|
| 1233 |
+
+ prompt_fn({"ocr": batch["ocr"][i], "class_name": None})
|
| 1234 |
+
)
|
| 1235 |
+
|
| 1236 |
+
# get predicted class names
|
| 1237 |
+
logprobs.append(
|
| 1238 |
+
eval_model.get_rank_classifications(
|
| 1239 |
+
batch_text,
|
| 1240 |
+
batch_images,
|
| 1241 |
+
all_class_names,
|
| 1242 |
+
use_cache=(not no_kv_caching),
|
| 1243 |
+
normalize_length=True,
|
| 1244 |
+
)
|
| 1245 |
+
)
|
| 1246 |
+
|
| 1247 |
+
# ensemble logprobs together
|
| 1248 |
+
logprobs = torch.mean(torch.stack(logprobs, dim=-1), dim=-1)
|
| 1249 |
+
|
| 1250 |
+
predicted_classnames, predicted_logprobs = utils.get_predicted_classnames(
|
| 1251 |
+
logprobs,
|
| 1252 |
+
k,
|
| 1253 |
+
class_id_to_name,
|
| 1254 |
+
)
|
| 1255 |
+
|
| 1256 |
+
# compute accuracy
|
| 1257 |
+
for i, topk in enumerate(predicted_classnames):
|
| 1258 |
+
y_i = batch["class_name"][i]
|
| 1259 |
+
score = torch.exp(
|
| 1260 |
+
predicted_logprobs[i][0] - torch.logsumexp(logprobs[i], dim=0)
|
| 1261 |
+
).item()
|
| 1262 |
+
predictions.append(
|
| 1263 |
+
{
|
| 1264 |
+
"id": batch["id"][i],
|
| 1265 |
+
"gt_label": y_i,
|
| 1266 |
+
"pred_label": topk[0],
|
| 1267 |
+
"pred_score": score,
|
| 1268 |
+
}
|
| 1269 |
+
)
|
| 1270 |
+
|
| 1271 |
+
# all gather
|
| 1272 |
+
all_predictions = [None for _ in range(args.world_size)]
|
| 1273 |
+
torch.distributed.all_gather_object(all_predictions, predictions) # list of lists
|
| 1274 |
+
if args.rank != 0:
|
| 1275 |
+
return
|
| 1276 |
+
|
| 1277 |
+
all_predictions = [
|
| 1278 |
+
item for sublist in all_predictions for item in sublist
|
| 1279 |
+
] # flatten
|
| 1280 |
+
|
| 1281 |
+
if dataset_name == "hateful_memes":
|
| 1282 |
+
# return ROC-AUC score
|
| 1283 |
+
greater_label = max(all_class_names)
|
| 1284 |
+
gts = [pred["gt_label"] for pred in all_predictions]
|
| 1285 |
+
pred_scores = [
|
| 1286 |
+
pred["pred_score"]
|
| 1287 |
+
if pred["pred_label"] == greater_label
|
| 1288 |
+
else 1 - pred["pred_score"]
|
| 1289 |
+
for pred in all_predictions
|
| 1290 |
+
]
|
| 1291 |
+
return roc_auc_score(gts, pred_scores)
|
| 1292 |
+
else:
|
| 1293 |
+
# return top-1 accuracy
|
| 1294 |
+
acc1 = sum(
|
| 1295 |
+
int(pred["gt_label"] == pred["pred_label"]) for pred in all_predictions
|
| 1296 |
+
)
|
| 1297 |
+
return float(acc1) / len(all_predictions)
|
| 1298 |
+
|
| 1299 |
+
|
| 1300 |
+
if __name__ == "__main__":
|
| 1301 |
+
main()
|
open_flamingo/eval/models/blip.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
| 7 |
+
from open_flamingo.eval.eval_model import BaseEvalModel
|
| 8 |
+
from open_flamingo.eval.utils import unwrap_model
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class EvalModel(BaseEvalModel):
|
| 12 |
+
"""BLIP-2 model evaluation.
|
| 13 |
+
|
| 14 |
+
Attributes:
|
| 15 |
+
model (nn.Module): Underlying Torch model.
|
| 16 |
+
tokenizer (transformers.PreTrainedTokenizer): Tokenizer for model.
|
| 17 |
+
device: Index of GPU to use, or the string "cpu"
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, model_args):
|
| 21 |
+
assert (
|
| 22 |
+
"processor_path" in model_args and "lm_path" in model_args
|
| 23 |
+
), "BLIP-2 requires processor_path, lm_path, and device arguments to be specified"
|
| 24 |
+
|
| 25 |
+
self.processor = Blip2Processor.from_pretrained(model_args["processor_path"])
|
| 26 |
+
self.model = Blip2ForConditionalGeneration.from_pretrained(
|
| 27 |
+
model_args["lm_path"]
|
| 28 |
+
)
|
| 29 |
+
self.model.eval()
|
| 30 |
+
self.processor.tokenizer.padding_side = "left"
|
| 31 |
+
self.lm_name = model_args["lm_path"].split("/")[-1]
|
| 32 |
+
|
| 33 |
+
def _prepare_images(self, batch: List[List[torch.Tensor]]) -> torch.Tensor:
|
| 34 |
+
"""Preprocess images and stack them.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
batch: A list of lists of images.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
A Tensor of shape
|
| 41 |
+
(batch_size, channels, height, width).
|
| 42 |
+
"""
|
| 43 |
+
batch_images = None
|
| 44 |
+
assert all(
|
| 45 |
+
len(example) == 1 for example in batch
|
| 46 |
+
), "BLIP-2 only supports one image per example"
|
| 47 |
+
|
| 48 |
+
for example in batch:
|
| 49 |
+
assert len(example) == 1, "BLIP-2 only supports one image per example"
|
| 50 |
+
batch_images = torch.cat(
|
| 51 |
+
[
|
| 52 |
+
batch_images,
|
| 53 |
+
self.processor.image_processor(example, return_tensors="pt")[
|
| 54 |
+
"pixel_values"
|
| 55 |
+
],
|
| 56 |
+
]
|
| 57 |
+
if batch_images is not None
|
| 58 |
+
else [
|
| 59 |
+
self.processor.image_processor(example, return_tensors="pt")[
|
| 60 |
+
"pixel_values"
|
| 61 |
+
]
|
| 62 |
+
],
|
| 63 |
+
dim=0,
|
| 64 |
+
)
|
| 65 |
+
return batch_images
|
| 66 |
+
|
| 67 |
+
def get_outputs(
|
| 68 |
+
self,
|
| 69 |
+
batch_text: List[str],
|
| 70 |
+
batch_images: List[List[Image.Image]],
|
| 71 |
+
min_generation_length: int,
|
| 72 |
+
max_generation_length: int,
|
| 73 |
+
num_beams: int,
|
| 74 |
+
length_penalty: float,
|
| 75 |
+
) -> List[str]:
|
| 76 |
+
encodings = self.processor.tokenizer(
|
| 77 |
+
batch_text,
|
| 78 |
+
padding="longest",
|
| 79 |
+
truncation=True,
|
| 80 |
+
return_tensors="pt",
|
| 81 |
+
max_length=2000,
|
| 82 |
+
)
|
| 83 |
+
input_ids = encodings["input_ids"]
|
| 84 |
+
attention_mask = encodings["attention_mask"]
|
| 85 |
+
|
| 86 |
+
with torch.inference_mode():
|
| 87 |
+
outputs = unwrap_model(self.model).generate(
|
| 88 |
+
self._prepare_images(batch_images).to(self.device),
|
| 89 |
+
input_ids.to(self.device),
|
| 90 |
+
attention_mask=attention_mask.to(self.device),
|
| 91 |
+
max_new_tokens=max_generation_length,
|
| 92 |
+
min_new_tokens=min_generation_length,
|
| 93 |
+
num_beams=num_beams,
|
| 94 |
+
length_penalty=length_penalty,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return self.processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 98 |
+
|
| 99 |
+
def get_vqa_prompt(self, question, answer=None) -> str:
|
| 100 |
+
return (
|
| 101 |
+
f"Question:{question} Short answer:{answer if answer is not None else ''}"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def get_caption_prompt(self, caption=None) -> str:
|
| 105 |
+
return f"A photo of {caption if caption is not None else ''}"
|
| 106 |
+
|
| 107 |
+
def get_rank_classifications(
|
| 108 |
+
self,
|
| 109 |
+
batch_text: List[str],
|
| 110 |
+
batch_images: List[List[Image.Image]],
|
| 111 |
+
all_class_names: List[str],
|
| 112 |
+
use_cache: bool,
|
| 113 |
+
normalize_length: bool,
|
| 114 |
+
):
|
| 115 |
+
raise NotImplementedError(
|
| 116 |
+
"BLIP-2 classification-based evaluation not implemented"
|
| 117 |
+
)
|
open_flamingo/eval/models/open_flamingo.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict
|
| 2 |
+
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch
|
| 5 |
+
from einops import repeat
|
| 6 |
+
|
| 7 |
+
from open_flamingo.eval.eval_model import BaseEvalModel
|
| 8 |
+
from open_flamingo.src.factory import create_model_and_transforms
|
| 9 |
+
from open_flamingo.eval.utils import unwrap_model, get_autocast, get_cast_dtype
|
| 10 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class EvalModel(BaseEvalModel):
|
| 14 |
+
"""OpenFlamingo model evaluation.
|
| 15 |
+
|
| 16 |
+
Attributes:
|
| 17 |
+
model (nn.Module): Underlying Torch model.
|
| 18 |
+
tokenizer (transformers.PreTrainedTokenizer): Tokenizer for model.
|
| 19 |
+
device: Index of GPU to use, or the string "CPU"
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, model_args):
|
| 23 |
+
assert (
|
| 24 |
+
"vision_encoder_path" in model_args
|
| 25 |
+
and "lm_path" in model_args
|
| 26 |
+
and "checkpoint_path" in model_args
|
| 27 |
+
and "lm_tokenizer_path" in model_args
|
| 28 |
+
and "cross_attn_every_n_layers" in model_args
|
| 29 |
+
and "vision_encoder_pretrained" in model_args
|
| 30 |
+
and "precision" in model_args
|
| 31 |
+
), "OpenFlamingo requires vision_encoder_path, lm_path, device, checkpoint_path, lm_tokenizer_path, cross_attn_every_n_layers, vision_encoder_pretrained, and precision arguments to be specified"
|
| 32 |
+
|
| 33 |
+
self.device = (
|
| 34 |
+
model_args["device"]
|
| 35 |
+
if ("device" in model_args and model_args["device"] >= 0)
|
| 36 |
+
else "cpu"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
(
|
| 40 |
+
self.model,
|
| 41 |
+
self.image_processor,
|
| 42 |
+
self.tokenizer,
|
| 43 |
+
) = create_model_and_transforms(
|
| 44 |
+
model_args["vision_encoder_path"],
|
| 45 |
+
model_args["vision_encoder_pretrained"],
|
| 46 |
+
model_args["lm_path"],
|
| 47 |
+
model_args["lm_tokenizer_path"],
|
| 48 |
+
cross_attn_every_n_layers=int(model_args["cross_attn_every_n_layers"]),
|
| 49 |
+
)
|
| 50 |
+
checkpoint = torch.load(model_args["checkpoint_path"], map_location=self.device)
|
| 51 |
+
if "model_state_dict" in checkpoint:
|
| 52 |
+
checkpoint = checkpoint["model_state_dict"]
|
| 53 |
+
checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()}
|
| 54 |
+
self.model.load_state_dict(checkpoint, strict=False)
|
| 55 |
+
self.model.to(self.device)
|
| 56 |
+
self.model.eval()
|
| 57 |
+
self.tokenizer.padding_side = "left"
|
| 58 |
+
|
| 59 |
+
self.lm_name = model_args["lm_path"].split("/")[-1]
|
| 60 |
+
|
| 61 |
+
# autocast
|
| 62 |
+
self.autocast = get_autocast(model_args["precision"])
|
| 63 |
+
self.cast_dtype = get_cast_dtype(model_args["precision"])
|
| 64 |
+
|
| 65 |
+
def _prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor:
|
| 66 |
+
"""
|
| 67 |
+
Convert images to tensors, reshape them, and stack them.
|
| 68 |
+
Args:
|
| 69 |
+
batch: A list of lists of images.
|
| 70 |
+
Returns:
|
| 71 |
+
preprocessed images (tensors) or None
|
| 72 |
+
shape (B, T_img, F, C, H, W)
|
| 73 |
+
None if no images in batch
|
| 74 |
+
"""
|
| 75 |
+
images_per_example = max(len(x) for x in batch)
|
| 76 |
+
batch_images = None
|
| 77 |
+
for iexample, example in enumerate(batch):
|
| 78 |
+
for iimage, image in enumerate(example):
|
| 79 |
+
preprocessed = self.image_processor(image)
|
| 80 |
+
if batch_images is None:
|
| 81 |
+
batch_images = torch.zeros(
|
| 82 |
+
(len(batch), images_per_example, 1) + preprocessed.shape,
|
| 83 |
+
dtype=preprocessed.dtype,
|
| 84 |
+
)
|
| 85 |
+
batch_images[iexample, iimage, 0] = preprocessed
|
| 86 |
+
if batch_images is not None:
|
| 87 |
+
batch_images = batch_images.to(
|
| 88 |
+
self.device, dtype=self.cast_dtype, non_blocking=True
|
| 89 |
+
)
|
| 90 |
+
return batch_images
|
| 91 |
+
|
| 92 |
+
def _prepare_text(
|
| 93 |
+
self,
|
| 94 |
+
batch: List[List[str]],
|
| 95 |
+
padding="longest",
|
| 96 |
+
truncation=True,
|
| 97 |
+
max_length=2000,
|
| 98 |
+
):
|
| 99 |
+
"""
|
| 100 |
+
Tokenize the text and stack them.
|
| 101 |
+
Args:
|
| 102 |
+
batch: A list of lists of strings.
|
| 103 |
+
Returns:
|
| 104 |
+
input_ids (tensor)
|
| 105 |
+
shape (B, T_txt)
|
| 106 |
+
attention_mask (tensor)
|
| 107 |
+
shape (B, T_txt)
|
| 108 |
+
"""
|
| 109 |
+
encodings = self.tokenizer(
|
| 110 |
+
batch,
|
| 111 |
+
padding=padding,
|
| 112 |
+
truncation=truncation,
|
| 113 |
+
return_tensors="pt",
|
| 114 |
+
max_length=max_length,
|
| 115 |
+
)
|
| 116 |
+
input_ids, attention_mask = encodings["input_ids"], encodings["attention_mask"]
|
| 117 |
+
input_ids = input_ids.to(self.device, dtype=self.cast_dtype, non_blocking=True)
|
| 118 |
+
attention_mask = attention_mask.to(
|
| 119 |
+
self.device, dtype=self.cast_dtype, non_blocking=True
|
| 120 |
+
)
|
| 121 |
+
return input_ids, attention_mask.bool()
|
| 122 |
+
|
| 123 |
+
def get_outputs(
|
| 124 |
+
self,
|
| 125 |
+
batch_text: List[str],
|
| 126 |
+
batch_images: List[List[Image.Image]],
|
| 127 |
+
min_generation_length: int,
|
| 128 |
+
max_generation_length: int,
|
| 129 |
+
num_beams: int,
|
| 130 |
+
length_penalty: float,
|
| 131 |
+
) -> List[str]:
|
| 132 |
+
"""
|
| 133 |
+
Get generation outputs.
|
| 134 |
+
"""
|
| 135 |
+
batch_images = self._prepare_images(batch_images)
|
| 136 |
+
input_ids, attention_mask = self._prepare_text(batch_text)
|
| 137 |
+
|
| 138 |
+
with torch.inference_mode():
|
| 139 |
+
with self.autocast():
|
| 140 |
+
outputs = unwrap_model(self.model).generate(
|
| 141 |
+
batch_images,
|
| 142 |
+
input_ids,
|
| 143 |
+
attention_mask,
|
| 144 |
+
min_new_tokens=min_generation_length,
|
| 145 |
+
max_new_tokens=max_generation_length,
|
| 146 |
+
num_beams=num_beams,
|
| 147 |
+
length_penalty=length_penalty,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Extract only the new gnerated tokens
|
| 151 |
+
outputs = outputs[:, len(input_ids[0]) :]
|
| 152 |
+
|
| 153 |
+
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 154 |
+
|
| 155 |
+
def get_rank_classifications(
|
| 156 |
+
self,
|
| 157 |
+
batch_text: List[str],
|
| 158 |
+
batch_images: List[List[Image.Image]],
|
| 159 |
+
all_class_names: List[str],
|
| 160 |
+
use_cache: bool,
|
| 161 |
+
normalize_length: bool,
|
| 162 |
+
):
|
| 163 |
+
"""
|
| 164 |
+
Returns a (B, |all_class_names|) tensor containing the logprobs for each class name.
|
| 165 |
+
"""
|
| 166 |
+
batch_images = self._prepare_images(batch_images)
|
| 167 |
+
ctx_input_ids, ctx_attention_mask = self._prepare_text(batch_text)
|
| 168 |
+
|
| 169 |
+
# Cache the context
|
| 170 |
+
if use_cache:
|
| 171 |
+
# reserve the last token in the context for the main forward pass
|
| 172 |
+
self.cache_media(
|
| 173 |
+
input_ids=ctx_input_ids,
|
| 174 |
+
vision_x=batch_images,
|
| 175 |
+
)
|
| 176 |
+
precomputed = self.__call__(
|
| 177 |
+
vision_x=None,
|
| 178 |
+
lang_x=ctx_input_ids,
|
| 179 |
+
attention_mask=ctx_attention_mask,
|
| 180 |
+
clear_conditioned_layers=False,
|
| 181 |
+
use_cache=True,
|
| 182 |
+
)
|
| 183 |
+
precomputed_logits = precomputed.logits
|
| 184 |
+
precomputed_pkvs = precomputed.past_key_values
|
| 185 |
+
else:
|
| 186 |
+
precomputed_pkvs = None
|
| 187 |
+
|
| 188 |
+
# Loop through class names and get log-likelihoods
|
| 189 |
+
# Note: if all classnames are one token, this code is redundant, since we could
|
| 190 |
+
# get all logits after one pass. However, if there are multi-token classnames,
|
| 191 |
+
# we need to loop through each classname separately.
|
| 192 |
+
overall_probs = []
|
| 193 |
+
for class_name in all_class_names:
|
| 194 |
+
# Tokenize only the class name
|
| 195 |
+
classname_tokens = self.tokenizer(
|
| 196 |
+
class_name, add_special_tokens=False, return_tensors="pt"
|
| 197 |
+
)["input_ids"].to(self.device)
|
| 198 |
+
assert classname_tokens.ndim == 2
|
| 199 |
+
classname_tokens = repeat(
|
| 200 |
+
classname_tokens, "b s -> (repeat b) s", repeat=len(batch_text)
|
| 201 |
+
)
|
| 202 |
+
num_tokens_in_classname = classname_tokens.shape[1]
|
| 203 |
+
|
| 204 |
+
# Concatenate the class name tokens
|
| 205 |
+
if not use_cache:
|
| 206 |
+
_lang_x = torch.cat([ctx_input_ids, classname_tokens], dim=1)
|
| 207 |
+
_attention_mask = torch.cat(
|
| 208 |
+
[
|
| 209 |
+
ctx_attention_mask,
|
| 210 |
+
torch.ones_like(classname_tokens).bool(),
|
| 211 |
+
],
|
| 212 |
+
dim=1,
|
| 213 |
+
)
|
| 214 |
+
_vision_x = batch_images
|
| 215 |
+
else:
|
| 216 |
+
_lang_x = classname_tokens
|
| 217 |
+
_attention_mask = None
|
| 218 |
+
_vision_x = None
|
| 219 |
+
|
| 220 |
+
# Call forward to get the logits
|
| 221 |
+
outputs = self.__call__(
|
| 222 |
+
vision_x=_vision_x,
|
| 223 |
+
lang_x=_lang_x,
|
| 224 |
+
attention_mask=_attention_mask,
|
| 225 |
+
clear_conditioned_layers=(not use_cache),
|
| 226 |
+
past_key_values=precomputed_pkvs,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Get the logits of the classname
|
| 230 |
+
# logits shape is either (B, num_tokens_in_classname, vocab_len) with use_cache
|
| 231 |
+
# or (B, len(_lang_x), vocab_len) without use_cache
|
| 232 |
+
# remember that the logits at index t on dim 1 correspond to predictions for the t+1st token
|
| 233 |
+
logits = outputs.logits
|
| 234 |
+
if use_cache:
|
| 235 |
+
logits = torch.cat([precomputed_logits, logits], dim=1)
|
| 236 |
+
|
| 237 |
+
logprobs = torch.log_softmax(logits, dim=-1)
|
| 238 |
+
gen_probs = logprobs[
|
| 239 |
+
:, -num_tokens_in_classname - 1 : -1, :
|
| 240 |
+
] # (B, num_tokens_in_classname, vocab_len)
|
| 241 |
+
gen_probs = torch.gather(
|
| 242 |
+
gen_probs, 2, classname_tokens[:, :, None]
|
| 243 |
+
).squeeze(-1)
|
| 244 |
+
|
| 245 |
+
# Aggregate over tokens in the classname
|
| 246 |
+
if normalize_length:
|
| 247 |
+
class_prob = torch.mean(gen_probs, dim=1)
|
| 248 |
+
else:
|
| 249 |
+
class_prob = torch.sum(gen_probs, dim=1)
|
| 250 |
+
overall_probs.append(class_prob) # (B, 1)
|
| 251 |
+
|
| 252 |
+
self.uncache_media()
|
| 253 |
+
overall_probs = torch.vstack(overall_probs).T.cpu() # shape (B, num_classes)
|
| 254 |
+
return overall_probs
|
| 255 |
+
|
| 256 |
+
def __call__(
|
| 257 |
+
self,
|
| 258 |
+
lang_x: torch.Tensor,
|
| 259 |
+
vision_x: torch.Tensor,
|
| 260 |
+
attention_mask: torch.Tensor,
|
| 261 |
+
past_key_values: torch.Tensor = None,
|
| 262 |
+
clear_conditioned_layers: bool = False,
|
| 263 |
+
use_cache: bool = False,
|
| 264 |
+
):
|
| 265 |
+
"""
|
| 266 |
+
Calls the forward function of the model.
|
| 267 |
+
Special logic to handle the case if past_key_values is not None:
|
| 268 |
+
then lang_x is assumed to contain the tokens to be generated
|
| 269 |
+
*excluding* the tokens already in past_key_values.
|
| 270 |
+
We then repeatedly call forward, updating the past_key_values.
|
| 271 |
+
"""
|
| 272 |
+
# standard forward pass
|
| 273 |
+
if past_key_values is None:
|
| 274 |
+
with torch.inference_mode():
|
| 275 |
+
with self.autocast():
|
| 276 |
+
outputs = self.model(
|
| 277 |
+
vision_x=vision_x,
|
| 278 |
+
lang_x=lang_x,
|
| 279 |
+
attention_mask=attention_mask,
|
| 280 |
+
clear_conditioned_layers=clear_conditioned_layers,
|
| 281 |
+
past_key_values=past_key_values,
|
| 282 |
+
use_cache=use_cache,
|
| 283 |
+
)
|
| 284 |
+
return outputs
|
| 285 |
+
|
| 286 |
+
# loop to handle updating past_key_values
|
| 287 |
+
logits = []
|
| 288 |
+
for token_idx in range(lang_x.shape[1]):
|
| 289 |
+
_lang_x = lang_x[:, token_idx].reshape((-1, 1))
|
| 290 |
+
if attention_mask is not None:
|
| 291 |
+
_attention_mask = attention_mask[:, token_idx].reshape((-1, 1))
|
| 292 |
+
else:
|
| 293 |
+
_attention_mask = None
|
| 294 |
+
|
| 295 |
+
with torch.inference_mode():
|
| 296 |
+
with self.autocast():
|
| 297 |
+
outputs = self.model(
|
| 298 |
+
vision_x=vision_x,
|
| 299 |
+
lang_x=_lang_x,
|
| 300 |
+
attention_mask=_attention_mask,
|
| 301 |
+
clear_conditioned_layers=False,
|
| 302 |
+
past_key_values=past_key_values,
|
| 303 |
+
use_cache=True,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
past_key_values = outputs.past_key_values
|
| 307 |
+
logits.append(outputs.logits)
|
| 308 |
+
|
| 309 |
+
logits = torch.cat(logits, dim=1)
|
| 310 |
+
return CausalLMOutputWithPast(
|
| 311 |
+
logits=logits,
|
| 312 |
+
past_key_values=past_key_values,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
def encode_vision_x(self, image_tensor: torch.Tensor):
|
| 316 |
+
unwrap_model(self.model)._encode_vision_x(image_tensor.to(self.device))
|
| 317 |
+
|
| 318 |
+
def uncache_media(self):
|
| 319 |
+
unwrap_model(self.model).uncache_media()
|
| 320 |
+
|
| 321 |
+
def cache_media(self, input_ids, vision_x):
|
| 322 |
+
unwrap_model(self.model).cache_media(input_ids=input_ids, vision_x=vision_x)
|
| 323 |
+
|
| 324 |
+
def get_vqa_prompt(self, question, answer=None) -> str:
|
| 325 |
+
return f"<image>Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"
|
| 326 |
+
|
| 327 |
+
def get_caption_prompt(self, caption=None) -> str:
|
| 328 |
+
return f"<image>Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"
|
| 329 |
+
|
| 330 |
+
def get_imagenet_prompt(self, label=None) -> str:
|
| 331 |
+
return f"<image>Output:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}"
|
| 332 |
+
|
| 333 |
+
def get_hateful_memes_prompt(self, text, label=None) -> str:
|
| 334 |
+
return f"<image>is an image with: '{text}' written on it. Is it hateful? Answer:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}"
|
open_flamingo/eval/ok_vqa_utils.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Those are manual mapping that are not caught by our stemming rules or would
|
| 2 |
+
# would be done incorrectly by our automatic stemming rule. In details,
|
| 3 |
+
# the keys of the _MANUAL_MATCHES dict contains the original word and the value
|
| 4 |
+
# contains the transformation of the word expected by the OKVQA stemming rule.
|
| 5 |
+
# These manual rules were found by checking the `raw_answers` and the `answers`
|
| 6 |
+
# fields of the released OKVQA dataset and checking all things that were not
|
| 7 |
+
# properly mapped by our automatic rules. In particular some of the mapping
|
| 8 |
+
# are sometimes constant, e.g. christmas -> christmas which was incorrectly
|
| 9 |
+
# singularized by our inflection.singularize.
|
| 10 |
+
import re
|
| 11 |
+
import nltk
|
| 12 |
+
from nltk.corpus.reader import VERB
|
| 13 |
+
import inflection
|
| 14 |
+
|
| 15 |
+
_MANUAL_MATCHES = {
|
| 16 |
+
"police": "police",
|
| 17 |
+
"las": "las",
|
| 18 |
+
"vegas": "vegas",
|
| 19 |
+
"yes": "yes",
|
| 20 |
+
"jeans": "jean",
|
| 21 |
+
"hell's": "hell",
|
| 22 |
+
"domino's": "domino",
|
| 23 |
+
"morning": "morn",
|
| 24 |
+
"clothes": "cloth",
|
| 25 |
+
"are": "are",
|
| 26 |
+
"riding": "ride",
|
| 27 |
+
"leaves": "leaf",
|
| 28 |
+
"dangerous": "danger",
|
| 29 |
+
"clothing": "cloth",
|
| 30 |
+
"texting": "text",
|
| 31 |
+
"kiting": "kite",
|
| 32 |
+
"firefighters": "firefight",
|
| 33 |
+
"ties": "tie",
|
| 34 |
+
"married": "married",
|
| 35 |
+
"teething": "teeth",
|
| 36 |
+
"gloves": "glove",
|
| 37 |
+
"tennis": "tennis",
|
| 38 |
+
"dining": "dine",
|
| 39 |
+
"directions": "direct",
|
| 40 |
+
"waves": "wave",
|
| 41 |
+
"christmas": "christmas",
|
| 42 |
+
"drives": "drive",
|
| 43 |
+
"pudding": "pud",
|
| 44 |
+
"coding": "code",
|
| 45 |
+
"plating": "plate",
|
| 46 |
+
"quantas": "quanta",
|
| 47 |
+
"hornes": "horn",
|
| 48 |
+
"graves": "grave",
|
| 49 |
+
"mating": "mate",
|
| 50 |
+
"paned": "pane",
|
| 51 |
+
"alertness": "alert",
|
| 52 |
+
"sunbathing": "sunbath",
|
| 53 |
+
"tenning": "ten",
|
| 54 |
+
"wetness": "wet",
|
| 55 |
+
"urinating": "urine",
|
| 56 |
+
"sickness": "sick",
|
| 57 |
+
"braves": "brave",
|
| 58 |
+
"firefighting": "firefight",
|
| 59 |
+
"lenses": "lens",
|
| 60 |
+
"reflections": "reflect",
|
| 61 |
+
"backpackers": "backpack",
|
| 62 |
+
"eatting": "eat",
|
| 63 |
+
"designers": "design",
|
| 64 |
+
"curiousity": "curious",
|
| 65 |
+
"playfulness": "play",
|
| 66 |
+
"blindness": "blind",
|
| 67 |
+
"hawke": "hawk",
|
| 68 |
+
"tomatoe": "tomato",
|
| 69 |
+
"rodeoing": "rodeo",
|
| 70 |
+
"brightness": "bright",
|
| 71 |
+
"circuses": "circus",
|
| 72 |
+
"skateboarders": "skateboard",
|
| 73 |
+
"staring": "stare",
|
| 74 |
+
"electronics": "electron",
|
| 75 |
+
"electicity": "elect",
|
| 76 |
+
"mountainous": "mountain",
|
| 77 |
+
"socializing": "social",
|
| 78 |
+
"hamburgers": "hamburg",
|
| 79 |
+
"caves": "cave",
|
| 80 |
+
"transitions": "transit",
|
| 81 |
+
"wading": "wade",
|
| 82 |
+
"creame": "cream",
|
| 83 |
+
"toileting": "toilet",
|
| 84 |
+
"sautee": "saute",
|
| 85 |
+
"buildings": "build",
|
| 86 |
+
"belongings": "belong",
|
| 87 |
+
"stockings": "stock",
|
| 88 |
+
"walle": "wall",
|
| 89 |
+
"cumulis": "cumuli",
|
| 90 |
+
"travelers": "travel",
|
| 91 |
+
"conducter": "conduct",
|
| 92 |
+
"browsing": "brows",
|
| 93 |
+
"pooping": "poop",
|
| 94 |
+
"haircutting": "haircut",
|
| 95 |
+
"toppings": "top",
|
| 96 |
+
"hearding": "heard",
|
| 97 |
+
"sunblocker": "sunblock",
|
| 98 |
+
"bases": "base",
|
| 99 |
+
"markings": "mark",
|
| 100 |
+
"mopeds": "mope",
|
| 101 |
+
"kindergartener": "kindergarten",
|
| 102 |
+
"pies": "pie",
|
| 103 |
+
"scrapbooking": "scrapbook",
|
| 104 |
+
"couponing": "coupon",
|
| 105 |
+
"meetings": "meet",
|
| 106 |
+
"elevators": "elev",
|
| 107 |
+
"lowes": "low",
|
| 108 |
+
"men's": "men",
|
| 109 |
+
"childrens": "children",
|
| 110 |
+
"shelves": "shelve",
|
| 111 |
+
"paintings": "paint",
|
| 112 |
+
"raines": "rain",
|
| 113 |
+
"paring": "pare",
|
| 114 |
+
"expressions": "express",
|
| 115 |
+
"routes": "rout",
|
| 116 |
+
"pease": "peas",
|
| 117 |
+
"vastness": "vast",
|
| 118 |
+
"awning": "awn",
|
| 119 |
+
"boy's": "boy",
|
| 120 |
+
"drunkenness": "drunken",
|
| 121 |
+
"teasing": "teas",
|
| 122 |
+
"conferences": "confer",
|
| 123 |
+
"ripeness": "ripe",
|
| 124 |
+
"suspenders": "suspend",
|
| 125 |
+
"earnings": "earn",
|
| 126 |
+
"reporters": "report",
|
| 127 |
+
"kid's": "kid",
|
| 128 |
+
"containers": "contain",
|
| 129 |
+
"corgie": "corgi",
|
| 130 |
+
"porche": "porch",
|
| 131 |
+
"microwaves": "microwave",
|
| 132 |
+
"batter's": "batter",
|
| 133 |
+
"sadness": "sad",
|
| 134 |
+
"apartments": "apart",
|
| 135 |
+
"oxygenize": "oxygen",
|
| 136 |
+
"striping": "stripe",
|
| 137 |
+
"purring": "pure",
|
| 138 |
+
"professionals": "profession",
|
| 139 |
+
"piping": "pipe",
|
| 140 |
+
"farmer's": "farmer",
|
| 141 |
+
"potatoe": "potato",
|
| 142 |
+
"emirates": "emir",
|
| 143 |
+
"womens": "women",
|
| 144 |
+
"veteran's": "veteran",
|
| 145 |
+
"wilderness": "wilder",
|
| 146 |
+
"propellers": "propel",
|
| 147 |
+
"alpes": "alp",
|
| 148 |
+
"charioteering": "chariot",
|
| 149 |
+
"swining": "swine",
|
| 150 |
+
"illness": "ill",
|
| 151 |
+
"crepte": "crept",
|
| 152 |
+
"adhesives": "adhesive",
|
| 153 |
+
"regent's": "regent",
|
| 154 |
+
"decorations": "decor",
|
| 155 |
+
"rabbies": "rabbi",
|
| 156 |
+
"overseas": "oversea",
|
| 157 |
+
"travellers": "travel",
|
| 158 |
+
"casings": "case",
|
| 159 |
+
"smugness": "smug",
|
| 160 |
+
"doves": "dove",
|
| 161 |
+
"nationals": "nation",
|
| 162 |
+
"mustange": "mustang",
|
| 163 |
+
"ringe": "ring",
|
| 164 |
+
"gondoliere": "gondolier",
|
| 165 |
+
"vacationing": "vacate",
|
| 166 |
+
"reminders": "remind",
|
| 167 |
+
"baldness": "bald",
|
| 168 |
+
"settings": "set",
|
| 169 |
+
"glaced": "glace",
|
| 170 |
+
"coniferous": "conifer",
|
| 171 |
+
"revelations": "revel",
|
| 172 |
+
"personals": "person",
|
| 173 |
+
"daughter's": "daughter",
|
| 174 |
+
"badness": "bad",
|
| 175 |
+
"projections": "project",
|
| 176 |
+
"polarizing": "polar",
|
| 177 |
+
"vandalizers": "vandal",
|
| 178 |
+
"minerals": "miner",
|
| 179 |
+
"protesters": "protest",
|
| 180 |
+
"controllers": "control",
|
| 181 |
+
"weddings": "wed",
|
| 182 |
+
"sometimes": "sometime",
|
| 183 |
+
"earing": "ear",
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class OKVQAStemmer:
|
| 188 |
+
"""Stemmer to match OKVQA v1.1 procedure."""
|
| 189 |
+
|
| 190 |
+
def __init__(self):
|
| 191 |
+
self._wordnet_lemmatizer = nltk.stem.WordNetLemmatizer()
|
| 192 |
+
|
| 193 |
+
def stem(self, input_string):
|
| 194 |
+
"""Apply stemming."""
|
| 195 |
+
word_and_pos = nltk.pos_tag(nltk.tokenize.word_tokenize(input_string))
|
| 196 |
+
stemmed_words = []
|
| 197 |
+
for w, p in word_and_pos:
|
| 198 |
+
if w in _MANUAL_MATCHES:
|
| 199 |
+
w = _MANUAL_MATCHES[w]
|
| 200 |
+
elif w.endswith("ing"):
|
| 201 |
+
w = self._wordnet_lemmatizer.lemmatize(w, VERB)
|
| 202 |
+
elif p.startswith("NNS") or p.startswith("NNPS"):
|
| 203 |
+
w = inflection.singularize(w)
|
| 204 |
+
stemmed_words.append(w)
|
| 205 |
+
return " ".join(stemmed_words)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
stemmer = OKVQAStemmer()
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def postprocess_ok_vqa_generation(predictions) -> str:
|
| 212 |
+
prediction = re.split("Question|Answer|Short", predictions, 1)[0]
|
| 213 |
+
prediction = re.split(", ", prediction, 1)[0]
|
| 214 |
+
prediction_stem = stemmer.stem(prediction)
|
| 215 |
+
return prediction_stem
|
open_flamingo/eval/rices.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import open_clip
|
| 2 |
+
import torch
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import torch
|
| 5 |
+
from utils import custom_collate_fn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class RICES:
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
dataset,
|
| 12 |
+
device,
|
| 13 |
+
batch_size,
|
| 14 |
+
vision_encoder_path="ViT-B-32",
|
| 15 |
+
vision_encoder_pretrained="openai",
|
| 16 |
+
cached_features=None,
|
| 17 |
+
):
|
| 18 |
+
self.dataset = dataset
|
| 19 |
+
self.device = device
|
| 20 |
+
self.batch_size = batch_size
|
| 21 |
+
|
| 22 |
+
# Load the model and processor
|
| 23 |
+
vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
|
| 24 |
+
vision_encoder_path,
|
| 25 |
+
pretrained=vision_encoder_pretrained,
|
| 26 |
+
)
|
| 27 |
+
self.model = vision_encoder.to(self.device)
|
| 28 |
+
self.image_processor = image_processor
|
| 29 |
+
|
| 30 |
+
# Precompute features
|
| 31 |
+
if cached_features is None:
|
| 32 |
+
self.features = self._precompute_features()
|
| 33 |
+
else:
|
| 34 |
+
self.features = cached_features
|
| 35 |
+
|
| 36 |
+
def _precompute_features(self):
|
| 37 |
+
features = []
|
| 38 |
+
|
| 39 |
+
# Switch to evaluation mode
|
| 40 |
+
self.model.eval()
|
| 41 |
+
|
| 42 |
+
# Set up loader
|
| 43 |
+
loader = torch.utils.data.DataLoader(
|
| 44 |
+
self.dataset,
|
| 45 |
+
batch_size=self.batch_size,
|
| 46 |
+
collate_fn=custom_collate_fn,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
with torch.no_grad():
|
| 50 |
+
for batch in tqdm(
|
| 51 |
+
loader,
|
| 52 |
+
desc="Precomputing features for RICES",
|
| 53 |
+
):
|
| 54 |
+
batch = batch["image"]
|
| 55 |
+
inputs = torch.stack(
|
| 56 |
+
[self.image_processor(image) for image in batch]
|
| 57 |
+
).to(self.device)
|
| 58 |
+
image_features = self.model.encode_image(inputs)
|
| 59 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 60 |
+
features.append(image_features.detach())
|
| 61 |
+
|
| 62 |
+
features = torch.cat(features)
|
| 63 |
+
return features
|
| 64 |
+
|
| 65 |
+
def find(self, batch, num_examples):
|
| 66 |
+
"""
|
| 67 |
+
Get the top num_examples most similar examples to the images.
|
| 68 |
+
"""
|
| 69 |
+
# Switch to evaluation mode
|
| 70 |
+
self.model.eval()
|
| 71 |
+
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
inputs = torch.stack([self.image_processor(image) for image in batch]).to(
|
| 74 |
+
self.device
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Get the feature of the input image
|
| 78 |
+
query_feature = self.model.encode_image(inputs)
|
| 79 |
+
query_feature /= query_feature.norm(dim=-1, keepdim=True)
|
| 80 |
+
query_feature = query_feature.detach().cpu()
|
| 81 |
+
|
| 82 |
+
if query_feature.ndim == 1:
|
| 83 |
+
query_feature = query_feature.unsqueeze(0)
|
| 84 |
+
|
| 85 |
+
# Compute the similarity of the input image to the precomputed features
|
| 86 |
+
similarity = (query_feature @ self.features.T).squeeze()
|
| 87 |
+
|
| 88 |
+
if similarity.ndim == 1:
|
| 89 |
+
similarity = similarity.unsqueeze(0)
|
| 90 |
+
|
| 91 |
+
# Get the indices of the 'num_examples' most similar images
|
| 92 |
+
indices = similarity.argsort(dim=-1, descending=True)[:, :num_examples]
|
| 93 |
+
|
| 94 |
+
# Return with the most similar images last
|
| 95 |
+
return [[self.dataset[i] for i in reversed(row)] for row in indices]
|
open_flamingo/eval/utils.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from contextlib import suppress
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def random_seed(seed=42, rank=0):
|
| 9 |
+
torch.manual_seed(seed + rank)
|
| 10 |
+
np.random.seed(seed + rank)
|
| 11 |
+
random.seed(seed + rank)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def custom_collate_fn(batch):
|
| 15 |
+
"""
|
| 16 |
+
Collate function for DataLoader that collates a list of dicts into a dict of lists.
|
| 17 |
+
"""
|
| 18 |
+
collated_batch = {}
|
| 19 |
+
for key in batch[0].keys():
|
| 20 |
+
collated_batch[key] = [item[key] for item in batch]
|
| 21 |
+
return collated_batch
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def compute_effective_num_shots(num_shots, model_type):
|
| 25 |
+
"""
|
| 26 |
+
Compute the effective number of shots for a given model type.
|
| 27 |
+
For example, following Flamingo, 0-shot OF evaluations use two text-only shots.
|
| 28 |
+
"""
|
| 29 |
+
if model_type == "open_flamingo":
|
| 30 |
+
return num_shots if num_shots > 0 else 2
|
| 31 |
+
return num_shots
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def sample_batch_demos_from_query_set(query_set, num_samples, batch_size):
|
| 35 |
+
"""
|
| 36 |
+
Sample random demonstrations from the query set.
|
| 37 |
+
"""
|
| 38 |
+
return [random.sample(query_set, num_samples) for _ in range(batch_size)]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_query_set(train_dataset, query_set_size):
|
| 42 |
+
"""
|
| 43 |
+
Get a subset of the training dataset to use as the query set.
|
| 44 |
+
"""
|
| 45 |
+
query_set = np.random.choice(len(train_dataset), query_set_size, replace=False)
|
| 46 |
+
return [train_dataset[i] for i in query_set]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def prepare_eval_samples(test_dataset, num_samples, batch_size):
|
| 50 |
+
"""
|
| 51 |
+
Subset the test dataset and return a DataLoader.
|
| 52 |
+
"""
|
| 53 |
+
random_indices = np.random.choice(len(test_dataset), num_samples, replace=False)
|
| 54 |
+
dataset = torch.utils.data.Subset(test_dataset, random_indices)
|
| 55 |
+
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
| 56 |
+
loader = torch.utils.data.DataLoader(
|
| 57 |
+
dataset,
|
| 58 |
+
batch_size=batch_size,
|
| 59 |
+
sampler=sampler,
|
| 60 |
+
collate_fn=custom_collate_fn,
|
| 61 |
+
)
|
| 62 |
+
return loader
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_indices_of_unique(x):
|
| 66 |
+
"""
|
| 67 |
+
Return the indices of x that correspond to unique elements.
|
| 68 |
+
If value v is unique and two indices in x have value v, the first index is returned.
|
| 69 |
+
"""
|
| 70 |
+
unique_elements = torch.unique(x)
|
| 71 |
+
first_indices = []
|
| 72 |
+
for v in unique_elements:
|
| 73 |
+
indices = torch.where(x == v)[0]
|
| 74 |
+
first_indices.append(indices[0]) # Take the first index for each unique element
|
| 75 |
+
return torch.tensor(first_indices)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def unwrap_model(model):
|
| 79 |
+
"""
|
| 80 |
+
Unwrap a model from a DataParallel or DistributedDataParallel wrapper.
|
| 81 |
+
"""
|
| 82 |
+
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
|
| 83 |
+
return model.module
|
| 84 |
+
else:
|
| 85 |
+
return model
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_predicted_classnames(logprobs, k, class_id_to_name):
|
| 89 |
+
"""
|
| 90 |
+
Args:
|
| 91 |
+
- logprobs shape (B, Y) containing logprobs for each classname
|
| 92 |
+
- k: number for top-k
|
| 93 |
+
- class_id_to_name: dict mapping class index to classname
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
- top-k predicted classnames shape (B, k) type str
|
| 97 |
+
- top-k logprobs shape (B, k) type float
|
| 98 |
+
"""
|
| 99 |
+
# convert indices to classnames
|
| 100 |
+
_, predictions = torch.topk(logprobs, k=k, dim=1) # shape (B, k)
|
| 101 |
+
predicted_classnames = [
|
| 102 |
+
[class_id_to_name[ix] for ix in item] for item in predictions.tolist()
|
| 103 |
+
]
|
| 104 |
+
predicted_logprobs = torch.gather(logprobs, 1, predictions)
|
| 105 |
+
return predicted_classnames, predicted_logprobs
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_cast_dtype(precision: str):
|
| 109 |
+
cast_dtype = None
|
| 110 |
+
if precision == "bf16":
|
| 111 |
+
cast_dtype = torch.bfloat16
|
| 112 |
+
elif precision == "fp16":
|
| 113 |
+
cast_dtype = torch.float16
|
| 114 |
+
return cast_dtype
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_autocast(precision):
|
| 118 |
+
if precision == "amp":
|
| 119 |
+
return torch.cuda.amp.autocast
|
| 120 |
+
elif precision == "amp_bfloat16" or precision == "amp_bf16":
|
| 121 |
+
# amp_bfloat16 is more stable than amp float16 for clip training
|
| 122 |
+
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
|
| 123 |
+
else:
|
| 124 |
+
return suppress
|
open_flamingo/eval/vqa_metric.py
ADDED
|
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import datetime
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import re
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
# Interface for accessing the VQA dataset.
|
| 10 |
+
|
| 11 |
+
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
|
| 12 |
+
# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
|
| 13 |
+
|
| 14 |
+
# The following functions are defined:
|
| 15 |
+
# VQA - VQA class that loads VQA annotation file and prepares data structures.
|
| 16 |
+
# getQuesIds - Get question ids that satisfy given filter conditions.
|
| 17 |
+
# getImgIds - Get image ids that satisfy given filter conditions.
|
| 18 |
+
# loadQA - Load questions and answers with the specified question ids.
|
| 19 |
+
# showQA - Display the specified questions and answers.
|
| 20 |
+
# loadRes - Load result file and create result object.
|
| 21 |
+
|
| 22 |
+
# Help on each function can be accessed by: "help(COCO.function)"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class VQA:
|
| 26 |
+
def __init__(self, annotation_file=None, question_file=None):
|
| 27 |
+
"""
|
| 28 |
+
Constructor of VQA helper class for reading and visualizing questions and answers.
|
| 29 |
+
:param annotation_file (str): location of VQA annotation file
|
| 30 |
+
:return:
|
| 31 |
+
"""
|
| 32 |
+
# load dataset
|
| 33 |
+
self.dataset = {}
|
| 34 |
+
self.questions = {}
|
| 35 |
+
self.qa = {}
|
| 36 |
+
self.qqa = {}
|
| 37 |
+
self.imgToQA = {}
|
| 38 |
+
if not annotation_file == None and not question_file == None:
|
| 39 |
+
print("loading VQA annotations and questions into memory...")
|
| 40 |
+
time_t = datetime.datetime.utcnow()
|
| 41 |
+
dataset = json.load(open(annotation_file, "r"))
|
| 42 |
+
questions = json.load(open(question_file, "r"))
|
| 43 |
+
print(datetime.datetime.utcnow() - time_t)
|
| 44 |
+
self.dataset = dataset
|
| 45 |
+
self.questions = questions
|
| 46 |
+
self.createIndex()
|
| 47 |
+
|
| 48 |
+
def createIndex(self):
|
| 49 |
+
# create index
|
| 50 |
+
print("creating index...")
|
| 51 |
+
imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]}
|
| 52 |
+
qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
|
| 53 |
+
qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
|
| 54 |
+
for ann in self.dataset["annotations"]:
|
| 55 |
+
imgToQA[ann["image_id"]] += [ann]
|
| 56 |
+
qa[ann["question_id"]] = ann
|
| 57 |
+
for ques in self.questions["questions"]:
|
| 58 |
+
qqa[ques["question_id"]] = ques
|
| 59 |
+
print("index created!")
|
| 60 |
+
|
| 61 |
+
# create class members
|
| 62 |
+
self.qa = qa
|
| 63 |
+
self.qqa = qqa
|
| 64 |
+
self.imgToQA = imgToQA
|
| 65 |
+
|
| 66 |
+
def info(self):
|
| 67 |
+
"""
|
| 68 |
+
Print information about the VQA annotation file.
|
| 69 |
+
:return:
|
| 70 |
+
"""
|
| 71 |
+
for key, value in self.dataset["info"].items():
|
| 72 |
+
print("%s: %s" % (key, value))
|
| 73 |
+
|
| 74 |
+
def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
|
| 75 |
+
"""
|
| 76 |
+
Get question ids that satisfy given filter conditions. default skips that filter
|
| 77 |
+
:param imgIds (int array) : get question ids for given imgs
|
| 78 |
+
quesTypes (str array) : get question ids for given question types
|
| 79 |
+
ansTypes (str array) : get question ids for given answer types
|
| 80 |
+
:return: ids (int array) : integer array of question ids
|
| 81 |
+
"""
|
| 82 |
+
imgIds = imgIds if type(imgIds) == list else [imgIds]
|
| 83 |
+
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
|
| 84 |
+
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
|
| 85 |
+
|
| 86 |
+
if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
|
| 87 |
+
anns = self.dataset["annotations"]
|
| 88 |
+
else:
|
| 89 |
+
if not len(imgIds) == 0:
|
| 90 |
+
anns = sum(
|
| 91 |
+
[self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],
|
| 92 |
+
[],
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
anns = self.dataset["annotations"]
|
| 96 |
+
anns = (
|
| 97 |
+
anns
|
| 98 |
+
if len(quesTypes) == 0
|
| 99 |
+
else [ann for ann in anns if ann["question_type"] in quesTypes]
|
| 100 |
+
)
|
| 101 |
+
anns = (
|
| 102 |
+
anns
|
| 103 |
+
if len(ansTypes) == 0
|
| 104 |
+
else [ann for ann in anns if ann["answer_type"] in ansTypes]
|
| 105 |
+
)
|
| 106 |
+
ids = [ann["question_id"] for ann in anns]
|
| 107 |
+
return ids
|
| 108 |
+
|
| 109 |
+
def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
|
| 110 |
+
"""
|
| 111 |
+
Get image ids that satisfy given filter conditions. default skips that filter
|
| 112 |
+
:param quesIds (int array) : get image ids for given question ids
|
| 113 |
+
quesTypes (str array) : get image ids for given question types
|
| 114 |
+
ansTypes (str array) : get image ids for given answer types
|
| 115 |
+
:return: ids (int array) : integer array of image ids
|
| 116 |
+
"""
|
| 117 |
+
quesIds = quesIds if type(quesIds) == list else [quesIds]
|
| 118 |
+
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
|
| 119 |
+
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
|
| 120 |
+
|
| 121 |
+
if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
|
| 122 |
+
anns = self.dataset["annotations"]
|
| 123 |
+
else:
|
| 124 |
+
if not len(quesIds) == 0:
|
| 125 |
+
anns = sum(
|
| 126 |
+
[self.qa[quesId] for quesId in quesIds if quesId in self.qa], []
|
| 127 |
+
)
|
| 128 |
+
else:
|
| 129 |
+
anns = self.dataset["annotations"]
|
| 130 |
+
anns = (
|
| 131 |
+
anns
|
| 132 |
+
if len(quesTypes) == 0
|
| 133 |
+
else [ann for ann in anns if ann["question_type"] in quesTypes]
|
| 134 |
+
)
|
| 135 |
+
anns = (
|
| 136 |
+
anns
|
| 137 |
+
if len(ansTypes) == 0
|
| 138 |
+
else [ann for ann in anns if ann["answer_type"] in ansTypes]
|
| 139 |
+
)
|
| 140 |
+
ids = [ann["image_id"] for ann in anns]
|
| 141 |
+
return ids
|
| 142 |
+
|
| 143 |
+
def loadQA(self, ids=[]):
|
| 144 |
+
"""
|
| 145 |
+
Load questions and answers with the specified question ids.
|
| 146 |
+
:param ids (int array) : integer ids specifying question ids
|
| 147 |
+
:return: qa (object array) : loaded qa objects
|
| 148 |
+
"""
|
| 149 |
+
if type(ids) == list:
|
| 150 |
+
return [self.qa[id] for id in ids]
|
| 151 |
+
elif type(ids) == int:
|
| 152 |
+
return [self.qa[ids]]
|
| 153 |
+
|
| 154 |
+
def showQA(self, anns):
|
| 155 |
+
"""
|
| 156 |
+
Display the specified annotations.
|
| 157 |
+
:param anns (array of object): annotations to display
|
| 158 |
+
:return: None
|
| 159 |
+
"""
|
| 160 |
+
if len(anns) == 0:
|
| 161 |
+
return 0
|
| 162 |
+
for ann in anns:
|
| 163 |
+
quesId = ann["question_id"]
|
| 164 |
+
print("Question: %s" % (self.qqa[quesId]["question"]))
|
| 165 |
+
for ans in ann["answers"]:
|
| 166 |
+
print("Answer %d: %s" % (ans["answer_id"], ans["answer"]))
|
| 167 |
+
|
| 168 |
+
def loadRes(self, resFile, quesFile):
|
| 169 |
+
"""
|
| 170 |
+
Load result file and return a result object.
|
| 171 |
+
:param resFile (str) : file name of result file
|
| 172 |
+
:return: res (obj) : result api object
|
| 173 |
+
"""
|
| 174 |
+
res = VQA()
|
| 175 |
+
res.questions = json.load(open(quesFile))
|
| 176 |
+
res.dataset["info"] = copy.deepcopy(self.questions["info"])
|
| 177 |
+
res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"])
|
| 178 |
+
res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"])
|
| 179 |
+
res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"])
|
| 180 |
+
res.dataset["license"] = copy.deepcopy(self.questions["license"])
|
| 181 |
+
|
| 182 |
+
print("Loading and preparing results... ")
|
| 183 |
+
time_t = datetime.datetime.utcnow()
|
| 184 |
+
anns = json.load(open(resFile))
|
| 185 |
+
assert type(anns) == list, "results is not an array of objects"
|
| 186 |
+
annsQuesIds = [ann["question_id"] for ann in anns]
|
| 187 |
+
# print set of question ids that do not have corresponding annotations
|
| 188 |
+
|
| 189 |
+
# assert set(annsQuesIds) == set(self.getQuesIds()), \
|
| 190 |
+
# 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
|
| 191 |
+
for ann in anns:
|
| 192 |
+
quesId = ann["question_id"]
|
| 193 |
+
if res.dataset["task_type"] == "Multiple Choice":
|
| 194 |
+
assert (
|
| 195 |
+
ann["answer"] in self.qqa[quesId]["multiple_choices"]
|
| 196 |
+
), "predicted answer is not one of the multiple choices"
|
| 197 |
+
qaAnn = self.qa[quesId]
|
| 198 |
+
ann["image_id"] = qaAnn["image_id"]
|
| 199 |
+
ann["question_type"] = qaAnn["question_type"]
|
| 200 |
+
if "answer_type" in ann:
|
| 201 |
+
ann["answer_type"] = qaAnn["answer_type"]
|
| 202 |
+
print(
|
| 203 |
+
"DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds())
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
res.dataset["annotations"] = anns
|
| 207 |
+
res.createIndex()
|
| 208 |
+
return res
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class VQAEval:
|
| 212 |
+
def __init__(self, vqa, vqaRes, n=2):
|
| 213 |
+
self.n = n
|
| 214 |
+
self.accuracy = {}
|
| 215 |
+
self.evalQA = {}
|
| 216 |
+
self.evalQuesType = {}
|
| 217 |
+
self.evalAnsType = {}
|
| 218 |
+
self.vqa = vqa
|
| 219 |
+
self.vqaRes = vqaRes
|
| 220 |
+
if not vqa is None and not vqaRes is None:
|
| 221 |
+
self.params = {"question_id": vqaRes.getQuesIds()}
|
| 222 |
+
self.contractions = {
|
| 223 |
+
"aint": "ain't",
|
| 224 |
+
"arent": "aren't",
|
| 225 |
+
"cant": "can't",
|
| 226 |
+
"couldve": "could've",
|
| 227 |
+
"couldnt": "couldn't",
|
| 228 |
+
"couldn'tve": "couldn't've",
|
| 229 |
+
"couldnt've": "couldn't've",
|
| 230 |
+
"didnt": "didn't",
|
| 231 |
+
"doesnt": "doesn't",
|
| 232 |
+
"dont": "don't",
|
| 233 |
+
"hadnt": "hadn't",
|
| 234 |
+
"hadnt've": "hadn't've",
|
| 235 |
+
"hadn'tve": "hadn't've",
|
| 236 |
+
"hasnt": "hasn't",
|
| 237 |
+
"havent": "haven't",
|
| 238 |
+
"hed": "he'd",
|
| 239 |
+
"hed've": "he'd've",
|
| 240 |
+
"he'dve": "he'd've",
|
| 241 |
+
"hes": "he's",
|
| 242 |
+
"howd": "how'd",
|
| 243 |
+
"howll": "how'll",
|
| 244 |
+
"hows": "how's",
|
| 245 |
+
"Id've": "I'd've",
|
| 246 |
+
"I'dve": "I'd've",
|
| 247 |
+
"Im": "I'm",
|
| 248 |
+
"Ive": "I've",
|
| 249 |
+
"isnt": "isn't",
|
| 250 |
+
"itd": "it'd",
|
| 251 |
+
"itd've": "it'd've",
|
| 252 |
+
"it'dve": "it'd've",
|
| 253 |
+
"itll": "it'll",
|
| 254 |
+
"let's": "let's",
|
| 255 |
+
"maam": "ma'am",
|
| 256 |
+
"mightnt": "mightn't",
|
| 257 |
+
"mightnt've": "mightn't've",
|
| 258 |
+
"mightn'tve": "mightn't've",
|
| 259 |
+
"mightve": "might've",
|
| 260 |
+
"mustnt": "mustn't",
|
| 261 |
+
"mustve": "must've",
|
| 262 |
+
"neednt": "needn't",
|
| 263 |
+
"notve": "not've",
|
| 264 |
+
"oclock": "o'clock",
|
| 265 |
+
"oughtnt": "oughtn't",
|
| 266 |
+
"ow's'at": "'ow's'at",
|
| 267 |
+
"'ows'at": "'ow's'at",
|
| 268 |
+
"'ow'sat": "'ow's'at",
|
| 269 |
+
"shant": "shan't",
|
| 270 |
+
"shed've": "she'd've",
|
| 271 |
+
"she'dve": "she'd've",
|
| 272 |
+
"she's": "she's",
|
| 273 |
+
"shouldve": "should've",
|
| 274 |
+
"shouldnt": "shouldn't",
|
| 275 |
+
"shouldnt've": "shouldn't've",
|
| 276 |
+
"shouldn'tve": "shouldn't've",
|
| 277 |
+
"somebody'd": "somebodyd",
|
| 278 |
+
"somebodyd've": "somebody'd've",
|
| 279 |
+
"somebody'dve": "somebody'd've",
|
| 280 |
+
"somebodyll": "somebody'll",
|
| 281 |
+
"somebodys": "somebody's",
|
| 282 |
+
"someoned": "someone'd",
|
| 283 |
+
"someoned've": "someone'd've",
|
| 284 |
+
"someone'dve": "someone'd've",
|
| 285 |
+
"someonell": "someone'll",
|
| 286 |
+
"someones": "someone's",
|
| 287 |
+
"somethingd": "something'd",
|
| 288 |
+
"somethingd've": "something'd've",
|
| 289 |
+
"something'dve": "something'd've",
|
| 290 |
+
"somethingll": "something'll",
|
| 291 |
+
"thats": "that's",
|
| 292 |
+
"thered": "there'd",
|
| 293 |
+
"thered've": "there'd've",
|
| 294 |
+
"there'dve": "there'd've",
|
| 295 |
+
"therere": "there're",
|
| 296 |
+
"theres": "there's",
|
| 297 |
+
"theyd": "they'd",
|
| 298 |
+
"theyd've": "they'd've",
|
| 299 |
+
"they'dve": "they'd've",
|
| 300 |
+
"theyll": "they'll",
|
| 301 |
+
"theyre": "they're",
|
| 302 |
+
"theyve": "they've",
|
| 303 |
+
"twas": "'twas",
|
| 304 |
+
"wasnt": "wasn't",
|
| 305 |
+
"wed've": "we'd've",
|
| 306 |
+
"we'dve": "we'd've",
|
| 307 |
+
"weve": "we've",
|
| 308 |
+
"werent": "weren't",
|
| 309 |
+
"whatll": "what'll",
|
| 310 |
+
"whatre": "what're",
|
| 311 |
+
"whats": "what's",
|
| 312 |
+
"whatve": "what've",
|
| 313 |
+
"whens": "when's",
|
| 314 |
+
"whered": "where'd",
|
| 315 |
+
"wheres": "where's",
|
| 316 |
+
"whereve": "where've",
|
| 317 |
+
"whod": "who'd",
|
| 318 |
+
"whod've": "who'd've",
|
| 319 |
+
"who'dve": "who'd've",
|
| 320 |
+
"wholl": "who'll",
|
| 321 |
+
"whos": "who's",
|
| 322 |
+
"whove": "who've",
|
| 323 |
+
"whyll": "why'll",
|
| 324 |
+
"whyre": "why're",
|
| 325 |
+
"whys": "why's",
|
| 326 |
+
"wont": "won't",
|
| 327 |
+
"wouldve": "would've",
|
| 328 |
+
"wouldnt": "wouldn't",
|
| 329 |
+
"wouldnt've": "wouldn't've",
|
| 330 |
+
"wouldn'tve": "wouldn't've",
|
| 331 |
+
"yall": "y'all",
|
| 332 |
+
"yall'll": "y'all'll",
|
| 333 |
+
"y'allll": "y'all'll",
|
| 334 |
+
"yall'd've": "y'all'd've",
|
| 335 |
+
"y'alld've": "y'all'd've",
|
| 336 |
+
"y'all'dve": "y'all'd've",
|
| 337 |
+
"youd": "you'd",
|
| 338 |
+
"youd've": "you'd've",
|
| 339 |
+
"you'dve": "you'd've",
|
| 340 |
+
"youll": "you'll",
|
| 341 |
+
"youre": "you're",
|
| 342 |
+
"youve": "you've",
|
| 343 |
+
}
|
| 344 |
+
self.manualMap = {
|
| 345 |
+
"none": "0",
|
| 346 |
+
"zero": "0",
|
| 347 |
+
"one": "1",
|
| 348 |
+
"two": "2",
|
| 349 |
+
"three": "3",
|
| 350 |
+
"four": "4",
|
| 351 |
+
"five": "5",
|
| 352 |
+
"six": "6",
|
| 353 |
+
"seven": "7",
|
| 354 |
+
"eight": "8",
|
| 355 |
+
"nine": "9",
|
| 356 |
+
"ten": "10",
|
| 357 |
+
}
|
| 358 |
+
self.articles = ["a", "an", "the"]
|
| 359 |
+
|
| 360 |
+
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
|
| 361 |
+
self.commaStrip = re.compile("(\d)(\,)(\d)")
|
| 362 |
+
self.punct = [
|
| 363 |
+
";",
|
| 364 |
+
r"/",
|
| 365 |
+
"[",
|
| 366 |
+
"]",
|
| 367 |
+
'"',
|
| 368 |
+
"{",
|
| 369 |
+
"}",
|
| 370 |
+
"(",
|
| 371 |
+
")",
|
| 372 |
+
"=",
|
| 373 |
+
"+",
|
| 374 |
+
"\\",
|
| 375 |
+
"_",
|
| 376 |
+
"-",
|
| 377 |
+
">",
|
| 378 |
+
"<",
|
| 379 |
+
"@",
|
| 380 |
+
"`",
|
| 381 |
+
",",
|
| 382 |
+
"?",
|
| 383 |
+
"!",
|
| 384 |
+
]
|
| 385 |
+
|
| 386 |
+
def evaluate(self, quesIds=None):
|
| 387 |
+
if quesIds == None:
|
| 388 |
+
quesIds = [quesId for quesId in self.params["question_id"]]
|
| 389 |
+
gts = {}
|
| 390 |
+
res = {}
|
| 391 |
+
for quesId in quesIds:
|
| 392 |
+
gts[quesId] = self.vqa.qa[quesId]
|
| 393 |
+
res[quesId] = self.vqaRes.qa[quesId]
|
| 394 |
+
|
| 395 |
+
# =================================================
|
| 396 |
+
# Compute accuracy
|
| 397 |
+
# =================================================
|
| 398 |
+
accQA = []
|
| 399 |
+
accQuesType = {}
|
| 400 |
+
accAnsType = {}
|
| 401 |
+
print("computing accuracy")
|
| 402 |
+
step = 0
|
| 403 |
+
for quesId in quesIds:
|
| 404 |
+
for ansDic in gts[quesId]["answers"]:
|
| 405 |
+
ansDic["answer"] = ansDic["answer"].replace("\n", " ")
|
| 406 |
+
ansDic["answer"] = ansDic["answer"].replace("\t", " ")
|
| 407 |
+
ansDic["answer"] = ansDic["answer"].strip()
|
| 408 |
+
resAns = res[quesId]["answer"]
|
| 409 |
+
resAns = resAns.replace("\n", " ")
|
| 410 |
+
resAns = resAns.replace("\t", " ")
|
| 411 |
+
resAns = resAns.strip()
|
| 412 |
+
resAns = self.processPunctuation(resAns)
|
| 413 |
+
resAns = self.processDigitArticle(resAns)
|
| 414 |
+
gtAcc = []
|
| 415 |
+
|
| 416 |
+
for ansDic in gts[quesId]["answers"]:
|
| 417 |
+
ansDic["answer"] = self.processPunctuation(ansDic["answer"])
|
| 418 |
+
ansDic["answer"] = self.processDigitArticle(ansDic["answer"])
|
| 419 |
+
|
| 420 |
+
for gtAnsDatum in gts[quesId]["answers"]:
|
| 421 |
+
otherGTAns = [
|
| 422 |
+
item for item in gts[quesId]["answers"] if item != gtAnsDatum
|
| 423 |
+
]
|
| 424 |
+
matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
|
| 425 |
+
acc = min(1, float(len(matchingAns)) / 3)
|
| 426 |
+
gtAcc.append(acc)
|
| 427 |
+
quesType = gts[quesId]["question_type"]
|
| 428 |
+
ansType = (
|
| 429 |
+
gts[quesId]["answer_type"] if "answer_type" in gts[quesId] else "other"
|
| 430 |
+
)
|
| 431 |
+
avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
|
| 432 |
+
accQA.append(avgGTAcc)
|
| 433 |
+
if quesType not in accQuesType:
|
| 434 |
+
accQuesType[quesType] = []
|
| 435 |
+
accQuesType[quesType].append(avgGTAcc)
|
| 436 |
+
if ansType not in accAnsType:
|
| 437 |
+
accAnsType[ansType] = []
|
| 438 |
+
accAnsType[ansType].append(avgGTAcc)
|
| 439 |
+
self.setEvalQA(quesId, avgGTAcc)
|
| 440 |
+
self.setEvalQuesType(quesId, quesType, avgGTAcc)
|
| 441 |
+
self.setEvalAnsType(quesId, ansType, avgGTAcc)
|
| 442 |
+
if step % 100 == 0:
|
| 443 |
+
self.updateProgress(step / float(len(quesIds)))
|
| 444 |
+
step = step + 1
|
| 445 |
+
|
| 446 |
+
self.setAccuracy(accQA, accQuesType, accAnsType)
|
| 447 |
+
print("Done computing accuracy")
|
| 448 |
+
|
| 449 |
+
def processPunctuation(self, inText):
|
| 450 |
+
outText = inText
|
| 451 |
+
for p in self.punct:
|
| 452 |
+
if (p + " " in inText or " " + p in inText) or (
|
| 453 |
+
re.search(self.commaStrip, inText) != None
|
| 454 |
+
):
|
| 455 |
+
outText = outText.replace(p, "")
|
| 456 |
+
else:
|
| 457 |
+
outText = outText.replace(p, " ")
|
| 458 |
+
outText = self.periodStrip.sub("", outText, re.UNICODE)
|
| 459 |
+
return outText
|
| 460 |
+
|
| 461 |
+
def processDigitArticle(self, inText):
|
| 462 |
+
outText = []
|
| 463 |
+
tempText = inText.lower().split()
|
| 464 |
+
for word in tempText:
|
| 465 |
+
word = self.manualMap.setdefault(word, word)
|
| 466 |
+
if word not in self.articles:
|
| 467 |
+
outText.append(word)
|
| 468 |
+
else:
|
| 469 |
+
pass
|
| 470 |
+
for wordId, word in enumerate(outText):
|
| 471 |
+
if word in self.contractions:
|
| 472 |
+
outText[wordId] = self.contractions[word]
|
| 473 |
+
outText = " ".join(outText)
|
| 474 |
+
return outText
|
| 475 |
+
|
| 476 |
+
def setAccuracy(self, accQA, accQuesType, accAnsType):
|
| 477 |
+
self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
|
| 478 |
+
self.accuracy["perQuestionType"] = {
|
| 479 |
+
quesType: round(
|
| 480 |
+
100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
|
| 481 |
+
self.n,
|
| 482 |
+
)
|
| 483 |
+
for quesType in accQuesType
|
| 484 |
+
}
|
| 485 |
+
self.accuracy["perAnswerType"] = {
|
| 486 |
+
ansType: round(
|
| 487 |
+
100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
|
| 488 |
+
)
|
| 489 |
+
for ansType in accAnsType
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
def setEvalQA(self, quesId, acc):
|
| 493 |
+
self.evalQA[quesId] = round(100 * acc, self.n)
|
| 494 |
+
|
| 495 |
+
def setEvalQuesType(self, quesId, quesType, acc):
|
| 496 |
+
if quesType not in self.evalQuesType:
|
| 497 |
+
self.evalQuesType[quesType] = {}
|
| 498 |
+
self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
|
| 499 |
+
|
| 500 |
+
def setEvalAnsType(self, quesId, ansType, acc):
|
| 501 |
+
if ansType not in self.evalAnsType:
|
| 502 |
+
self.evalAnsType[ansType] = {}
|
| 503 |
+
self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
|
| 504 |
+
|
| 505 |
+
def updateProgress(self, progress):
|
| 506 |
+
barLength = 20
|
| 507 |
+
status = ""
|
| 508 |
+
if isinstance(progress, int):
|
| 509 |
+
progress = float(progress)
|
| 510 |
+
if not isinstance(progress, float):
|
| 511 |
+
progress = 0
|
| 512 |
+
status = "error: progress var must be float\r\n"
|
| 513 |
+
if progress < 0:
|
| 514 |
+
progress = 0
|
| 515 |
+
status = "Halt...\r\n"
|
| 516 |
+
if progress >= 1:
|
| 517 |
+
progress = 1
|
| 518 |
+
status = "Done...\r\n"
|
| 519 |
+
block = int(round(barLength * progress))
|
| 520 |
+
text = "\rFinshed Percent: [{0}] {1}% {2}".format(
|
| 521 |
+
"#" * block + "-" * (barLength - block), int(progress * 100), status
|
| 522 |
+
)
|
| 523 |
+
sys.stdout.write(text)
|
| 524 |
+
sys.stdout.flush()
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def compute_vqa_accuracy(result_json_path, question_json_path, annotation_json_path):
|
| 528 |
+
"""Compute the VQA accuracy metric.
|
| 529 |
+
|
| 530 |
+
Args:
|
| 531 |
+
result_json_path (str): Path to the json file with model outputs
|
| 532 |
+
question_json_path (str): Path to the json file with questions
|
| 533 |
+
annotation_json_path (str): Path to the json file with annotations
|
| 534 |
+
|
| 535 |
+
Returns:
|
| 536 |
+
float: VQA accuracy
|
| 537 |
+
"""
|
| 538 |
+
|
| 539 |
+
# create vqa object and vqaRes object
|
| 540 |
+
vqa = VQA(annotation_json_path, question_json_path)
|
| 541 |
+
vqaRes = vqa.loadRes(result_json_path, question_json_path)
|
| 542 |
+
|
| 543 |
+
# create vqaEval object by taking vqa and vqaRes
|
| 544 |
+
# n is precision of accuracy (number of places after decimal), default is 2
|
| 545 |
+
vqaEval = VQAEval(vqa, vqaRes, n=2)
|
| 546 |
+
|
| 547 |
+
# evaluate results
|
| 548 |
+
"""
|
| 549 |
+
If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
|
| 550 |
+
By default it uses all the question ids in annotation file
|
| 551 |
+
"""
|
| 552 |
+
vqaEval.evaluate()
|
| 553 |
+
|
| 554 |
+
return vqaEval.accuracy["overall"]
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def postprocess_vqa_generation(predictions):
|
| 558 |
+
answer = re.split("Question|Answer|Short", predictions, 1)[0]
|
| 559 |
+
answer = re.split(", ", answer, 1)[0]
|
| 560 |
+
return answer
|
open_flamingo/scripts/cache_rices_features.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cache CLIP features for all images in training split in preparation for RICES
|
| 3 |
+
"""
|
| 4 |
+
import argparse
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
sys.path.append(
|
| 9 |
+
os.path.join(
|
| 10 |
+
os.path.dirname(os.path.abspath(__file__)),
|
| 11 |
+
"..",
|
| 12 |
+
)
|
| 13 |
+
)
|
| 14 |
+
from eval.rices import RICES
|
| 15 |
+
from eval.eval_datasets import (
|
| 16 |
+
CaptionDataset,
|
| 17 |
+
VQADataset,
|
| 18 |
+
ImageNetDataset,
|
| 19 |
+
HatefulMemesDataset,
|
| 20 |
+
)
|
| 21 |
+
import os
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
parser = argparse.ArgumentParser()
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--output_dir",
|
| 27 |
+
type=str,
|
| 28 |
+
required=True,
|
| 29 |
+
help="Directory to save the cached features.",
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
|
| 32 |
+
parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
|
| 33 |
+
parser.add_argument("--batch_size", default=256)
|
| 34 |
+
|
| 35 |
+
# Per-dataset flags
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--eval_coco",
|
| 38 |
+
action="store_true",
|
| 39 |
+
default=False,
|
| 40 |
+
help="Whether to cache COCO.",
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--eval_vqav2",
|
| 44 |
+
action="store_true",
|
| 45 |
+
default=False,
|
| 46 |
+
help="Whether to cache VQAV2.",
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--eval_ok_vqa",
|
| 50 |
+
action="store_true",
|
| 51 |
+
default=False,
|
| 52 |
+
help="Whether to cache OK-VQA.",
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--eval_vizwiz",
|
| 56 |
+
action="store_true",
|
| 57 |
+
default=False,
|
| 58 |
+
help="Whether to cache VizWiz.",
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--eval_textvqa",
|
| 62 |
+
action="store_true",
|
| 63 |
+
default=False,
|
| 64 |
+
help="Whether to cache TextVQA.",
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--eval_imagenet",
|
| 68 |
+
action="store_true",
|
| 69 |
+
default=False,
|
| 70 |
+
help="Whether to cache ImageNet.",
|
| 71 |
+
)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--eval_flickr30",
|
| 74 |
+
action="store_true",
|
| 75 |
+
default=False,
|
| 76 |
+
help="Whether to cache Flickr30.",
|
| 77 |
+
)
|
| 78 |
+
parser.add_argument(
|
| 79 |
+
"--eval_hateful_memes",
|
| 80 |
+
action="store_true",
|
| 81 |
+
default=False,
|
| 82 |
+
help="Whether to cache Hateful Memes.",
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Dataset arguments
|
| 86 |
+
|
| 87 |
+
## Flickr30 Dataset
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--flickr_image_dir_path",
|
| 90 |
+
type=str,
|
| 91 |
+
help="Path to the flickr30/flickr30k_images directory.",
|
| 92 |
+
default=None,
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--flickr_karpathy_json_path",
|
| 96 |
+
type=str,
|
| 97 |
+
help="Path to the dataset_flickr30k.json file.",
|
| 98 |
+
default=None,
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--flickr_annotations_json_path",
|
| 102 |
+
type=str,
|
| 103 |
+
help="Path to the dataset_flickr30k_coco_style.json file.",
|
| 104 |
+
)
|
| 105 |
+
## COCO Dataset
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--coco_train_image_dir_path",
|
| 108 |
+
type=str,
|
| 109 |
+
default=None,
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--coco_val_image_dir_path",
|
| 113 |
+
type=str,
|
| 114 |
+
default=None,
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--coco_karpathy_json_path",
|
| 118 |
+
type=str,
|
| 119 |
+
default=None,
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--coco_annotations_json_path",
|
| 123 |
+
type=str,
|
| 124 |
+
default=None,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
## VQAV2 Dataset
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"--vqav2_train_image_dir_path",
|
| 130 |
+
type=str,
|
| 131 |
+
default=None,
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--vqav2_train_questions_json_path",
|
| 135 |
+
type=str,
|
| 136 |
+
default=None,
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--vqav2_train_annotations_json_path",
|
| 140 |
+
type=str,
|
| 141 |
+
default=None,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
## OK-VQA Dataset
|
| 145 |
+
parser.add_argument(
|
| 146 |
+
"--ok_vqa_train_image_dir_path",
|
| 147 |
+
type=str,
|
| 148 |
+
help="Path to the vqav2/train2014 directory.",
|
| 149 |
+
default=None,
|
| 150 |
+
)
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--ok_vqa_train_questions_json_path",
|
| 153 |
+
type=str,
|
| 154 |
+
help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
|
| 155 |
+
default=None,
|
| 156 |
+
)
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--ok_vqa_train_annotations_json_path",
|
| 159 |
+
type=str,
|
| 160 |
+
help="Path to the v2_mscoco_train2014_annotations.json file.",
|
| 161 |
+
default=None,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
## VizWiz Dataset
|
| 165 |
+
parser.add_argument(
|
| 166 |
+
"--vizwiz_train_image_dir_path",
|
| 167 |
+
type=str,
|
| 168 |
+
help="Path to the vizwiz train images directory.",
|
| 169 |
+
default=None,
|
| 170 |
+
)
|
| 171 |
+
parser.add_argument(
|
| 172 |
+
"--vizwiz_train_questions_json_path",
|
| 173 |
+
type=str,
|
| 174 |
+
help="Path to the vizwiz questions json file.",
|
| 175 |
+
default=None,
|
| 176 |
+
)
|
| 177 |
+
parser.add_argument(
|
| 178 |
+
"--vizwiz_train_annotations_json_path",
|
| 179 |
+
type=str,
|
| 180 |
+
help="Path to the vizwiz annotations json file.",
|
| 181 |
+
default=None,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# TextVQA Dataset
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--textvqa_image_dir_path",
|
| 187 |
+
type=str,
|
| 188 |
+
help="Path to the textvqa images directory.",
|
| 189 |
+
default=None,
|
| 190 |
+
)
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--textvqa_train_questions_json_path",
|
| 193 |
+
type=str,
|
| 194 |
+
help="Path to the textvqa questions json file.",
|
| 195 |
+
default=None,
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--textvqa_train_annotations_json_path",
|
| 199 |
+
type=str,
|
| 200 |
+
help="Path to the textvqa annotations json file.",
|
| 201 |
+
default=None,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
## Imagenet dataset
|
| 206 |
+
parser.add_argument("--imagenet_root", type=str, default="/tmp")
|
| 207 |
+
|
| 208 |
+
## Hateful Memes dataset
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--hateful_memes_image_dir_path",
|
| 211 |
+
type=str,
|
| 212 |
+
default=None,
|
| 213 |
+
)
|
| 214 |
+
parser.add_argument(
|
| 215 |
+
"--hateful_memes_train_annotations_json_path",
|
| 216 |
+
type=str,
|
| 217 |
+
default=None,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def main():
|
| 222 |
+
args, leftovers = parser.parse_known_args()
|
| 223 |
+
device_id = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
|
| 224 |
+
if args.eval_flickr30:
|
| 225 |
+
print("Caching Flickr30k...")
|
| 226 |
+
train_dataset = CaptionDataset(
|
| 227 |
+
image_train_dir_path=args.flickr_image_dir_path,
|
| 228 |
+
image_val_dir_path=None,
|
| 229 |
+
annotations_path=args.flickr_karpathy_json_path,
|
| 230 |
+
is_train=True,
|
| 231 |
+
dataset_name="flickr",
|
| 232 |
+
)
|
| 233 |
+
rices_dataset = RICES(
|
| 234 |
+
train_dataset,
|
| 235 |
+
device_id,
|
| 236 |
+
args.batch_size,
|
| 237 |
+
vision_encoder_path=args.vision_encoder_path,
|
| 238 |
+
vision_encoder_pretrained=args.vision_encoder_pretrained,
|
| 239 |
+
)
|
| 240 |
+
torch.save(
|
| 241 |
+
rices_dataset.features,
|
| 242 |
+
os.path.join(args.output_dir, "flickr30.pkl"),
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
if args.eval_coco:
|
| 246 |
+
print("Caching COCO...")
|
| 247 |
+
train_dataset = CaptionDataset(
|
| 248 |
+
image_train_dir_path=args.coco_train_image_dir_path,
|
| 249 |
+
image_val_dir_path=args.coco_val_image_dir_path,
|
| 250 |
+
annotations_path=args.coco_karpathy_json_path,
|
| 251 |
+
is_train=True,
|
| 252 |
+
dataset_name="coco",
|
| 253 |
+
)
|
| 254 |
+
rices_dataset = RICES(
|
| 255 |
+
train_dataset,
|
| 256 |
+
device_id,
|
| 257 |
+
args.batch_size,
|
| 258 |
+
vision_encoder_path=args.vision_encoder_path,
|
| 259 |
+
vision_encoder_pretrained=args.vision_encoder_pretrained,
|
| 260 |
+
)
|
| 261 |
+
torch.save(
|
| 262 |
+
rices_dataset.features,
|
| 263 |
+
os.path.join(args.output_dir, "coco.pkl"),
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
if args.eval_ok_vqa:
|
| 267 |
+
print("Caching OK-VQA...")
|
| 268 |
+
train_dataset = VQADataset(
|
| 269 |
+
image_dir_path=args.ok_vqa_train_image_dir_path,
|
| 270 |
+
question_path=args.ok_vqa_train_questions_json_path,
|
| 271 |
+
annotations_path=args.ok_vqa_train_annotations_json_path,
|
| 272 |
+
is_train=True,
|
| 273 |
+
dataset_name="ok_vqa",
|
| 274 |
+
)
|
| 275 |
+
rices_dataset = RICES(
|
| 276 |
+
train_dataset,
|
| 277 |
+
device_id,
|
| 278 |
+
args.batch_size,
|
| 279 |
+
vision_encoder_path=args.vision_encoder_path,
|
| 280 |
+
vision_encoder_pretrained=args.vision_encoder_pretrained,
|
| 281 |
+
)
|
| 282 |
+
torch.save(
|
| 283 |
+
rices_dataset.features,
|
| 284 |
+
os.path.join(args.output_dir, "ok_vqa.pkl"),
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
if args.eval_vizwiz:
|
| 288 |
+
print("Caching VizWiz...")
|
| 289 |
+
train_dataset = VQADataset(
|
| 290 |
+
image_dir_path=args.vizwiz_train_image_dir_path,
|
| 291 |
+
question_path=args.vizwiz_train_questions_json_path,
|
| 292 |
+
annotations_path=args.vizwiz_train_annotations_json_path,
|
| 293 |
+
is_train=True,
|
| 294 |
+
dataset_name="vizwiz",
|
| 295 |
+
)
|
| 296 |
+
rices_dataset = RICES(
|
| 297 |
+
train_dataset,
|
| 298 |
+
device_id,
|
| 299 |
+
args.batch_size,
|
| 300 |
+
vision_encoder_path=args.vision_encoder_path,
|
| 301 |
+
vision_encoder_pretrained=args.vision_encoder_pretrained,
|
| 302 |
+
)
|
| 303 |
+
torch.save(
|
| 304 |
+
rices_dataset.features,
|
| 305 |
+
os.path.join(args.output_dir, "vizwiz.pkl"),
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
if args.eval_vqav2:
|
| 309 |
+
print("Caching VQAv2...")
|
| 310 |
+
train_dataset = VQADataset(
|
| 311 |
+
image_dir_path=args.vqav2_train_image_dir_path,
|
| 312 |
+
question_path=args.vqav2_train_questions_json_path,
|
| 313 |
+
annotations_path=args.vqav2_train_annotations_json_path,
|
| 314 |
+
is_train=True,
|
| 315 |
+
dataset_name="vqav2",
|
| 316 |
+
)
|
| 317 |
+
rices_dataset = RICES(
|
| 318 |
+
train_dataset,
|
| 319 |
+
device_id,
|
| 320 |
+
args.batch_size,
|
| 321 |
+
vision_encoder_path=args.vision_encoder_path,
|
| 322 |
+
vision_encoder_pretrained=args.vision_encoder_pretrained,
|
| 323 |
+
)
|
| 324 |
+
torch.save(
|
| 325 |
+
rices_dataset.features,
|
| 326 |
+
os.path.join(args.output_dir, "vqav2.pkl"),
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
if args.eval_textvqa:
|
| 330 |
+
print("Caching TextVQA...")
|
| 331 |
+
train_dataset = VQADataset(
|
| 332 |
+
image_dir_path=args.textvqa_image_dir_path,
|
| 333 |
+
question_path=args.textvqa_train_questions_json_path,
|
| 334 |
+
annotations_path=args.textvqa_train_annotations_json_path,
|
| 335 |
+
is_train=True,
|
| 336 |
+
dataset_name="textvqa",
|
| 337 |
+
)
|
| 338 |
+
rices_dataset = RICES(
|
| 339 |
+
train_dataset,
|
| 340 |
+
device_id,
|
| 341 |
+
args.batch_size,
|
| 342 |
+
vision_encoder_path=args.vision_encoder_path,
|
| 343 |
+
vision_encoder_pretrained=args.vision_encoder_pretrained,
|
| 344 |
+
)
|
| 345 |
+
torch.save(
|
| 346 |
+
rices_dataset.features,
|
| 347 |
+
os.path.join(args.output_dir, "textvqa.pkl"),
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
if args.eval_hateful_memes:
|
| 351 |
+
print("Caching Hateful Memes...")
|
| 352 |
+
train_dataset = HatefulMemesDataset(
|
| 353 |
+
image_dir_path=args.hateful_memes_image_dir_path,
|
| 354 |
+
annotations_path=args.hateful_memes_train_annotations_json_path,
|
| 355 |
+
)
|
| 356 |
+
rices_dataset = RICES(
|
| 357 |
+
train_dataset,
|
| 358 |
+
device_id,
|
| 359 |
+
args.batch_size,
|
| 360 |
+
vision_encoder_path=args.vision_encoder_path,
|
| 361 |
+
vision_encoder_pretrained=args.vision_encoder_pretrained,
|
| 362 |
+
)
|
| 363 |
+
torch.save(
|
| 364 |
+
rices_dataset.features,
|
| 365 |
+
os.path.join(args.output_dir, "hateful_memes.pkl"),
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
if __name__ == "__main__":
|
| 370 |
+
main()
|
open_flamingo/scripts/convert_mmc4_to_wds.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import uuid
|
| 5 |
+
import zipfile
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import base64
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
|
| 10 |
+
import braceexpand
|
| 11 |
+
import webdataset as wds
|
| 12 |
+
|
| 13 |
+
arg_parser = argparse.ArgumentParser()
|
| 14 |
+
arg_parser.add_argument(
|
| 15 |
+
"--output_dir",
|
| 16 |
+
type=str,
|
| 17 |
+
help="Pass in the directory where the output shards (as tar files) will be written to.",
|
| 18 |
+
)
|
| 19 |
+
arg_parser.add_argument(
|
| 20 |
+
"--zip_files",
|
| 21 |
+
type=str,
|
| 22 |
+
help="Pass in a list of MMC4 shards in the format path_to_shard/shard_{0..23098}.zip",
|
| 23 |
+
)
|
| 24 |
+
arg_parser.add_argument(
|
| 25 |
+
"--image_dir",
|
| 26 |
+
type=str,
|
| 27 |
+
help="Pass in the directory where the images have been downloaded to.",
|
| 28 |
+
)
|
| 29 |
+
arg_parser.add_argument(
|
| 30 |
+
"--num_files_per_shard",
|
| 31 |
+
type=int,
|
| 32 |
+
default=1000,
|
| 33 |
+
)
|
| 34 |
+
args = arg_parser.parse_args()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main():
|
| 38 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
doc_shards = list(braceexpand.braceexpand(args.zip_files))
|
| 41 |
+
|
| 42 |
+
with wds.ShardWriter(args.output_dir + "/%09d.tar") as sink:
|
| 43 |
+
for idx in range(len(doc_shards)):
|
| 44 |
+
# Open the ZIP archive and extract the JSON file
|
| 45 |
+
with zipfile.ZipFile(doc_shards[idx], "r") as zip_file:
|
| 46 |
+
# Assumes the JSON file is the first file in the archive
|
| 47 |
+
json_filename = zip_file.namelist()[0]
|
| 48 |
+
with zip_file.open(json_filename, "r") as json_file:
|
| 49 |
+
for sample_data in json_file:
|
| 50 |
+
# get image names from json
|
| 51 |
+
sample_data = json.loads(sample_data)
|
| 52 |
+
image_info = sample_data["image_info"]
|
| 53 |
+
image_names = [image["image_name"] for image in image_info]
|
| 54 |
+
|
| 55 |
+
# Add each image to the tar file
|
| 56 |
+
for img_idx, image_name in enumerate(image_names):
|
| 57 |
+
try:
|
| 58 |
+
# load image
|
| 59 |
+
img = Image.open(
|
| 60 |
+
os.path.join(args.image_dir, str(idx), image_name)
|
| 61 |
+
).convert("RGB")
|
| 62 |
+
buffered = BytesIO()
|
| 63 |
+
img.save(buffered, format="JPEG")
|
| 64 |
+
img_str = base64.b64encode(buffered.getvalue())
|
| 65 |
+
|
| 66 |
+
# convert to base64
|
| 67 |
+
sample_data["image_info"][img_idx][
|
| 68 |
+
"image_base64"
|
| 69 |
+
] = img_str.decode("utf-8")
|
| 70 |
+
except FileNotFoundError:
|
| 71 |
+
print(
|
| 72 |
+
f"Did not find {image_name} downloaded. This can happen if the url is now 404."
|
| 73 |
+
)
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f"Error processing {image_name}: {e}")
|
| 76 |
+
|
| 77 |
+
key_str = uuid.uuid4().hex
|
| 78 |
+
sink.write({"__key__": key_str, "json": sample_data})
|
| 79 |
+
|
| 80 |
+
if (idx + 1) % args.num_files_per_shard == 0:
|
| 81 |
+
sink.next_stream()
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
main()
|
open_flamingo/scripts/fill_vqa_testdev_results.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Helper scripts to prepare a vqa test-dev evaluation for EvalAI submission.
|
| 3 |
+
Note: EvalAI requires VQAv2 submissions to have predictions for all the questions in the test2015 set, not just the test-dev set.
|
| 4 |
+
Given a json with a subset of the vqa questions, fill in the rest of the questions with an empty string as the model prediction.
|
| 5 |
+
"""
|
| 6 |
+
import json
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
sys.path.append(
|
| 11 |
+
os.path.join(
|
| 12 |
+
os.path.dirname(os.path.abspath(__file__)),
|
| 13 |
+
"..",
|
| 14 |
+
)
|
| 15 |
+
)
|
| 16 |
+
from eval.vqa_metric import VQAEval
|
| 17 |
+
|
| 18 |
+
postprocessor = VQAEval(None, None)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def fill_vizwiz_test_json(
|
| 22 |
+
input_path,
|
| 23 |
+
output_path,
|
| 24 |
+
vqa_test_questions_json_path,
|
| 25 |
+
):
|
| 26 |
+
# read the input json and build a set with all question_ids
|
| 27 |
+
with open(input_path, "r") as f:
|
| 28 |
+
input_json = json.load(f)
|
| 29 |
+
|
| 30 |
+
# postprocess answers
|
| 31 |
+
question_id_to_answer = {}
|
| 32 |
+
for q in input_json:
|
| 33 |
+
resAns = q["answer"]
|
| 34 |
+
resAns = resAns.replace("\n", " ")
|
| 35 |
+
resAns = resAns.replace("\t", " ")
|
| 36 |
+
resAns = resAns.strip()
|
| 37 |
+
resAns = postprocessor.processPunctuation(resAns)
|
| 38 |
+
resAns = postprocessor.processDigitArticle(resAns)
|
| 39 |
+
question_id_to_answer[q["question_id"]] = resAns
|
| 40 |
+
|
| 41 |
+
# read the vqa test json to get all the qustion_ids that need to be filled
|
| 42 |
+
with open(vqa_test_questions_json_path, "r") as f:
|
| 43 |
+
vqa_test_json = json.load(f)
|
| 44 |
+
vqa_test_json = vqa_test_json["questions"]
|
| 45 |
+
|
| 46 |
+
# if the question_id is not in the set, add it to the copy of the input json with an empty string as the answer
|
| 47 |
+
output_json = []
|
| 48 |
+
for q in vqa_test_json:
|
| 49 |
+
output_json.append(
|
| 50 |
+
{
|
| 51 |
+
"image": q["image_id"],
|
| 52 |
+
"answer": question_id_to_answer.get(q["question_id"], ""),
|
| 53 |
+
}
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# write the json to the output path
|
| 57 |
+
with open(output_path, "w") as f:
|
| 58 |
+
json.dump(output_json, f)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def fill_vqav2_test_json(
|
| 62 |
+
input_path,
|
| 63 |
+
output_path,
|
| 64 |
+
vqa_test_questions_json_path,
|
| 65 |
+
):
|
| 66 |
+
# read the input json and build a set with all question_ids
|
| 67 |
+
with open(input_path, "r") as f:
|
| 68 |
+
input_json = json.load(f)
|
| 69 |
+
question_ids = set()
|
| 70 |
+
for q in input_json:
|
| 71 |
+
question_ids.add(q["question_id"])
|
| 72 |
+
|
| 73 |
+
# make a copy of the input json
|
| 74 |
+
output_json = []
|
| 75 |
+
for q in input_json:
|
| 76 |
+
resAns = q["answer"]
|
| 77 |
+
resAns = resAns.replace("\n", " ")
|
| 78 |
+
resAns = resAns.replace("\t", " ")
|
| 79 |
+
resAns = resAns.strip()
|
| 80 |
+
resAns = postprocessor.processPunctuation(resAns)
|
| 81 |
+
resAns = postprocessor.processDigitArticle(resAns)
|
| 82 |
+
q["answer"] = resAns
|
| 83 |
+
output_json.append(q)
|
| 84 |
+
|
| 85 |
+
# read the vqa test json to get all the qustion_ids that need to be filled
|
| 86 |
+
with open(vqa_test_questions_json_path, "r") as f:
|
| 87 |
+
vqa_test_json = json.load(f)
|
| 88 |
+
vqa_test_json = vqa_test_json["questions"]
|
| 89 |
+
|
| 90 |
+
# if the question_id is not in the set, add it to the copy of the input json with an empty string as the answer
|
| 91 |
+
for q in vqa_test_json:
|
| 92 |
+
if q["question_id"] not in question_ids:
|
| 93 |
+
output_json.append(
|
| 94 |
+
{
|
| 95 |
+
"question_id": q["question_id"],
|
| 96 |
+
"answer": "",
|
| 97 |
+
}
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# write the json to the output path
|
| 101 |
+
with open(output_path, "w") as f:
|
| 102 |
+
json.dump(output_json, f)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
import argparse
|
| 107 |
+
|
| 108 |
+
parser = argparse.ArgumentParser()
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--dataset",
|
| 111 |
+
type=str,
|
| 112 |
+
choices=["vqav2", "vizwiz"],
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--input_path",
|
| 116 |
+
type=str,
|
| 117 |
+
help="Path to the json file with the subset of the vqa test-dev questions.",
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--vqa_test_questions_json_path",
|
| 121 |
+
type=str,
|
| 122 |
+
help="Path to the json file with all the vqa test questions.",
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--output_path",
|
| 126 |
+
type=str,
|
| 127 |
+
help="Path to store the filled json.",
|
| 128 |
+
)
|
| 129 |
+
args = parser.parse_args()
|
| 130 |
+
|
| 131 |
+
if args.dataset == "vqav2":
|
| 132 |
+
fill_vqav2_test_json(
|
| 133 |
+
args.input_path,
|
| 134 |
+
args.output_path,
|
| 135 |
+
args.vqa_test_questions_json_path,
|
| 136 |
+
)
|
| 137 |
+
else:
|
| 138 |
+
fill_vizwiz_test_json(
|
| 139 |
+
args.input_path,
|
| 140 |
+
args.output_path,
|
| 141 |
+
args.vqa_test_questions_json_path,
|
| 142 |
+
)
|
open_flamingo/train/README.md
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OpenFlamingo Training
|
| 2 |
+
To train OpenFlamingo, please ensure your environment matches that of `environment.yml`.
|
| 3 |
+
|
| 4 |
+
## Data
|
| 5 |
+
Our codebase uses [WebDataset](https://github.com/webdataset/webdataset) to efficiently load `.tar` files containing image and text sequences. We recommend resampling shards with replacement during training using the `--dataset_resampled` flag.
|
| 6 |
+
|
| 7 |
+
### LAION-2B Dataset
|
| 8 |
+
[LAION-2B](https://arxiv.org/abs/2210.08402) contains 2B web-scraped (image, text) pairs.
|
| 9 |
+
We use [img2dataset](https://github.com/rom1504/img2dataset) to download this dataset into tar files.
|
| 10 |
+
|
| 11 |
+
### Multimodal C4 Dataset
|
| 12 |
+
We train on the full version of [Multimodal C4 (MMC4)](https://github.com/allenai/mmc4), which includes 103M documents of web-scraped, interleaved image-text sequences. During training, we truncate sequences to 256 text tokens and six images per sequence.
|
| 13 |
+
|
| 14 |
+
Our codebase expects `.tar` files containing `.json` files, which include raw images encoded in base64.
|
| 15 |
+
We provide scripts to convert MMC4 to this format:
|
| 16 |
+
|
| 17 |
+
1. Download the MMC4 shards into `.zip` files using [the MMC4-provided scripts](https://github.com/allenai/mmc4/tree/main/scripts) (e.g., `fewer_facesv2.sh`).
|
| 18 |
+
2. Download the MMC4 raw images into an image directory using [the MMC4-provided scripts](https://github.com/allenai/mmc4/tree/main/scripts) (e.g., `download_images.py`).
|
| 19 |
+
2. Run `scripts/convert_mmc4_to_wds.py` to convert the downloaded items into the expected tar files.
|
| 20 |
+
|
| 21 |
+
### ChatGPT-generated sequences
|
| 22 |
+
A subset of our models (listed below) were also trained on experimental ChatGPT-generated (image, text) sequences, where images are pulled from LAION. The shards containing these sequences can be found at [this CodaLab worksheet](https://worksheets.codalab.org/worksheets/0xdcd888ff7c754ae680c5e038f6ed1d9b). We are unable to distribute raw images in the released shards; images must be pre-downloaded from the urls in the json files and converted to base64 before using this data for training in our codebase.
|
| 23 |
+
|
| 24 |
+
Models trained with ChatGPT-generated sequences:
|
| 25 |
+
|
| 26 |
+
* OpenFlamingo-4B-vitl-rpj3b
|
| 27 |
+
* OpenFlamingo-4B-vitl-rpj3b-langinstruct
|
| 28 |
+
|
| 29 |
+
## Example training command
|
| 30 |
+
We provide a sample Slurm training script in `scripts/`. You can also modify the following command:
|
| 31 |
+
|
| 32 |
+
```
|
| 33 |
+
torchrun --nnodes=1 --nproc_per_node=4 train.py \
|
| 34 |
+
--lm_path anas-awadalla/mpt-1b-redpajama-200b \
|
| 35 |
+
--tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \
|
| 36 |
+
--cross_attn_every_n_layers 1 \
|
| 37 |
+
--dataset_resampled \
|
| 38 |
+
--batch_size_mmc4 32 \
|
| 39 |
+
--batch_size_laion 64 \
|
| 40 |
+
--train_num_samples_mmc4 125000\
|
| 41 |
+
--train_num_samples_laion 250000 \
|
| 42 |
+
--loss_multiplier_laion 0.2 \
|
| 43 |
+
--workers=4 \
|
| 44 |
+
--run_name OpenFlamingo-3B-vitl-mpt1b \
|
| 45 |
+
--num_epochs 480 \
|
| 46 |
+
--warmup_steps 1875 \
|
| 47 |
+
--mmc4_textsim_threshold 0.24 \
|
| 48 |
+
--laion_shards "/path/to/shards/shard-{0000..0999}.tar" \
|
| 49 |
+
--mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \
|
| 50 |
+
--report_to_wandb
|
| 51 |
+
```
|
| 52 |
+
*Note: The MPT-1B [base](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) and [instruct](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b-dolly) modeling code does not accept the `labels` kwarg or compute cross-entropy loss directly within `forward()`, as expected by our codebase. We suggest using a modified version of the MPT-1B models found [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b) and [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b-dolly).*
|
| 53 |
+
|
| 54 |
+
## Distributed training
|
| 55 |
+
|
| 56 |
+
By default, `train.py` uses Pytorch's [DistributedDataParallel](https://pytorch.org/docs/stable/torch.nn.parallel.DistributedDataParallel.html) for training.
|
| 57 |
+
To use [FullyShardedDataParallel](https://pytorch.org/docs/stable/fsdp.html), use the `--fsdp` flag.
|
| 58 |
+
|
| 59 |
+
Some notes on FSDP:
|
| 60 |
+
|
| 61 |
+
* We recommend using the `--fsdp_use_orig_params` flag. If `--fsdp` is on without this flag, all language model embeddings will be unfrozen during training. (In contrast, the default behavior is to only train the newly added `<image>` and `<|endofchunk|>` tokens.)
|
| 62 |
+
* Note: we've encountered issues using OPT with this flag. Other language models should be compatible.
|
| 63 |
+
* Our current FSDP wrapping strategy does not permit training language model embeddings that use tied weights (i.e., tied input / output embeddings). To train such models with FSDP, the language model embeddings must be frozen with the `--freeze_lm_embeddings` flag.
|
| 64 |
+
|
| 65 |
+
We also implement gradient checkpointing and mixed precision training. Use the `--gradient_checkpointing` and `--precision` arguments respectively.
|
open_flamingo/train/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
open_flamingo/train/data.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Preprocess and load datasets for training.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import functools
|
| 6 |
+
import io
|
| 7 |
+
import json
|
| 8 |
+
import math
|
| 9 |
+
import re
|
| 10 |
+
import random
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torchvision
|
| 14 |
+
import webdataset as wds
|
| 15 |
+
from PIL import Image
|
| 16 |
+
import base64
|
| 17 |
+
from scipy.optimize import linear_sum_assignment
|
| 18 |
+
|
| 19 |
+
from data_utils import *
|
| 20 |
+
|
| 21 |
+
Image.MAX_IMAGE_PIXELS = 1000000000
|
| 22 |
+
N_CHANNELS = 3
|
| 23 |
+
MIN_KB = 10
|
| 24 |
+
_SHARD_SHUFFLE_SIZE = 2000
|
| 25 |
+
_SHARD_SHUFFLE_INITIAL = 500
|
| 26 |
+
_SAMPLE_SHUFFLE_SIZE = 5000
|
| 27 |
+
_SAMPLE_SHUFFLE_INITIAL = 1000
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
import horovod.torch as hvd
|
| 31 |
+
except ImportError:
|
| 32 |
+
hvd = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def preprocess_image(sample, image_processor):
|
| 36 |
+
"""
|
| 37 |
+
Convert images to tensors for training.
|
| 38 |
+
Augmentations: random horizontal flip.
|
| 39 |
+
Normalization handled by wds.
|
| 40 |
+
"""
|
| 41 |
+
image = [image_processor(s).unsqueeze(0) for s in sample]
|
| 42 |
+
image = torch.cat(image, dim=0)
|
| 43 |
+
image = torchvision.transforms.RandomHorizontalFlip(p=0.5)(image)
|
| 44 |
+
return image
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def filter_no_caption_or_no_image(sample):
|
| 48 |
+
"""
|
| 49 |
+
Filter out LAION samples with no caption or no image.
|
| 50 |
+
"""
|
| 51 |
+
return ("txt" in sample) and (
|
| 52 |
+
"png" in sample or "jpg" in sample or "jpeg" in sample
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def preprocess_laion_text(sample, tokenizer, max_tokens=32):
|
| 57 |
+
"""
|
| 58 |
+
Preprocess text for LAION.
|
| 59 |
+
Captions are truncated to 32 tokens by default.
|
| 60 |
+
"""
|
| 61 |
+
tokenizer.padding_side = "right"
|
| 62 |
+
sample = [
|
| 63 |
+
(f"<image>{s.strip()}<|endofchunk|>{tokenizer.eos_token}") for s in sample
|
| 64 |
+
]
|
| 65 |
+
text = tokenizer(
|
| 66 |
+
sample,
|
| 67 |
+
max_length=max_tokens,
|
| 68 |
+
padding="longest",
|
| 69 |
+
truncation="only_first",
|
| 70 |
+
return_tensors="pt",
|
| 71 |
+
)
|
| 72 |
+
return text["input_ids"], text["attention_mask"]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def preprocess_gpt_interleaved(
|
| 76 |
+
info, tokenizer, clip_processor, min_num_images, max_num_images, max_tokens=256
|
| 77 |
+
):
|
| 78 |
+
"""
|
| 79 |
+
Preprocess a ChatGPT-generated image-text sequence.
|
| 80 |
+
"""
|
| 81 |
+
text = info["example"]
|
| 82 |
+
text = re.sub(r"_!_IMAGE\d+_!_", "<|endofchunk|><image>", text)
|
| 83 |
+
|
| 84 |
+
# convert images from base64 to PIL
|
| 85 |
+
images = []
|
| 86 |
+
for image_key in range(1, len(info["image_map"]) + 1):
|
| 87 |
+
image_base64 = info["image_map"][f"_!_IMAGE{image_key}_!_"]["base64_image"]
|
| 88 |
+
rawbytes = base64.b64decode(image_base64)
|
| 89 |
+
images.append(Image.open(io.BytesIO(rawbytes)).convert("RGB"))
|
| 90 |
+
|
| 91 |
+
# preprocess and pad images
|
| 92 |
+
images_tensors = preprocess_image(images, clip_processor)
|
| 93 |
+
keep_ixs = range(min(len(images_tensors), max_num_images))
|
| 94 |
+
images_tensors = images_tensors[keep_ixs]
|
| 95 |
+
if len(images_tensors) < max_num_images:
|
| 96 |
+
zero_padding = torch.zeros(
|
| 97 |
+
(max_num_images - len(images_tensors), 3, 224, 224), dtype=torch.float
|
| 98 |
+
)
|
| 99 |
+
images_tensors = torch.cat((images_tensors, zero_padding), dim=0)
|
| 100 |
+
|
| 101 |
+
# preprocess and tokenize text
|
| 102 |
+
text = text.replace("<|endofchunk|>", "", 1) # but remove first eoc
|
| 103 |
+
# whitespace cleanup
|
| 104 |
+
text = (
|
| 105 |
+
text.replace(" <|endofchunk|>", "<|endofchunk|>")
|
| 106 |
+
.replace("<image> ", "<image>")
|
| 107 |
+
.replace(" <image>", "<image>")
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
indices = [m.start() for m in re.finditer("<image>", text)]
|
| 111 |
+
if len(indices) > max_num_images:
|
| 112 |
+
start_index = indices[max_num_images - 1]
|
| 113 |
+
text = text[:start_index]
|
| 114 |
+
|
| 115 |
+
text = f"{text}<|endofchunk|>{tokenizer.eos_token}"
|
| 116 |
+
tokenizer.padding_side = "right"
|
| 117 |
+
text_tensor = tokenizer(
|
| 118 |
+
text,
|
| 119 |
+
max_length=max_tokens,
|
| 120 |
+
truncation=True,
|
| 121 |
+
padding="max_length",
|
| 122 |
+
return_tensors="pt",
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# reject sequences with too few images after truncation
|
| 126 |
+
num_images = torch.count_nonzero(
|
| 127 |
+
text_tensor["input_ids"]
|
| 128 |
+
== tokenizer.additional_special_tokens_ids[
|
| 129 |
+
tokenizer.additional_special_tokens.index("<image>")
|
| 130 |
+
]
|
| 131 |
+
)
|
| 132 |
+
if num_images < min_num_images:
|
| 133 |
+
raise ValueError(f"Fewer than {min_num_images} images in sample")
|
| 134 |
+
|
| 135 |
+
return (images_tensors, (text_tensor["input_ids"], text_tensor["attention_mask"]))
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def preprocess_interleaved(
|
| 139 |
+
sample,
|
| 140 |
+
tokenizer,
|
| 141 |
+
clip_processor,
|
| 142 |
+
sim_threshold,
|
| 143 |
+
min_num_images,
|
| 144 |
+
max_num_images,
|
| 145 |
+
max_tokens=256,
|
| 146 |
+
):
|
| 147 |
+
"""
|
| 148 |
+
Preprocess an interleaved image-text sequence, either by calling preprocess_gpt_interleaved (if the sequence
|
| 149 |
+
is ChatGPT-generated) or by preprocessing in this function (if the sequences is from MMC4).
|
| 150 |
+
"""
|
| 151 |
+
info = json.loads(sample[0])
|
| 152 |
+
if "is_gpt" in info:
|
| 153 |
+
return preprocess_gpt_interleaved(
|
| 154 |
+
info, tokenizer, clip_processor, min_num_images, max_num_images, max_tokens
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
sentences = info["text_list"]
|
| 158 |
+
sim_matrix = info["similarity_matrix"]
|
| 159 |
+
|
| 160 |
+
# load images first to find which ones are valid
|
| 161 |
+
valid_images, valid_image_indices = [], []
|
| 162 |
+
for i, sample_image in enumerate(info["image_info"]):
|
| 163 |
+
if "image_base64" not in sample_image:
|
| 164 |
+
continue
|
| 165 |
+
image_base64 = sample_image["image_base64"]
|
| 166 |
+
rawbytes = base64.b64decode(image_base64)
|
| 167 |
+
|
| 168 |
+
# filter to images >= 10KB
|
| 169 |
+
if len(rawbytes) // 1000 <= MIN_KB:
|
| 170 |
+
continue
|
| 171 |
+
|
| 172 |
+
image = Image.open(io.BytesIO(rawbytes)).convert("RGB")
|
| 173 |
+
valid_images.append(image)
|
| 174 |
+
valid_image_indices.append(i)
|
| 175 |
+
|
| 176 |
+
if len(valid_image_indices) == 0:
|
| 177 |
+
raise ValueError("No images in sample")
|
| 178 |
+
|
| 179 |
+
sim_matrix = np.array(sim_matrix) # of shape images x sentences
|
| 180 |
+
sim_matrix = sim_matrix[valid_image_indices]
|
| 181 |
+
|
| 182 |
+
# negate the similarities to turn then into costs
|
| 183 |
+
cost_matrix = -sim_matrix
|
| 184 |
+
# find one to one assignements
|
| 185 |
+
image_indices, sentence_indices = linear_sum_assignment(cost_matrix)
|
| 186 |
+
|
| 187 |
+
images, sentence_ixs = [], []
|
| 188 |
+
for i, sim_ix in zip(image_indices, sentence_indices):
|
| 189 |
+
sim_score = sim_matrix[i][sim_ix]
|
| 190 |
+
|
| 191 |
+
if sim_score < sim_threshold:
|
| 192 |
+
continue
|
| 193 |
+
|
| 194 |
+
images.append(valid_images[i])
|
| 195 |
+
sentence_ixs.append(sim_ix)
|
| 196 |
+
|
| 197 |
+
if len(images) == 0:
|
| 198 |
+
raise ValueError("No images in sample")
|
| 199 |
+
|
| 200 |
+
# preprocess and pad images
|
| 201 |
+
images_tensors = preprocess_image(images, clip_processor)
|
| 202 |
+
keep_ixs = range(min(len(images_tensors), max_num_images))
|
| 203 |
+
images_tensors = images_tensors[keep_ixs]
|
| 204 |
+
sentence_ixs = [sentence_ixs[ix] for ix in keep_ixs]
|
| 205 |
+
if len(images_tensors) < max_num_images:
|
| 206 |
+
zero_padding = torch.zeros(
|
| 207 |
+
(
|
| 208 |
+
max_num_images - len(images_tensors),
|
| 209 |
+
N_CHANNELS,
|
| 210 |
+
images_tensors[0].shape[1],
|
| 211 |
+
images_tensors[0].shape[2],
|
| 212 |
+
),
|
| 213 |
+
dtype=torch.float,
|
| 214 |
+
)
|
| 215 |
+
images_tensors = torch.cat((images_tensors, zero_padding), dim=0)
|
| 216 |
+
|
| 217 |
+
# preprocess and tokenize text
|
| 218 |
+
# add in <image> and <eoc> tokens
|
| 219 |
+
for ix in sentence_ixs:
|
| 220 |
+
sentences[ix] = f"<|endofchunk|><image>{sentences[ix]}"
|
| 221 |
+
text = " ".join(sentences)
|
| 222 |
+
text = text.replace("<|endofchunk|>", "", 1) # but remove first eoc
|
| 223 |
+
# whitespace cleanup
|
| 224 |
+
text = (
|
| 225 |
+
text.replace(" <|endofchunk|>", "<|endofchunk|>")
|
| 226 |
+
.replace("<image> ", "<image>")
|
| 227 |
+
.replace(" <image>", "<image>")
|
| 228 |
+
)
|
| 229 |
+
text = f"{text}<|endofchunk|>{tokenizer.eos_token}"
|
| 230 |
+
tokenizer.padding_side = "right"
|
| 231 |
+
text_tensor = tokenizer(
|
| 232 |
+
text,
|
| 233 |
+
max_length=max_tokens,
|
| 234 |
+
truncation=True,
|
| 235 |
+
padding="max_length",
|
| 236 |
+
return_tensors="pt",
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# reject sequences with too few images (after truncation)
|
| 240 |
+
num_images = torch.count_nonzero(
|
| 241 |
+
text_tensor["input_ids"]
|
| 242 |
+
== tokenizer.additional_special_tokens_ids[
|
| 243 |
+
tokenizer.additional_special_tokens.index("<image>")
|
| 244 |
+
]
|
| 245 |
+
)
|
| 246 |
+
if num_images < min_num_images:
|
| 247 |
+
raise ValueError(f"Fewer than {min_num_images} images in sample")
|
| 248 |
+
elif (
|
| 249 |
+
num_images == 1 and random.random() <= 0.5
|
| 250 |
+
): # 50% chance of keeping single image samples
|
| 251 |
+
raise ValueError("Only one image in sample")
|
| 252 |
+
|
| 253 |
+
# avoid the situation where there's one <image> token and it's at the end
|
| 254 |
+
if (
|
| 255 |
+
num_images == 1
|
| 256 |
+
and text_tensor["input_ids"][:, -1]
|
| 257 |
+
== tokenizer.additional_special_tokens_ids[
|
| 258 |
+
tokenizer.additional_special_tokens.index("<image>")
|
| 259 |
+
]
|
| 260 |
+
):
|
| 261 |
+
raise ValueError(
|
| 262 |
+
"Only one image at the end of sample, so labels will all be -100"
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
return (
|
| 266 |
+
images_tensors,
|
| 267 |
+
(text_tensor["input_ids"], text_tensor["attention_mask"]),
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def get_mmc4_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
|
| 272 |
+
"""
|
| 273 |
+
Initialize webdataset for MMC4 / ChatGPT sequences
|
| 274 |
+
"""
|
| 275 |
+
input_shards = args.mmc4_shards
|
| 276 |
+
assert input_shards is not None
|
| 277 |
+
resampled = getattr(args, "dataset_resampled", False)
|
| 278 |
+
|
| 279 |
+
num_samples, num_shards = get_dataset_size(input_shards)
|
| 280 |
+
num_samples = None
|
| 281 |
+
if not num_samples:
|
| 282 |
+
num_samples = args.train_num_samples_mmc4
|
| 283 |
+
if not num_samples:
|
| 284 |
+
raise RuntimeError(
|
| 285 |
+
"Currently, number of dataset samples must be specified for training dataset. "
|
| 286 |
+
"Please specify via `--train-num-samples` if no dataset length info present."
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# create a shared epoch store to sync epoch to dataloader worker proc
|
| 290 |
+
shared_epoch = SharedEpoch(epoch=epoch)
|
| 291 |
+
if resampled:
|
| 292 |
+
pipeline = [
|
| 293 |
+
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch)
|
| 294 |
+
]
|
| 295 |
+
else:
|
| 296 |
+
pipeline = [wds.SimpleShardList(input_shards)]
|
| 297 |
+
|
| 298 |
+
preprocess_fn = functools.partial(
|
| 299 |
+
preprocess_interleaved,
|
| 300 |
+
clip_processor=image_processor,
|
| 301 |
+
tokenizer=tokenizer,
|
| 302 |
+
sim_threshold=args.mmc4_textsim_threshold,
|
| 303 |
+
min_num_images=args.mmc4_min_num_images,
|
| 304 |
+
max_num_images=args.mmc4_max_num_images,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# at this point we have an iterator over all the shards
|
| 308 |
+
if not resampled:
|
| 309 |
+
pipeline.extend(
|
| 310 |
+
[
|
| 311 |
+
detshuffle2(
|
| 312 |
+
bufsize=_SHARD_SHUFFLE_SIZE,
|
| 313 |
+
initial=_SHARD_SHUFFLE_INITIAL,
|
| 314 |
+
seed=args.seed,
|
| 315 |
+
epoch=shared_epoch,
|
| 316 |
+
),
|
| 317 |
+
wds.split_by_node,
|
| 318 |
+
wds.split_by_worker,
|
| 319 |
+
]
|
| 320 |
+
)
|
| 321 |
+
pipeline.extend(
|
| 322 |
+
[
|
| 323 |
+
# at this point, we have an iterator over the shards assigned to each worker at each node
|
| 324 |
+
# wds.tarfile_to_samples(handler=log_and_continue),
|
| 325 |
+
tarfile_to_samples_nothrow,
|
| 326 |
+
wds.shuffle(
|
| 327 |
+
bufsize=_SAMPLE_SHUFFLE_SIZE,
|
| 328 |
+
initial=_SAMPLE_SHUFFLE_INITIAL,
|
| 329 |
+
),
|
| 330 |
+
]
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
pipeline.extend(
|
| 334 |
+
[
|
| 335 |
+
wds.to_tuple("json", handler=log_and_continue),
|
| 336 |
+
wds.map(preprocess_fn, handler=log_and_continue),
|
| 337 |
+
wds.batched(args.batch_size_mmc4, partial=False),
|
| 338 |
+
]
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
dataset = wds.DataPipeline(*pipeline)
|
| 342 |
+
if not resampled:
|
| 343 |
+
assert (
|
| 344 |
+
num_shards >= args.workers * args.world_size
|
| 345 |
+
), "number of shards must be >= total workers"
|
| 346 |
+
# roll over and repeat a few samples to get same number of full batches on each node
|
| 347 |
+
round_fn = math.floor if floor else math.ceil
|
| 348 |
+
global_batch_size = args.batch_size_mmc4 * args.world_size
|
| 349 |
+
num_batches = round_fn(num_samples / global_batch_size)
|
| 350 |
+
num_workers = max(1, args.workers)
|
| 351 |
+
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
|
| 352 |
+
num_batches = num_worker_batches * num_workers
|
| 353 |
+
num_samples = num_batches * global_batch_size
|
| 354 |
+
# each worker is iterating over this
|
| 355 |
+
dataset = dataset.with_epoch(num_worker_batches)
|
| 356 |
+
|
| 357 |
+
dataloader = wds.WebLoader(
|
| 358 |
+
dataset,
|
| 359 |
+
batch_size=None,
|
| 360 |
+
shuffle=False,
|
| 361 |
+
num_workers=args.workers,
|
| 362 |
+
persistent_workers=True,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# add meta-data to dataloader instance for convenience
|
| 366 |
+
dataloader.num_batches = num_batches
|
| 367 |
+
dataloader.num_samples = num_samples
|
| 368 |
+
|
| 369 |
+
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def get_laion_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
|
| 373 |
+
"""
|
| 374 |
+
Initialize webdataset for LAION data
|
| 375 |
+
"""
|
| 376 |
+
input_shards = args.laion_shards
|
| 377 |
+
assert input_shards is not None
|
| 378 |
+
resampled = getattr(args, "dataset_resampled", False)
|
| 379 |
+
|
| 380 |
+
num_samples, num_shards = get_dataset_size(input_shards)
|
| 381 |
+
num_samples = None
|
| 382 |
+
if not num_samples:
|
| 383 |
+
num_samples = args.train_num_samples_laion
|
| 384 |
+
if not num_samples:
|
| 385 |
+
raise RuntimeError(
|
| 386 |
+
"Currently, number of dataset samples must be specified for training dataset. "
|
| 387 |
+
"Please specify via `--train-num-samples` if no dataset length info present."
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# create a shared epoch store to sync epoch to dataloader worker proc
|
| 391 |
+
shared_epoch = SharedEpoch(epoch=epoch)
|
| 392 |
+
if resampled:
|
| 393 |
+
pipeline = [
|
| 394 |
+
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch)
|
| 395 |
+
]
|
| 396 |
+
else:
|
| 397 |
+
pipeline = [wds.SimpleShardList(input_shards)]
|
| 398 |
+
|
| 399 |
+
# create two preprocess functions that take in the passed in image_processor and tokenizer
|
| 400 |
+
preprocess_image_fn = functools.partial(
|
| 401 |
+
preprocess_image, image_processor=image_processor
|
| 402 |
+
)
|
| 403 |
+
preprocess_text_fn = functools.partial(preprocess_laion_text, tokenizer=tokenizer)
|
| 404 |
+
|
| 405 |
+
# at this point we have an iterator over all the shards
|
| 406 |
+
if not resampled:
|
| 407 |
+
pipeline.extend(
|
| 408 |
+
[
|
| 409 |
+
detshuffle2(
|
| 410 |
+
bufsize=_SHARD_SHUFFLE_SIZE,
|
| 411 |
+
initial=_SHARD_SHUFFLE_INITIAL,
|
| 412 |
+
seed=args.seed,
|
| 413 |
+
epoch=shared_epoch,
|
| 414 |
+
),
|
| 415 |
+
wds.split_by_node,
|
| 416 |
+
wds.split_by_worker,
|
| 417 |
+
]
|
| 418 |
+
)
|
| 419 |
+
pipeline.extend(
|
| 420 |
+
[
|
| 421 |
+
# at this point, we have an iterator over the shards assigned to each worker at each node
|
| 422 |
+
# wds.tarfile_to_samples(handler=log_and_continue),
|
| 423 |
+
tarfile_to_samples_nothrow,
|
| 424 |
+
wds.shuffle(
|
| 425 |
+
bufsize=_SAMPLE_SHUFFLE_SIZE,
|
| 426 |
+
initial=_SAMPLE_SHUFFLE_INITIAL,
|
| 427 |
+
),
|
| 428 |
+
]
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
pipeline.extend(
|
| 432 |
+
[
|
| 433 |
+
wds.select(filter_no_caption_or_no_image),
|
| 434 |
+
wds.decode("pilrgb", handler=log_and_continue),
|
| 435 |
+
wds.to_tuple("jpg;png;jpeg", "txt", handler=log_and_continue),
|
| 436 |
+
wds.batched(args.batch_size_laion, partial=False),
|
| 437 |
+
wds.map_tuple(
|
| 438 |
+
preprocess_image_fn, preprocess_text_fn, handler=log_and_continue
|
| 439 |
+
),
|
| 440 |
+
]
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
dataset = wds.DataPipeline(*pipeline)
|
| 444 |
+
if not resampled:
|
| 445 |
+
assert (
|
| 446 |
+
num_shards >= args.workers * args.world_size
|
| 447 |
+
), "number of shards must be >= total workers"
|
| 448 |
+
# roll over and repeat a few samples to get same number of full batches on each node
|
| 449 |
+
round_fn = math.floor if floor else math.ceil
|
| 450 |
+
global_batch_size = args.batch_size_laion * args.world_size
|
| 451 |
+
num_batches = round_fn(num_samples / global_batch_size)
|
| 452 |
+
num_workers = max(1, args.workers)
|
| 453 |
+
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
|
| 454 |
+
num_batches = num_worker_batches * num_workers
|
| 455 |
+
num_samples = num_batches * global_batch_size
|
| 456 |
+
# each worker is iterating over this
|
| 457 |
+
dataset = dataset.with_epoch(num_worker_batches)
|
| 458 |
+
|
| 459 |
+
dataloader = wds.WebLoader(
|
| 460 |
+
dataset,
|
| 461 |
+
batch_size=None,
|
| 462 |
+
shuffle=False,
|
| 463 |
+
num_workers=args.workers,
|
| 464 |
+
persistent_workers=True,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
# add meta-data to dataloader instance for convenience
|
| 468 |
+
dataloader.num_batches = num_batches
|
| 469 |
+
dataloader.num_samples = num_samples
|
| 470 |
+
|
| 471 |
+
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def get_dataset_fn(dataset_type):
|
| 475 |
+
"""
|
| 476 |
+
Helper function to get the dataset function based on the dataset type
|
| 477 |
+
"""
|
| 478 |
+
if dataset_type == "image_text":
|
| 479 |
+
return get_laion_dataset
|
| 480 |
+
elif dataset_type == "mmc4":
|
| 481 |
+
return get_mmc4_dataset
|
| 482 |
+
else:
|
| 483 |
+
raise ValueError(f"Unsupported dataset type: {dataset_type}")
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def get_data(args, image_processor, tokenizer, dataset_type, epoch=0):
|
| 487 |
+
"""
|
| 488 |
+
Interface for getting the webdatasets
|
| 489 |
+
"""
|
| 490 |
+
return get_dataset_fn(dataset_type)(
|
| 491 |
+
args, image_processor=image_processor, epoch=epoch, tokenizer=tokenizer
|
| 492 |
+
)
|
open_flamingo/train/data_utils.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Util functions for initializing webdataset objects
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import ast
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import random
|
| 10 |
+
import sys
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from multiprocessing import Value
|
| 13 |
+
|
| 14 |
+
import braceexpand
|
| 15 |
+
import numpy as np
|
| 16 |
+
import webdataset as wds
|
| 17 |
+
from PIL import Image
|
| 18 |
+
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
| 19 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 20 |
+
from webdataset.filters import _shuffle
|
| 21 |
+
from webdataset.tariterators import (
|
| 22 |
+
base_plus_ext,
|
| 23 |
+
tar_file_expander,
|
| 24 |
+
url_opener,
|
| 25 |
+
valid_sample,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
import horovod.torch as hvd
|
| 30 |
+
except ImportError:
|
| 31 |
+
hvd = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class SharedEpoch:
|
| 35 |
+
def __init__(self, epoch: int = 0):
|
| 36 |
+
self.shared_epoch = Value("i", epoch)
|
| 37 |
+
|
| 38 |
+
def set_value(self, epoch):
|
| 39 |
+
self.shared_epoch.value = epoch
|
| 40 |
+
|
| 41 |
+
def get_value(self):
|
| 42 |
+
return self.shared_epoch.value
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class DataInfo:
|
| 47 |
+
dataloader: DataLoader
|
| 48 |
+
sampler: DistributedSampler = None
|
| 49 |
+
shared_epoch: SharedEpoch = None
|
| 50 |
+
|
| 51 |
+
def set_epoch(self, epoch):
|
| 52 |
+
if self.shared_epoch is not None:
|
| 53 |
+
self.shared_epoch.set_value(epoch)
|
| 54 |
+
if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
|
| 55 |
+
self.sampler.set_epoch(epoch)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_dataset_size(shards):
|
| 59 |
+
shards_list = list(braceexpand.braceexpand(shards))
|
| 60 |
+
dir_path = os.path.dirname(shards[0])
|
| 61 |
+
sizes_filename = os.path.join(dir_path, "sizes.json")
|
| 62 |
+
len_filename = os.path.join(dir_path, "__len__")
|
| 63 |
+
if os.path.exists(sizes_filename):
|
| 64 |
+
sizes = json.load(open(sizes_filename, "r"))
|
| 65 |
+
total_size = sum(
|
| 66 |
+
[
|
| 67 |
+
int(sizes[os.path.basename(shard)])
|
| 68 |
+
if os.path.basename(shard) in sizes
|
| 69 |
+
else 0
|
| 70 |
+
for shard in shards_list
|
| 71 |
+
]
|
| 72 |
+
)
|
| 73 |
+
elif os.path.exists(len_filename):
|
| 74 |
+
# FIXME this used to be eval(open(...)) but that seemed rather unsafe
|
| 75 |
+
total_size = ast.literal_eval(open(len_filename, "r").read())
|
| 76 |
+
else:
|
| 77 |
+
total_size = None # num samples undefined
|
| 78 |
+
# some common dataset sizes (at time of authors last download)
|
| 79 |
+
# CC3M (train): 2905954
|
| 80 |
+
# CC12M: 10968539
|
| 81 |
+
# LAION-400M: 407332084
|
| 82 |
+
# LAION-2B (english): 2170337258
|
| 83 |
+
num_shards = len(shards_list)
|
| 84 |
+
return total_size, num_shards
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def count_samples(dataloader):
|
| 88 |
+
os.environ["WDS_EPOCH"] = "0"
|
| 89 |
+
n_elements, n_batches = 0, 0
|
| 90 |
+
for images, texts in dataloader:
|
| 91 |
+
n_batches += 1
|
| 92 |
+
n_elements += len(images)
|
| 93 |
+
assert len(images) == len(texts)
|
| 94 |
+
return n_elements, n_batches
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def log_and_continue(exn):
|
| 98 |
+
"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
|
| 99 |
+
logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
|
| 100 |
+
return True
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def group_by_keys_nothrow(
|
| 104 |
+
data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None
|
| 105 |
+
):
|
| 106 |
+
"""Return function over iterator that groups key, value pairs into samples.
|
| 107 |
+
|
| 108 |
+
:param keys: function that splits the key into key and extension (base_plus_ext)
|
| 109 |
+
:param lcase: convert suffixes to lower case (Default value = True)
|
| 110 |
+
"""
|
| 111 |
+
current_sample = None
|
| 112 |
+
for filesample in data:
|
| 113 |
+
assert isinstance(filesample, dict)
|
| 114 |
+
fname, value = filesample["fname"], filesample["data"]
|
| 115 |
+
prefix, suffix = keys(fname)
|
| 116 |
+
if prefix is None:
|
| 117 |
+
continue
|
| 118 |
+
if lcase:
|
| 119 |
+
suffix = suffix.lower()
|
| 120 |
+
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
|
| 121 |
+
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
|
| 122 |
+
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
|
| 123 |
+
if (
|
| 124 |
+
current_sample is None
|
| 125 |
+
or prefix != current_sample["__key__"]
|
| 126 |
+
or suffix in current_sample
|
| 127 |
+
):
|
| 128 |
+
if valid_sample(current_sample):
|
| 129 |
+
yield current_sample
|
| 130 |
+
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
|
| 131 |
+
if suffixes is None or suffix in suffixes:
|
| 132 |
+
current_sample[suffix] = value
|
| 133 |
+
if valid_sample(current_sample):
|
| 134 |
+
yield current_sample
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def tarfile_to_samples_nothrow(src, handler=log_and_continue):
|
| 138 |
+
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
|
| 139 |
+
streams = url_opener(src, handler=handler)
|
| 140 |
+
files = tar_file_expander(streams, handler=handler)
|
| 141 |
+
samples = group_by_keys_nothrow(files, handler=handler)
|
| 142 |
+
return samples
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def pytorch_worker_seed(increment=0):
|
| 146 |
+
"""get dataloader worker seed from pytorch"""
|
| 147 |
+
worker_info = get_worker_info()
|
| 148 |
+
if worker_info is not None:
|
| 149 |
+
# favour using the seed already created for pytorch dataloader workers if it exists
|
| 150 |
+
seed = worker_info.seed
|
| 151 |
+
if increment:
|
| 152 |
+
# space out seed increments so they can't overlap across workers in different iterations
|
| 153 |
+
seed += increment * max(1, worker_info.num_workers)
|
| 154 |
+
return seed
|
| 155 |
+
# fallback to wds rank based seed
|
| 156 |
+
return wds.utils.pytorch_worker_seed()
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class detshuffle2(wds.PipelineStage):
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
bufsize=1000,
|
| 163 |
+
initial=100,
|
| 164 |
+
seed=0,
|
| 165 |
+
epoch=-1,
|
| 166 |
+
):
|
| 167 |
+
self.bufsize = bufsize
|
| 168 |
+
self.initial = initial
|
| 169 |
+
self.seed = seed
|
| 170 |
+
self.epoch = epoch
|
| 171 |
+
|
| 172 |
+
def run(self, src):
|
| 173 |
+
if isinstance(self.epoch, SharedEpoch):
|
| 174 |
+
epoch = self.epoch.get_value()
|
| 175 |
+
else:
|
| 176 |
+
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
|
| 177 |
+
# situation as different workers may wrap at different times (or not at all).
|
| 178 |
+
self.epoch += 1
|
| 179 |
+
epoch = self.epoch
|
| 180 |
+
rng = random.Random()
|
| 181 |
+
if self.seed < 0:
|
| 182 |
+
# If seed is negative, we use the worker's seed, this will be different across all nodes/workers
|
| 183 |
+
seed = pytorch_worker_seed(epoch)
|
| 184 |
+
else:
|
| 185 |
+
# This seed to be deterministic AND the same across all nodes/workers in each epoch
|
| 186 |
+
seed = self.seed + epoch
|
| 187 |
+
rng.seed(seed)
|
| 188 |
+
return _shuffle(src, self.bufsize, self.initial, rng)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class ResampledShards2(IterableDataset):
|
| 192 |
+
"""An iterable dataset yielding a list of urls."""
|
| 193 |
+
|
| 194 |
+
def __init__(
|
| 195 |
+
self,
|
| 196 |
+
urls,
|
| 197 |
+
nshards=sys.maxsize,
|
| 198 |
+
worker_seed=None,
|
| 199 |
+
deterministic=False,
|
| 200 |
+
epoch=-1,
|
| 201 |
+
):
|
| 202 |
+
"""Sample shards from the shard list with replacement.
|
| 203 |
+
:param urls: a list of URLs as a Python list or brace notation string
|
| 204 |
+
"""
|
| 205 |
+
super().__init__()
|
| 206 |
+
urls = wds.shardlists.expand_urls(urls)
|
| 207 |
+
self.urls = urls
|
| 208 |
+
assert isinstance(self.urls[0], str)
|
| 209 |
+
self.nshards = nshards
|
| 210 |
+
self.rng = random.Random()
|
| 211 |
+
self.worker_seed = worker_seed
|
| 212 |
+
self.deterministic = deterministic
|
| 213 |
+
self.epoch = epoch
|
| 214 |
+
|
| 215 |
+
def __iter__(self):
|
| 216 |
+
"""Return an iterator over the shards."""
|
| 217 |
+
if isinstance(self.epoch, SharedEpoch):
|
| 218 |
+
epoch = self.epoch.get_value()
|
| 219 |
+
else:
|
| 220 |
+
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
|
| 221 |
+
# situation as different workers may wrap at different times (or not at all).
|
| 222 |
+
self.epoch += 1
|
| 223 |
+
epoch = self.epoch
|
| 224 |
+
|
| 225 |
+
if self.deterministic:
|
| 226 |
+
# reset seed w/ epoch if deterministic
|
| 227 |
+
if self.worker_seed is None:
|
| 228 |
+
# pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id
|
| 229 |
+
seed = pytorch_worker_seed(epoch)
|
| 230 |
+
else:
|
| 231 |
+
seed = self.worker_seed() + epoch
|
| 232 |
+
self.rng.seed(seed)
|
| 233 |
+
for _ in range(self.nshards):
|
| 234 |
+
yield dict(url=self.rng.choice(self.urls))
|
open_flamingo/train/distributed.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Util functions for setting up distributed training.
|
| 3 |
+
Credit: https://github.com/mlfoundations/open_clip/blob/main/src/training/distributed.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
import horovod.torch as hvd
|
| 11 |
+
except ImportError:
|
| 12 |
+
hvd = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def is_global_master(args):
|
| 16 |
+
return args.rank == 0
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def is_local_master(args):
|
| 20 |
+
return args.local_rank == 0
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def is_master(args, local=False):
|
| 24 |
+
return is_local_master(args) if local else is_global_master(args)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def is_using_horovod():
|
| 28 |
+
# NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
|
| 29 |
+
# Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
|
| 30 |
+
ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
|
| 31 |
+
pmi_vars = ["PMI_RANK", "PMI_SIZE"]
|
| 32 |
+
if all([var in os.environ for var in ompi_vars]) or all(
|
| 33 |
+
[var in os.environ for var in pmi_vars]
|
| 34 |
+
):
|
| 35 |
+
return True
|
| 36 |
+
else:
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def is_using_distributed():
|
| 41 |
+
if "WORLD_SIZE" in os.environ:
|
| 42 |
+
return int(os.environ["WORLD_SIZE"]) > 1
|
| 43 |
+
if "SLURM_NTASKS" in os.environ:
|
| 44 |
+
return int(os.environ["SLURM_NTASKS"]) > 1
|
| 45 |
+
return False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def world_info_from_env():
|
| 49 |
+
local_rank = 0
|
| 50 |
+
for v in (
|
| 51 |
+
"LOCAL_RANK",
|
| 52 |
+
"MPI_LOCALRANKID",
|
| 53 |
+
"SLURM_LOCALID",
|
| 54 |
+
"OMPI_COMM_WORLD_LOCAL_RANK",
|
| 55 |
+
):
|
| 56 |
+
if v in os.environ:
|
| 57 |
+
local_rank = int(os.environ[v])
|
| 58 |
+
break
|
| 59 |
+
global_rank = 0
|
| 60 |
+
for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
|
| 61 |
+
if v in os.environ:
|
| 62 |
+
global_rank = int(os.environ[v])
|
| 63 |
+
break
|
| 64 |
+
world_size = 1
|
| 65 |
+
for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
|
| 66 |
+
if v in os.environ:
|
| 67 |
+
world_size = int(os.environ[v])
|
| 68 |
+
break
|
| 69 |
+
|
| 70 |
+
return local_rank, global_rank, world_size
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def init_distributed_device(args):
|
| 74 |
+
# Distributed training = training on more than one GPU.
|
| 75 |
+
# Works in both single and multi-node scenarios.
|
| 76 |
+
args.distributed = False
|
| 77 |
+
args.world_size = 1
|
| 78 |
+
args.rank = 0 # global rank
|
| 79 |
+
args.local_rank = 0
|
| 80 |
+
if args.horovod:
|
| 81 |
+
assert hvd is not None, "Horovod is not installed"
|
| 82 |
+
hvd.init()
|
| 83 |
+
args.local_rank = int(hvd.local_rank())
|
| 84 |
+
args.rank = hvd.rank()
|
| 85 |
+
args.world_size = hvd.size()
|
| 86 |
+
args.distributed = True
|
| 87 |
+
os.environ["LOCAL_RANK"] = str(args.local_rank)
|
| 88 |
+
os.environ["RANK"] = str(args.rank)
|
| 89 |
+
os.environ["WORLD_SIZE"] = str(args.world_size)
|
| 90 |
+
elif is_using_distributed():
|
| 91 |
+
if "SLURM_PROCID" in os.environ:
|
| 92 |
+
# DDP via SLURM
|
| 93 |
+
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
| 94 |
+
# SLURM var -> torch.distributed vars in case needed
|
| 95 |
+
os.environ["LOCAL_RANK"] = str(args.local_rank)
|
| 96 |
+
os.environ["RANK"] = str(args.rank)
|
| 97 |
+
os.environ["WORLD_SIZE"] = str(args.world_size)
|
| 98 |
+
torch.distributed.init_process_group(
|
| 99 |
+
backend=args.dist_backend,
|
| 100 |
+
init_method=args.dist_url,
|
| 101 |
+
world_size=args.world_size,
|
| 102 |
+
rank=args.rank,
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
# DDP via torchrun, torch.distributed.launch
|
| 106 |
+
args.local_rank, _, _ = world_info_from_env()
|
| 107 |
+
torch.distributed.init_process_group(
|
| 108 |
+
backend=args.dist_backend, init_method=args.dist_url
|
| 109 |
+
)
|
| 110 |
+
args.world_size = torch.distributed.get_world_size()
|
| 111 |
+
args.rank = torch.distributed.get_rank()
|
| 112 |
+
args.distributed = True
|
| 113 |
+
else:
|
| 114 |
+
# needed to run on single gpu
|
| 115 |
+
torch.distributed.init_process_group(
|
| 116 |
+
backend=args.dist_backend,
|
| 117 |
+
init_method=args.dist_url,
|
| 118 |
+
world_size=1,
|
| 119 |
+
rank=0,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if torch.cuda.is_available():
|
| 123 |
+
if args.distributed and not args.no_set_device_rank:
|
| 124 |
+
device = "cuda:%d" % args.local_rank
|
| 125 |
+
else:
|
| 126 |
+
device = "cuda:0"
|
| 127 |
+
torch.cuda.set_device(device)
|
| 128 |
+
else:
|
| 129 |
+
device = "cpu"
|
| 130 |
+
args.device = device
|
| 131 |
+
device = torch.device(device)
|
| 132 |
+
return device
|
open_flamingo/train/train.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Main training script """
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import glob
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import wandb
|
| 11 |
+
from data import get_data
|
| 12 |
+
from distributed import init_distributed_device, world_info_from_env
|
| 13 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 14 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 15 |
+
from train_utils import (
|
| 16 |
+
train_one_epoch,
|
| 17 |
+
get_mp_policy_dtype,
|
| 18 |
+
save_checkpoint,
|
| 19 |
+
)
|
| 20 |
+
from transformers import (
|
| 21 |
+
get_constant_schedule_with_warmup,
|
| 22 |
+
get_cosine_schedule_with_warmup,
|
| 23 |
+
get_linear_schedule_with_warmup,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from torch.distributed.fsdp import (
|
| 27 |
+
CPUOffload,
|
| 28 |
+
MixedPrecision,
|
| 29 |
+
ShardingStrategy,
|
| 30 |
+
BackwardPrefetch,
|
| 31 |
+
)
|
| 32 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
| 33 |
+
checkpoint_wrapper,
|
| 34 |
+
CheckpointWrapper,
|
| 35 |
+
CheckpointImpl,
|
| 36 |
+
apply_activation_checkpointing,
|
| 37 |
+
)
|
| 38 |
+
from torch.distributed.fsdp._init_utils import _init_intra_and_inter_node_groups
|
| 39 |
+
from torch.distributed.distributed_c10d import _get_default_group
|
| 40 |
+
import functools
|
| 41 |
+
|
| 42 |
+
from open_flamingo import create_model_and_transforms
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def random_seed(seed=42, rank=0):
|
| 46 |
+
torch.manual_seed(seed + rank)
|
| 47 |
+
np.random.seed(seed + rank)
|
| 48 |
+
random.seed(seed + rank)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def main():
|
| 52 |
+
parser = argparse.ArgumentParser()
|
| 53 |
+
# model configuration args
|
| 54 |
+
parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
|
| 55 |
+
parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
|
| 56 |
+
parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--tokenizer_path",
|
| 59 |
+
default="facebook/opt-30b",
|
| 60 |
+
type=str,
|
| 61 |
+
help="path to tokenizer",
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--cross_attn_every_n_layers",
|
| 65 |
+
type=int,
|
| 66 |
+
default=1,
|
| 67 |
+
help="how often to add a cross-attention layer after each transformer layer",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# training args
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--run_name",
|
| 73 |
+
type=str,
|
| 74 |
+
default="openflamingo3B",
|
| 75 |
+
help="used to name saving directory and wandb run",
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--resume_from_checkpoint",
|
| 79 |
+
type=str,
|
| 80 |
+
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states. if there exists a checkpoint in the dir named run_name, we will resume from that checkpoint by default",
|
| 81 |
+
default=None,
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--delete_previous_checkpoint",
|
| 85 |
+
action="store_true",
|
| 86 |
+
help="delete previous checkpoint when saving new checkpoint",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument("--batch_size_mmc4", type=int, default=128)
|
| 89 |
+
parser.add_argument("--batch_size_laion", type=int, default=128)
|
| 90 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
| 91 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 92 |
+
parser.add_argument("--learning_rate", default=1e-4, type=float)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--lr_scheduler",
|
| 95 |
+
default="constant",
|
| 96 |
+
type=str,
|
| 97 |
+
help="constant, linear, or cosine",
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument("--loss_multiplier_mmc4", type=float, default=1.0)
|
| 100 |
+
parser.add_argument("--loss_multiplier_laion", type=float, default=1.0)
|
| 101 |
+
parser.add_argument("--warmup_steps", default=5000, type=int)
|
| 102 |
+
parser.add_argument("--weight_decay", default=0.1, type=float)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--precision",
|
| 105 |
+
choices=["amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"],
|
| 106 |
+
default="fp32",
|
| 107 |
+
help="Floating point precision.",
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--gradient_checkpointing",
|
| 111 |
+
action="store_true",
|
| 112 |
+
help="whether to train with gradient/activation checkpointing",
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--num_epochs",
|
| 116 |
+
type=int,
|
| 117 |
+
default=1,
|
| 118 |
+
help="we define an 'epoch' as a fixed number of examples (train_num_samples_mmc4, train_num_samples_laion), not a pass through the entire dataset",
|
| 119 |
+
)
|
| 120 |
+
parser.add_argument("--offline", action="store_true")
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--freeze_lm_embeddings",
|
| 123 |
+
action="store_true",
|
| 124 |
+
help="if True, we freeze the LM embeddings during training. Otherwise, we train the <image> and <|endofchunk|> embeddings.",
|
| 125 |
+
)
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--logging_steps", type=int, default=100, help="log loss every n steps"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# data args
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--laion_shards",
|
| 133 |
+
type=str,
|
| 134 |
+
help="path to laion shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--mmc4_shards",
|
| 138 |
+
type=str,
|
| 139 |
+
help="path to c4 shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument("--workers", type=int, default=1)
|
| 142 |
+
parser.add_argument("--train_num_samples_mmc4", type=int, default=10000)
|
| 143 |
+
parser.add_argument("--train_num_samples_laion", type=int, default=10000)
|
| 144 |
+
parser.add_argument("--dataset_resampled", action="store_true")
|
| 145 |
+
parser.add_argument(
|
| 146 |
+
"--mmc4_textsim_threshold",
|
| 147 |
+
default=30,
|
| 148 |
+
type=float,
|
| 149 |
+
help="threshold for filtering images in mmc4 based on image-text similarity",
|
| 150 |
+
)
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--mmc4_max_num_images",
|
| 153 |
+
default=6,
|
| 154 |
+
type=int,
|
| 155 |
+
help="max number of images per sequence in mmc4 / chatgpt",
|
| 156 |
+
)
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--mmc4_min_num_images",
|
| 159 |
+
default=1,
|
| 160 |
+
type=int,
|
| 161 |
+
help="min number of images per sequence in mmc4 / chatgpt",
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# distributed training args
|
| 165 |
+
parser.add_argument(
|
| 166 |
+
"--dist-url",
|
| 167 |
+
default="env://",
|
| 168 |
+
type=str,
|
| 169 |
+
help="url used to set up distributed training",
|
| 170 |
+
)
|
| 171 |
+
parser.add_argument(
|
| 172 |
+
"--dist-backend", default="nccl", type=str, help="distributed backend"
|
| 173 |
+
)
|
| 174 |
+
parser.add_argument(
|
| 175 |
+
"--horovod",
|
| 176 |
+
default=False,
|
| 177 |
+
action="store_true",
|
| 178 |
+
help="Use horovod for distributed training.",
|
| 179 |
+
)
|
| 180 |
+
parser.add_argument(
|
| 181 |
+
"--no-set-device-rank",
|
| 182 |
+
default=False,
|
| 183 |
+
action="store_true",
|
| 184 |
+
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--fsdp",
|
| 188 |
+
default=False,
|
| 189 |
+
action="store_true",
|
| 190 |
+
help="Use FullyShardedDataParallel for distributed training.",
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--fsdp_use_orig_params",
|
| 194 |
+
default=False,
|
| 195 |
+
action="store_true",
|
| 196 |
+
help="Passed into the FSDP constructor. Enables param_groups and gradient masking for weight_decay. Does not work with OPT.",
|
| 197 |
+
)
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--fsdp_sharding_strategy", default="full", type=str, choices=["full", "hybrid"]
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# wandb args
|
| 203 |
+
parser.add_argument("--report_to_wandb", default=False, action="store_true")
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--wandb_project",
|
| 206 |
+
type=str,
|
| 207 |
+
)
|
| 208 |
+
parser.add_argument(
|
| 209 |
+
"--wandb_entity",
|
| 210 |
+
type=str,
|
| 211 |
+
)
|
| 212 |
+
parser.add_argument(
|
| 213 |
+
"--save_checkpoints_to_wandb",
|
| 214 |
+
default=False,
|
| 215 |
+
action="store_true",
|
| 216 |
+
help="save checkpoints to wandb",
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
args = parser.parse_args()
|
| 220 |
+
|
| 221 |
+
# Validate args
|
| 222 |
+
if args.laion_shards.startswith("s3"):
|
| 223 |
+
args.laion_shards = f"pipe:aws s3 cp {args.laion_shards} -"
|
| 224 |
+
|
| 225 |
+
if args.mmc4_shards.startswith("s3"):
|
| 226 |
+
args.mmc4_shards = f"pipe:aws s3 cp {args.mmc4_shards} -"
|
| 227 |
+
|
| 228 |
+
if args.save_checkpoints_to_wandb and not args.report_to_wandb:
|
| 229 |
+
raise ValueError("save_checkpoints_to_wandb requires report_to_wandb")
|
| 230 |
+
|
| 231 |
+
if args.fsdp and not args.fsdp_use_orig_params:
|
| 232 |
+
print(
|
| 233 |
+
"Warning: FSDP is running without fsdp_use_orig_params flag. "
|
| 234 |
+
+ "This is not recommended because it means we will use uniform weight decay"
|
| 235 |
+
+ " and train all embeddings, not just the newly added ones. "
|
| 236 |
+
+ "Note: OPT models are not compatible with fsdp_use_orig_params flag."
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
if args.fsdp and args.fsdp_sharding_strategy == "hybrid":
|
| 240 |
+
print(
|
| 241 |
+
"Warning: As of torch=2.0.1, the FSDP logic for optim_state_dict() is broken for hybrid sharding."
|
| 242 |
+
+ "To make this method work, we need to modify torch.distributed.fsdp._optim_utils.py"
|
| 243 |
+
+ "Copy and paste the code from the _optim_utils.py in this repo into the torch file."
|
| 244 |
+
+ "The main issue was the missing group kwarg on line 1596 in _all_gather_optim_state."
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
assert (args.train_num_samples_laion // args.batch_size_laion) == (
|
| 248 |
+
args.train_num_samples_mmc4 // args.batch_size_mmc4
|
| 249 |
+
), "number of samples per epoch must be equal for mmc4 and laion"
|
| 250 |
+
|
| 251 |
+
# Set up distributed training
|
| 252 |
+
if args.offline:
|
| 253 |
+
os.environ["WANDB_MODE"] = "offline"
|
| 254 |
+
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
| 255 |
+
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
| 256 |
+
device_id = init_distributed_device(args)
|
| 257 |
+
random_seed(args.seed)
|
| 258 |
+
|
| 259 |
+
# Initialize model
|
| 260 |
+
model, image_processor, tokenizer = create_model_and_transforms(
|
| 261 |
+
args.vision_encoder_path,
|
| 262 |
+
args.vision_encoder_pretrained,
|
| 263 |
+
args.lm_path,
|
| 264 |
+
args.tokenizer_path if args.tokenizer_path else args.lm_path,
|
| 265 |
+
cross_attn_every_n_layers=args.cross_attn_every_n_layers,
|
| 266 |
+
use_local_files=args.offline,
|
| 267 |
+
gradient_checkpointing=args.gradient_checkpointing,
|
| 268 |
+
freeze_lm_embeddings=args.freeze_lm_embeddings,
|
| 269 |
+
)
|
| 270 |
+
random_seed(args.seed, args.rank)
|
| 271 |
+
|
| 272 |
+
# Initialize logging
|
| 273 |
+
print(f"Start running training on rank {args.rank}.")
|
| 274 |
+
if args.rank == 0 and args.report_to_wandb:
|
| 275 |
+
wandb.init(
|
| 276 |
+
project=args.wandb_project,
|
| 277 |
+
entity=args.wandb_entity,
|
| 278 |
+
name=args.run_name,
|
| 279 |
+
config=vars(args),
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Load model checkpoint on CPU
|
| 283 |
+
if os.path.exists(f"{args.run_name}") and args.resume_from_checkpoint is None:
|
| 284 |
+
# if args do not specify a checkpoint to resume from, check if checkpoints exist for this run
|
| 285 |
+
# and automatically resume from the latest checkpoint
|
| 286 |
+
checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt")
|
| 287 |
+
if len(checkpoint_list) == 0:
|
| 288 |
+
print(f"Found no checkpoints for run {args.run_name}.")
|
| 289 |
+
else:
|
| 290 |
+
args.resume_from_checkpoint = sorted(
|
| 291 |
+
checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
|
| 292 |
+
)[-1]
|
| 293 |
+
print(
|
| 294 |
+
f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}."
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
resume_from_epoch = 0
|
| 298 |
+
if args.resume_from_checkpoint is not None:
|
| 299 |
+
if args.rank == 0:
|
| 300 |
+
print(f"Loading checkpoint from {args.resume_from_checkpoint}")
|
| 301 |
+
checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
|
| 302 |
+
msd = checkpoint["model_state_dict"]
|
| 303 |
+
msd = {k.replace("module.", ""): v for k, v in msd.items()}
|
| 304 |
+
resume_from_epoch = checkpoint["epoch"] + 1
|
| 305 |
+
|
| 306 |
+
# for fsdp, only one rank needs to load the state dict
|
| 307 |
+
if not args.fsdp or args.rank == 0:
|
| 308 |
+
model.load_state_dict(msd, False)
|
| 309 |
+
|
| 310 |
+
# Initialize FSDP / DDP, and ensure the model is on GPU
|
| 311 |
+
print(f"Initializing distributed training with {args.world_size} GPUs.")
|
| 312 |
+
if args.fsdp:
|
| 313 |
+
print(
|
| 314 |
+
f"Before FSDP parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}"
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
# init MixedPrecision
|
| 318 |
+
if args.precision != "fp32":
|
| 319 |
+
cast_dtype = get_mp_policy_dtype(args.precision)
|
| 320 |
+
mp_policy = MixedPrecision(
|
| 321 |
+
param_dtype=torch.float32,
|
| 322 |
+
reduce_dtype=cast_dtype, # gradient communication
|
| 323 |
+
buffer_dtype=cast_dtype,
|
| 324 |
+
)
|
| 325 |
+
else:
|
| 326 |
+
mp_policy = None
|
| 327 |
+
|
| 328 |
+
# init process groups
|
| 329 |
+
if args.fsdp_sharding_strategy == "hybrid":
|
| 330 |
+
intra_node_group, inter_node_group = _init_intra_and_inter_node_groups(
|
| 331 |
+
_get_default_group()
|
| 332 |
+
)
|
| 333 |
+
args.my_group = intra_node_group # for optimizer saving
|
| 334 |
+
process_group = (intra_node_group, inter_node_group) # for FSDP init
|
| 335 |
+
else:
|
| 336 |
+
args.my_group = None # for optimizer saving
|
| 337 |
+
process_group = None # for FSDP init
|
| 338 |
+
|
| 339 |
+
# init FSDP
|
| 340 |
+
wrapper_kwargs = dict(
|
| 341 |
+
process_group=process_group,
|
| 342 |
+
cpu_offload=CPUOffload(offload_params=False),
|
| 343 |
+
device_id=device_id,
|
| 344 |
+
sync_module_states=True, # broadcast loaded ckpt from rank 0 -> all ranks
|
| 345 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD
|
| 346 |
+
if args.fsdp_sharding_strategy == "full"
|
| 347 |
+
else ShardingStrategy.HYBRID_SHARD,
|
| 348 |
+
use_orig_params=args.fsdp_use_orig_params,
|
| 349 |
+
mixed_precision=mp_policy,
|
| 350 |
+
forward_prefetch=True,
|
| 351 |
+
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
| 352 |
+
limit_all_gathers=True,
|
| 353 |
+
)
|
| 354 |
+
model.wrap_fsdp(wrapper_kwargs, device_id)
|
| 355 |
+
ddp_model = model
|
| 356 |
+
|
| 357 |
+
print(
|
| 358 |
+
f"After FSDP parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}"
|
| 359 |
+
)
|
| 360 |
+
print(
|
| 361 |
+
f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}"
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
else:
|
| 365 |
+
model = model.to(device_id)
|
| 366 |
+
ddp_model = DDP(model, device_ids=[device_id])
|
| 367 |
+
|
| 368 |
+
# Initialize gradient checkpointing
|
| 369 |
+
if args.gradient_checkpointing:
|
| 370 |
+
non_reentrant_wrapper = functools.partial(
|
| 371 |
+
checkpoint_wrapper,
|
| 372 |
+
offload_to_cpu=True,
|
| 373 |
+
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
| 374 |
+
)
|
| 375 |
+
apply_activation_checkpointing(
|
| 376 |
+
ddp_model,
|
| 377 |
+
checkpoint_wrapper_fn=non_reentrant_wrapper,
|
| 378 |
+
check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
|
| 379 |
+
and not isinstance(m, FSDP)
|
| 380 |
+
and not isinstance(m, CheckpointWrapper),
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# Initialize optimizer
|
| 384 |
+
params_to_optimize = ddp_model.named_parameters()
|
| 385 |
+
params_to_optimize = list(
|
| 386 |
+
filter(
|
| 387 |
+
lambda x: x[1].requires_grad
|
| 388 |
+
and not getattr(x[1], "exclude_from_optimizer", False),
|
| 389 |
+
params_to_optimize,
|
| 390 |
+
)
|
| 391 |
+
)
|
| 392 |
+
if not args.fsdp or args.fsdp_use_orig_params:
|
| 393 |
+
# apply weight decay only to params in the xattn layers
|
| 394 |
+
def get_grouped_params(model):
|
| 395 |
+
params_with_wd, params_without_wd = [], []
|
| 396 |
+
for n, p in params_to_optimize:
|
| 397 |
+
if "gated_cross_attn" in n:
|
| 398 |
+
params_with_wd.append(p)
|
| 399 |
+
else:
|
| 400 |
+
params_without_wd.append(p)
|
| 401 |
+
return [
|
| 402 |
+
{"params": params_with_wd, "weight_decay": args.weight_decay},
|
| 403 |
+
{"params": params_without_wd, "weight_decay": 0.0},
|
| 404 |
+
]
|
| 405 |
+
|
| 406 |
+
optimizer = torch.optim.AdamW(
|
| 407 |
+
get_grouped_params(params_to_optimize), lr=args.learning_rate
|
| 408 |
+
)
|
| 409 |
+
else:
|
| 410 |
+
# unclear if we should be using no weight decay or small weight decay for all parameters
|
| 411 |
+
optimizer = torch.optim.AdamW(
|
| 412 |
+
(p for _, p in params_to_optimize),
|
| 413 |
+
lr=args.learning_rate,
|
| 414 |
+
weight_decay=args.weight_decay,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
# load optimizer checkpoint
|
| 418 |
+
if args.resume_from_checkpoint is not None:
|
| 419 |
+
osd = checkpoint["optimizer_state_dict"]
|
| 420 |
+
if args.fsdp:
|
| 421 |
+
osd = FSDP.optim_state_dict_to_load(osd, ddp_model, optimizer)
|
| 422 |
+
optimizer.load_state_dict(osd)
|
| 423 |
+
|
| 424 |
+
# Initialize data loaders
|
| 425 |
+
laion_dataset = get_data(args, image_processor, tokenizer, "image_text")
|
| 426 |
+
mmc4_dataset = get_data(args, image_processor, tokenizer, "mmc4")
|
| 427 |
+
total_training_steps = (
|
| 428 |
+
(args.train_num_samples_mmc4) // (args.batch_size_mmc4 * args.world_size)
|
| 429 |
+
) * args.num_epochs
|
| 430 |
+
|
| 431 |
+
if args.rank == 0:
|
| 432 |
+
print(f"Total training steps: {total_training_steps}")
|
| 433 |
+
|
| 434 |
+
# Initialize lr scheduler
|
| 435 |
+
if args.lr_scheduler == "linear":
|
| 436 |
+
lr_scheduler = get_linear_schedule_with_warmup(
|
| 437 |
+
optimizer,
|
| 438 |
+
num_warmup_steps=args.warmup_steps,
|
| 439 |
+
num_training_steps=total_training_steps,
|
| 440 |
+
)
|
| 441 |
+
elif args.lr_scheduler == "cosine":
|
| 442 |
+
lr_scheduler = get_cosine_schedule_with_warmup(
|
| 443 |
+
optimizer,
|
| 444 |
+
num_warmup_steps=args.warmup_steps,
|
| 445 |
+
num_training_steps=total_training_steps,
|
| 446 |
+
)
|
| 447 |
+
else:
|
| 448 |
+
lr_scheduler = get_constant_schedule_with_warmup(
|
| 449 |
+
optimizer, num_warmup_steps=args.warmup_steps
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# load lr scheduler checkpoint
|
| 453 |
+
if args.resume_from_checkpoint is not None:
|
| 454 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
|
| 455 |
+
|
| 456 |
+
# Start training!
|
| 457 |
+
ddp_model.train()
|
| 458 |
+
|
| 459 |
+
for epoch in range(resume_from_epoch, args.num_epochs):
|
| 460 |
+
laion_dataset.set_epoch(epoch)
|
| 461 |
+
laion_loader = laion_dataset.dataloader
|
| 462 |
+
mmc4_dataset.set_epoch(epoch)
|
| 463 |
+
mmc4_loader = mmc4_dataset.dataloader
|
| 464 |
+
|
| 465 |
+
train_one_epoch(
|
| 466 |
+
args=args,
|
| 467 |
+
model=ddp_model,
|
| 468 |
+
epoch=epoch,
|
| 469 |
+
tokenizer=tokenizer,
|
| 470 |
+
optimizer=optimizer,
|
| 471 |
+
lr_scheduler=lr_scheduler,
|
| 472 |
+
laion_loader=laion_loader,
|
| 473 |
+
mmc4_loader=mmc4_loader,
|
| 474 |
+
device_id=device_id,
|
| 475 |
+
wandb=wandb,
|
| 476 |
+
)
|
| 477 |
+
save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args)
|
| 478 |
+
|
| 479 |
+
# save final checkpoint
|
| 480 |
+
save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
if __name__ == "__main__":
|
| 484 |
+
main()
|
open_flamingo/train/train_utils.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from contextlib import suppress
|
| 3 |
+
import torch
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 6 |
+
from torch.distributed.fsdp import (
|
| 7 |
+
FullStateDictConfig,
|
| 8 |
+
StateDictType,
|
| 9 |
+
)
|
| 10 |
+
from torch.distributed.fsdp.api import FullOptimStateDictConfig
|
| 11 |
+
import os
|
| 12 |
+
import wandb
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_cast_dtype(precision: str):
|
| 17 |
+
cast_dtype = None
|
| 18 |
+
if precision == "bf16":
|
| 19 |
+
cast_dtype = torch.bfloat16
|
| 20 |
+
elif precision == "fp16":
|
| 21 |
+
cast_dtype = torch.float16
|
| 22 |
+
return cast_dtype
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_mp_policy_dtype(precision: str):
|
| 26 |
+
if "bfloat16" in precision or "bf16" in precision:
|
| 27 |
+
return torch.bfloat16
|
| 28 |
+
elif precision == "fp16":
|
| 29 |
+
return torch.float16
|
| 30 |
+
else:
|
| 31 |
+
return torch.float32
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_autocast(precision, cache_enabled=True):
|
| 35 |
+
if precision == "amp":
|
| 36 |
+
return torch.cuda.amp.autocast(cache_enabled=cache_enabled)
|
| 37 |
+
elif precision == "amp_bfloat16" or precision == "amp_bf16":
|
| 38 |
+
# amp_bfloat16 is more stable than amp float16 for clip training
|
| 39 |
+
return lambda: torch.cuda.amp.autocast(
|
| 40 |
+
dtype=torch.bfloat16, cache_enabled=cache_enabled
|
| 41 |
+
)
|
| 42 |
+
else:
|
| 43 |
+
return suppress
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def train_one_epoch(
|
| 47 |
+
args,
|
| 48 |
+
model,
|
| 49 |
+
epoch,
|
| 50 |
+
laion_loader,
|
| 51 |
+
mmc4_loader,
|
| 52 |
+
tokenizer,
|
| 53 |
+
optimizer,
|
| 54 |
+
lr_scheduler,
|
| 55 |
+
device_id,
|
| 56 |
+
wandb,
|
| 57 |
+
):
|
| 58 |
+
# setup loaders
|
| 59 |
+
num_batches_per_epoch_laion = laion_loader.num_batches
|
| 60 |
+
num_batches_per_epoch_mmc4 = mmc4_loader.num_batches
|
| 61 |
+
assert (
|
| 62 |
+
num_batches_per_epoch_laion == num_batches_per_epoch_mmc4
|
| 63 |
+
), "Number of batches in laion and mmc4 datasets must be the same"
|
| 64 |
+
num_batches_per_epoch = num_batches_per_epoch_mmc4
|
| 65 |
+
total_training_steps = num_batches_per_epoch * args.num_epochs
|
| 66 |
+
|
| 67 |
+
autocast = get_autocast(
|
| 68 |
+
args.precision, cache_enabled=(not args.fsdp)
|
| 69 |
+
) # if fsdp, disable cache to save memory
|
| 70 |
+
cast_dtype = get_cast_dtype(args.precision)
|
| 71 |
+
|
| 72 |
+
# setup model
|
| 73 |
+
media_token_id = tokenizer("<image>", add_special_tokens=False)["input_ids"][-1]
|
| 74 |
+
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)[
|
| 75 |
+
"input_ids"
|
| 76 |
+
][-1]
|
| 77 |
+
model.train()
|
| 78 |
+
|
| 79 |
+
# setup logging
|
| 80 |
+
step_time_m = AverageMeter()
|
| 81 |
+
data_time_m = AverageMeter()
|
| 82 |
+
end = time.time()
|
| 83 |
+
|
| 84 |
+
# loop through dataloader
|
| 85 |
+
for num_steps, (batch_laion, batch_mmc4) in tqdm(
|
| 86 |
+
enumerate(zip(laion_loader, mmc4_loader)),
|
| 87 |
+
disable=args.rank != 0,
|
| 88 |
+
total=total_training_steps,
|
| 89 |
+
initial=(epoch * num_batches_per_epoch),
|
| 90 |
+
):
|
| 91 |
+
data_time_m.update(time.time() - end)
|
| 92 |
+
global_step = num_steps + epoch * num_batches_per_epoch
|
| 93 |
+
|
| 94 |
+
#### LAION FORWARD PASS ####
|
| 95 |
+
images = batch_laion[0].to(device_id, dtype=cast_dtype, non_blocking=True)
|
| 96 |
+
images = rearrange(images, "(b t f) c h w -> b t f c h w", t=1, f=1)
|
| 97 |
+
input_ids = batch_laion[1][0].to(device_id, dtype=cast_dtype, non_blocking=True)
|
| 98 |
+
attention_mask = batch_laion[1][1].to(
|
| 99 |
+
device_id, dtype=cast_dtype, non_blocking=True
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# set up labels; language model is expected to handle shifting
|
| 103 |
+
labels = input_ids.clone()
|
| 104 |
+
labels[labels == tokenizer.pad_token_id] = -100
|
| 105 |
+
labels[labels == tokenizer.eos_token] = -100
|
| 106 |
+
labels[labels == media_token_id] = -100
|
| 107 |
+
labels = labels.to(device_id)
|
| 108 |
+
|
| 109 |
+
# gradient accumulation w/ fsdp cpu offloading requires a no_sync context manager
|
| 110 |
+
with autocast():
|
| 111 |
+
loss_laion = model(
|
| 112 |
+
vision_x=images,
|
| 113 |
+
lang_x=input_ids,
|
| 114 |
+
attention_mask=attention_mask,
|
| 115 |
+
labels=labels,
|
| 116 |
+
)[0]
|
| 117 |
+
|
| 118 |
+
divided_loss_laion = loss_laion / args.gradient_accumulation_steps
|
| 119 |
+
(divided_loss_laion * args.loss_multiplier_laion).backward()
|
| 120 |
+
|
| 121 |
+
#### MMC4 FORWARD PASS ####
|
| 122 |
+
images = batch_mmc4[0].to(device_id, dtype=cast_dtype, non_blocking=True)
|
| 123 |
+
images = rearrange(images, "b (t f) c h w -> b t f c h w", f=1)
|
| 124 |
+
input_ids = torch.stack([x[0] for x in batch_mmc4[1]]).squeeze(1)
|
| 125 |
+
attention_mask = torch.stack([x[1] for x in batch_mmc4[1]]).squeeze(1)
|
| 126 |
+
|
| 127 |
+
# set up labels; language model is expected to handle shifting
|
| 128 |
+
labels = input_ids.clone()
|
| 129 |
+
labels[labels == tokenizer.pad_token_id] = -100
|
| 130 |
+
labels[labels == tokenizer.eos_token] = -100
|
| 131 |
+
for i in range(labels.shape[0]):
|
| 132 |
+
# remove loss for any token before the first <image> token
|
| 133 |
+
label_idx = 0
|
| 134 |
+
while (
|
| 135 |
+
label_idx < labels.shape[1] and labels[i][label_idx] != media_token_id
|
| 136 |
+
):
|
| 137 |
+
labels[i][label_idx] = -100
|
| 138 |
+
label_idx += 1
|
| 139 |
+
|
| 140 |
+
# get index of all endofchunk tokens in the sequence
|
| 141 |
+
endofchunk_idxs = torch.where(labels[i] == endofchunk_token_id)[0]
|
| 142 |
+
for endofchunk_idx in endofchunk_idxs:
|
| 143 |
+
token_idx = endofchunk_idx + 1
|
| 144 |
+
while (
|
| 145 |
+
token_idx < labels.shape[1]
|
| 146 |
+
and labels[i][token_idx] != media_token_id
|
| 147 |
+
):
|
| 148 |
+
labels[i][token_idx] = -100
|
| 149 |
+
token_idx += 1
|
| 150 |
+
|
| 151 |
+
labels[labels == media_token_id] = -100
|
| 152 |
+
labels = labels.to(device_id)
|
| 153 |
+
|
| 154 |
+
# gradient accumulation w/ fsdp cpu offloading requires a no_sync context manager
|
| 155 |
+
with autocast():
|
| 156 |
+
loss_mmc4 = model(
|
| 157 |
+
vision_x=images,
|
| 158 |
+
lang_x=input_ids.to(device_id),
|
| 159 |
+
attention_mask=attention_mask.to(device_id),
|
| 160 |
+
labels=labels,
|
| 161 |
+
)[0]
|
| 162 |
+
|
| 163 |
+
# if loss is nan, skip this batch
|
| 164 |
+
# this hack of skipping the batch is not FSDP-compatible
|
| 165 |
+
if torch.isnan(loss_mmc4):
|
| 166 |
+
print("loss is nan, skipping this batch")
|
| 167 |
+
print("input_ids: ", tokenizer.batch_decode(input_ids))
|
| 168 |
+
print("labels: ", labels)
|
| 169 |
+
print("images: ", images)
|
| 170 |
+
optimizer.zero_grad(set_to_none=True)
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
divided_loss_mmc4 = loss_mmc4 / args.gradient_accumulation_steps
|
| 174 |
+
(divided_loss_mmc4 * args.loss_multiplier_mmc4).backward()
|
| 175 |
+
|
| 176 |
+
if (not args.freeze_lm_embeddings) and (
|
| 177 |
+
not args.fsdp or args.fsdp_use_orig_params
|
| 178 |
+
):
|
| 179 |
+
# Mask gradients for input embeddings s.t. we only update the added tokens <image> and <|endofchunk|>
|
| 180 |
+
if args.fsdp:
|
| 181 |
+
embed_grad = model.lang_encoder.get_input_embeddings().weight.grad
|
| 182 |
+
else:
|
| 183 |
+
embed_grad = (
|
| 184 |
+
model.module.lang_encoder.get_input_embeddings().weight.grad
|
| 185 |
+
)
|
| 186 |
+
zero_mask = torch.zeros_like(embed_grad)
|
| 187 |
+
zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
|
| 188 |
+
zero_mask[endofchunk_token_id] = torch.ones_like(
|
| 189 |
+
zero_mask[endofchunk_token_id]
|
| 190 |
+
)
|
| 191 |
+
if args.fsdp:
|
| 192 |
+
model.lang_encoder.get_input_embeddings().weight.grad = (
|
| 193 |
+
embed_grad * zero_mask
|
| 194 |
+
)
|
| 195 |
+
else:
|
| 196 |
+
model.module.lang_encoder.get_input_embeddings().weight.grad = (
|
| 197 |
+
embed_grad * zero_mask
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# clip gradient norm
|
| 201 |
+
if args.fsdp:
|
| 202 |
+
"""
|
| 203 |
+
The way we clip gradients with FSDP is different than the non-FSDP case,
|
| 204 |
+
because during FSDP, gradient norms are computed over certain submodules,
|
| 205 |
+
rather than the entire model.
|
| 206 |
+
At least for OPT-125M, this didn't seem to make a difference in performance.
|
| 207 |
+
"""
|
| 208 |
+
model.clip_grad_norm_(1.0)
|
| 209 |
+
else:
|
| 210 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 211 |
+
|
| 212 |
+
# step optimizer and log
|
| 213 |
+
if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (
|
| 214 |
+
num_steps == num_batches_per_epoch - 1
|
| 215 |
+
):
|
| 216 |
+
optimizer.step()
|
| 217 |
+
lr_scheduler.step()
|
| 218 |
+
optimizer.zero_grad(set_to_none=True)
|
| 219 |
+
|
| 220 |
+
# step time and reset end outside of rank 0
|
| 221 |
+
step_time_m.update(time.time() - end)
|
| 222 |
+
end = time.time()
|
| 223 |
+
|
| 224 |
+
# rank 0 logging
|
| 225 |
+
if args.rank == 0 and args.report_to_wandb:
|
| 226 |
+
laion_samples_per_second = (
|
| 227 |
+
args.gradient_accumulation_steps
|
| 228 |
+
* args.batch_size_laion
|
| 229 |
+
* args.world_size
|
| 230 |
+
/ step_time_m.val
|
| 231 |
+
)
|
| 232 |
+
laion_samples_per_second_per_gpu = (
|
| 233 |
+
args.gradient_accumulation_steps
|
| 234 |
+
* args.batch_size_laion
|
| 235 |
+
/ step_time_m.val
|
| 236 |
+
)
|
| 237 |
+
c4_samples_per_second = (
|
| 238 |
+
args.gradient_accumulation_steps
|
| 239 |
+
* args.batch_size_mmc4
|
| 240 |
+
* args.world_size
|
| 241 |
+
/ step_time_m.val
|
| 242 |
+
)
|
| 243 |
+
c4_samples_per_second_per_gpu = (
|
| 244 |
+
args.gradient_accumulation_steps
|
| 245 |
+
* args.batch_size_mmc4
|
| 246 |
+
/ step_time_m.val
|
| 247 |
+
)
|
| 248 |
+
wandb.log(
|
| 249 |
+
{
|
| 250 |
+
"data_time": data_time_m.avg,
|
| 251 |
+
"step_time": step_time_m.avg,
|
| 252 |
+
"laion_samples_per_second": laion_samples_per_second,
|
| 253 |
+
"laion_samples_per_second_per_gpu": laion_samples_per_second_per_gpu,
|
| 254 |
+
"c4_samples_per_second": c4_samples_per_second,
|
| 255 |
+
"c4_samples_per_second_per_gpu": c4_samples_per_second_per_gpu,
|
| 256 |
+
"lr": optimizer.param_groups[0]["lr"],
|
| 257 |
+
},
|
| 258 |
+
commit=False,
|
| 259 |
+
)
|
| 260 |
+
step_time_m.reset()
|
| 261 |
+
data_time_m.reset()
|
| 262 |
+
|
| 263 |
+
wandb.log(
|
| 264 |
+
{
|
| 265 |
+
"loss_laion": loss_laion.item(),
|
| 266 |
+
"global_step": global_step,
|
| 267 |
+
},
|
| 268 |
+
commit=False,
|
| 269 |
+
)
|
| 270 |
+
wandb.log(
|
| 271 |
+
{"loss_mmc4": loss_mmc4.item(), "global_step": global_step},
|
| 272 |
+
commit=True,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# Log loss to console
|
| 276 |
+
if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0:
|
| 277 |
+
print(
|
| 278 |
+
f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss LAION: {loss_laion.item():.3f} // Loss MMC4: {loss_mmc4.item():.3f}"
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class AverageMeter(object):
|
| 283 |
+
"""Computes and stores the average and current value"""
|
| 284 |
+
|
| 285 |
+
def __init__(self):
|
| 286 |
+
self.reset()
|
| 287 |
+
|
| 288 |
+
def reset(self):
|
| 289 |
+
self.val = 0
|
| 290 |
+
self.avg = 0
|
| 291 |
+
self.sum = 0
|
| 292 |
+
self.count = 0
|
| 293 |
+
|
| 294 |
+
def update(self, val, n=1):
|
| 295 |
+
self.val = val
|
| 296 |
+
self.sum += val * n
|
| 297 |
+
self.count += n
|
| 298 |
+
self.avg = self.sum / self.count
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def filter_state_dict_to_trainable(model, state_dict):
|
| 302 |
+
"""
|
| 303 |
+
Remove non-trainable parameters from model state dict.
|
| 304 |
+
Exception: Embeddings will not be removed, even if frozen.
|
| 305 |
+
This is because we need the new <image> <|endofchunk|> tokens to
|
| 306 |
+
be consistent across initializations.
|
| 307 |
+
"""
|
| 308 |
+
for (
|
| 309 |
+
name,
|
| 310 |
+
p,
|
| 311 |
+
) in model.named_parameters(): # won't work for fsdp + use_orig_params=False
|
| 312 |
+
if "fsdp" in name:
|
| 313 |
+
continue
|
| 314 |
+
if "embed" in name or isinstance(p, torch.nn.Embedding):
|
| 315 |
+
continue
|
| 316 |
+
if not p.requires_grad:
|
| 317 |
+
name = name.replace("._checkpoint_wrapped_module", "")
|
| 318 |
+
if name in state_dict:
|
| 319 |
+
del state_dict[name]
|
| 320 |
+
else:
|
| 321 |
+
print(f"WARNING: filtering but {name} not in state_dict")
|
| 322 |
+
|
| 323 |
+
# also remove the keys in state_dict generated from
|
| 324 |
+
# lang_encoder.old_decoder_blocks and lang_encoder.gated_cross_attn_layers
|
| 325 |
+
# because these are already saved in lang_encoder.model...
|
| 326 |
+
to_delete = [
|
| 327 |
+
n
|
| 328 |
+
for n in state_dict.keys()
|
| 329 |
+
if ("lang_encoder.old_decoder_blocks" in n)
|
| 330 |
+
or ("lang_encoder.gated_cross_attn_layers" in n)
|
| 331 |
+
or ("vision_encoder" in n)
|
| 332 |
+
]
|
| 333 |
+
for name in to_delete:
|
| 334 |
+
del state_dict[name]
|
| 335 |
+
return state_dict
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def save_checkpoint(model, optimizer, lr_scheduler, epoch, args):
|
| 339 |
+
"""
|
| 340 |
+
Save training checkpoint with model, optimizer, and lr_scheduler state.
|
| 341 |
+
"""
|
| 342 |
+
if args.fsdp:
|
| 343 |
+
FSDP.set_state_dict_type(
|
| 344 |
+
model,
|
| 345 |
+
StateDictType.FULL_STATE_DICT,
|
| 346 |
+
FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
|
| 347 |
+
FullOptimStateDictConfig(rank0_only=True),
|
| 348 |
+
)
|
| 349 |
+
model_state = model.state_dict()
|
| 350 |
+
optim_state = FSDP.optim_state_dict(model, optimizer, group=args.my_group)
|
| 351 |
+
|
| 352 |
+
else:
|
| 353 |
+
model_state = model.state_dict()
|
| 354 |
+
optim_state = optimizer.state_dict()
|
| 355 |
+
|
| 356 |
+
if args.rank == 0:
|
| 357 |
+
if not (args.fsdp and not args.fsdp_use_orig_params):
|
| 358 |
+
model_state = filter_state_dict_to_trainable(model, model_state)
|
| 359 |
+
|
| 360 |
+
if not os.path.exists(args.run_name):
|
| 361 |
+
os.makedirs(args.run_name)
|
| 362 |
+
|
| 363 |
+
checkpoint_dict = {
|
| 364 |
+
"epoch": epoch,
|
| 365 |
+
"model_state_dict": model_state,
|
| 366 |
+
"optimizer_state_dict": optim_state,
|
| 367 |
+
"lr_scheduler_state_dict": lr_scheduler.state_dict(),
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
print(f"Saving checkpoint to {args.run_name}/checkpoint_{epoch}.pt")
|
| 371 |
+
torch.save(checkpoint_dict, f"{args.run_name}/checkpoint_{epoch}.pt")
|
| 372 |
+
if args.report_to_wandb and args.save_checkpoints_to_wandb:
|
| 373 |
+
wandb.save(f"{args.run_name}/checkpoint_{epoch}.pt")
|
| 374 |
+
|
| 375 |
+
if args.delete_previous_checkpoint:
|
| 376 |
+
if epoch > 0:
|
| 377 |
+
os.remove(f"{args.run_name}/checkpoint_{epoch-1}.pt")
|
requirements.txt
CHANGED
|
@@ -4,14 +4,9 @@ torchvision==0.16.0
|
|
| 4 |
transformers==4.35.0
|
| 5 |
huggingface_hub>=0.16.4,<1.0
|
| 6 |
einops==0.7.0
|
| 7 |
-
einops-exts==0.0.4
|
| 8 |
numpy==1.24.3
|
| 9 |
Pillow==10.1.0
|
| 10 |
matplotlib==3.8.0
|
| 11 |
open_clip_torch==2.23.0
|
| 12 |
accelerate==0.24.1
|
| 13 |
-
safetensors==0.4.0
|
| 14 |
-
zarr==2.16.1
|
| 15 |
-
numcodecs==0.12.1
|
| 16 |
-
hydra-core==1.3.2
|
| 17 |
-
omegaconf==2.3.0
|
|
|
|
| 4 |
transformers==4.35.0
|
| 5 |
huggingface_hub>=0.16.4,<1.0
|
| 6 |
einops==0.7.0
|
|
|
|
| 7 |
numpy==1.24.3
|
| 8 |
Pillow==10.1.0
|
| 9 |
matplotlib==3.8.0
|
| 10 |
open_clip_torch==2.23.0
|
| 11 |
accelerate==0.24.1
|
| 12 |
+
safetensors==0.4.0
|
|
|
|
|
|
|
|
|
|
|
|