import csv import math import os from functools import partial import numpy as np import torch import torch.nn.functional as F import transformers from torch.utils.data import DataLoader from configs import args from datasets import REFAVS from decoder_invariance_check import build_model, set_seed from load_model import collate_fn, dict_to_cuda def make_loader(tokenizer): dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer") return DataLoader( dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer), ) def build_tokenizer(): tokenizer = transformers.AutoTokenizer.from_pretrained( args.mllm, cache_dir=None, model_max_length=2048, padding_side="right", use_fast=False, ) tokenizer.pad_token = tokenizer.unk_token tokenizer.add_tokens("[SEG]") seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] return tokenizer, seg_token_idx def get_q(model, batch): with torch.cuda.amp.autocast(dtype=torch.bfloat16): output = model.forward( images=batch["images"], images_clip=batch["images_clip"], audio_features=batch["audio_feats"], image_features=batch["image_feats"], input_ids=batch["input_ids"], labels=batch["labels"], attention_masks=batch["attention_masks"], masks_list=batch["masks"], resize_list=batch["resizes"], orgsize_list=batch["orgsizes"], conversation_list=batch["convs"], refs_num=batch["refs_num"], fids=batch["fids"], vids=batch["vids"], contrast=args.ct_weight, ref_ids=batch["ref_ids"], inference=True, ) return output["seg_embeddings"][0][0].float() def decode_low_res(model, batch, q): visual_model = model.get_model().visual_model sparse, dense = visual_model.prompt_encoder( points=None, boxes=None, masks=None, text_embeds=q.view(1, 1, -1).to(next(visual_model.parameters()).dtype), ) sparse = sparse.to(q.dtype) dense = dense.to(q.dtype) with torch.cuda.amp.autocast(dtype=torch.bfloat16): low_res_masks, iou_predictions = visual_model.mask_decoder( image_embeddings=batch["image_feats"][0], image_pe=visual_model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse, dense_prompt_embeddings=dense, multimask_output=False, ) return low_res_masks.float(), iou_predictions.float().squeeze(-1) def masks_to_64(mask_logits_or_binary): if mask_logits_or_binary.ndim == 3: mask_logits_or_binary = mask_logits_or_binary.unsqueeze(1) return F.interpolate( mask_logits_or_binary.float(), size=(64, 64), mode="bilinear", align_corners=False, ).clamp(0.0, 1.0) def d2_scores(image_embeddings, mask64, q, beta): feats = image_embeddings.float() if mask64.shape[0] != feats.shape[0]: raise ValueError(f"Mask/frame mismatch: {mask64.shape} vs {feats.shape}") q = F.normalize(q.float().view(1, -1), dim=-1) mask = mask64.float() comp = 1.0 - mask z_in = (feats * mask).sum(dim=(2, 3)) / mask.sum(dim=(2, 3)).clamp_min(1e-6) z_out = (feats * comp).sum(dim=(2, 3)) / comp.sum(dim=(2, 3)).clamp_min(1e-6) z_in = F.normalize(z_in, dim=-1) z_out = F.normalize(z_out, dim=-1) return (z_in @ q.T).squeeze(-1) - beta * (z_out @ q.T).squeeze(-1) def frame_iou(pred_logits, gt_masks): pred = (torch.sigmoid(pred_logits.float()) > 0.4).float() gt = gt_masks.float() if pred.ndim == 4: pred = pred.squeeze(1) inter = (pred * gt).sum(dim=(1, 2)) union = torch.maximum(pred, gt).sum(dim=(1, 2)) num_pixels = pred.shape[-1] * pred.shape[-2] no_obj = gt.sum(dim=(1, 2)) == 0 inter_no_obj = ((1.0 - pred) * (1.0 - gt)).sum(dim=(1, 2)) inter = torch.where(no_obj, inter_no_obj, inter) union = torch.where(no_obj, torch.full_like(union, float(num_pixels)), union) return inter / union.clamp_min(1e-7) def frame_fscore_proxy(pred_logits, gt_masks): pred = (torch.sigmoid(pred_logits.float()) > 0.4).float() gt = gt_masks.float() if pred.ndim == 4: pred = pred.squeeze(1) tp = (pred * gt).sum(dim=(1, 2)) precision = tp / pred.sum(dim=(1, 2)).clamp_min(1e-7) recall = tp / gt.sum(dim=(1, 2)).clamp_min(1e-7) beta2 = 0.3 fscore = (1 + beta2) * precision * recall / (beta2 * precision + recall).clamp_min(1e-7) no_obj = gt.sum(dim=(1, 2)) == 0 return torch.where(no_obj, torch.zeros_like(fscore), fscore) def parse_betas(): raw = os.environ.get("D2_BETAS", "0.5") return [float(x.strip()) for x in raw.split(",") if x.strip()] def collect_q_pool(model, tokenizer, limit): q_pool = [] loader = make_loader(tokenizer) for sample_idx, batch in enumerate(loader): if sample_idx >= limit: break batch = dict_to_cuda(batch) q = get_q(model, batch) q_pool.append( { "sample_idx": sample_idx, "vid": batch["vids"][0], "ref": batch["refs"][0][0], "fid": int(batch["fids"][0][0]), "q": q.cpu(), } ) print(f"Collected q {sample_idx}: vid={q_pool[-1]['vid']} ref={q_pool[-1]['ref']}") if not q_pool: raise RuntimeError("No q vectors collected. Is the selected split empty?") return q_pool def choose_shuffled_idx(sample_idx, q_pool): if len(q_pool) <= 1: return None return (sample_idx + 1) % len(q_pool) def choose_wrong_ref_idx(sample_idx, q_pool): current = q_pool[sample_idx] for item in q_pool: if item["sample_idx"] == sample_idx: continue if item["vid"] == current["vid"] and item["fid"] != current["fid"]: return item["sample_idx"] for item in q_pool: if item["sample_idx"] == sample_idx: continue if item["vid"] == current["vid"] and item["ref"] != current["ref"]: return item["sample_idx"] return None def run_d2(model, tokenizer, q_pool, betas, limit): rows = [] loader = make_loader(tokenizer) q_lookup = {item["sample_idx"]: item for item in q_pool} generator = torch.Generator(device="cuda") generator.manual_seed(1234) for sample_idx, batch in enumerate(loader): if sample_idx >= limit: break batch = dict_to_cuda(batch) item = q_lookup[sample_idx] real_q = item["q"].cuda() low_res_masks, iou_predictions = decode_low_res(model, batch, real_q) pred_mask64 = masks_to_64(torch.sigmoid(low_res_masks)) gt_masks = batch["masks"][0][0].float() gt_mask64 = masks_to_64(gt_masks) image_embeddings = batch["image_feats"][0].float() pred_logits_hr = model.get_model().visual_model.postprocess_masks( low_res_masks.to(batch["image_feats"][0].dtype), input_size=batch["resizes"][0], original_size=batch["orgsizes"][0], ).squeeze(1) frame_ious = frame_iou(pred_logits_hr, gt_masks) frame_fscores = frame_fscore_proxy(pred_logits_hr, gt_masks) pred_area = (torch.sigmoid(pred_logits_hr.float()) > 0.4).float().mean(dim=(1, 2)) gt_area = gt_masks.float().mean(dim=(1, 2)) shuffled_idx = choose_shuffled_idx(sample_idx, q_pool) wrong_ref_idx = choose_wrong_ref_idx(sample_idx, q_pool) q_controls = [ ("real", real_q, sample_idx), ("random", torch.randn(real_q.shape, device=real_q.device, generator=generator), None), ] if shuffled_idx is not None: q_controls.append(("shuffled", q_lookup[shuffled_idx]["q"].cuda(), shuffled_idx)) if wrong_ref_idx is not None: q_controls.append(("wrong_ref", q_lookup[wrong_ref_idx]["q"].cuda(), wrong_ref_idx)) for beta in betas: for q_type, q, q_source_idx in q_controls: pred_scores = d2_scores(image_embeddings, pred_mask64, q, beta) gt_scores = d2_scores(image_embeddings, gt_mask64, q, beta) base_info = { "sample_idx": sample_idx, "vid": item["vid"], "ref": item["ref"], "fid": item["fid"], "split": args.eval_split, "frame_iou": math.nan, "frame_fscore_proxy": math.nan, "iou_pred": math.nan, "pred_area": math.nan, "gt_area": math.nan, } for frame_idx in range(pred_scores.shape[0]): base_info_frame = dict(base_info) base_info_frame.update( { "frame_iou": frame_ious[frame_idx].item(), "frame_fscore_proxy": frame_fscores[frame_idx].item(), "iou_pred": iou_predictions[frame_idx].item(), "pred_area": pred_area[frame_idx].item(), "gt_area": gt_area[frame_idx].item(), } ) row = dict(base_info_frame) row.update( { "frame": frame_idx, "q_type": q_type, "beta": beta, "s_pred": pred_scores[frame_idx].item(), "s_gt": gt_scores[frame_idx].item(), "q_source_idx": q_source_idx if q_source_idx is not None else "", } ) rows.append(row) real_rows = [ r for r in rows if r["sample_idx"] == sample_idx and r["q_type"] == "real" and r["beta"] == betas[0] ] s_pred_values = [r["s_pred"] for r in real_rows] print( f"D2 {sample_idx}: vid={item['vid']} ref={item['ref']} " f"mean_s_pred={np.mean(s_pred_values):.4f} min_s_pred={np.min(s_pred_values):.4f} " f"mean_iou={frame_ious.mean().item():.4f}" ) return rows def print_summary(rows): real_rows = [r for r in rows if r["q_type"] == "real"] if not real_rows: return by_beta = sorted(set(r["beta"] for r in real_rows)) print("\nSummary") print(f"rows: {len(rows)}") for beta in by_beta: beta_rows = [r for r in rows if r["beta"] == beta] print(f"\nbeta={beta}") for q_type in sorted(set(r["q_type"] for r in beta_rows)): qr = [r for r in beta_rows if r["q_type"] == q_type] print( f"{q_type:10s} " f"mean_s_pred={np.mean([r['s_pred'] for r in qr]):+.4f} " f"mean_s_gt={np.mean([r['s_gt'] for r in qr]):+.4f}" ) real_beta = [r for r in beta_rows if r["q_type"] == "real"] s_pred = np.array([r["s_pred"] for r in real_beta]) frame_iou_values = np.array([r["frame_iou"] for r in real_beta]) if len(s_pred) > 1 and np.std(s_pred) > 1e-8 and np.std(frame_iou_values) > 1e-8: corr = np.corrcoef(s_pred, frame_iou_values)[0, 1] print(f"corr(real s_pred, frame_iou)={corr:+.4f}") else: print("corr(real s_pred, frame_iou)=nan") def main(): set_seed(42) torch.set_grad_enabled(False) betas = parse_betas() tokenizer, seg_token_idx = build_tokenizer() limit = args.max_eval_rows if args.max_eval_rows > 0 else 30 print(f"Split: {args.eval_split} | samples: {limit} | betas: {betas}") model = build_model(tokenizer, seg_token_idx) q_pool = collect_q_pool(model, tokenizer, limit) rows = run_d2(model, tokenizer, q_pool, betas, limit) print_summary(rows) csv_path = os.environ.get("D2_BASIC_CSV", f"/workspace/SimToken/d2_basic_{args.eval_split}_{limit}.csv") os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True) with open(csv_path, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) writer.writeheader() writer.writerows(rows) print(f"\nSaved CSV: {csv_path}") if __name__ == "__main__": main()