Spaces:
Runtime error
Runtime error
| """ | |
| embed_worker.py β Single-GPU embedding worker. | |
| Usage (run one per GPU in separate terminals): | |
| CUDA_VISIBLE_DEVICES=5 python embed_worker.py --start 12586 --end 16000 --out embeddings_gpu5.pkl | |
| CUDA_VISIBLE_DEVICES=6 python embed_worker.py --start 16000 --end 20000 --out embeddings_gpu6.pkl | |
| CUDA_VISIBLE_DEVICES=7 python embed_worker.py --start 20000 --end 24586 --out embeddings_gpu7.pkl | |
| The existing embeddings_blip2.pkl already has frames 0-12585 β don't re-do those. | |
| """ | |
| import argparse, os, sys, pickle | |
| import ssl | |
| ssl._create_default_https_context = ssl._create_unverified_context | |
| # ββ PATHS β adjust if needed ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| FRAMES_ROOT = "/media/RTCIN15TB/Datasets/NvidiaPhisicalAIFrames" | |
| CACHE_DIR = "model_cache" | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| local_cache_path = os.path.join(script_dir, CACHE_DIR) | |
| os.makedirs(local_cache_path, exist_ok=True) | |
| os.environ['HF_HOME'] = local_cache_path | |
| os.environ['LAVIS_CACHE_ROOT'] = local_cache_path | |
| # ββ LAVIS path ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| path_to_project_root = os.path.abspath(os.path.join(script_dir, "..")) | |
| path_to_lavis_parent_dir = os.path.join(path_to_project_root, "LAVIS") | |
| if not (os.path.isdir(path_to_lavis_parent_dir) and | |
| os.path.isdir(os.path.join(path_to_lavis_parent_dir, "lavis"))): | |
| path_to_lavis_parent_dir = "/media/RTCIN7TBDriveB/Interns/RDT2/gte3kor/LAVIS" | |
| sys.path.insert(0, path_to_lavis_parent_dir) | |
| # ββ Patches (same as main script) βββββββββββββββββββββββββββββββββββββββββββββ | |
| import torch, torch.nn as nn, torch.distributions.constraints as constraints, inspect | |
| from transformers.modeling_utils import PreTrainedModel | |
| from lavis.models.blip2_models.blip2_qformer import Blip2Qformer | |
| if hasattr(constraints, '_PositiveDefinite') and hasattr(constraints._PositiveDefinite, 'check'): | |
| _orig_pdc = constraints._PositiveDefinite.check | |
| def _patched_pdc(self, value): | |
| if isinstance(value, torch.Tensor) and value.is_meta: | |
| return torch.ones_like(value, dtype=torch.bool, device=value.device) | |
| return _orig_pdc(self, value) | |
| constraints._PositiveDefinite.check = _patched_pdc | |
| if hasattr(PreTrainedModel, '_init_added_embeddings_weights_with_mean'): | |
| _orig_iae = PreTrainedModel._init_added_embeddings_weights_with_mean | |
| def _patched_iae(self, new_emb, old_emb, num_added, *args, **kwargs): | |
| if not (isinstance(new_emb, nn.Embedding) and isinstance(old_emb, nn.Embedding)): | |
| return _orig_iae(self, new_emb, old_emb, num_added, *args, **kwargs) | |
| new_w, old_w = new_emb.weight, old_emb.weight | |
| if num_added > 0 and old_w.device.type == 'meta': | |
| start, end = old_w.shape[0], new_w.shape[0] | |
| sl = slice(start, end) | |
| if new_w.device.type != 'meta' and sl.start < sl.stop: | |
| with torch.no_grad(): | |
| new_w[sl].normal_(mean=0.0, std=self.config.initializer_range) | |
| return | |
| return _orig_iae(self, new_emb, old_emb, num_added, *args, **kwargs) | |
| PreTrainedModel._init_added_embeddings_weights_with_mean = _patched_iae | |
| _orig_lsd = nn.Module.load_state_dict | |
| # def _patched_lsd(self, state_dict, strict=True, assign=False): | |
| # if isinstance(self, Blip2Qformer): | |
| # model_sd = self.state_dict() | |
| # for key in ["Qformer.cls.predictions.bias", "Qformer.cls.predictions.decoder.weight"]: | |
| # if key in state_dict and key in model_sd: | |
| # ckpt_t, model_t = state_dict[key], model_sd[key] | |
| # if ckpt_t.shape[0] != model_t.shape[0]: | |
| # state_dict[key] = ckpt_t.narrow(0, 0, model_t.shape[0]) | |
| # if any(p.is_meta for p in self.parameters()): | |
| # assign = True | |
| # sig = inspect.signature(_orig_lsd) | |
| # if 'assign' in sig.parameters: | |
| # return _orig_lsd(self, state_dict, strict=strict, assign=assign) | |
| # else: | |
| # return _orig_lsd(self, state_dict, strict=strict) | |
| # nn.Module.load_state_dict = _patched_lsd | |
| def _patched_lsd(self, state_dict, strict=True, assign=False): | |
| if isinstance(self, Blip2Qformer): | |
| model_sd = self.state_dict() | |
| for key in ["Qformer.cls.predictions.bias", "Qformer.cls.predictions.decoder.weight"]: | |
| if key in state_dict and key in model_sd: | |
| ckpt_t, model_t = state_dict[key], model_sd[key] | |
| if ckpt_t.shape[0] != model_t.shape[0]: | |
| state_dict[key] = ckpt_t.narrow(0, 0, model_t.shape[0]) | |
| # ββ ADD THIS BLOCK β interpolate pos_embed if size mismatch ββ | |
| # if "visual_encoder.pos_embed" in state_dict and "visual_encoder.pos_embed" in model_sd: | |
| # ckpt_pos = state_dict["visual_encoder.pos_embed"] # [1, 1381, 1408] | |
| # model_pos = model_sd["visual_encoder.pos_embed"] # [1, 1370, 1408] | |
| # if ckpt_pos.shape != model_pos.shape: | |
| # import torch.nn.functional as F | |
| # # strip cls token, interpolate, re-attach | |
| # cls_tok = ckpt_pos[:, :1, :] # [1, 1, 1408] | |
| # patches = ckpt_pos[:, 1:, :] # [1, N_ckpt, 1408] | |
| # N_model = model_pos.shape[1] - 1 | |
| # # reshape to 2D grid, interpolate, reshape back | |
| # import math | |
| # gs_ckpt = int(math.sqrt(patches.shape[1])) # 37 | |
| # gs_model = int(math.sqrt(N_model)) # target grid | |
| # dim = patches.shape[-1] | |
| # print(f"DEBUG: patches.shape={patches.shape}, gs_ckpt={gs_ckpt}, gs_ckpt^2={gs_ckpt*gs_ckpt}") | |
| # patches = patches.reshape(1, gs_ckpt, gs_ckpt, dim).permute(0, 3, 1, 2) | |
| # # patches = patches.reshape(1, gs_ckpt, gs_ckpt, 1408).permute(0, 3, 1, 2) # [1,1408,37,37] | |
| # patches = F.interpolate(patches.float(), size=(gs_model, gs_model), mode='bicubic', align_corners=False) | |
| # patches = patches.permute(0, 2, 3, 1).reshape(1, gs_model*gs_model, dim) | |
| # state_dict["visual_encoder.pos_embed"] = torch.cat([cls_tok, patches], dim=1) | |
| # print(f"INFO: Interpolated pos_embed {ckpt_pos.shape} β {state_dict['visual_encoder.pos_embed'].shape}") | |
| if "visual_encoder.pos_embed" in state_dict and "visual_encoder.pos_embed" in model_sd: | |
| ckpt_pos = state_dict["visual_encoder.pos_embed"] | |
| model_pos = model_sd["visual_encoder.pos_embed"] | |
| if ckpt_pos.shape != model_pos.shape: | |
| import torch.nn.functional as F, math | |
| print(f"DEBUG: ckpt_pos={ckpt_pos.shape}, model_pos={model_pos.shape}") | |
| cls_tok = ckpt_pos[:, :1, :] | |
| patches = ckpt_pos[:, 1:, :] | |
| dim = patches.shape[-1] | |
| N_ckpt = patches.shape[1] | |
| N_model = model_pos.shape[1] - 1 | |
| print(f"DEBUG: N_ckpt={N_ckpt}, N_model={N_model}, dim={dim}") | |
| # find grid sizes β may not be square | |
| gs_ckpt_h = gs_ckpt_w = int(math.sqrt(N_ckpt)) | |
| # if not perfect square, brute force find h,w factors | |
| if gs_ckpt_h * gs_ckpt_w != N_ckpt: | |
| for h in range(int(math.sqrt(N_ckpt)), 0, -1): | |
| if N_ckpt % h == 0: | |
| gs_ckpt_h, gs_ckpt_w = h, N_ckpt // h | |
| break | |
| gs_model_h = gs_model_w = int(math.sqrt(N_model)) | |
| if gs_model_h * gs_model_w != N_model: | |
| for h in range(int(math.sqrt(N_model)), 0, -1): | |
| if N_model % h == 0: | |
| gs_model_h, gs_model_w = h, N_model // h | |
| break | |
| print(f"DEBUG: ckpt grid={gs_ckpt_h}x{gs_ckpt_w}, model grid={gs_model_h}x{gs_model_w}") | |
| patches = patches.reshape(1, gs_ckpt_h, gs_ckpt_w, dim).permute(0, 3, 1, 2) | |
| patches = F.interpolate(patches.float(), size=(gs_model_h, gs_model_w), mode='bicubic', align_corners=False) | |
| patches = patches.permute(0, 2, 3, 1).reshape(1, gs_model_h * gs_model_w, dim) | |
| state_dict["visual_encoder.pos_embed"] = torch.cat([cls_tok, patches], dim=1) | |
| print(f"INFO: Interpolated pos_embed {ckpt_pos.shape} β {state_dict['visual_encoder.pos_embed'].shape}") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if any(p.is_meta for p in self.parameters()): | |
| assign = True | |
| sig = inspect.signature(_orig_lsd) | |
| if 'assign' in sig.parameters: | |
| return _orig_lsd(self, state_dict, strict=strict, assign=assign) | |
| else: | |
| return _orig_lsd(self, state_dict, strict=strict) | |
| nn.Module.load_state_dict = _patched_lsd | |
| print("INFO: Patched nn.Module.load_state_dict.") | |
| # ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| from lavis.models import load_model_and_preprocess | |
| from PIL import Image | |
| import numpy as np | |
| def discover_frames(root): | |
| all_paths = [] | |
| for chunk in sorted(os.listdir(root)): | |
| cp = os.path.join(root, chunk) | |
| if not os.path.isdir(cp) or chunk.startswith('.'): continue | |
| for vid in sorted(os.listdir(cp)): | |
| vp = os.path.join(cp, vid) | |
| if not os.path.isdir(vp): continue | |
| all_paths.extend(sorted( | |
| os.path.join(vp, f) for f in os.listdir(vp) | |
| if f.lower().endswith('.jpg') | |
| )) | |
| return all_paths | |
| def run(start: int, end: int, out_pkl: str, override_paths=None): | |
| device = "cuda:0" # CUDA_VISIBLE_DEVICES remaps the GPU to index 0 | |
| # print(f"Loading BLIP-2 on {device} β¦") | |
| # model, vis_processors, _ = load_model_and_preprocess( | |
| # name="blip2", model_type="gen3_322_840", is_eval=True, device=device | |
| # ) | |
| model, vis_processors, text_processors = load_model_and_preprocess( | |
| name="blip2", | |
| model_type="gen3_518_518", | |
| is_eval=True, | |
| device=device | |
| ) | |
| model.eval() | |
| # print("Discovering frames β¦") | |
| # all_paths = discover_frames(FRAMES_ROOT) | |
| # slice_paths = all_paths[start:end] | |
| # print(f"This worker: frames [{start}, {end}) β {len(slice_paths)} paths") | |
| if override_paths is not None: | |
| all_paths = override_paths | |
| start = 0 | |
| end = len(all_paths) | |
| slice_paths = all_paths | |
| total = len(all_paths) | |
| # π₯ AUTO HANDLE END | |
| if end is None: | |
| end = total | |
| # clamp safety | |
| end = min(end, total) | |
| if start >= total: | |
| print(f"β οΈ Start {start} exceeds total frames {total}") | |
| return | |
| print(f"Processing range [{start}, {end}) out of {total}") | |
| # Resume support: load existing partial output if present | |
| if os.path.exists(out_pkl): | |
| with open(out_pkl, "rb") as f: | |
| embedding_dict = pickle.load(f) | |
| print(f"Resumed from {out_pkl}: {len(embedding_dict)} already done.") | |
| else: | |
| embedding_dict = {} | |
| todo = [p for p in slice_paths if p not in embedding_dict] | |
| print(f"{len(todo)} frames still need embedding.") | |
| BATCH_SIZE = 32 | |
| SAVE_EVERY = 500 | |
| computed = 0 | |
| model.eval() | |
| with torch.no_grad(): | |
| for i in range(0, len(todo), BATCH_SIZE): | |
| batch_paths = todo[i: i + BATCH_SIZE] | |
| images, valid_paths = [], [] | |
| for p in batch_paths: | |
| try: | |
| img = Image.open(p).convert("RGB") | |
| images.append(vis_processors["eval"](img)) | |
| valid_paths.append(p) | |
| except Exception as e: | |
| print(f" WARNING: {p}: {e}") | |
| if not images: | |
| continue | |
| image_tensor = torch.stack(images, dim=0).to(device) | |
| feats = model.extract_features( | |
| {"image": image_tensor}, mode="image" | |
| ).image_embeds_proj[:, 0, :] | |
| for path, emb in zip(valid_paths, feats): | |
| embedding_dict[path] = emb.cpu() | |
| computed += len(valid_paths) | |
| if computed % SAVE_EVERY == 0: | |
| print(f" [{computed}/{len(todo)}] Saving checkpoint β {out_pkl}") | |
| with open(out_pkl, "wb") as f: | |
| pickle.dump(embedding_dict, f) | |
| print(f"Done. Saving final output β {out_pkl} ({len(embedding_dict)} embeddings)") | |
| with open(out_pkl, "wb") as f: | |
| pickle.dump(embedding_dict, f) | |
| if __name__ == "__main__": | |
| # parser = argparse.ArgumentParser() | |
| # parser.add_argument("--start", type=int, required=True, help="Start frame index (inclusive)") | |
| # parser.add_argument("--end", type=int, default=None, help="End frame index (exclusive)") | |
| # parser.add_argument("--out", type=str, required=True, help="Output .pkl filename") | |
| # args = parser.parse_args() | |
| # run(args.start, args.end, args.out) | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--start", type=int, default=0) | |
| parser.add_argument("--end", type=int, default=-1) | |
| parser.add_argument("--out", type=str, required=True) | |
| parser.add_argument("--folder", type=str, default=None, | |
| help="If set, embed only JPGs under this specific folder (overrides --start/--end)") | |
| args = parser.parse_args() | |
| if args.folder: | |
| # collect all jpgs directly from the specified folder | |
| import glob | |
| specific_paths = sorted(glob.glob(os.path.join(args.folder, "**", "*.jpg"), recursive=True)) | |
| if not specific_paths: | |
| specific_paths = sorted(glob.glob(os.path.join(args.folder, "*.jpg"))) | |
| print(f"Folder mode: found {len(specific_paths)} JPGs under {args.folder}") | |
| run(0, len(specific_paths), args.out, override_paths=specific_paths) | |
| else: | |
| run(args.start, args.end, args.out) |