""" embed_worker.py — Single-GPU embedding worker. Usage (run one per GPU in separate terminals): CUDA_VISIBLE_DEVICES=5 python embed_worker.py --start 12586 --end 16000 --out embeddings_gpu5.pkl CUDA_VISIBLE_DEVICES=6 python embed_worker.py --start 16000 --end 20000 --out embeddings_gpu6.pkl CUDA_VISIBLE_DEVICES=7 python embed_worker.py --start 20000 --end 24586 --out embeddings_gpu7.pkl The existing embeddings_blip2.pkl already has frames 0-12585 — don't re-do those. """ import argparse, os, sys, pickle import ssl ssl._create_default_https_context = ssl._create_unverified_context # ── PATHS — adjust if needed ────────────────────────────────────────────────── FRAMES_ROOT = "/media/RTCIN15TB/Datasets/NvidiaPhisicalAIFrames" CACHE_DIR = "model_cache" script_dir = os.path.dirname(os.path.abspath(__file__)) local_cache_path = os.path.join(script_dir, CACHE_DIR) os.makedirs(local_cache_path, exist_ok=True) os.environ['HF_HOME'] = local_cache_path os.environ['LAVIS_CACHE_ROOT'] = local_cache_path # ── LAVIS path ──────────────────────────────────────────────────────────────── path_to_project_root = os.path.abspath(os.path.join(script_dir, "..")) path_to_lavis_parent_dir = os.path.join(path_to_project_root, "LAVIS") if not (os.path.isdir(path_to_lavis_parent_dir) and os.path.isdir(os.path.join(path_to_lavis_parent_dir, "lavis"))): path_to_lavis_parent_dir = "/media/RTCIN7TBDriveB/Interns/RDT2/gte3kor/LAVIS" sys.path.insert(0, path_to_lavis_parent_dir) # ── Patches (same as main script) ───────────────────────────────────────────── import torch, torch.nn as nn, torch.distributions.constraints as constraints, inspect from transformers.modeling_utils import PreTrainedModel from lavis.models.blip2_models.blip2_qformer import Blip2Qformer if hasattr(constraints, '_PositiveDefinite') and hasattr(constraints._PositiveDefinite, 'check'): _orig_pdc = constraints._PositiveDefinite.check def _patched_pdc(self, value): if isinstance(value, torch.Tensor) and value.is_meta: return torch.ones_like(value, dtype=torch.bool, device=value.device) return _orig_pdc(self, value) constraints._PositiveDefinite.check = _patched_pdc if hasattr(PreTrainedModel, '_init_added_embeddings_weights_with_mean'): _orig_iae = PreTrainedModel._init_added_embeddings_weights_with_mean def _patched_iae(self, new_emb, old_emb, num_added, *args, **kwargs): if not (isinstance(new_emb, nn.Embedding) and isinstance(old_emb, nn.Embedding)): return _orig_iae(self, new_emb, old_emb, num_added, *args, **kwargs) new_w, old_w = new_emb.weight, old_emb.weight if num_added > 0 and old_w.device.type == 'meta': start, end = old_w.shape[0], new_w.shape[0] sl = slice(start, end) if new_w.device.type != 'meta' and sl.start < sl.stop: with torch.no_grad(): new_w[sl].normal_(mean=0.0, std=self.config.initializer_range) return return _orig_iae(self, new_emb, old_emb, num_added, *args, **kwargs) PreTrainedModel._init_added_embeddings_weights_with_mean = _patched_iae _orig_lsd = nn.Module.load_state_dict # def _patched_lsd(self, state_dict, strict=True, assign=False): # if isinstance(self, Blip2Qformer): # model_sd = self.state_dict() # for key in ["Qformer.cls.predictions.bias", "Qformer.cls.predictions.decoder.weight"]: # if key in state_dict and key in model_sd: # ckpt_t, model_t = state_dict[key], model_sd[key] # if ckpt_t.shape[0] != model_t.shape[0]: # state_dict[key] = ckpt_t.narrow(0, 0, model_t.shape[0]) # if any(p.is_meta for p in self.parameters()): # assign = True # sig = inspect.signature(_orig_lsd) # if 'assign' in sig.parameters: # return _orig_lsd(self, state_dict, strict=strict, assign=assign) # else: # return _orig_lsd(self, state_dict, strict=strict) # nn.Module.load_state_dict = _patched_lsd def _patched_lsd(self, state_dict, strict=True, assign=False): if isinstance(self, Blip2Qformer): model_sd = self.state_dict() for key in ["Qformer.cls.predictions.bias", "Qformer.cls.predictions.decoder.weight"]: if key in state_dict and key in model_sd: ckpt_t, model_t = state_dict[key], model_sd[key] if ckpt_t.shape[0] != model_t.shape[0]: state_dict[key] = ckpt_t.narrow(0, 0, model_t.shape[0]) # ── ADD THIS BLOCK — interpolate pos_embed if size mismatch ── # if "visual_encoder.pos_embed" in state_dict and "visual_encoder.pos_embed" in model_sd: # ckpt_pos = state_dict["visual_encoder.pos_embed"] # [1, 1381, 1408] # model_pos = model_sd["visual_encoder.pos_embed"] # [1, 1370, 1408] # if ckpt_pos.shape != model_pos.shape: # import torch.nn.functional as F # # strip cls token, interpolate, re-attach # cls_tok = ckpt_pos[:, :1, :] # [1, 1, 1408] # patches = ckpt_pos[:, 1:, :] # [1, N_ckpt, 1408] # N_model = model_pos.shape[1] - 1 # # reshape to 2D grid, interpolate, reshape back # import math # gs_ckpt = int(math.sqrt(patches.shape[1])) # 37 # gs_model = int(math.sqrt(N_model)) # target grid # dim = patches.shape[-1] # print(f"DEBUG: patches.shape={patches.shape}, gs_ckpt={gs_ckpt}, gs_ckpt^2={gs_ckpt*gs_ckpt}") # patches = patches.reshape(1, gs_ckpt, gs_ckpt, dim).permute(0, 3, 1, 2) # # patches = patches.reshape(1, gs_ckpt, gs_ckpt, 1408).permute(0, 3, 1, 2) # [1,1408,37,37] # patches = F.interpolate(patches.float(), size=(gs_model, gs_model), mode='bicubic', align_corners=False) # patches = patches.permute(0, 2, 3, 1).reshape(1, gs_model*gs_model, dim) # state_dict["visual_encoder.pos_embed"] = torch.cat([cls_tok, patches], dim=1) # print(f"INFO: Interpolated pos_embed {ckpt_pos.shape} → {state_dict['visual_encoder.pos_embed'].shape}") if "visual_encoder.pos_embed" in state_dict and "visual_encoder.pos_embed" in model_sd: ckpt_pos = state_dict["visual_encoder.pos_embed"] model_pos = model_sd["visual_encoder.pos_embed"] if ckpt_pos.shape != model_pos.shape: import torch.nn.functional as F, math print(f"DEBUG: ckpt_pos={ckpt_pos.shape}, model_pos={model_pos.shape}") cls_tok = ckpt_pos[:, :1, :] patches = ckpt_pos[:, 1:, :] dim = patches.shape[-1] N_ckpt = patches.shape[1] N_model = model_pos.shape[1] - 1 print(f"DEBUG: N_ckpt={N_ckpt}, N_model={N_model}, dim={dim}") # find grid sizes — may not be square gs_ckpt_h = gs_ckpt_w = int(math.sqrt(N_ckpt)) # if not perfect square, brute force find h,w factors if gs_ckpt_h * gs_ckpt_w != N_ckpt: for h in range(int(math.sqrt(N_ckpt)), 0, -1): if N_ckpt % h == 0: gs_ckpt_h, gs_ckpt_w = h, N_ckpt // h break gs_model_h = gs_model_w = int(math.sqrt(N_model)) if gs_model_h * gs_model_w != N_model: for h in range(int(math.sqrt(N_model)), 0, -1): if N_model % h == 0: gs_model_h, gs_model_w = h, N_model // h break print(f"DEBUG: ckpt grid={gs_ckpt_h}x{gs_ckpt_w}, model grid={gs_model_h}x{gs_model_w}") patches = patches.reshape(1, gs_ckpt_h, gs_ckpt_w, dim).permute(0, 3, 1, 2) patches = F.interpolate(patches.float(), size=(gs_model_h, gs_model_w), mode='bicubic', align_corners=False) patches = patches.permute(0, 2, 3, 1).reshape(1, gs_model_h * gs_model_w, dim) state_dict["visual_encoder.pos_embed"] = torch.cat([cls_tok, patches], dim=1) print(f"INFO: Interpolated pos_embed {ckpt_pos.shape} → {state_dict['visual_encoder.pos_embed'].shape}") # ───────────────────────────────────────────────────────────── if any(p.is_meta for p in self.parameters()): assign = True sig = inspect.signature(_orig_lsd) if 'assign' in sig.parameters: return _orig_lsd(self, state_dict, strict=strict, assign=assign) else: return _orig_lsd(self, state_dict, strict=strict) nn.Module.load_state_dict = _patched_lsd print("INFO: Patched nn.Module.load_state_dict.") # ── Main ────────────────────────────────────────────────────────────────────── from lavis.models import load_model_and_preprocess from PIL import Image import numpy as np def discover_frames(root): all_paths = [] for chunk in sorted(os.listdir(root)): cp = os.path.join(root, chunk) if not os.path.isdir(cp) or chunk.startswith('.'): continue for vid in sorted(os.listdir(cp)): vp = os.path.join(cp, vid) if not os.path.isdir(vp): continue all_paths.extend(sorted( os.path.join(vp, f) for f in os.listdir(vp) if f.lower().endswith('.jpg') )) return all_paths def run(start: int, end: int, out_pkl: str, override_paths=None): device = "cuda:0" # CUDA_VISIBLE_DEVICES remaps the GPU to index 0 # print(f"Loading BLIP-2 on {device} …") # model, vis_processors, _ = load_model_and_preprocess( # name="blip2", model_type="gen3_322_840", is_eval=True, device=device # ) model, vis_processors, text_processors = load_model_and_preprocess( name="blip2", model_type="gen3_518_518", is_eval=True, device=device ) model.eval() # print("Discovering frames …") # all_paths = discover_frames(FRAMES_ROOT) # slice_paths = all_paths[start:end] # print(f"This worker: frames [{start}, {end}) → {len(slice_paths)} paths") if override_paths is not None: all_paths = override_paths start = 0 end = len(all_paths) slice_paths = all_paths total = len(all_paths) # 🔥 AUTO HANDLE END if end is None: end = total # clamp safety end = min(end, total) if start >= total: print(f"⚠️ Start {start} exceeds total frames {total}") return print(f"Processing range [{start}, {end}) out of {total}") # Resume support: load existing partial output if present if os.path.exists(out_pkl): with open(out_pkl, "rb") as f: embedding_dict = pickle.load(f) print(f"Resumed from {out_pkl}: {len(embedding_dict)} already done.") else: embedding_dict = {} todo = [p for p in slice_paths if p not in embedding_dict] print(f"{len(todo)} frames still need embedding.") BATCH_SIZE = 32 SAVE_EVERY = 500 computed = 0 model.eval() with torch.no_grad(): for i in range(0, len(todo), BATCH_SIZE): batch_paths = todo[i: i + BATCH_SIZE] images, valid_paths = [], [] for p in batch_paths: try: img = Image.open(p).convert("RGB") images.append(vis_processors["eval"](img)) valid_paths.append(p) except Exception as e: print(f" WARNING: {p}: {e}") if not images: continue image_tensor = torch.stack(images, dim=0).to(device) feats = model.extract_features( {"image": image_tensor}, mode="image" ).image_embeds_proj[:, 0, :] for path, emb in zip(valid_paths, feats): embedding_dict[path] = emb.cpu() computed += len(valid_paths) if computed % SAVE_EVERY == 0: print(f" [{computed}/{len(todo)}] Saving checkpoint → {out_pkl}") with open(out_pkl, "wb") as f: pickle.dump(embedding_dict, f) print(f"Done. Saving final output → {out_pkl} ({len(embedding_dict)} embeddings)") with open(out_pkl, "wb") as f: pickle.dump(embedding_dict, f) if __name__ == "__main__": # parser = argparse.ArgumentParser() # parser.add_argument("--start", type=int, required=True, help="Start frame index (inclusive)") # parser.add_argument("--end", type=int, default=None, help="End frame index (exclusive)") # parser.add_argument("--out", type=str, required=True, help="Output .pkl filename") # args = parser.parse_args() # run(args.start, args.end, args.out) parser = argparse.ArgumentParser() parser.add_argument("--start", type=int, default=0) parser.add_argument("--end", type=int, default=-1) parser.add_argument("--out", type=str, required=True) parser.add_argument("--folder", type=str, default=None, help="If set, embed only JPGs under this specific folder (overrides --start/--end)") args = parser.parse_args() if args.folder: # collect all jpgs directly from the specified folder import glob specific_paths = sorted(glob.glob(os.path.join(args.folder, "**", "*.jpg"), recursive=True)) if not specific_paths: specific_paths = sorted(glob.glob(os.path.join(args.folder, "*.jpg"))) print(f"Folder mode: found {len(specific_paths)} JPGs under {args.folder}") run(0, len(specific_paths), args.out, override_paths=specific_paths) else: run(args.start, args.end, args.out)