| import csv |
| import math |
| import os |
| from functools import partial |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import transformers |
| from torch.utils.data import DataLoader |
|
|
| from configs import args |
| from datasets import REFAVS |
| from decoder_invariance_check import build_model, set_seed |
| from load_model import collate_fn, dict_to_cuda |
|
|
|
|
| def make_loader(tokenizer): |
| dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer") |
| return DataLoader( |
| dataset, |
| batch_size=1, |
| shuffle=False, |
| num_workers=0, |
| collate_fn=partial(collate_fn, tokenizer=tokenizer), |
| ) |
|
|
|
|
| def build_tokenizer(): |
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| args.mllm, |
| cache_dir=None, |
| model_max_length=2048, |
| padding_side="right", |
| use_fast=False, |
| ) |
| tokenizer.pad_token = tokenizer.unk_token |
| tokenizer.add_tokens("[SEG]") |
| seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] |
| return tokenizer, seg_token_idx |
|
|
|
|
| def get_q(model, batch): |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
| output = model.forward( |
| images=batch["images"], |
| images_clip=batch["images_clip"], |
| audio_features=batch["audio_feats"], |
| image_features=batch["image_feats"], |
| input_ids=batch["input_ids"], |
| labels=batch["labels"], |
| attention_masks=batch["attention_masks"], |
| masks_list=batch["masks"], |
| resize_list=batch["resizes"], |
| orgsize_list=batch["orgsizes"], |
| conversation_list=batch["convs"], |
| refs_num=batch["refs_num"], |
| fids=batch["fids"], |
| vids=batch["vids"], |
| contrast=args.ct_weight, |
| ref_ids=batch["ref_ids"], |
| inference=True, |
| ) |
| return output["seg_embeddings"][0][0].float() |
|
|
|
|
| def decode_low_res(model, batch, q): |
| visual_model = model.get_model().visual_model |
| sparse, dense = visual_model.prompt_encoder( |
| points=None, |
| boxes=None, |
| masks=None, |
| text_embeds=q.view(1, 1, -1).to(next(visual_model.parameters()).dtype), |
| ) |
| sparse = sparse.to(q.dtype) |
| dense = dense.to(q.dtype) |
|
|
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
| low_res_masks, iou_predictions = visual_model.mask_decoder( |
| image_embeddings=batch["image_feats"][0], |
| image_pe=visual_model.prompt_encoder.get_dense_pe(), |
| sparse_prompt_embeddings=sparse, |
| dense_prompt_embeddings=dense, |
| multimask_output=False, |
| ) |
| return low_res_masks.float(), iou_predictions.float().squeeze(-1) |
|
|
|
|
| def masks_to_64(mask_logits_or_binary): |
| if mask_logits_or_binary.ndim == 3: |
| mask_logits_or_binary = mask_logits_or_binary.unsqueeze(1) |
| return F.interpolate( |
| mask_logits_or_binary.float(), |
| size=(64, 64), |
| mode="bilinear", |
| align_corners=False, |
| ).clamp(0.0, 1.0) |
|
|
|
|
| def d2_scores(image_embeddings, mask64, q, beta): |
| feats = image_embeddings.float() |
| if mask64.shape[0] != feats.shape[0]: |
| raise ValueError(f"Mask/frame mismatch: {mask64.shape} vs {feats.shape}") |
|
|
| q = F.normalize(q.float().view(1, -1), dim=-1) |
| mask = mask64.float() |
| comp = 1.0 - mask |
|
|
| z_in = (feats * mask).sum(dim=(2, 3)) / mask.sum(dim=(2, 3)).clamp_min(1e-6) |
| z_out = (feats * comp).sum(dim=(2, 3)) / comp.sum(dim=(2, 3)).clamp_min(1e-6) |
|
|
| z_in = F.normalize(z_in, dim=-1) |
| z_out = F.normalize(z_out, dim=-1) |
| return (z_in @ q.T).squeeze(-1) - beta * (z_out @ q.T).squeeze(-1) |
|
|
|
|
| def frame_iou(pred_logits, gt_masks): |
| pred = (torch.sigmoid(pred_logits.float()) > 0.4).float() |
| gt = gt_masks.float() |
| if pred.ndim == 4: |
| pred = pred.squeeze(1) |
| inter = (pred * gt).sum(dim=(1, 2)) |
| union = torch.maximum(pred, gt).sum(dim=(1, 2)) |
| num_pixels = pred.shape[-1] * pred.shape[-2] |
| no_obj = gt.sum(dim=(1, 2)) == 0 |
| inter_no_obj = ((1.0 - pred) * (1.0 - gt)).sum(dim=(1, 2)) |
| inter = torch.where(no_obj, inter_no_obj, inter) |
| union = torch.where(no_obj, torch.full_like(union, float(num_pixels)), union) |
| return inter / union.clamp_min(1e-7) |
|
|
|
|
| def frame_fscore_proxy(pred_logits, gt_masks): |
| pred = (torch.sigmoid(pred_logits.float()) > 0.4).float() |
| gt = gt_masks.float() |
| if pred.ndim == 4: |
| pred = pred.squeeze(1) |
| tp = (pred * gt).sum(dim=(1, 2)) |
| precision = tp / pred.sum(dim=(1, 2)).clamp_min(1e-7) |
| recall = tp / gt.sum(dim=(1, 2)).clamp_min(1e-7) |
| beta2 = 0.3 |
| fscore = (1 + beta2) * precision * recall / (beta2 * precision + recall).clamp_min(1e-7) |
| no_obj = gt.sum(dim=(1, 2)) == 0 |
| return torch.where(no_obj, torch.zeros_like(fscore), fscore) |
|
|
|
|
| def parse_betas(): |
| raw = os.environ.get("D2_BETAS", "0.5") |
| return [float(x.strip()) for x in raw.split(",") if x.strip()] |
|
|
|
|
| def collect_q_pool(model, tokenizer, limit): |
| q_pool = [] |
| loader = make_loader(tokenizer) |
| for sample_idx, batch in enumerate(loader): |
| if sample_idx >= limit: |
| break |
| batch = dict_to_cuda(batch) |
| q = get_q(model, batch) |
| q_pool.append( |
| { |
| "sample_idx": sample_idx, |
| "vid": batch["vids"][0], |
| "ref": batch["refs"][0][0], |
| "fid": int(batch["fids"][0][0]), |
| "q": q.cpu(), |
| } |
| ) |
| print(f"Collected q {sample_idx}: vid={q_pool[-1]['vid']} ref={q_pool[-1]['ref']}") |
| if not q_pool: |
| raise RuntimeError("No q vectors collected. Is the selected split empty?") |
| return q_pool |
|
|
|
|
| def choose_shuffled_idx(sample_idx, q_pool): |
| if len(q_pool) <= 1: |
| return None |
| return (sample_idx + 1) % len(q_pool) |
|
|
|
|
| def choose_wrong_ref_idx(sample_idx, q_pool): |
| current = q_pool[sample_idx] |
| for item in q_pool: |
| if item["sample_idx"] == sample_idx: |
| continue |
| if item["vid"] == current["vid"] and item["fid"] != current["fid"]: |
| return item["sample_idx"] |
| for item in q_pool: |
| if item["sample_idx"] == sample_idx: |
| continue |
| if item["vid"] == current["vid"] and item["ref"] != current["ref"]: |
| return item["sample_idx"] |
| return None |
|
|
|
|
| def run_d2(model, tokenizer, q_pool, betas, limit): |
| rows = [] |
| loader = make_loader(tokenizer) |
| q_lookup = {item["sample_idx"]: item for item in q_pool} |
| generator = torch.Generator(device="cuda") |
| generator.manual_seed(1234) |
|
|
| for sample_idx, batch in enumerate(loader): |
| if sample_idx >= limit: |
| break |
| batch = dict_to_cuda(batch) |
| item = q_lookup[sample_idx] |
| real_q = item["q"].cuda() |
|
|
| low_res_masks, iou_predictions = decode_low_res(model, batch, real_q) |
| pred_mask64 = masks_to_64(torch.sigmoid(low_res_masks)) |
| gt_masks = batch["masks"][0][0].float() |
| gt_mask64 = masks_to_64(gt_masks) |
| image_embeddings = batch["image_feats"][0].float() |
|
|
| pred_logits_hr = model.get_model().visual_model.postprocess_masks( |
| low_res_masks.to(batch["image_feats"][0].dtype), |
| input_size=batch["resizes"][0], |
| original_size=batch["orgsizes"][0], |
| ).squeeze(1) |
|
|
| frame_ious = frame_iou(pred_logits_hr, gt_masks) |
| frame_fscores = frame_fscore_proxy(pred_logits_hr, gt_masks) |
| pred_area = (torch.sigmoid(pred_logits_hr.float()) > 0.4).float().mean(dim=(1, 2)) |
| gt_area = gt_masks.float().mean(dim=(1, 2)) |
|
|
| shuffled_idx = choose_shuffled_idx(sample_idx, q_pool) |
| wrong_ref_idx = choose_wrong_ref_idx(sample_idx, q_pool) |
| q_controls = [ |
| ("real", real_q, sample_idx), |
| ("random", torch.randn(real_q.shape, device=real_q.device, generator=generator), None), |
| ] |
| if shuffled_idx is not None: |
| q_controls.append(("shuffled", q_lookup[shuffled_idx]["q"].cuda(), shuffled_idx)) |
| if wrong_ref_idx is not None: |
| q_controls.append(("wrong_ref", q_lookup[wrong_ref_idx]["q"].cuda(), wrong_ref_idx)) |
|
|
| for beta in betas: |
| for q_type, q, q_source_idx in q_controls: |
| pred_scores = d2_scores(image_embeddings, pred_mask64, q, beta) |
| gt_scores = d2_scores(image_embeddings, gt_mask64, q, beta) |
| base_info = { |
| "sample_idx": sample_idx, |
| "vid": item["vid"], |
| "ref": item["ref"], |
| "fid": item["fid"], |
| "split": args.eval_split, |
| "frame_iou": math.nan, |
| "frame_fscore_proxy": math.nan, |
| "iou_pred": math.nan, |
| "pred_area": math.nan, |
| "gt_area": math.nan, |
| } |
| for frame_idx in range(pred_scores.shape[0]): |
| base_info_frame = dict(base_info) |
| base_info_frame.update( |
| { |
| "frame_iou": frame_ious[frame_idx].item(), |
| "frame_fscore_proxy": frame_fscores[frame_idx].item(), |
| "iou_pred": iou_predictions[frame_idx].item(), |
| "pred_area": pred_area[frame_idx].item(), |
| "gt_area": gt_area[frame_idx].item(), |
| } |
| ) |
| row = dict(base_info_frame) |
| row.update( |
| { |
| "frame": frame_idx, |
| "q_type": q_type, |
| "beta": beta, |
| "s_pred": pred_scores[frame_idx].item(), |
| "s_gt": gt_scores[frame_idx].item(), |
| "q_source_idx": q_source_idx if q_source_idx is not None else "", |
| } |
| ) |
| rows.append(row) |
|
|
| real_rows = [ |
| r for r in rows if r["sample_idx"] == sample_idx and r["q_type"] == "real" and r["beta"] == betas[0] |
| ] |
| s_pred_values = [r["s_pred"] for r in real_rows] |
| print( |
| f"D2 {sample_idx}: vid={item['vid']} ref={item['ref']} " |
| f"mean_s_pred={np.mean(s_pred_values):.4f} min_s_pred={np.min(s_pred_values):.4f} " |
| f"mean_iou={frame_ious.mean().item():.4f}" |
| ) |
|
|
| return rows |
|
|
|
|
| def print_summary(rows): |
| real_rows = [r for r in rows if r["q_type"] == "real"] |
| if not real_rows: |
| return |
| by_beta = sorted(set(r["beta"] for r in real_rows)) |
| print("\nSummary") |
| print(f"rows: {len(rows)}") |
| for beta in by_beta: |
| beta_rows = [r for r in rows if r["beta"] == beta] |
| print(f"\nbeta={beta}") |
| for q_type in sorted(set(r["q_type"] for r in beta_rows)): |
| qr = [r for r in beta_rows if r["q_type"] == q_type] |
| print( |
| f"{q_type:10s} " |
| f"mean_s_pred={np.mean([r['s_pred'] for r in qr]):+.4f} " |
| f"mean_s_gt={np.mean([r['s_gt'] for r in qr]):+.4f}" |
| ) |
| real_beta = [r for r in beta_rows if r["q_type"] == "real"] |
| s_pred = np.array([r["s_pred"] for r in real_beta]) |
| frame_iou_values = np.array([r["frame_iou"] for r in real_beta]) |
| if len(s_pred) > 1 and np.std(s_pred) > 1e-8 and np.std(frame_iou_values) > 1e-8: |
| corr = np.corrcoef(s_pred, frame_iou_values)[0, 1] |
| print(f"corr(real s_pred, frame_iou)={corr:+.4f}") |
| else: |
| print("corr(real s_pred, frame_iou)=nan") |
|
|
|
|
| def main(): |
| set_seed(42) |
| torch.set_grad_enabled(False) |
| betas = parse_betas() |
| tokenizer, seg_token_idx = build_tokenizer() |
| limit = args.max_eval_rows if args.max_eval_rows > 0 else 30 |
| print(f"Split: {args.eval_split} | samples: {limit} | betas: {betas}") |
|
|
| model = build_model(tokenizer, seg_token_idx) |
| q_pool = collect_q_pool(model, tokenizer, limit) |
| rows = run_d2(model, tokenizer, q_pool, betas, limit) |
| print_summary(rows) |
|
|
| csv_path = os.environ.get("D2_BASIC_CSV", f"/workspace/SimToken/d2_basic_{args.eval_split}_{limit}.csv") |
| os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True) |
| with open(csv_path, "w", newline="") as f: |
| writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) |
| writer.writeheader() |
| writer.writerows(rows) |
| print(f"\nSaved CSV: {csv_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|