SimToken / d2_basic.py
yfan07's picture
Add files using upload-large-folder tool
65fb4ac verified
raw
history blame
12.6 kB
import csv
import math
import os
from functools import partial
import numpy as np
import torch
import torch.nn.functional as F
import transformers
from torch.utils.data import DataLoader
from configs import args
from datasets import REFAVS
from decoder_invariance_check import build_model, set_seed
from load_model import collate_fn, dict_to_cuda
def make_loader(tokenizer):
dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer")
return DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=partial(collate_fn, tokenizer=tokenizer),
)
def build_tokenizer():
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.mllm,
cache_dir=None,
model_max_length=2048,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.add_tokens("[SEG]")
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
return tokenizer, seg_token_idx
def get_q(model, batch):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model.forward(
images=batch["images"],
images_clip=batch["images_clip"],
audio_features=batch["audio_feats"],
image_features=batch["image_feats"],
input_ids=batch["input_ids"],
labels=batch["labels"],
attention_masks=batch["attention_masks"],
masks_list=batch["masks"],
resize_list=batch["resizes"],
orgsize_list=batch["orgsizes"],
conversation_list=batch["convs"],
refs_num=batch["refs_num"],
fids=batch["fids"],
vids=batch["vids"],
contrast=args.ct_weight,
ref_ids=batch["ref_ids"],
inference=True,
)
return output["seg_embeddings"][0][0].float()
def decode_low_res(model, batch, q):
visual_model = model.get_model().visual_model
sparse, dense = visual_model.prompt_encoder(
points=None,
boxes=None,
masks=None,
text_embeds=q.view(1, 1, -1).to(next(visual_model.parameters()).dtype),
)
sparse = sparse.to(q.dtype)
dense = dense.to(q.dtype)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
low_res_masks, iou_predictions = visual_model.mask_decoder(
image_embeddings=batch["image_feats"][0],
image_pe=visual_model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse,
dense_prompt_embeddings=dense,
multimask_output=False,
)
return low_res_masks.float(), iou_predictions.float().squeeze(-1)
def masks_to_64(mask_logits_or_binary):
if mask_logits_or_binary.ndim == 3:
mask_logits_or_binary = mask_logits_or_binary.unsqueeze(1)
return F.interpolate(
mask_logits_or_binary.float(),
size=(64, 64),
mode="bilinear",
align_corners=False,
).clamp(0.0, 1.0)
def d2_scores(image_embeddings, mask64, q, beta):
feats = image_embeddings.float()
if mask64.shape[0] != feats.shape[0]:
raise ValueError(f"Mask/frame mismatch: {mask64.shape} vs {feats.shape}")
q = F.normalize(q.float().view(1, -1), dim=-1)
mask = mask64.float()
comp = 1.0 - mask
z_in = (feats * mask).sum(dim=(2, 3)) / mask.sum(dim=(2, 3)).clamp_min(1e-6)
z_out = (feats * comp).sum(dim=(2, 3)) / comp.sum(dim=(2, 3)).clamp_min(1e-6)
z_in = F.normalize(z_in, dim=-1)
z_out = F.normalize(z_out, dim=-1)
return (z_in @ q.T).squeeze(-1) - beta * (z_out @ q.T).squeeze(-1)
def frame_iou(pred_logits, gt_masks):
pred = (torch.sigmoid(pred_logits.float()) > 0.4).float()
gt = gt_masks.float()
if pred.ndim == 4:
pred = pred.squeeze(1)
inter = (pred * gt).sum(dim=(1, 2))
union = torch.maximum(pred, gt).sum(dim=(1, 2))
num_pixels = pred.shape[-1] * pred.shape[-2]
no_obj = gt.sum(dim=(1, 2)) == 0
inter_no_obj = ((1.0 - pred) * (1.0 - gt)).sum(dim=(1, 2))
inter = torch.where(no_obj, inter_no_obj, inter)
union = torch.where(no_obj, torch.full_like(union, float(num_pixels)), union)
return inter / union.clamp_min(1e-7)
def frame_fscore_proxy(pred_logits, gt_masks):
pred = (torch.sigmoid(pred_logits.float()) > 0.4).float()
gt = gt_masks.float()
if pred.ndim == 4:
pred = pred.squeeze(1)
tp = (pred * gt).sum(dim=(1, 2))
precision = tp / pred.sum(dim=(1, 2)).clamp_min(1e-7)
recall = tp / gt.sum(dim=(1, 2)).clamp_min(1e-7)
beta2 = 0.3
fscore = (1 + beta2) * precision * recall / (beta2 * precision + recall).clamp_min(1e-7)
no_obj = gt.sum(dim=(1, 2)) == 0
return torch.where(no_obj, torch.zeros_like(fscore), fscore)
def parse_betas():
raw = os.environ.get("D2_BETAS", "0.5")
return [float(x.strip()) for x in raw.split(",") if x.strip()]
def collect_q_pool(model, tokenizer, limit):
q_pool = []
loader = make_loader(tokenizer)
for sample_idx, batch in enumerate(loader):
if sample_idx >= limit:
break
batch = dict_to_cuda(batch)
q = get_q(model, batch)
q_pool.append(
{
"sample_idx": sample_idx,
"vid": batch["vids"][0],
"ref": batch["refs"][0][0],
"fid": int(batch["fids"][0][0]),
"q": q.cpu(),
}
)
print(f"Collected q {sample_idx}: vid={q_pool[-1]['vid']} ref={q_pool[-1]['ref']}")
if not q_pool:
raise RuntimeError("No q vectors collected. Is the selected split empty?")
return q_pool
def choose_shuffled_idx(sample_idx, q_pool):
if len(q_pool) <= 1:
return None
return (sample_idx + 1) % len(q_pool)
def choose_wrong_ref_idx(sample_idx, q_pool):
current = q_pool[sample_idx]
for item in q_pool:
if item["sample_idx"] == sample_idx:
continue
if item["vid"] == current["vid"] and item["fid"] != current["fid"]:
return item["sample_idx"]
for item in q_pool:
if item["sample_idx"] == sample_idx:
continue
if item["vid"] == current["vid"] and item["ref"] != current["ref"]:
return item["sample_idx"]
return None
def run_d2(model, tokenizer, q_pool, betas, limit):
rows = []
loader = make_loader(tokenizer)
q_lookup = {item["sample_idx"]: item for item in q_pool}
generator = torch.Generator(device="cuda")
generator.manual_seed(1234)
for sample_idx, batch in enumerate(loader):
if sample_idx >= limit:
break
batch = dict_to_cuda(batch)
item = q_lookup[sample_idx]
real_q = item["q"].cuda()
low_res_masks, iou_predictions = decode_low_res(model, batch, real_q)
pred_mask64 = masks_to_64(torch.sigmoid(low_res_masks))
gt_masks = batch["masks"][0][0].float()
gt_mask64 = masks_to_64(gt_masks)
image_embeddings = batch["image_feats"][0].float()
pred_logits_hr = model.get_model().visual_model.postprocess_masks(
low_res_masks.to(batch["image_feats"][0].dtype),
input_size=batch["resizes"][0],
original_size=batch["orgsizes"][0],
).squeeze(1)
frame_ious = frame_iou(pred_logits_hr, gt_masks)
frame_fscores = frame_fscore_proxy(pred_logits_hr, gt_masks)
pred_area = (torch.sigmoid(pred_logits_hr.float()) > 0.4).float().mean(dim=(1, 2))
gt_area = gt_masks.float().mean(dim=(1, 2))
shuffled_idx = choose_shuffled_idx(sample_idx, q_pool)
wrong_ref_idx = choose_wrong_ref_idx(sample_idx, q_pool)
q_controls = [
("real", real_q, sample_idx),
("random", torch.randn(real_q.shape, device=real_q.device, generator=generator), None),
]
if shuffled_idx is not None:
q_controls.append(("shuffled", q_lookup[shuffled_idx]["q"].cuda(), shuffled_idx))
if wrong_ref_idx is not None:
q_controls.append(("wrong_ref", q_lookup[wrong_ref_idx]["q"].cuda(), wrong_ref_idx))
for beta in betas:
for q_type, q, q_source_idx in q_controls:
pred_scores = d2_scores(image_embeddings, pred_mask64, q, beta)
gt_scores = d2_scores(image_embeddings, gt_mask64, q, beta)
base_info = {
"sample_idx": sample_idx,
"vid": item["vid"],
"ref": item["ref"],
"fid": item["fid"],
"split": args.eval_split,
"frame_iou": math.nan,
"frame_fscore_proxy": math.nan,
"iou_pred": math.nan,
"pred_area": math.nan,
"gt_area": math.nan,
}
for frame_idx in range(pred_scores.shape[0]):
base_info_frame = dict(base_info)
base_info_frame.update(
{
"frame_iou": frame_ious[frame_idx].item(),
"frame_fscore_proxy": frame_fscores[frame_idx].item(),
"iou_pred": iou_predictions[frame_idx].item(),
"pred_area": pred_area[frame_idx].item(),
"gt_area": gt_area[frame_idx].item(),
}
)
row = dict(base_info_frame)
row.update(
{
"frame": frame_idx,
"q_type": q_type,
"beta": beta,
"s_pred": pred_scores[frame_idx].item(),
"s_gt": gt_scores[frame_idx].item(),
"q_source_idx": q_source_idx if q_source_idx is not None else "",
}
)
rows.append(row)
real_rows = [
r for r in rows if r["sample_idx"] == sample_idx and r["q_type"] == "real" and r["beta"] == betas[0]
]
s_pred_values = [r["s_pred"] for r in real_rows]
print(
f"D2 {sample_idx}: vid={item['vid']} ref={item['ref']} "
f"mean_s_pred={np.mean(s_pred_values):.4f} min_s_pred={np.min(s_pred_values):.4f} "
f"mean_iou={frame_ious.mean().item():.4f}"
)
return rows
def print_summary(rows):
real_rows = [r for r in rows if r["q_type"] == "real"]
if not real_rows:
return
by_beta = sorted(set(r["beta"] for r in real_rows))
print("\nSummary")
print(f"rows: {len(rows)}")
for beta in by_beta:
beta_rows = [r for r in rows if r["beta"] == beta]
print(f"\nbeta={beta}")
for q_type in sorted(set(r["q_type"] for r in beta_rows)):
qr = [r for r in beta_rows if r["q_type"] == q_type]
print(
f"{q_type:10s} "
f"mean_s_pred={np.mean([r['s_pred'] for r in qr]):+.4f} "
f"mean_s_gt={np.mean([r['s_gt'] for r in qr]):+.4f}"
)
real_beta = [r for r in beta_rows if r["q_type"] == "real"]
s_pred = np.array([r["s_pred"] for r in real_beta])
frame_iou_values = np.array([r["frame_iou"] for r in real_beta])
if len(s_pred) > 1 and np.std(s_pred) > 1e-8 and np.std(frame_iou_values) > 1e-8:
corr = np.corrcoef(s_pred, frame_iou_values)[0, 1]
print(f"corr(real s_pred, frame_iou)={corr:+.4f}")
else:
print("corr(real s_pred, frame_iou)=nan")
def main():
set_seed(42)
torch.set_grad_enabled(False)
betas = parse_betas()
tokenizer, seg_token_idx = build_tokenizer()
limit = args.max_eval_rows if args.max_eval_rows > 0 else 30
print(f"Split: {args.eval_split} | samples: {limit} | betas: {betas}")
model = build_model(tokenizer, seg_token_idx)
q_pool = collect_q_pool(model, tokenizer, limit)
rows = run_d2(model, tokenizer, q_pool, betas, limit)
print_summary(rows)
csv_path = os.environ.get("D2_BASIC_CSV", f"/workspace/SimToken/d2_basic_{args.eval_split}_{limit}.csv")
os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True)
with open(csv_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
writer.writeheader()
writer.writerows(rows)
print(f"\nSaved CSV: {csv_path}")
if __name__ == "__main__":
main()