SimToken / tools /ec_simtoken_smoke_test.py
yfan07's picture
Upload folder using huggingface_hub
9af2926 verified
#!/usr/bin/env python
"""EC-SimToken 2-step smoke test.
Verifies three core invariants before committing to 9-hour full training:
1. exist_loss > 0 β€” is_null is reaching model_forward and BCE is computed
2. mask_loss β‰ˆ 0 β€” null gate skips mask loss for null samples
3. exist_logit.shape[0] == batch_size β€” tensor shapes are consistent
Expected runtime: ~3-4 minutes (model load dominates), 2 forward passes.
Usage:
cd /workspace/SimToken && conda activate simtoken
python tools/ec_simtoken_smoke_test.py 2>&1 | tee runs/ec_simtoken_smoke.log
"""
from __future__ import annotations
import os, sys, random
from argparse import Namespace
from functools import partial
import numpy as np
import torch
import transformers
from peft import LoraConfig, get_peft_model
from torch.utils.data import DataLoader
from transformers import AutoConfig
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, ROOT)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from datasets.dataset_refavs import REFAVS
from models.ec_simtoken_model import ECSimtoken_ForCausalLM
# ── Paths & constants ─────────────────────────────────────────────────────────
MLLM = "/workspace/hf_models/Chat-UniVi-7B-v1.5"
SAM_CKPT = "/workspace/SimToken/models/segment_anything/sam_vit_h_4b8939.pth"
SIMTOKEN_CKPT = "/workspace/SimToken/checkpoints/simtoken_pretrained.pth"
DATA_DIR = "/workspace/SimToken/data"
VISION_TOWER = "/workspace/hf_models/clip-vit-large-patch14"
BATCH_SIZE = 4
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
AUDIO_TOKEN_INDEX = -300
# ── Minimal args namespace ─────────────────────────────────────────────────────
args = Namespace(
mllm=MLLM,
vision_pretrained=SAM_CKPT,
vision_tower=VISION_TOWER,
data_dir=DATA_DIR,
compress=True,
start=0,
batch_size=BATCH_SIZE,
exist_loss_weight=1.0,
frame_n=10,
text_max_len=25,
input_type="refer",
ct_weight=0.0, # disable contrastive for smoke test
conv_template=1,
)
# ── Collate (mirrors train_ec_simtoken.py) ────────────────────────────────────
import re
def tokenizer_image_audio_token(prompt, tokenizer,
image_token_index=IMAGE_TOKEN_INDEX,
audio_token_index=AUDIO_TOKEN_INDEX,
num_frames=10, return_tensors=None):
prompt_chunks = re.split(r'(<image>|<audio>|<video>)', prompt)
prompt_chunks = [c for c in prompt_chunks if c]
text_chunks, token_types = [], []
for chunk in prompt_chunks:
if chunk == "<image>":
token_types.append("image")
elif chunk == "<audio>":
token_types.append("audio")
elif chunk == "<video>":
token_types.append("video")
else:
text_chunks.append(chunk)
tokenized_chunks = [tokenizer(c).input_ids for c in text_chunks]
input_ids = []
offset = 0
if tokenized_chunks and tokenized_chunks[0] and tokenized_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(tokenized_chunks[0][0])
min_len = min(len(text_chunks), len(token_types))
for i in range(min_len):
input_ids.extend(tokenized_chunks[i][offset:])
if token_types[i] == "image":
input_ids.append(image_token_index)
elif token_types[i] == "audio":
input_ids.append(audio_token_index)
elif token_types[i] == "video":
input_ids.extend([image_token_index] * num_frames)
if len(text_chunks) > min_len:
input_ids.extend(tokenized_chunks[min_len][offset:])
if return_tensors == "pt":
return torch.tensor(input_ids, dtype=torch.long)
return input_ids
def collate_fn(batch, tokenizer=None):
vids, images, image_clips, masks, conversations = [], [], [], [], []
audio_feats, image_feats, resizes, orgsizes = [], [], [], []
refs, refs_num, fids = [], [], []
for data in batch:
vids.append(data["vid"])
images.append(data["image"])
image_clips.append(data["img_clip"])
masks.append(data["mask"])
conversations.append(data["conversation"])
audio_feats.append(data["feat_aud"])
resizes.append(data["resize"])
orgsizes.append(data["orgsize"])
image_feats.append(data["feat_sam"])
refs_num.append(len(data["ref"]))
fids.append(data["fids"])
refs.append(data["ref"][0])
input_ids = [
tokenizer_image_audio_token(c, tokenizer, return_tensors="pt")
for c in conversations
]
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
)
attention_masks = input_ids.ne(tokenizer.pad_token_id)
ref_ids = [
tokenizer_image_audio_token(r, tokenizer, return_tensors="pt") for r in refs
]
labels = input_ids.clone()
sep = "Sure, it is [SEG]"
for conversation, target in zip(conversations, labels):
parts = conversation.split(sep)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
sep_len = len(tokenizer_image_audio_token(sep, tokenizer)) - 1
for i in range(len(parts) - 1):
part_len = len(tokenizer_image_audio_token(parts[i], tokenizer)) - 2
target[cur_len: cur_len + part_len] = IGNORE_INDEX
cur_len += part_len + sep_len
target[cur_len:] = IGNORE_INDEX
return {
"vids": vids, "images": images, "images_clip": image_clips,
"masks": masks, "convs": conversations, "input_ids": input_ids,
"attention_masks": attention_masks, "labels": labels,
"audio_feats": audio_feats, "resizes": resizes, "orgsizes": orgsizes,
"image_feats": image_feats, "ref_ids": ref_ids,
"refs_num": refs_num, "fids": fids,
}
def dict_to_cuda(d: dict) -> dict:
for k, v in d.items():
if isinstance(v, torch.Tensor):
d[k] = v.cuda(non_blocking=True)
elif isinstance(v, list) and v and isinstance(v[0], torch.Tensor):
d[k] = [x.cuda(non_blocking=True) for x in v]
return d
# ── Build model (mirrors train_ec_simtoken.build_model) ───────────────────────
def build_model(args, tokenizer, seg_token_idx) -> ECSimtoken_ForCausalLM:
model_args = {
"train_mask_decoder": True,
"out_dim": 256,
"ce_loss_weight": 1.0,
"dice_loss_weight": 0.5,
"bce_loss_weight": 2.0,
"seg_token_idx": seg_token_idx,
"vision_pretrained": args.vision_pretrained,
"vision_tower": args.vision_tower,
"use_im_start_end": False,
"compress": args.compress,
"start": args.start,
"exist_loss_weight": args.exist_loss_weight,
}
model = ECSimtoken_ForCausalLM.from_pretrained(
args.mllm, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, **model_args
)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model.get_model().initialize_vision_modules(model.get_model().config)
vision_tower = model.get_model().get_vision_tower()
vision_tower.to(dtype=torch.bfloat16, device="cuda")
cfg_pt = AutoConfig.from_pretrained(args.mllm)
cfg_pt.use_cluster = True
cfg_pt.freeze = False
cfg_pt.mm_tune = True
cfg_pt.spatial_cluster_rate0 = 64
cfg_pt.spatial_cluster_rate1 = 32
cfg_pt.spatial_cluster_rate2 = 16
cfg_pt.temporal_cluster_rate = 0.0625
cfg_pt.vision_tune = False
model.get_model().initialize_cluster_modules(cfg_pt)
model.get_model().initialize_lisa_modules(model.get_model().config)
for p in vision_tower.parameters():
p.requires_grad = False
for p in model.get_model().mm_projector.parameters():
p.requires_grad = False
lora_r = 8
def find_linear_layers(m, targets):
names = set()
skip = {"visual_model", "vision_tower", "mm_projector",
"text_hidden_fcs", "audio_feature_layer", "existence_head"}
for name, mod in m.named_modules():
if (isinstance(mod, torch.nn.Linear)
and not any(s in name for s in skip)
and any(t in name for t in targets)):
names.add(name)
return sorted(names)
lora_config = LoraConfig(
r=lora_r, lora_alpha=16,
target_modules=find_linear_layers(model, ["q_proj", "v_proj"]),
lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model = model.to("cuda")
model = model.to(torch.bfloat16)
model.resize_token_embeddings(len(tokenizer))
for n, p in model.named_parameters():
if any(x in n for x in ["lm_head", "embed_tokens", "mask_decoder",
"text_hidden_fcs", "audio_feature_layer",
"existence_head"]):
p.requires_grad = True
return model
# ── Forward helper ────────────────────────────────────────────────────────────
def run_forward(model, batch, is_null):
is_null_cuda = is_null.cuda()
with torch.autocast("cuda", dtype=torch.bfloat16):
out = model.forward(
images=batch["images"],
images_clip=batch["images_clip"],
audio_features=batch["audio_feats"],
image_features=batch["image_feats"],
input_ids=batch["input_ids"],
labels=batch["labels"],
attention_masks=batch["attention_masks"],
masks_list=batch["masks"],
resize_list=batch["resizes"],
orgsize_list=batch["orgsizes"],
conversation_list=batch["convs"],
refs_num=batch["refs_num"],
fids=batch["fids"],
vids=batch["vids"],
ref_ids=batch["ref_ids"],
epoch=0,
inference=False,
contrast=0.0,
is_null=is_null_cuda,
)
return out
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
print("=" * 60)
print("EC-SimToken Smoke Test")
print("=" * 60)
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
# ── Tokenizer ─────────────────────────────────────────────────────────────
print("\n[1/4] Loading tokenizer...")
tokenizer = transformers.AutoTokenizer.from_pretrained(
MLLM, model_max_length=2048, padding_side="right", use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.add_tokens("[SEG]")
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
print(f" seg_token_idx = {seg_token_idx}")
# ── Dataset (train split, 2 batches) ─────────────────────────────────────
print("\n[2/4] Loading dataset (train split)...")
dataset = REFAVS("train", args, tokenizer, input_type="refer")
loader = DataLoader(
dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,
collate_fn=partial(collate_fn, tokenizer=tokenizer),
)
batch_iter = iter(loader)
batch0 = next(batch_iter)
batch1 = next(batch_iter)
print(f" Loaded 2 batches, batch_size={BATCH_SIZE}")
# ── Model ─────────────────────────────────────────────────────────────────
print("\n[3/4] Building model and loading SimToken weights...")
model = build_model(args, tokenizer, seg_token_idx)
if os.path.exists(SIMTOKEN_CKPT):
ckpt = torch.load(SIMTOKEN_CKPT, map_location="cuda")
state = ckpt.get("model", ckpt)
missing, unexpected = model.load_state_dict(state, strict=False)
print(f" Loaded {SIMTOKEN_CKPT}")
print(f" missing={len(missing)}, unexpected={len(unexpected)}")
# existence_head should be in missing (not in SimToken checkpoint)
eh_missing = [k for k in missing if "existence_head" in k]
print(f" existence_head keys in missing: {eh_missing} ← expected")
else:
print(f" WARNING: {SIMTOKEN_CKPT} not found β€” using random init")
model.train()
# ── Smoke assertions ──────────────────────────────────────────────────────
print("\n[4/4] Running 2-step smoke verification...")
results = {}
# ─── Step 0: mixed null β€” verify exist_loss > 0 and shape ────────────────
print("\n Step 0: mixed null (every other sample is null)")
is_null_mixed = torch.zeros(BATCH_SIZE, dtype=torch.bool)
is_null_mixed[::2] = True # indices 0, 2 are null
print(f" is_null = {is_null_mixed.tolist()}")
b0 = dict_to_cuda({k: v for k, v in batch0.items()})
out0 = run_forward(model, b0, is_null_mixed)
exist_loss_val = out0["exist_loss"].item()
exist_logit_shape = out0["exist_logit"].shape
mask_loss_mixed = out0["mask_loss"].item()
print(f" exist_loss = {exist_loss_val:.4f}")
print(f" exist_logit shape= {exist_logit_shape}")
print(f" mask_loss (mixed)= {mask_loss_mixed:.4f}")
# Assertion 1: exist_loss > 0
if exist_loss_val > 0:
print(" βœ“ PASS: exist_loss > 0 (BCE is being computed)")
results["exist_loss_nonzero"] = True
else:
print(" βœ— FAIL: exist_loss == 0 β†’ is_null not reaching model_forward!")
results["exist_loss_nonzero"] = False
# Assertion 3: shape consistency
if exist_logit_shape[0] == BATCH_SIZE:
print(f" βœ“ PASS: exist_logit.shape[0] == batch_size ({BATCH_SIZE})")
results["shape_consistent"] = True
else:
print(f" βœ— FAIL: exist_logit.shape[0]={exist_logit_shape[0]} != batch_size={BATCH_SIZE}")
results["shape_consistent"] = False
# ─── Step 1: all-null β€” verify mask_loss β‰ˆ 0 ─────────────────────────────
print("\n Step 1: all null (mask_loss gate check)")
is_null_all = torch.ones(BATCH_SIZE, dtype=torch.bool)
print(f" is_null = {is_null_all.tolist()}")
b1 = dict_to_cuda({k: v for k, v in batch1.items()})
out1 = run_forward(model, b1, is_null_all)
mask_loss_all_null = out1["mask_loss"].item()
exist_loss_all_null = out1["exist_loss"].item()
print(f" mask_loss (all null) = {mask_loss_all_null:.6f}")
print(f" exist_loss (all null)= {exist_loss_all_null:.4f}")
# Assertion 2: mask_loss β‰ˆ 0 when all null
MASK_LOSS_TOL = 1e-3
if mask_loss_all_null < MASK_LOSS_TOL:
print(f" βœ“ PASS: mask_loss < {MASK_LOSS_TOL} when all-null (null gate works)")
results["mask_gated"] = True
else:
print(f" βœ— FAIL: mask_loss={mask_loss_all_null:.6f} is not near 0 when all samples are null!")
results["mask_gated"] = False
# ─── Summary ──────────────────────────────────────────────────────────────
print("\n" + "=" * 60)
print("SMOKE TEST SUMMARY")
print("=" * 60)
checks = [
("exist_loss > 0 (is_null reaches model_forward)", results.get("exist_loss_nonzero")),
("mask_loss β‰ˆ 0 when all-null (null gate works)", results.get("mask_gated")),
("exist_logit.shape[0] == batch_size", results.get("shape_consistent")),
]
all_pass = True
for desc, passed in checks:
symbol = "βœ“ PASS" if passed else "βœ— FAIL"
print(f" {symbol} {desc}")
if not passed:
all_pass = False
print()
if all_pass:
print("ALL CHECKS PASSED β€” safe to proceed with full EC-SimToken training.")
else:
print("ONE OR MORE CHECKS FAILED β€” fix before starting full training.")
sys.exit(1)
if __name__ == "__main__":
import torch.multiprocessing as mp
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
main()