event_retrieval / embed_worker.py
sanskar407
change requirement.txt
546ee63
"""
embed_worker.py β€” Single-GPU embedding worker.
Usage (run one per GPU in separate terminals):
CUDA_VISIBLE_DEVICES=5 python embed_worker.py --start 12586 --end 16000 --out embeddings_gpu5.pkl
CUDA_VISIBLE_DEVICES=6 python embed_worker.py --start 16000 --end 20000 --out embeddings_gpu6.pkl
CUDA_VISIBLE_DEVICES=7 python embed_worker.py --start 20000 --end 24586 --out embeddings_gpu7.pkl
The existing embeddings_blip2.pkl already has frames 0-12585 β€” don't re-do those.
"""
import argparse, os, sys, pickle
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# ── PATHS β€” adjust if needed ──────────────────────────────────────────────────
FRAMES_ROOT = "/media/RTCIN15TB/Datasets/NvidiaPhisicalAIFrames"
CACHE_DIR = "model_cache"
script_dir = os.path.dirname(os.path.abspath(__file__))
local_cache_path = os.path.join(script_dir, CACHE_DIR)
os.makedirs(local_cache_path, exist_ok=True)
os.environ['HF_HOME'] = local_cache_path
os.environ['LAVIS_CACHE_ROOT'] = local_cache_path
# ── LAVIS path ────────────────────────────────────────────────────────────────
path_to_project_root = os.path.abspath(os.path.join(script_dir, ".."))
path_to_lavis_parent_dir = os.path.join(path_to_project_root, "LAVIS")
if not (os.path.isdir(path_to_lavis_parent_dir) and
os.path.isdir(os.path.join(path_to_lavis_parent_dir, "lavis"))):
path_to_lavis_parent_dir = "/media/RTCIN7TBDriveB/Interns/RDT2/gte3kor/LAVIS"
sys.path.insert(0, path_to_lavis_parent_dir)
# ── Patches (same as main script) ─────────────────────────────────────────────
import torch, torch.nn as nn, torch.distributions.constraints as constraints, inspect
from transformers.modeling_utils import PreTrainedModel
from lavis.models.blip2_models.blip2_qformer import Blip2Qformer
if hasattr(constraints, '_PositiveDefinite') and hasattr(constraints._PositiveDefinite, 'check'):
_orig_pdc = constraints._PositiveDefinite.check
def _patched_pdc(self, value):
if isinstance(value, torch.Tensor) and value.is_meta:
return torch.ones_like(value, dtype=torch.bool, device=value.device)
return _orig_pdc(self, value)
constraints._PositiveDefinite.check = _patched_pdc
if hasattr(PreTrainedModel, '_init_added_embeddings_weights_with_mean'):
_orig_iae = PreTrainedModel._init_added_embeddings_weights_with_mean
def _patched_iae(self, new_emb, old_emb, num_added, *args, **kwargs):
if not (isinstance(new_emb, nn.Embedding) and isinstance(old_emb, nn.Embedding)):
return _orig_iae(self, new_emb, old_emb, num_added, *args, **kwargs)
new_w, old_w = new_emb.weight, old_emb.weight
if num_added > 0 and old_w.device.type == 'meta':
start, end = old_w.shape[0], new_w.shape[0]
sl = slice(start, end)
if new_w.device.type != 'meta' and sl.start < sl.stop:
with torch.no_grad():
new_w[sl].normal_(mean=0.0, std=self.config.initializer_range)
return
return _orig_iae(self, new_emb, old_emb, num_added, *args, **kwargs)
PreTrainedModel._init_added_embeddings_weights_with_mean = _patched_iae
_orig_lsd = nn.Module.load_state_dict
# def _patched_lsd(self, state_dict, strict=True, assign=False):
# if isinstance(self, Blip2Qformer):
# model_sd = self.state_dict()
# for key in ["Qformer.cls.predictions.bias", "Qformer.cls.predictions.decoder.weight"]:
# if key in state_dict and key in model_sd:
# ckpt_t, model_t = state_dict[key], model_sd[key]
# if ckpt_t.shape[0] != model_t.shape[0]:
# state_dict[key] = ckpt_t.narrow(0, 0, model_t.shape[0])
# if any(p.is_meta for p in self.parameters()):
# assign = True
# sig = inspect.signature(_orig_lsd)
# if 'assign' in sig.parameters:
# return _orig_lsd(self, state_dict, strict=strict, assign=assign)
# else:
# return _orig_lsd(self, state_dict, strict=strict)
# nn.Module.load_state_dict = _patched_lsd
def _patched_lsd(self, state_dict, strict=True, assign=False):
if isinstance(self, Blip2Qformer):
model_sd = self.state_dict()
for key in ["Qformer.cls.predictions.bias", "Qformer.cls.predictions.decoder.weight"]:
if key in state_dict and key in model_sd:
ckpt_t, model_t = state_dict[key], model_sd[key]
if ckpt_t.shape[0] != model_t.shape[0]:
state_dict[key] = ckpt_t.narrow(0, 0, model_t.shape[0])
# ── ADD THIS BLOCK β€” interpolate pos_embed if size mismatch ──
# if "visual_encoder.pos_embed" in state_dict and "visual_encoder.pos_embed" in model_sd:
# ckpt_pos = state_dict["visual_encoder.pos_embed"] # [1, 1381, 1408]
# model_pos = model_sd["visual_encoder.pos_embed"] # [1, 1370, 1408]
# if ckpt_pos.shape != model_pos.shape:
# import torch.nn.functional as F
# # strip cls token, interpolate, re-attach
# cls_tok = ckpt_pos[:, :1, :] # [1, 1, 1408]
# patches = ckpt_pos[:, 1:, :] # [1, N_ckpt, 1408]
# N_model = model_pos.shape[1] - 1
# # reshape to 2D grid, interpolate, reshape back
# import math
# gs_ckpt = int(math.sqrt(patches.shape[1])) # 37
# gs_model = int(math.sqrt(N_model)) # target grid
# dim = patches.shape[-1]
# print(f"DEBUG: patches.shape={patches.shape}, gs_ckpt={gs_ckpt}, gs_ckpt^2={gs_ckpt*gs_ckpt}")
# patches = patches.reshape(1, gs_ckpt, gs_ckpt, dim).permute(0, 3, 1, 2)
# # patches = patches.reshape(1, gs_ckpt, gs_ckpt, 1408).permute(0, 3, 1, 2) # [1,1408,37,37]
# patches = F.interpolate(patches.float(), size=(gs_model, gs_model), mode='bicubic', align_corners=False)
# patches = patches.permute(0, 2, 3, 1).reshape(1, gs_model*gs_model, dim)
# state_dict["visual_encoder.pos_embed"] = torch.cat([cls_tok, patches], dim=1)
# print(f"INFO: Interpolated pos_embed {ckpt_pos.shape} β†’ {state_dict['visual_encoder.pos_embed'].shape}")
if "visual_encoder.pos_embed" in state_dict and "visual_encoder.pos_embed" in model_sd:
ckpt_pos = state_dict["visual_encoder.pos_embed"]
model_pos = model_sd["visual_encoder.pos_embed"]
if ckpt_pos.shape != model_pos.shape:
import torch.nn.functional as F, math
print(f"DEBUG: ckpt_pos={ckpt_pos.shape}, model_pos={model_pos.shape}")
cls_tok = ckpt_pos[:, :1, :]
patches = ckpt_pos[:, 1:, :]
dim = patches.shape[-1]
N_ckpt = patches.shape[1]
N_model = model_pos.shape[1] - 1
print(f"DEBUG: N_ckpt={N_ckpt}, N_model={N_model}, dim={dim}")
# find grid sizes β€” may not be square
gs_ckpt_h = gs_ckpt_w = int(math.sqrt(N_ckpt))
# if not perfect square, brute force find h,w factors
if gs_ckpt_h * gs_ckpt_w != N_ckpt:
for h in range(int(math.sqrt(N_ckpt)), 0, -1):
if N_ckpt % h == 0:
gs_ckpt_h, gs_ckpt_w = h, N_ckpt // h
break
gs_model_h = gs_model_w = int(math.sqrt(N_model))
if gs_model_h * gs_model_w != N_model:
for h in range(int(math.sqrt(N_model)), 0, -1):
if N_model % h == 0:
gs_model_h, gs_model_w = h, N_model // h
break
print(f"DEBUG: ckpt grid={gs_ckpt_h}x{gs_ckpt_w}, model grid={gs_model_h}x{gs_model_w}")
patches = patches.reshape(1, gs_ckpt_h, gs_ckpt_w, dim).permute(0, 3, 1, 2)
patches = F.interpolate(patches.float(), size=(gs_model_h, gs_model_w), mode='bicubic', align_corners=False)
patches = patches.permute(0, 2, 3, 1).reshape(1, gs_model_h * gs_model_w, dim)
state_dict["visual_encoder.pos_embed"] = torch.cat([cls_tok, patches], dim=1)
print(f"INFO: Interpolated pos_embed {ckpt_pos.shape} β†’ {state_dict['visual_encoder.pos_embed'].shape}")
# ─────────────────────────────────────────────────────────────
if any(p.is_meta for p in self.parameters()):
assign = True
sig = inspect.signature(_orig_lsd)
if 'assign' in sig.parameters:
return _orig_lsd(self, state_dict, strict=strict, assign=assign)
else:
return _orig_lsd(self, state_dict, strict=strict)
nn.Module.load_state_dict = _patched_lsd
print("INFO: Patched nn.Module.load_state_dict.")
# ── Main ──────────────────────────────────────────────────────────────────────
from lavis.models import load_model_and_preprocess
from PIL import Image
import numpy as np
def discover_frames(root):
all_paths = []
for chunk in sorted(os.listdir(root)):
cp = os.path.join(root, chunk)
if not os.path.isdir(cp) or chunk.startswith('.'): continue
for vid in sorted(os.listdir(cp)):
vp = os.path.join(cp, vid)
if not os.path.isdir(vp): continue
all_paths.extend(sorted(
os.path.join(vp, f) for f in os.listdir(vp)
if f.lower().endswith('.jpg')
))
return all_paths
def run(start: int, end: int, out_pkl: str, override_paths=None):
device = "cuda:0" # CUDA_VISIBLE_DEVICES remaps the GPU to index 0
# print(f"Loading BLIP-2 on {device} …")
# model, vis_processors, _ = load_model_and_preprocess(
# name="blip2", model_type="gen3_322_840", is_eval=True, device=device
# )
model, vis_processors, text_processors = load_model_and_preprocess(
name="blip2",
model_type="gen3_518_518",
is_eval=True,
device=device
)
model.eval()
# print("Discovering frames …")
# all_paths = discover_frames(FRAMES_ROOT)
# slice_paths = all_paths[start:end]
# print(f"This worker: frames [{start}, {end}) β†’ {len(slice_paths)} paths")
if override_paths is not None:
all_paths = override_paths
start = 0
end = len(all_paths)
slice_paths = all_paths
total = len(all_paths)
# πŸ”₯ AUTO HANDLE END
if end is None:
end = total
# clamp safety
end = min(end, total)
if start >= total:
print(f"⚠️ Start {start} exceeds total frames {total}")
return
print(f"Processing range [{start}, {end}) out of {total}")
# Resume support: load existing partial output if present
if os.path.exists(out_pkl):
with open(out_pkl, "rb") as f:
embedding_dict = pickle.load(f)
print(f"Resumed from {out_pkl}: {len(embedding_dict)} already done.")
else:
embedding_dict = {}
todo = [p for p in slice_paths if p not in embedding_dict]
print(f"{len(todo)} frames still need embedding.")
BATCH_SIZE = 32
SAVE_EVERY = 500
computed = 0
model.eval()
with torch.no_grad():
for i in range(0, len(todo), BATCH_SIZE):
batch_paths = todo[i: i + BATCH_SIZE]
images, valid_paths = [], []
for p in batch_paths:
try:
img = Image.open(p).convert("RGB")
images.append(vis_processors["eval"](img))
valid_paths.append(p)
except Exception as e:
print(f" WARNING: {p}: {e}")
if not images:
continue
image_tensor = torch.stack(images, dim=0).to(device)
feats = model.extract_features(
{"image": image_tensor}, mode="image"
).image_embeds_proj[:, 0, :]
for path, emb in zip(valid_paths, feats):
embedding_dict[path] = emb.cpu()
computed += len(valid_paths)
if computed % SAVE_EVERY == 0:
print(f" [{computed}/{len(todo)}] Saving checkpoint β†’ {out_pkl}")
with open(out_pkl, "wb") as f:
pickle.dump(embedding_dict, f)
print(f"Done. Saving final output β†’ {out_pkl} ({len(embedding_dict)} embeddings)")
with open(out_pkl, "wb") as f:
pickle.dump(embedding_dict, f)
if __name__ == "__main__":
# parser = argparse.ArgumentParser()
# parser.add_argument("--start", type=int, required=True, help="Start frame index (inclusive)")
# parser.add_argument("--end", type=int, default=None, help="End frame index (exclusive)")
# parser.add_argument("--out", type=str, required=True, help="Output .pkl filename")
# args = parser.parse_args()
# run(args.start, args.end, args.out)
parser = argparse.ArgumentParser()
parser.add_argument("--start", type=int, default=0)
parser.add_argument("--end", type=int, default=-1)
parser.add_argument("--out", type=str, required=True)
parser.add_argument("--folder", type=str, default=None,
help="If set, embed only JPGs under this specific folder (overrides --start/--end)")
args = parser.parse_args()
if args.folder:
# collect all jpgs directly from the specified folder
import glob
specific_paths = sorted(glob.glob(os.path.join(args.folder, "**", "*.jpg"), recursive=True))
if not specific_paths:
specific_paths = sorted(glob.glob(os.path.join(args.folder, "*.jpg")))
print(f"Folder mode: found {len(specific_paths)} JPGs under {args.folder}")
run(0, len(specific_paths), args.out, override_paths=specific_paths)
else:
run(args.start, args.end, args.out)