| """EC-SimToken training script. |
| |
| Adds existence head + synthetic null augmentation to SimToken. |
| |
| Key differences from train.py: |
| - Uses ECSimtoken_ForCausalLM (adds existence_head: Linear(256,1)) |
| - Audio-swap null augmentation: p_null fraction of batch items have |
| their audio replaced with another sample's audio β synthetic null |
| - is_null tensor passed to model_forward to gate mask loss |
| - test_n evaluation uses existence head (p_exist threshold) for Null S |
| |
| Usage (training): |
| python train_ec_simtoken.py \ |
| --data_dir data \ |
| --mllm Chat-UniVi/Chat-UniVi-7B-v1.5 \ |
| --vision_pretrained path/to/sam_vit_h_4b8939.pth \ |
| --name ec_simtoken_v1 \ |
| --epochs 10 \ |
| --batch_size 12 \ |
| --null_aug_prob 0.25 \ |
| --exist_loss_weight 1.0 |
| |
| Usage (eval only): |
| python train_ec_simtoken.py --run eval \ |
| --saved_model checkpoints/ec_simtoken_v1.pth \ |
| --eval_splits test_s,test_u,test_n |
| """ |
|
|
| import argparse |
| import os |
| import random |
| import warnings |
| from functools import partial |
|
|
| import numpy as np |
| import torch |
| import torch.multiprocessing as mp |
| import transformers |
| from peft import LoraConfig, get_peft_model |
| from torch.optim import AdamW |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from transformers import AutoConfig, get_cosine_schedule_with_warmup, logging |
|
|
| warnings.filterwarnings("ignore") |
| logging.set_verbosity_error() |
|
|
| import re |
|
|
| |
|
|
| IGNORE_INDEX = -100 |
| IMAGE_TOKEN_INDEX = -200 |
| AUDIO_TOKEN_INDEX = -300 |
|
|
|
|
| |
|
|
| from configs import args as base_args |
|
|
| _ec_parser = argparse.ArgumentParser(add_help=False) |
| _ec_parser.add_argument("--null_aug_prob", type=float, default=0.25, |
| help="Fraction of batch items with swapped audio (null aug)") |
| _ec_parser.add_argument("--exist_loss_weight", type=float, default=1.0, |
| help="Weight for BCE existence loss") |
| _ec_parser.add_argument("--exist_threshold", type=float, default=0.5, |
| help="p_exist sigmoid threshold for null classification") |
|
|
| ec_args, _ = _ec_parser.parse_known_args() |
| |
| args = base_args |
| args.null_aug_prob = ec_args.null_aug_prob |
| args.exist_loss_weight = ec_args.exist_loss_weight |
| args.exist_threshold = ec_args.exist_threshold |
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id |
|
|
|
|
| from datasets import REFAVS |
| from models.ec_simtoken_model import ECSimtoken_ForCausalLM |
| from utils import utility |
|
|
|
|
| |
|
|
| def set_seed(seed: int = 42): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| os.environ["PYTHONHASHSEED"] = str(seed) |
| |
| torch.backends.cudnn.benchmark = True |
|
|
|
|
| def seed_worker(worker_id): |
| seed = torch.initial_seed() % 2 ** 32 |
| np.random.seed(seed) |
| random.seed(seed) |
|
|
|
|
| def dict_to_cuda(d: dict) -> dict: |
| for k, v in d.items(): |
| if isinstance(v, torch.Tensor): |
| d[k] = v.cuda(non_blocking=True) |
| elif isinstance(v, list) and v and isinstance(v[0], torch.Tensor): |
| d[k] = [x.cuda(non_blocking=True) for x in v] |
| return d |
|
|
|
|
| |
|
|
| def apply_null_augmentation( |
| audio_feats: list, p_null: float = 0.25 |
| ) -> tuple[list, torch.BoolTensor]: |
| """Randomly replace some audio features with mismatched ones. |
| |
| Returns the (possibly mutated) audio_feats list and a bool tensor |
| `is_null` where True means the sample's audio was swapped. |
| """ |
| B = len(audio_feats) |
| is_null = torch.zeros(B, dtype=torch.bool) |
| if B < 2 or p_null <= 0.0: |
| return audio_feats, is_null |
|
|
| for i in range(B): |
| if random.random() < p_null: |
| candidates = [j for j in range(B) if j != i] |
| j = random.choice(candidates) |
| audio_feats[i] = audio_feats[j].clone() |
| is_null[i] = True |
|
|
| return audio_feats, is_null |
|
|
|
|
| |
|
|
| def tokenizer_image_audio_token( |
| prompt, tokenizer, |
| image_token_index=IMAGE_TOKEN_INDEX, |
| audio_token_index=AUDIO_TOKEN_INDEX, |
| num_frames=10, |
| return_tensors=None, |
| ): |
| prompt_chunks = re.split(r'(<image>|<audio>|<video>)', prompt) |
| prompt_chunks = [c for c in prompt_chunks if c] |
|
|
| text_chunks, token_types = [], [] |
| for chunk in prompt_chunks: |
| if chunk == "<image>": |
| token_types.append("image") |
| elif chunk == "<audio>": |
| token_types.append("audio") |
| elif chunk == "<video>": |
| token_types.append("video") |
| else: |
| text_chunks.append(chunk) |
|
|
| tokenized_chunks = [tokenizer(c).input_ids for c in text_chunks] |
|
|
| input_ids = [] |
| offset = 0 |
| if tokenized_chunks and tokenized_chunks[0] and tokenized_chunks[0][0] == tokenizer.bos_token_id: |
| offset = 1 |
| input_ids.append(tokenized_chunks[0][0]) |
|
|
| min_len = min(len(text_chunks), len(token_types)) |
| for i in range(min_len): |
| input_ids.extend(tokenized_chunks[i][offset:]) |
| if token_types[i] == "image": |
| input_ids.append(image_token_index) |
| elif token_types[i] == "audio": |
| input_ids.append(audio_token_index) |
| elif token_types[i] == "video": |
| input_ids.extend([image_token_index] * num_frames) |
| if len(text_chunks) > min_len: |
| input_ids.extend(tokenized_chunks[min_len][offset:]) |
|
|
| if return_tensors == "pt": |
| return torch.tensor(input_ids, dtype=torch.long) |
| return input_ids |
|
|
|
|
| def collate_fn(batch, tokenizer=None): |
| vids, images, image_clips, masks, conversations = [], [], [], [], [] |
| audio_feats, image_feats, resizes, orgsizes = [], [], [], [] |
| refs, refs_num, fids = [], [], [] |
|
|
| for data in batch: |
| vids.append(data["vid"]) |
| images.append(data["image"]) |
| image_clips.append(data["img_clip"]) |
| masks.append(data["mask"]) |
| conversations.append(data["conversation"]) |
| audio_feats.append(data["feat_aud"]) |
| resizes.append(data["resize"]) |
| orgsizes.append(data["orgsize"]) |
| image_feats.append(data["feat_sam"]) |
| refs_num.append(len(data["ref"])) |
| fids.append(data["fids"]) |
| refs.append(data["ref"][0]) |
|
|
| input_ids = [ |
| tokenizer_image_audio_token(c, tokenizer, return_tensors="pt") |
| for c in conversations |
| ] |
| input_ids = torch.nn.utils.rnn.pad_sequence( |
| input_ids, batch_first=True, padding_value=tokenizer.pad_token_id |
| ) |
| attention_masks = input_ids.ne(tokenizer.pad_token_id) |
|
|
| ref_ids = [ |
| tokenizer_image_audio_token(r, tokenizer, return_tensors="pt") for r in refs |
| ] |
|
|
| labels = input_ids.clone() |
| sep = "Sure, it is [SEG]" |
| for conversation, target in zip(conversations, labels): |
| parts = conversation.split(sep) |
| cur_len = 1 |
| target[:cur_len] = IGNORE_INDEX |
| sep_len = len(tokenizer_image_audio_token(sep, tokenizer)) - 1 |
| for i in range(len(parts) - 1): |
| part_len = len(tokenizer_image_audio_token(parts[i], tokenizer)) - 2 |
| target[cur_len : cur_len + part_len] = IGNORE_INDEX |
| cur_len += part_len + sep_len |
| target[cur_len:] = IGNORE_INDEX |
|
|
| return { |
| "vids": vids, |
| "images": images, |
| "images_clip": image_clips, |
| "masks": masks, |
| "convs": conversations, |
| "input_ids": input_ids, |
| "attention_masks": attention_masks, |
| "labels": labels, |
| "audio_feats": audio_feats, |
| "resizes": resizes, |
| "orgsizes": orgsizes, |
| "image_feats": image_feats, |
| "ref_ids": ref_ids, |
| "refs_num": refs_num, |
| "fids": fids, |
| } |
|
|
|
|
| |
|
|
| def build_model(args, tokenizer, seg_token_idx) -> ECSimtoken_ForCausalLM: |
| model_args = { |
| "train_mask_decoder": True, |
| "out_dim": 256, |
| "ce_loss_weight": 1.0, |
| "dice_loss_weight": 0.5, |
| "bce_loss_weight": 2.0, |
| "seg_token_idx": seg_token_idx, |
| "vision_pretrained": args.vision_pretrained, |
| "vision_tower": args.vision_tower, |
| "use_im_start_end": False, |
| "compress": args.compress, |
| "start": args.start, |
| "exist_loss_weight": args.exist_loss_weight, |
| } |
| model = ECSimtoken_ForCausalLM.from_pretrained( |
| args.mllm, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, **model_args |
| ) |
| model.config.eos_token_id = tokenizer.eos_token_id |
| model.config.bos_token_id = tokenizer.bos_token_id |
| model.config.pad_token_id = tokenizer.pad_token_id |
|
|
| model.enable_input_require_grads() |
| |
| |
| |
| |
|
|
| model.get_model().initialize_vision_modules(model.get_model().config) |
| vision_tower = model.get_model().get_vision_tower() |
| vision_tower.to(dtype=torch.bfloat16, device="cuda") |
|
|
| cfg_pt = AutoConfig.from_pretrained(args.mllm) |
| cfg_pt.use_cluster = True |
| cfg_pt.freeze = False |
| cfg_pt.mm_tune = True |
| cfg_pt.spatial_cluster_rate0 = 64 |
| cfg_pt.spatial_cluster_rate1 = 32 |
| cfg_pt.spatial_cluster_rate2 = 16 |
| cfg_pt.temporal_cluster_rate = 0.0625 |
| cfg_pt.vision_tune = False |
| model.get_model().initialize_cluster_modules(cfg_pt) |
| model.get_model().initialize_lisa_modules(model.get_model().config) |
|
|
| for p in vision_tower.parameters(): |
| p.requires_grad = False |
| for p in model.get_model().mm_projector.parameters(): |
| p.requires_grad = False |
|
|
| |
| lora_r = 8 |
|
|
| def find_linear_layers(m, targets): |
| names = set() |
| skip = {"visual_model", "vision_tower", "mm_projector", |
| "text_hidden_fcs", "audio_feature_layer", "existence_head"} |
| for name, mod in m.named_modules(): |
| if ( |
| isinstance(mod, torch.nn.Linear) |
| and not any(s in name for s in skip) |
| and any(t in name for t in targets) |
| ): |
| names.add(name) |
| return sorted(names) |
|
|
| lora_config = LoraConfig( |
| r=lora_r, |
| lora_alpha=16, |
| target_modules=find_linear_layers(model, ["q_proj", "v_proj"]), |
| lora_dropout=0.05, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
| model = get_peft_model(model, lora_config) |
| model.print_trainable_parameters() |
|
|
| model = model.to("cuda") |
| |
| |
| model = model.to(torch.bfloat16) |
| model.resize_token_embeddings(len(tokenizer)) |
|
|
| |
| for n, p in model.named_parameters(): |
| if any(x in n for x in ["lm_head", "embed_tokens", "mask_decoder", |
| "text_hidden_fcs", "audio_feature_layer", |
| "existence_head"]): |
| p.requires_grad = True |
|
|
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Trainable parameters: {trainable:,}") |
| return model |
|
|
|
|
| |
|
|
| @torch.no_grad() |
| def evaluate(model, dataloader, split_name: str, exist_threshold: float = 0.5): |
| model.eval() |
| total_iou = total_fscore = count = 0.0 |
| |
| total_null_metric = null_count = 0.0 |
| null_tp = 0 |
| null_fn = 0 |
|
|
| for batch in tqdm(dataloader, desc=f"Eval {split_name}", leave=False): |
| batch = dict_to_cuda(batch) |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| out = model.forward( |
| images=batch["images"], |
| images_clip=batch["images_clip"], |
| audio_features=batch["audio_feats"], |
| image_features=batch["image_feats"], |
| input_ids=batch["input_ids"], |
| labels=batch["labels"], |
| attention_masks=batch["attention_masks"], |
| masks_list=batch["masks"], |
| resize_list=batch["resizes"], |
| orgsize_list=batch["orgsizes"], |
| conversation_list=batch["convs"], |
| refs_num=batch["refs_num"], |
| fids=batch["fids"], |
| vids=batch["vids"], |
| ref_ids=batch["ref_ids"], |
| inference=True, |
| ) |
| pred_masks = out["pred_masks"] |
| gt_masks = out["gt_masks"] |
| |
| p_exist = torch.sigmoid(out["exist_logit"]).squeeze(-1).cpu() |
|
|
| for i in range(len(pred_masks)): |
| pred_i = pred_masks[i] |
| gt_i = gt_masks[i] |
| pe = p_exist[i].item() |
|
|
| if split_name == "test_n": |
| |
| if pe < exist_threshold: |
| null_score = utility.metric_s_for_null(torch.zeros_like(pred_i)) |
| null_tp += 1 |
| else: |
| null_score = utility.metric_s_for_null(pred_i) |
| null_fn += 1 |
| total_null_metric += null_score * pred_i.shape[0] * pred_i.shape[1] |
| null_count += pred_i.shape[0] * pred_i.shape[1] |
| else: |
| iou = utility.mask_iou(pred_i, gt_i) |
| fscore = utility.Eval_Fmeasure(pred_i, gt_i, None) |
| n = pred_i.shape[0] * pred_i.shape[1] |
| total_iou += iou * n |
| total_fscore += fscore * n |
| count += n |
|
|
| if split_name == "test_n": |
| null_s = total_null_metric / (null_count + 1e-8) |
| total_n = null_tp + null_fn |
| print(f"\n [{split_name}] Null_S={null_s:.4f} " |
| f"null_tp={null_tp}/{total_n} null_fn={null_fn}/{total_n}") |
| return {"null_s": null_s, "null_tp": null_tp, "null_fn": null_fn} |
| else: |
| miou = total_iou / (count + 1e-8) |
| fscore = total_fscore / (count + 1e-8) |
| print(f"\n [{split_name}] mIoU={miou:.4f} F={fscore:.4f}") |
| return {"miou": miou, "fscore": fscore} |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| mp.set_start_method("spawn") |
| set_seed(42) |
|
|
| os.makedirs(args.log_root, exist_ok=True) |
| os.makedirs(args.checkpoint_root, exist_ok=True) |
|
|
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| args.mllm, |
| cache_dir=None, |
| model_max_length=2048, |
| padding_side="right", |
| use_fast=False, |
| ) |
| tokenizer.pad_token = tokenizer.unk_token |
| tokenizer.add_tokens("[SEG]") |
| seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] |
| print(f"seg_token_idx: {seg_token_idx}") |
|
|
| |
| train_dataset = REFAVS("train", args, tokenizer, input_type="refer") |
| val_dataset_s = REFAVS("test_s", args, tokenizer, input_type="refer") |
| val_dataset_u = REFAVS("test_u", args, tokenizer, input_type="refer") |
| val_dataset_n = REFAVS("test_n", args, tokenizer, input_type="refer") |
|
|
| g = torch.Generator() |
| g.manual_seed(42) |
|
|
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=args.batch_size, |
| shuffle=True, |
| num_workers=4, |
| worker_init_fn=seed_worker, |
| collate_fn=partial(collate_fn, tokenizer=tokenizer), |
| generator=g, |
| pin_memory=True, |
| persistent_workers=False, |
| prefetch_factor=2, |
| ) |
| val_loader_s = DataLoader( |
| val_dataset_s, batch_size=4, shuffle=False, num_workers=4, |
| collate_fn=partial(collate_fn, tokenizer=tokenizer), |
| pin_memory=True, persistent_workers=False, |
| ) |
| val_loader_u = DataLoader( |
| val_dataset_u, batch_size=4, shuffle=False, num_workers=4, |
| collate_fn=partial(collate_fn, tokenizer=tokenizer), |
| pin_memory=True, persistent_workers=False, |
| ) |
| val_loader_n = DataLoader( |
| val_dataset_n, batch_size=4, shuffle=False, num_workers=4, |
| collate_fn=partial(collate_fn, tokenizer=tokenizer), |
| pin_memory=True, persistent_workers=False, |
| ) |
|
|
| |
| model = build_model(args, tokenizer, seg_token_idx) |
|
|
| if args.saved_model and os.path.exists(args.saved_model): |
| ckpt = torch.load(args.saved_model, map_location="cuda") |
| |
| state = ckpt.get("model", ckpt) |
| missing, unexpected = model.load_state_dict(state, strict=False) |
| print(f"Loaded {args.saved_model} missing={len(missing)} unexpected={len(unexpected)}") |
|
|
| if args.run == "eval": |
| for split, loader in [("test_s", val_loader_s), |
| ("test_u", val_loader_u), |
| ("test_n", val_loader_n)]: |
| if split in args.eval_splits: |
| evaluate(model, loader, split, args.exist_threshold) |
| exit(0) |
|
|
| |
| model.train() |
| optimizer = AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01) |
|
|
| gradient_accumulation_steps = max(1, int(16 // args.batch_size)) |
| steps_per_epoch = len(train_loader) // gradient_accumulation_steps |
| total_steps = args.epochs * steps_per_epoch |
| warmup_steps = max(1, int(total_steps * 0.1)) |
|
|
| scheduler = get_cosine_schedule_with_warmup( |
| optimizer, |
| num_warmup_steps=warmup_steps, |
| num_training_steps=total_steps, |
| ) |
|
|
| log_path = os.path.join(args.log_root, f"{args.name}.txt") |
|
|
| for epoch in range(args.epochs): |
| model.train() |
| optimizer.zero_grad() |
| running = {"loss": 0.0, "ce": 0.0, "mask": 0.0, "exist": 0.0} |
| n_steps = 0 |
|
|
| loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}") |
| for step, batch in enumerate(loop): |
| |
| batch["audio_feats"], is_null = apply_null_augmentation( |
| batch["audio_feats"], p_null=args.null_aug_prob |
| ) |
| batch = dict_to_cuda(batch) |
| is_null = is_null.cuda() |
|
|
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| out = model.forward( |
| images=batch["images"], |
| images_clip=batch["images_clip"], |
| audio_features=batch["audio_feats"], |
| image_features=batch["image_feats"], |
| input_ids=batch["input_ids"], |
| labels=batch["labels"], |
| attention_masks=batch["attention_masks"], |
| masks_list=batch["masks"], |
| resize_list=batch["resizes"], |
| orgsize_list=batch["orgsizes"], |
| conversation_list=batch["convs"], |
| refs_num=batch["refs_num"], |
| fids=batch["fids"], |
| vids=batch["vids"], |
| ref_ids=batch["ref_ids"], |
| epoch=epoch, |
| inference=False, |
| contrast=args.ct_weight, |
| is_null=is_null, |
| ) |
|
|
| loss = out["loss"] / gradient_accumulation_steps |
| loss.backward() |
|
|
| for k, key in [("loss", "loss"), ("ce", "ce_loss"), |
| ("mask", "mask_loss"), ("exist", "exist_loss")]: |
| v = out.get(key, torch.tensor(0.0)) |
| running[k] += v.item() if isinstance(v, torch.Tensor) else v |
|
|
| if (step + 1) % gradient_accumulation_steps == 0: |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| n_steps += 1 |
| lr = scheduler.get_last_lr()[0] |
| avg = {k: running[k] / n_steps for k in running} |
| loop.set_postfix( |
| lr=f"{lr:.2e}", |
| loss=f"{avg['loss']:.4f}", |
| exist=f"{avg['exist']:.4f}", |
| ) |
|
|
| |
| denom = max(n_steps, 1) |
| epoch_loss = running["loss"] / denom |
| print( |
| f"Epoch {epoch+1} loss={epoch_loss:.4f} " |
| f"ce={running['ce']/denom:.4f} " |
| f"mask={running['mask']/denom:.4f} " |
| f"exist={running['exist']/denom:.4f} " |
| f"lr={scheduler.get_last_lr()[0]:.2e}" |
| ) |
|
|
| with open(log_path, "a") as f: |
| f.write( |
| f"epoch={epoch+1} loss={epoch_loss:.4f} " |
| f"ce={running['ce']/denom:.4f} " |
| f"mask={running['mask']/denom:.4f} " |
| f"exist={running['exist']/denom:.4f}\n" |
| ) |
|
|
| |
| ckpt_ep = os.path.join(args.checkpoint_root, f"{args.name}_ep{epoch+1}.pth") |
| torch.save(model.state_dict(), ckpt_ep) |
| print(f"Saved: {ckpt_ep}") |
| prev_ckpt = os.path.join(args.checkpoint_root, f"{args.name}_ep{epoch-1}.pth") |
| if epoch >= 2 and os.path.exists(prev_ckpt): |
| os.remove(prev_ckpt) |
|
|
| evaluate(model, val_loader_s, "test_s", args.exist_threshold) |
| evaluate(model, val_loader_u, "test_u", args.exist_threshold) |
| evaluate(model, val_loader_n, "test_n", args.exist_threshold) |
|
|
| |
| ckpt_path = os.path.join(args.checkpoint_root, f"{args.name}.pth") |
| torch.save(model.state_dict(), ckpt_path) |
| print(f"Saved: {ckpt_path}") |
|
|