"""EC-SimToken training script. Adds existence head + synthetic null augmentation to SimToken. Key differences from train.py: - Uses ECSimtoken_ForCausalLM (adds existence_head: Linear(256,1)) - Audio-swap null augmentation: p_null fraction of batch items have their audio replaced with another sample's audio → synthetic null - is_null tensor passed to model_forward to gate mask loss - test_n evaluation uses existence head (p_exist threshold) for Null S Usage (training): python train_ec_simtoken.py \ --data_dir data \ --mllm Chat-UniVi/Chat-UniVi-7B-v1.5 \ --vision_pretrained path/to/sam_vit_h_4b8939.pth \ --name ec_simtoken_v1 \ --epochs 10 \ --batch_size 12 \ --null_aug_prob 0.25 \ --exist_loss_weight 1.0 Usage (eval only): python train_ec_simtoken.py --run eval \ --saved_model checkpoints/ec_simtoken_v1.pth \ --eval_splits test_s,test_u,test_n """ import argparse import os import random import warnings from functools import partial import numpy as np import torch import torch.multiprocessing as mp import transformers from peft import LoraConfig, get_peft_model from torch.optim import AdamW from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoConfig, get_cosine_schedule_with_warmup, logging warnings.filterwarnings("ignore") logging.set_verbosity_error() import re # ── Token constants ─────────────────────────────────────────────────────────── IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 AUDIO_TOKEN_INDEX = -300 # ── Args: base (from configs) + EC-SimToken extensions ─────────────────────── from configs import args as base_args # parses base SimToken args _ec_parser = argparse.ArgumentParser(add_help=False) _ec_parser.add_argument("--null_aug_prob", type=float, default=0.25, help="Fraction of batch items with swapped audio (null aug)") _ec_parser.add_argument("--exist_loss_weight", type=float, default=1.0, help="Weight for BCE existence loss") _ec_parser.add_argument("--exist_threshold", type=float, default=0.5, help="p_exist sigmoid threshold for null classification") ec_args, _ = _ec_parser.parse_known_args() # Merge EC-SimToken args into the base args namespace args = base_args args.null_aug_prob = ec_args.null_aug_prob args.exist_loss_weight = ec_args.exist_loss_weight args.exist_threshold = ec_args.exist_threshold os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id from datasets import REFAVS from models.ec_simtoken_model import ECSimtoken_ForCausalLM from utils import utility # ── Utilities ───────────────────────────────────────────────────────────────── def set_seed(seed: int = 42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) os.environ["PYTHONHASHSEED"] = str(seed) # benchmark=True lets cuDNN find fastest conv algorithms (important for SAM) torch.backends.cudnn.benchmark = True def seed_worker(worker_id): seed = torch.initial_seed() % 2 ** 32 np.random.seed(seed) random.seed(seed) def dict_to_cuda(d: dict) -> dict: for k, v in d.items(): if isinstance(v, torch.Tensor): d[k] = v.cuda(non_blocking=True) elif isinstance(v, list) and v and isinstance(v[0], torch.Tensor): d[k] = [x.cuda(non_blocking=True) for x in v] return d # ── Null augmentation ───────────────────────────────────────────────────────── def apply_null_augmentation( audio_feats: list, p_null: float = 0.25 ) -> tuple[list, torch.BoolTensor]: """Randomly replace some audio features with mismatched ones. Returns the (possibly mutated) audio_feats list and a bool tensor `is_null` where True means the sample's audio was swapped. """ B = len(audio_feats) is_null = torch.zeros(B, dtype=torch.bool) if B < 2 or p_null <= 0.0: return audio_feats, is_null for i in range(B): if random.random() < p_null: candidates = [j for j in range(B) if j != i] j = random.choice(candidates) audio_feats[i] = audio_feats[j].clone() is_null[i] = True return audio_feats, is_null # ── Collate (identical to train.py) ────────────────────────────────────────── def tokenizer_image_audio_token( prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, audio_token_index=AUDIO_TOKEN_INDEX, num_frames=10, return_tensors=None, ): prompt_chunks = re.split(r'(|