| |
| """EC-SimToken 2-step smoke test. |
| |
| Verifies three core invariants before committing to 9-hour full training: |
| 1. exist_loss > 0 β is_null is reaching model_forward and BCE is computed |
| 2. mask_loss β 0 β null gate skips mask loss for null samples |
| 3. exist_logit.shape[0] == batch_size β tensor shapes are consistent |
| |
| Expected runtime: ~3-4 minutes (model load dominates), 2 forward passes. |
| |
| Usage: |
| cd /workspace/SimToken && conda activate simtoken |
| python tools/ec_simtoken_smoke_test.py 2>&1 | tee runs/ec_simtoken_smoke.log |
| """ |
|
|
| from __future__ import annotations |
| import os, sys, random |
| from argparse import Namespace |
| from functools import partial |
|
|
| import numpy as np |
| import torch |
| import transformers |
| from peft import LoraConfig, get_peft_model |
| from torch.utils.data import DataLoader |
| from transformers import AutoConfig |
|
|
| ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| sys.path.insert(0, ROOT) |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
| from datasets.dataset_refavs import REFAVS |
| from models.ec_simtoken_model import ECSimtoken_ForCausalLM |
|
|
| |
| MLLM = "/workspace/hf_models/Chat-UniVi-7B-v1.5" |
| SAM_CKPT = "/workspace/SimToken/models/segment_anything/sam_vit_h_4b8939.pth" |
| SIMTOKEN_CKPT = "/workspace/SimToken/checkpoints/simtoken_pretrained.pth" |
| DATA_DIR = "/workspace/SimToken/data" |
| VISION_TOWER = "/workspace/hf_models/clip-vit-large-patch14" |
| BATCH_SIZE = 4 |
|
|
| IGNORE_INDEX = -100 |
| IMAGE_TOKEN_INDEX = -200 |
| AUDIO_TOKEN_INDEX = -300 |
|
|
| |
| args = Namespace( |
| mllm=MLLM, |
| vision_pretrained=SAM_CKPT, |
| vision_tower=VISION_TOWER, |
| data_dir=DATA_DIR, |
| compress=True, |
| start=0, |
| batch_size=BATCH_SIZE, |
| exist_loss_weight=1.0, |
| frame_n=10, |
| text_max_len=25, |
| input_type="refer", |
| ct_weight=0.0, |
| conv_template=1, |
| ) |
|
|
|
|
| |
|
|
| import re |
|
|
| def tokenizer_image_audio_token(prompt, tokenizer, |
| image_token_index=IMAGE_TOKEN_INDEX, |
| audio_token_index=AUDIO_TOKEN_INDEX, |
| num_frames=10, return_tensors=None): |
| prompt_chunks = re.split(r'(<image>|<audio>|<video>)', prompt) |
| prompt_chunks = [c for c in prompt_chunks if c] |
| text_chunks, token_types = [], [] |
| for chunk in prompt_chunks: |
| if chunk == "<image>": |
| token_types.append("image") |
| elif chunk == "<audio>": |
| token_types.append("audio") |
| elif chunk == "<video>": |
| token_types.append("video") |
| else: |
| text_chunks.append(chunk) |
| tokenized_chunks = [tokenizer(c).input_ids for c in text_chunks] |
| input_ids = [] |
| offset = 0 |
| if tokenized_chunks and tokenized_chunks[0] and tokenized_chunks[0][0] == tokenizer.bos_token_id: |
| offset = 1 |
| input_ids.append(tokenized_chunks[0][0]) |
| min_len = min(len(text_chunks), len(token_types)) |
| for i in range(min_len): |
| input_ids.extend(tokenized_chunks[i][offset:]) |
| if token_types[i] == "image": |
| input_ids.append(image_token_index) |
| elif token_types[i] == "audio": |
| input_ids.append(audio_token_index) |
| elif token_types[i] == "video": |
| input_ids.extend([image_token_index] * num_frames) |
| if len(text_chunks) > min_len: |
| input_ids.extend(tokenized_chunks[min_len][offset:]) |
| if return_tensors == "pt": |
| return torch.tensor(input_ids, dtype=torch.long) |
| return input_ids |
|
|
|
|
| def collate_fn(batch, tokenizer=None): |
| vids, images, image_clips, masks, conversations = [], [], [], [], [] |
| audio_feats, image_feats, resizes, orgsizes = [], [], [], [] |
| refs, refs_num, fids = [], [], [] |
| for data in batch: |
| vids.append(data["vid"]) |
| images.append(data["image"]) |
| image_clips.append(data["img_clip"]) |
| masks.append(data["mask"]) |
| conversations.append(data["conversation"]) |
| audio_feats.append(data["feat_aud"]) |
| resizes.append(data["resize"]) |
| orgsizes.append(data["orgsize"]) |
| image_feats.append(data["feat_sam"]) |
| refs_num.append(len(data["ref"])) |
| fids.append(data["fids"]) |
| refs.append(data["ref"][0]) |
| input_ids = [ |
| tokenizer_image_audio_token(c, tokenizer, return_tensors="pt") |
| for c in conversations |
| ] |
| input_ids = torch.nn.utils.rnn.pad_sequence( |
| input_ids, batch_first=True, padding_value=tokenizer.pad_token_id |
| ) |
| attention_masks = input_ids.ne(tokenizer.pad_token_id) |
| ref_ids = [ |
| tokenizer_image_audio_token(r, tokenizer, return_tensors="pt") for r in refs |
| ] |
| labels = input_ids.clone() |
| sep = "Sure, it is [SEG]" |
| for conversation, target in zip(conversations, labels): |
| parts = conversation.split(sep) |
| cur_len = 1 |
| target[:cur_len] = IGNORE_INDEX |
| sep_len = len(tokenizer_image_audio_token(sep, tokenizer)) - 1 |
| for i in range(len(parts) - 1): |
| part_len = len(tokenizer_image_audio_token(parts[i], tokenizer)) - 2 |
| target[cur_len: cur_len + part_len] = IGNORE_INDEX |
| cur_len += part_len + sep_len |
| target[cur_len:] = IGNORE_INDEX |
| return { |
| "vids": vids, "images": images, "images_clip": image_clips, |
| "masks": masks, "convs": conversations, "input_ids": input_ids, |
| "attention_masks": attention_masks, "labels": labels, |
| "audio_feats": audio_feats, "resizes": resizes, "orgsizes": orgsizes, |
| "image_feats": image_feats, "ref_ids": ref_ids, |
| "refs_num": refs_num, "fids": fids, |
| } |
|
|
|
|
| def dict_to_cuda(d: dict) -> dict: |
| for k, v in d.items(): |
| if isinstance(v, torch.Tensor): |
| d[k] = v.cuda(non_blocking=True) |
| elif isinstance(v, list) and v and isinstance(v[0], torch.Tensor): |
| d[k] = [x.cuda(non_blocking=True) for x in v] |
| return d |
|
|
|
|
| |
|
|
| def build_model(args, tokenizer, seg_token_idx) -> ECSimtoken_ForCausalLM: |
| model_args = { |
| "train_mask_decoder": True, |
| "out_dim": 256, |
| "ce_loss_weight": 1.0, |
| "dice_loss_weight": 0.5, |
| "bce_loss_weight": 2.0, |
| "seg_token_idx": seg_token_idx, |
| "vision_pretrained": args.vision_pretrained, |
| "vision_tower": args.vision_tower, |
| "use_im_start_end": False, |
| "compress": args.compress, |
| "start": args.start, |
| "exist_loss_weight": args.exist_loss_weight, |
| } |
| model = ECSimtoken_ForCausalLM.from_pretrained( |
| args.mllm, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, **model_args |
| ) |
| model.config.eos_token_id = tokenizer.eos_token_id |
| model.config.bos_token_id = tokenizer.bos_token_id |
| model.config.pad_token_id = tokenizer.pad_token_id |
|
|
| model.enable_input_require_grads() |
| model.gradient_checkpointing_enable() |
|
|
| model.get_model().initialize_vision_modules(model.get_model().config) |
| vision_tower = model.get_model().get_vision_tower() |
| vision_tower.to(dtype=torch.bfloat16, device="cuda") |
|
|
| cfg_pt = AutoConfig.from_pretrained(args.mllm) |
| cfg_pt.use_cluster = True |
| cfg_pt.freeze = False |
| cfg_pt.mm_tune = True |
| cfg_pt.spatial_cluster_rate0 = 64 |
| cfg_pt.spatial_cluster_rate1 = 32 |
| cfg_pt.spatial_cluster_rate2 = 16 |
| cfg_pt.temporal_cluster_rate = 0.0625 |
| cfg_pt.vision_tune = False |
| model.get_model().initialize_cluster_modules(cfg_pt) |
| model.get_model().initialize_lisa_modules(model.get_model().config) |
|
|
| for p in vision_tower.parameters(): |
| p.requires_grad = False |
| for p in model.get_model().mm_projector.parameters(): |
| p.requires_grad = False |
|
|
| lora_r = 8 |
|
|
| def find_linear_layers(m, targets): |
| names = set() |
| skip = {"visual_model", "vision_tower", "mm_projector", |
| "text_hidden_fcs", "audio_feature_layer", "existence_head"} |
| for name, mod in m.named_modules(): |
| if (isinstance(mod, torch.nn.Linear) |
| and not any(s in name for s in skip) |
| and any(t in name for t in targets)): |
| names.add(name) |
| return sorted(names) |
|
|
| lora_config = LoraConfig( |
| r=lora_r, lora_alpha=16, |
| target_modules=find_linear_layers(model, ["q_proj", "v_proj"]), |
| lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", |
| ) |
| model = get_peft_model(model, lora_config) |
|
|
| model = model.to("cuda") |
| model = model.to(torch.bfloat16) |
| model.resize_token_embeddings(len(tokenizer)) |
|
|
| for n, p in model.named_parameters(): |
| if any(x in n for x in ["lm_head", "embed_tokens", "mask_decoder", |
| "text_hidden_fcs", "audio_feature_layer", |
| "existence_head"]): |
| p.requires_grad = True |
|
|
| return model |
|
|
|
|
| |
|
|
| def run_forward(model, batch, is_null): |
| is_null_cuda = is_null.cuda() |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| out = model.forward( |
| images=batch["images"], |
| images_clip=batch["images_clip"], |
| audio_features=batch["audio_feats"], |
| image_features=batch["image_feats"], |
| input_ids=batch["input_ids"], |
| labels=batch["labels"], |
| attention_masks=batch["attention_masks"], |
| masks_list=batch["masks"], |
| resize_list=batch["resizes"], |
| orgsize_list=batch["orgsizes"], |
| conversation_list=batch["convs"], |
| refs_num=batch["refs_num"], |
| fids=batch["fids"], |
| vids=batch["vids"], |
| ref_ids=batch["ref_ids"], |
| epoch=0, |
| inference=False, |
| contrast=0.0, |
| is_null=is_null_cuda, |
| ) |
| return out |
|
|
|
|
| |
|
|
| def main(): |
| print("=" * 60) |
| print("EC-SimToken Smoke Test") |
| print("=" * 60) |
|
|
| random.seed(42) |
| np.random.seed(42) |
| torch.manual_seed(42) |
|
|
| |
| print("\n[1/4] Loading tokenizer...") |
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| MLLM, model_max_length=2048, padding_side="right", use_fast=False, |
| ) |
| tokenizer.pad_token = tokenizer.unk_token |
| tokenizer.add_tokens("[SEG]") |
| seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] |
| print(f" seg_token_idx = {seg_token_idx}") |
|
|
| |
| print("\n[2/4] Loading dataset (train split)...") |
| dataset = REFAVS("train", args, tokenizer, input_type="refer") |
| loader = DataLoader( |
| dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, |
| collate_fn=partial(collate_fn, tokenizer=tokenizer), |
| ) |
| batch_iter = iter(loader) |
| batch0 = next(batch_iter) |
| batch1 = next(batch_iter) |
| print(f" Loaded 2 batches, batch_size={BATCH_SIZE}") |
|
|
| |
| print("\n[3/4] Building model and loading SimToken weights...") |
| model = build_model(args, tokenizer, seg_token_idx) |
|
|
| if os.path.exists(SIMTOKEN_CKPT): |
| ckpt = torch.load(SIMTOKEN_CKPT, map_location="cuda") |
| state = ckpt.get("model", ckpt) |
| missing, unexpected = model.load_state_dict(state, strict=False) |
| print(f" Loaded {SIMTOKEN_CKPT}") |
| print(f" missing={len(missing)}, unexpected={len(unexpected)}") |
| |
| eh_missing = [k for k in missing if "existence_head" in k] |
| print(f" existence_head keys in missing: {eh_missing} β expected") |
| else: |
| print(f" WARNING: {SIMTOKEN_CKPT} not found β using random init") |
|
|
| model.train() |
|
|
| |
| print("\n[4/4] Running 2-step smoke verification...") |
| results = {} |
|
|
| |
| print("\n Step 0: mixed null (every other sample is null)") |
| is_null_mixed = torch.zeros(BATCH_SIZE, dtype=torch.bool) |
| is_null_mixed[::2] = True |
| print(f" is_null = {is_null_mixed.tolist()}") |
|
|
| b0 = dict_to_cuda({k: v for k, v in batch0.items()}) |
| out0 = run_forward(model, b0, is_null_mixed) |
|
|
| exist_loss_val = out0["exist_loss"].item() |
| exist_logit_shape = out0["exist_logit"].shape |
| mask_loss_mixed = out0["mask_loss"].item() |
|
|
| print(f" exist_loss = {exist_loss_val:.4f}") |
| print(f" exist_logit shape= {exist_logit_shape}") |
| print(f" mask_loss (mixed)= {mask_loss_mixed:.4f}") |
|
|
| |
| if exist_loss_val > 0: |
| print(" β PASS: exist_loss > 0 (BCE is being computed)") |
| results["exist_loss_nonzero"] = True |
| else: |
| print(" β FAIL: exist_loss == 0 β is_null not reaching model_forward!") |
| results["exist_loss_nonzero"] = False |
|
|
| |
| if exist_logit_shape[0] == BATCH_SIZE: |
| print(f" β PASS: exist_logit.shape[0] == batch_size ({BATCH_SIZE})") |
| results["shape_consistent"] = True |
| else: |
| print(f" β FAIL: exist_logit.shape[0]={exist_logit_shape[0]} != batch_size={BATCH_SIZE}") |
| results["shape_consistent"] = False |
|
|
| |
| print("\n Step 1: all null (mask_loss gate check)") |
| is_null_all = torch.ones(BATCH_SIZE, dtype=torch.bool) |
| print(f" is_null = {is_null_all.tolist()}") |
|
|
| b1 = dict_to_cuda({k: v for k, v in batch1.items()}) |
| out1 = run_forward(model, b1, is_null_all) |
|
|
| mask_loss_all_null = out1["mask_loss"].item() |
| exist_loss_all_null = out1["exist_loss"].item() |
|
|
| print(f" mask_loss (all null) = {mask_loss_all_null:.6f}") |
| print(f" exist_loss (all null)= {exist_loss_all_null:.4f}") |
|
|
| |
| MASK_LOSS_TOL = 1e-3 |
| if mask_loss_all_null < MASK_LOSS_TOL: |
| print(f" β PASS: mask_loss < {MASK_LOSS_TOL} when all-null (null gate works)") |
| results["mask_gated"] = True |
| else: |
| print(f" β FAIL: mask_loss={mask_loss_all_null:.6f} is not near 0 when all samples are null!") |
| results["mask_gated"] = False |
|
|
| |
| print("\n" + "=" * 60) |
| print("SMOKE TEST SUMMARY") |
| print("=" * 60) |
|
|
| checks = [ |
| ("exist_loss > 0 (is_null reaches model_forward)", results.get("exist_loss_nonzero")), |
| ("mask_loss β 0 when all-null (null gate works)", results.get("mask_gated")), |
| ("exist_logit.shape[0] == batch_size", results.get("shape_consistent")), |
| ] |
|
|
| all_pass = True |
| for desc, passed in checks: |
| symbol = "β PASS" if passed else "β FAIL" |
| print(f" {symbol} {desc}") |
| if not passed: |
| all_pass = False |
|
|
| print() |
| if all_pass: |
| print("ALL CHECKS PASSED β safe to proceed with full EC-SimToken training.") |
| else: |
| print("ONE OR MORE CHECKS FAILED β fix before starting full training.") |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| import torch.multiprocessing as mp |
| try: |
| mp.set_start_method("spawn") |
| except RuntimeError: |
| pass |
| main() |
|
|