| import csv |
| import os |
| import random |
| from functools import partial |
|
|
| import numpy as np |
| import torch |
| import transformers |
| from peft import LoraConfig, get_peft_model |
| from torch.utils.data import DataLoader |
| from transformers import AutoConfig |
|
|
| from configs import args |
| from datasets import REFAVS |
| from load_model import collate_fn, dict_to_cuda |
| from models.avs_model import Simtoken_ForCausalLM |
|
|
|
|
| def set_seed(seed=42): |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
|
|
| def find_lora_target_modules(model, target_modules=("q_proj", "v_proj")): |
| modules = set() |
| excluded = [ |
| "visual_model", |
| "vision_tower", |
| "mm_projector", |
| "text_hidden_fcs", |
| "audio_feature_layer", |
| ] |
| for name, module in model.named_modules(): |
| if not isinstance(module, torch.nn.Linear): |
| continue |
| if any(x in name for x in excluded): |
| continue |
| if any(x in name for x in target_modules): |
| modules.add(name) |
| return sorted(modules) |
|
|
|
|
| def build_model(tokenizer, seg_token_idx): |
| model_args = { |
| "train_mask_decoder": True, |
| "out_dim": 256, |
| "ce_loss_weight": 1.0, |
| "dice_loss_weight": 0.5, |
| "bce_loss_weight": 2.0, |
| "seg_token_idx": seg_token_idx, |
| "vision_pretrained": args.vision_pretrained, |
| "vision_tower": args.vision_tower, |
| "use_im_start_end": False, |
| "compress": args.compress, |
| "start": args.start, |
| } |
|
|
| model = Simtoken_ForCausalLM.from_pretrained( |
| args.mllm, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| **model_args, |
| ) |
|
|
| model.config.eos_token_id = tokenizer.eos_token_id |
| model.config.bos_token_id = tokenizer.bos_token_id |
| model.config.pad_token_id = tokenizer.pad_token_id |
|
|
| model.get_model().initialize_vision_modules(model.get_model().config) |
| vision_tower = model.get_model().get_vision_tower() |
| vision_tower.to(dtype=torch.float32, device="cuda") |
|
|
| model_args_from_pt = AutoConfig.from_pretrained(args.mllm) |
| model_args_from_pt.use_cluster = True |
| model_args_from_pt.freeze = False |
| model_args_from_pt.mm_tune = True |
| model_args_from_pt.spatial_cluster_rate0 = 64 |
| model_args_from_pt.spatial_cluster_rate1 = 32 |
| model_args_from_pt.spatial_cluster_rate2 = 16 |
| model_args_from_pt.temporal_cluster_rate = 0.0625 |
| model_args_from_pt.vision_tune = False |
| model.get_model().initialize_cluster_modules(model_args_from_pt) |
| model.get_model().initialize_lisa_modules(model.get_model().config) |
|
|
| lora_config = LoraConfig( |
| r=8, |
| lora_alpha=16, |
| target_modules=find_lora_target_modules(model), |
| lora_dropout=0.05, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
| model = get_peft_model(model, lora_config) |
| model = model.to("cuda") |
| model.resize_token_embeddings(len(tokenizer)) |
|
|
| state = torch.load(args.saved_model, map_location="cpu") |
| missing, unexpected = model.load_state_dict(state, strict=False) |
| print(f"Loaded checkpoint: {args.saved_model}") |
| print(f"Missing keys: {len(missing)} | Unexpected keys: {len(unexpected)}") |
|
|
| model.eval() |
| return model |
|
|
|
|
| def get_seg_embedding(model, batch): |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
| output = model.forward( |
| images=batch["images"], |
| images_clip=batch["images_clip"], |
| audio_features=batch["audio_feats"], |
| image_features=batch["image_feats"], |
| input_ids=batch["input_ids"], |
| labels=batch["labels"], |
| attention_masks=batch["attention_masks"], |
| masks_list=batch["masks"], |
| resize_list=batch["resizes"], |
| orgsize_list=batch["orgsizes"], |
| conversation_list=batch["convs"], |
| refs_num=batch["refs_num"], |
| fids=batch["fids"], |
| vids=batch["vids"], |
| contrast=args.ct_weight, |
| ref_ids=batch["ref_ids"], |
| inference=True, |
| ) |
| return output["seg_embeddings"][0][0:1] |
|
|
|
|
| def check_one_sample(model, batch): |
| q = get_seg_embedding(model, batch) |
| image_embeddings = batch["image_feats"][0] |
|
|
| visual_model = model.get_model().visual_model |
| sparse, dense = visual_model.prompt_encoder( |
| points=None, |
| boxes=None, |
| masks=None, |
| text_embeds=q.unsqueeze(1), |
| ) |
| sparse = sparse.to(q.dtype) |
| dense = dense.to(q.dtype) |
|
|
| decoder = visual_model.mask_decoder |
| image_pe = visual_model.prompt_encoder.get_dense_pe() |
|
|
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
| full_masks, full_iou = decoder( |
| image_embeddings=image_embeddings, |
| image_pe=image_pe, |
| sparse_prompt_embeddings=sparse, |
| dense_prompt_embeddings=dense, |
| multimask_output=False, |
| ) |
|
|
| rows = [] |
| for t in range(image_embeddings.shape[0]): |
| single_masks, single_iou = decoder( |
| image_embeddings=image_embeddings[t : t + 1], |
| image_pe=image_pe, |
| sparse_prompt_embeddings=sparse, |
| dense_prompt_embeddings=dense, |
| multimask_output=False, |
| ) |
|
|
| diff = (full_masks[t : t + 1] - single_masks).float().abs() |
| iou_diff = (full_iou[t : t + 1] - single_iou).float().abs() |
| rows.append( |
| { |
| "vid": batch["vids"][0], |
| "ref": batch["refs"][0][0], |
| "frame": t, |
| "max_abs_diff": diff.max().item(), |
| "mean_abs_diff": diff.mean().item(), |
| "iou_pred_diff": iou_diff.max().item(), |
| } |
| ) |
| return rows |
|
|
|
|
| def main(): |
| set_seed(42) |
| torch.set_grad_enabled(False) |
|
|
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| args.mllm, |
| cache_dir=None, |
| model_max_length=2048, |
| padding_side="right", |
| use_fast=False, |
| ) |
| tokenizer.pad_token = tokenizer.unk_token |
| tokenizer.add_tokens("[SEG]") |
| seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] |
|
|
| dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer") |
| loader = DataLoader( |
| dataset, |
| batch_size=1, |
| shuffle=False, |
| num_workers=0, |
| collate_fn=partial(collate_fn, tokenizer=tokenizer), |
| ) |
|
|
| limit = args.max_eval_rows if args.max_eval_rows > 0 else 1 |
| print(f"Split: {args.eval_split} | samples to check: {limit}") |
|
|
| model = build_model(tokenizer, seg_token_idx) |
|
|
| all_rows = [] |
| for sample_idx, batch in enumerate(loader): |
| if sample_idx >= limit: |
| break |
| batch = dict_to_cuda(batch) |
| rows = check_one_sample(model, batch) |
| all_rows.extend(rows) |
|
|
| print(f"\nSample {sample_idx}: vid={rows[0]['vid']} ref={rows[0]['ref']}") |
| print("frame | max_abs_diff | mean_abs_diff | iou_pred_diff") |
| for row in rows: |
| print( |
| f"{row['frame']:02d} | " |
| f"{row['max_abs_diff']:.8e} | " |
| f"{row['mean_abs_diff']:.8e} | " |
| f"{row['iou_pred_diff']:.8e}" |
| ) |
|
|
| if not all_rows: |
| raise RuntimeError("No rows were checked. Is the selected split empty?") |
|
|
| max_diff = max(row["max_abs_diff"] for row in all_rows) |
| mean_diff = sum(row["mean_abs_diff"] for row in all_rows) / len(all_rows) |
| max_iou_diff = max(row["iou_pred_diff"] for row in all_rows) |
|
|
| print("\nSummary") |
| print(f"checked frames: {len(all_rows)}") |
| print(f"global max_abs_diff: {max_diff:.8e}") |
| print(f"average mean_abs_diff: {mean_diff:.8e}") |
| print(f"global max_iou_pred_diff: {max_iou_diff:.8e}") |
|
|
| csv_path = os.environ.get("DECODER_INVARIANCE_CSV") |
| if csv_path: |
| os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True) |
| with open(csv_path, "w", newline="") as f: |
| writer = csv.DictWriter(f, fieldnames=list(all_rows[0].keys())) |
| writer.writeheader() |
| writer.writerows(all_rows) |
| print(f"Saved CSV: {csv_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|