import csv import os import random from functools import partial import numpy as np import torch import transformers from peft import LoraConfig, get_peft_model from torch.utils.data import DataLoader from transformers import AutoConfig from configs import args from datasets import REFAVS from load_model import collate_fn, dict_to_cuda from models.avs_model import Simtoken_ForCausalLM def set_seed(seed=42): torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def find_lora_target_modules(model, target_modules=("q_proj", "v_proj")): modules = set() excluded = [ "visual_model", "vision_tower", "mm_projector", "text_hidden_fcs", "audio_feature_layer", ] for name, module in model.named_modules(): if not isinstance(module, torch.nn.Linear): continue if any(x in name for x in excluded): continue if any(x in name for x in target_modules): modules.add(name) return sorted(modules) def build_model(tokenizer, seg_token_idx): model_args = { "train_mask_decoder": True, "out_dim": 256, "ce_loss_weight": 1.0, "dice_loss_weight": 0.5, "bce_loss_weight": 2.0, "seg_token_idx": seg_token_idx, "vision_pretrained": args.vision_pretrained, "vision_tower": args.vision_tower, "use_im_start_end": False, "compress": args.compress, "start": args.start, } model = Simtoken_ForCausalLM.from_pretrained( args.mllm, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, **model_args, ) model.config.eos_token_id = tokenizer.eos_token_id model.config.bos_token_id = tokenizer.bos_token_id model.config.pad_token_id = tokenizer.pad_token_id model.get_model().initialize_vision_modules(model.get_model().config) vision_tower = model.get_model().get_vision_tower() vision_tower.to(dtype=torch.float32, device="cuda") model_args_from_pt = AutoConfig.from_pretrained(args.mllm) model_args_from_pt.use_cluster = True model_args_from_pt.freeze = False model_args_from_pt.mm_tune = True model_args_from_pt.spatial_cluster_rate0 = 64 model_args_from_pt.spatial_cluster_rate1 = 32 model_args_from_pt.spatial_cluster_rate2 = 16 model_args_from_pt.temporal_cluster_rate = 0.0625 model_args_from_pt.vision_tune = False model.get_model().initialize_cluster_modules(model_args_from_pt) model.get_model().initialize_lisa_modules(model.get_model().config) lora_config = LoraConfig( r=8, lora_alpha=16, target_modules=find_lora_target_modules(model), lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", ) model = get_peft_model(model, lora_config) model = model.to("cuda") model.resize_token_embeddings(len(tokenizer)) state = torch.load(args.saved_model, map_location="cpu") missing, unexpected = model.load_state_dict(state, strict=False) print(f"Loaded checkpoint: {args.saved_model}") print(f"Missing keys: {len(missing)} | Unexpected keys: {len(unexpected)}") model.eval() return model def get_seg_embedding(model, batch): with torch.cuda.amp.autocast(dtype=torch.bfloat16): output = model.forward( images=batch["images"], images_clip=batch["images_clip"], audio_features=batch["audio_feats"], image_features=batch["image_feats"], input_ids=batch["input_ids"], labels=batch["labels"], attention_masks=batch["attention_masks"], masks_list=batch["masks"], resize_list=batch["resizes"], orgsize_list=batch["orgsizes"], conversation_list=batch["convs"], refs_num=batch["refs_num"], fids=batch["fids"], vids=batch["vids"], contrast=args.ct_weight, ref_ids=batch["ref_ids"], inference=True, ) return output["seg_embeddings"][0][0:1] def check_one_sample(model, batch): q = get_seg_embedding(model, batch) image_embeddings = batch["image_feats"][0] visual_model = model.get_model().visual_model sparse, dense = visual_model.prompt_encoder( points=None, boxes=None, masks=None, text_embeds=q.unsqueeze(1), ) sparse = sparse.to(q.dtype) dense = dense.to(q.dtype) decoder = visual_model.mask_decoder image_pe = visual_model.prompt_encoder.get_dense_pe() with torch.cuda.amp.autocast(dtype=torch.bfloat16): full_masks, full_iou = decoder( image_embeddings=image_embeddings, image_pe=image_pe, sparse_prompt_embeddings=sparse, dense_prompt_embeddings=dense, multimask_output=False, ) rows = [] for t in range(image_embeddings.shape[0]): single_masks, single_iou = decoder( image_embeddings=image_embeddings[t : t + 1], image_pe=image_pe, sparse_prompt_embeddings=sparse, dense_prompt_embeddings=dense, multimask_output=False, ) diff = (full_masks[t : t + 1] - single_masks).float().abs() iou_diff = (full_iou[t : t + 1] - single_iou).float().abs() rows.append( { "vid": batch["vids"][0], "ref": batch["refs"][0][0], "frame": t, "max_abs_diff": diff.max().item(), "mean_abs_diff": diff.mean().item(), "iou_pred_diff": iou_diff.max().item(), } ) return rows def main(): set_seed(42) torch.set_grad_enabled(False) tokenizer = transformers.AutoTokenizer.from_pretrained( args.mllm, cache_dir=None, model_max_length=2048, padding_side="right", use_fast=False, ) tokenizer.pad_token = tokenizer.unk_token tokenizer.add_tokens("[SEG]") seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer") loader = DataLoader( dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer), ) limit = args.max_eval_rows if args.max_eval_rows > 0 else 1 print(f"Split: {args.eval_split} | samples to check: {limit}") model = build_model(tokenizer, seg_token_idx) all_rows = [] for sample_idx, batch in enumerate(loader): if sample_idx >= limit: break batch = dict_to_cuda(batch) rows = check_one_sample(model, batch) all_rows.extend(rows) print(f"\nSample {sample_idx}: vid={rows[0]['vid']} ref={rows[0]['ref']}") print("frame | max_abs_diff | mean_abs_diff | iou_pred_diff") for row in rows: print( f"{row['frame']:02d} | " f"{row['max_abs_diff']:.8e} | " f"{row['mean_abs_diff']:.8e} | " f"{row['iou_pred_diff']:.8e}" ) if not all_rows: raise RuntimeError("No rows were checked. Is the selected split empty?") max_diff = max(row["max_abs_diff"] for row in all_rows) mean_diff = sum(row["mean_abs_diff"] for row in all_rows) / len(all_rows) max_iou_diff = max(row["iou_pred_diff"] for row in all_rows) print("\nSummary") print(f"checked frames: {len(all_rows)}") print(f"global max_abs_diff: {max_diff:.8e}") print(f"average mean_abs_diff: {mean_diff:.8e}") print(f"global max_iou_pred_diff: {max_iou_diff:.8e}") csv_path = os.environ.get("DECODER_INVARIANCE_CSV") if csv_path: os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True) with open(csv_path, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=list(all_rows[0].keys())) writer.writeheader() writer.writerows(all_rows) print(f"Saved CSV: {csv_path}") if __name__ == "__main__": main()