SimToken / tools /ec_simtoken_eval.py
yfan07's picture
Upload folder using huggingface_hub
9af2926 verified
#!/usr/bin/env python
"""EC-SimToken standalone evaluation: score distribution + threshold sweep.
Loads a saved checkpoint and reports:
1. p_exist distribution per split (mean/median/p10/p25/p75/p90)
2. AUC-ROC (test_n as null class vs test_s+test_u as positive class)
3. Threshold sweep 0.05β†’0.95: J&F, Null_S, null_tp_rate, positive_fnr
Usage:
cd /workspace/SimToken
python tools/ec_simtoken_eval.py \
--checkpoint checkpoints/ec_simtoken/ec_simtoken_v1_ep2.pth \
--out_dir runs/ec_simtoken/eval_ep2
"""
from __future__ import annotations
import argparse, os, sys
from functools import partial
import numpy as np
import torch
import transformers
from peft import LoraConfig, get_peft_model
from torch.utils.data import DataLoader
from transformers import AutoConfig
from tqdm import tqdm
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, ROOT)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from datasets.dataset_refavs import REFAVS
from models.ec_simtoken_model import ECSimtoken_ForCausalLM
from utils import utility
# ── Defaults (match training command) ────────────────────────────────────────
MLLM = "/workspace/hf_models/Chat-UniVi-7B-v1.5"
SAM_CKPT = "/workspace/SimToken/models/segment_anything/sam_vit_h_4b8939.pth"
VISION_TOWER = "/workspace/hf_models/clip-vit-large-patch14"
DATA_DIR = "data"
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
AUDIO_TOKEN_INDEX = -300
import re
def tokenizer_image_audio_token(prompt, tokenizer,
image_token_index=IMAGE_TOKEN_INDEX,
audio_token_index=AUDIO_TOKEN_INDEX,
num_frames=10, return_tensors=None):
prompt_chunks = re.split(r'(<image>|<audio>|<video>)', prompt)
prompt_chunks = [c for c in prompt_chunks if c]
text_chunks, token_types = [], []
for chunk in prompt_chunks:
if chunk == "<image>": token_types.append("image")
elif chunk == "<audio>": token_types.append("audio")
elif chunk == "<video>": token_types.append("video")
else: text_chunks.append(chunk)
tokenized_chunks = [tokenizer(c).input_ids for c in text_chunks]
input_ids, offset = [], 0
if tokenized_chunks and tokenized_chunks[0] and tokenized_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(tokenized_chunks[0][0])
min_len = min(len(text_chunks), len(token_types))
for i in range(min_len):
input_ids.extend(tokenized_chunks[i][offset:])
if token_types[i] == "image": input_ids.append(image_token_index)
elif token_types[i] == "audio": input_ids.append(audio_token_index)
elif token_types[i] == "video": input_ids.extend([image_token_index] * num_frames)
if len(text_chunks) > min_len:
input_ids.extend(tokenized_chunks[min_len][offset:])
if return_tensors == "pt":
return torch.tensor(input_ids, dtype=torch.long)
return input_ids
def collate_fn(batch, tokenizer=None):
vids, images, image_clips, masks, conversations = [], [], [], [], []
audio_feats, image_feats, resizes, orgsizes = [], [], [], []
refs, refs_num, fids = [], [], []
for data in batch:
vids.append(data["vid"]); images.append(data["image"])
image_clips.append(data["img_clip"]); masks.append(data["mask"])
conversations.append(data["conversation"])
audio_feats.append(data["feat_aud"]); resizes.append(data["resize"])
orgsizes.append(data["orgsize"]); image_feats.append(data["feat_sam"])
refs_num.append(len(data["ref"])); fids.append(data["fids"])
refs.append(data["ref"][0])
input_ids = [tokenizer_image_audio_token(c, tokenizer, return_tensors="pt")
for c in conversations]
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
attention_masks = input_ids.ne(tokenizer.pad_token_id)
ref_ids = [tokenizer_image_audio_token(r, tokenizer, return_tensors="pt")
for r in refs]
labels = input_ids.clone()
sep = "Sure, it is [SEG]"
for conversation, target in zip(conversations, labels):
parts = conversation.split(sep)
cur_len = 1; target[:cur_len] = IGNORE_INDEX
sep_len = len(tokenizer_image_audio_token(sep, tokenizer)) - 1
for i in range(len(parts) - 1):
part_len = len(tokenizer_image_audio_token(parts[i], tokenizer)) - 2
target[cur_len: cur_len + part_len] = IGNORE_INDEX
cur_len += part_len + sep_len
target[cur_len:] = IGNORE_INDEX
return {"vids": vids, "images": images, "images_clip": image_clips,
"masks": masks, "convs": conversations, "input_ids": input_ids,
"attention_masks": attention_masks, "labels": labels,
"audio_feats": audio_feats, "resizes": resizes, "orgsizes": orgsizes,
"image_feats": image_feats, "ref_ids": ref_ids,
"refs_num": refs_num, "fids": fids}
def dict_to_cuda(d):
for k, v in d.items():
if isinstance(v, torch.Tensor):
d[k] = v.cuda(non_blocking=True)
elif isinstance(v, list) and v and isinstance(v[0], torch.Tensor):
d[k] = [x.cuda(non_blocking=True) for x in v]
return d
def build_model(args, tokenizer, seg_token_idx):
model_args = {
"train_mask_decoder": True, "out_dim": 256,
"ce_loss_weight": 1.0, "dice_loss_weight": 0.5, "bce_loss_weight": 2.0,
"seg_token_idx": seg_token_idx,
"vision_pretrained": args.vision_pretrained,
"vision_tower": args.vision_tower,
"use_im_start_end": False, "compress": True, "start": 0,
"exist_loss_weight": 1.0,
}
model = ECSimtoken_ForCausalLM.from_pretrained(
args.mllm, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, **model_args)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.get_model().initialize_vision_modules(model.get_model().config)
vision_tower = model.get_model().get_vision_tower()
vision_tower.to(dtype=torch.bfloat16, device="cuda")
cfg_pt = AutoConfig.from_pretrained(args.mllm)
cfg_pt.use_cluster = True; cfg_pt.freeze = False; cfg_pt.mm_tune = True
cfg_pt.spatial_cluster_rate0 = 64; cfg_pt.spatial_cluster_rate1 = 32
cfg_pt.spatial_cluster_rate2 = 16; cfg_pt.temporal_cluster_rate = 0.0625
cfg_pt.vision_tune = False
model.get_model().initialize_cluster_modules(cfg_pt)
model.get_model().initialize_lisa_modules(model.get_model().config)
def find_linear_layers(m, targets):
names = set()
skip = {"visual_model", "vision_tower", "mm_projector",
"text_hidden_fcs", "audio_feature_layer", "existence_head"}
for name, mod in m.named_modules():
if (isinstance(mod, torch.nn.Linear)
and not any(s in name for s in skip)
and any(t in name for t in targets)):
names.add(name)
return sorted(names)
lora_config = LoraConfig(
r=8, lora_alpha=16,
target_modules=find_linear_layers(model, ["q_proj", "v_proj"]),
lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model = model.to("cuda").to(torch.bfloat16)
model.resize_token_embeddings(len(tokenizer))
return model
# ── Collect p_exist + metrics + per-sample masks (single inference pass) ──────
@torch.no_grad()
def collect(model, dataloader, split_name: str):
"""Single inference pass: returns p_exist array, aggregate metrics, and
per-sample (pred_mask, gt_mask) lists for the threshold sweep."""
model.eval()
all_p_exist = []
all_pred_masks = [] # list of CPU tensors [num_seg, T, H, W]
all_gt_masks = []
total_iou = total_f = count = 0.0
total_null_s = null_count = 0.0
for batch in tqdm(dataloader, desc=split_name, leave=False):
batch = dict_to_cuda(batch)
with torch.autocast("cuda", dtype=torch.bfloat16):
out = model.forward(
images=batch["images"], images_clip=batch["images_clip"],
audio_features=batch["audio_feats"], image_features=batch["image_feats"],
input_ids=batch["input_ids"], labels=batch["labels"],
attention_masks=batch["attention_masks"], masks_list=batch["masks"],
resize_list=batch["resizes"], orgsize_list=batch["orgsizes"],
conversation_list=batch["convs"], refs_num=batch["refs_num"],
fids=batch["fids"], vids=batch["vids"], ref_ids=batch["ref_ids"],
inference=True,
)
p_exist = torch.sigmoid(out["exist_logit"]).squeeze(-1).cpu().float()
all_p_exist.extend(p_exist.tolist())
pred_masks = out["pred_masks"]
gt_masks = out["gt_masks"]
for i in range(len(pred_masks)):
pred_i = pred_masks[i].cpu()
gt_i = gt_masks[i].cpu()
all_pred_masks.append(pred_i)
all_gt_masks.append(gt_i)
n = pred_i.shape[0] * pred_i.shape[1]
if split_name == "test_n":
s = utility.metric_s_for_null(pred_i)
total_null_s += s * n; null_count += n
else:
iou = utility.mask_iou(pred_i, gt_i)
f = utility.Eval_Fmeasure(pred_i, gt_i, None)
total_iou += iou * n; total_f += f * n; count += n
result = {
"p_exist": np.array(all_p_exist, dtype=np.float32),
"pred_masks": all_pred_masks,
"gt_masks": all_gt_masks,
"split": split_name,
}
if split_name == "test_n":
result["null_s_default"] = total_null_s / (null_count + 1e-8)
else:
result["miou"] = total_iou / (count + 1e-8)
result["fscore"] = total_f / (count + 1e-8)
return result
# ── Statistics ────────────────────────────────────────────────────────────────
def dist_stats(arr: np.ndarray) -> dict:
return {
"n": len(arr), "mean": arr.mean(), "median": np.median(arr),
"p10": np.percentile(arr, 10), "p25": np.percentile(arr, 25),
"p75": np.percentile(arr, 75), "p90": np.percentile(arr, 90),
"min": arr.min(), "max": arr.max(),
}
def auc_roc(null_scores: np.ndarray, pos_scores: np.ndarray) -> float:
"""AUC: P(null_score < pos_score). Lower p_exist = more null-like."""
try:
from sklearn.metrics import roc_auc_score
y = np.concatenate([np.zeros(len(null_scores)), np.ones(len(pos_scores))])
s = np.concatenate([null_scores, pos_scores])
return float(roc_auc_score(y, s))
except ImportError:
# O(n log n) manual AUC via sorting
null_sorted = np.sort(null_scores)
auc = 0.0
for ps in pos_scores:
auc += np.searchsorted(null_sorted, ps, side="right")
return float(auc) / (len(null_scores) * len(pos_scores))
# ── Threshold sweep ───────────────────────────────────────────────────────────
def threshold_sweep(null_p: np.ndarray, pos_p: np.ndarray,
pos_pred_masks, pos_gt_masks,
null_pred_masks):
"""
At each threshold t:
- null_tp_rate = # nulls with p_exist < t / len(null)
- positive_fnr = # pos with p_exist < t / len(pos)
- null_s(t) = metric_s over null samples (zero mask if detected null)
- pos_j_and_f = J&F over pos samples (zero mask if falsely detected null)
"""
thresholds = np.round(np.arange(0.05, 1.00, 0.05), 2)
rows = []
for t in thresholds:
null_tp = int((null_p < t).sum())
null_tp_rate = null_tp / len(null_p)
pos_fn = int((pos_p < t).sum())
pos_fnr = pos_fn / len(pos_p)
# Null_S at this threshold
total_ns = 0.0; ns_count = 0
for i, pm in enumerate(null_pred_masks):
if null_p[i] < t:
mask = torch.zeros_like(pm)
else:
mask = pm
n = pm.shape[0] * pm.shape[1]
total_ns += utility.metric_s_for_null(mask) * n
ns_count += n
null_s_t = total_ns / (ns_count + 1e-8)
# J&F at this threshold (pos samples)
total_iou = total_f = count = 0.0
for i, (pm, gm) in enumerate(zip(pos_pred_masks, pos_gt_masks)):
if pos_p[i] < t:
pm = torch.zeros_like(pm)
n = pm.shape[0] * pm.shape[1]
total_iou += utility.mask_iou(pm, gm) * n
total_f += utility.Eval_Fmeasure(pm, gm, None) * n
count += n
miou_t = total_iou / (count + 1e-8)
f_t = total_f / (count + 1e-8)
jf_t = (miou_t + f_t) / 2
rows.append({
"threshold": t,
"null_tp_rate": null_tp_rate,
"positive_fnr": pos_fnr,
"Null_S": null_s_t,
"pos_mIoU": miou_t,
"pos_F": f_t,
"pos_J&F": jf_t,
})
return rows
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", required=True)
parser.add_argument("--mllm", default=MLLM)
parser.add_argument("--vision_pretrained", default=SAM_CKPT)
parser.add_argument("--vision_tower", default=VISION_TOWER)
parser.add_argument("--data_dir", default=DATA_DIR)
parser.add_argument("--out_dir", default="runs/ec_simtoken/eval")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--num_workers", type=int, default=4)
args = parser.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
ep_tag = os.path.basename(args.checkpoint).replace(".pth", "")
out_path = os.path.join(args.out_dir, f"{ep_tag}_report.txt")
# ── Tokenizer ─────────────────────────────────────────────────────────────
print("Loading tokenizer …")
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.mllm, model_max_length=2048, padding_side="right", use_fast=False)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.add_tokens("[SEG]")
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
# ── Datasets ──────────────────────────────────────────────────────────────
from argparse import Namespace
cfg = Namespace(data_dir=args.data_dir, frame_n=10, text_max_len=25,
conv_template=1, vision_tower=args.vision_tower)
cfn = partial(collate_fn, tokenizer=tokenizer)
dl_kw = dict(batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, collate_fn=cfn,
pin_memory=True, persistent_workers=False)
ds_s = REFAVS("test_s", cfg, tokenizer, input_type="refer")
ds_u = REFAVS("test_u", cfg, tokenizer, input_type="refer")
ds_n = REFAVS("test_n", cfg, tokenizer, input_type="refer")
loader_s = DataLoader(ds_s, **dl_kw)
loader_u = DataLoader(ds_u, **dl_kw)
loader_n = DataLoader(ds_n, **dl_kw)
# ── Model ─────────────────────────────────────────────────────────────────
print("Building model …")
model = build_model(args, tokenizer, seg_token_idx)
ckpt = torch.load(args.checkpoint, map_location="cuda")
state = ckpt.get("model", ckpt)
missing, unexpected = model.load_state_dict(state, strict=False)
print(f"Loaded {args.checkpoint} missing={len(missing)} unexpected={len(unexpected)}")
model.eval()
# ── Collect ───────────────────────────────────────────────────────────────
print("Collecting test_s …")
res_s = collect(model, loader_s, "test_s")
print("Collecting test_u …")
res_u = collect(model, loader_u, "test_u")
print("Collecting test_n …")
res_n = collect(model, loader_n, "test_n")
lines = []
def log(s=""):
print(s); lines.append(s)
# ── Distribution ──────────────────────────────────────────────────────────
log(f"\n{'='*64}")
log(f"EC-SimToken Eval | {ep_tag}")
log(f"{'='*64}")
log("\n── p_exist distribution ─────────────────────────────────────")
hdr = f"{'split':<10} {'n':>6} {'mean':>6} {'med':>6} {'p10':>6} {'p25':>6} {'p75':>6} {'p90':>6} {'min':>6} {'max':>6}"
log(hdr)
for res, label in [(res_s, "test_s(+)"), (res_u, "test_u(+)"), (res_n, "test_n(null)")]:
st = dist_stats(res["p_exist"])
log(f"{label:<10} {st['n']:>6} {st['mean']:>6.3f} {st['median']:>6.3f} "
f"{st['p10']:>6.3f} {st['p25']:>6.3f} {st['p75']:>6.3f} {st['p90']:>6.3f} "
f"{st['min']:>6.3f} {st['max']:>6.3f}")
# ── AUC ───────────────────────────────────────────────────────────────────
pos_p = np.concatenate([res_s["p_exist"], res_u["p_exist"]])
null_p = res_n["p_exist"]
auc = auc_roc(null_p, pos_p)
log(f"\nAUC-ROC (null vs positive): {auc:.4f}")
log(" (0.5 = random, 1.0 = perfect separation)")
# ── Default-threshold metrics ─────────────────────────────────────────────
log(f"\n── Default threshold = 0.50 ──────────────────────────────────")
jf_s = (res_s["miou"] + res_s["fscore"]) / 2
jf_u = (res_u["miou"] + res_u["fscore"]) / 2
log(f" test_s mIoU={res_s['miou']:.4f} F={res_s['fscore']:.4f} J&F={jf_s:.4f}")
log(f" test_u mIoU={res_u['miou']:.4f} F={res_u['fscore']:.4f} J&F={jf_u:.4f}")
null_tp_50 = int((null_p < 0.5).sum())
log(f" test_n Null_S={res_n['null_s_default']:.4f} "
f"null_tp={null_tp_50}/{len(null_p)} ({100*null_tp_50/len(null_p):.1f}%)")
# ── Threshold sweep ───────────────────────────────────────────────────────
log(f"\n── Threshold sweep ───────────────────────────────────────────")
# Per-sample masks already cached from collect() β€” no second inference pass needed
pos_preds = res_s["pred_masks"] + res_u["pred_masks"]
pos_gts = res_s["gt_masks"] + res_u["gt_masks"]
pos_p2 = np.concatenate([res_s["p_exist"], res_u["p_exist"]])
null_preds_n = res_n["pred_masks"]
p_n = res_n["p_exist"]
sweep_rows = threshold_sweep(p_n, pos_p2, pos_preds, pos_gts, null_preds_n)
hdr2 = (f"{'thresh':>7} {'null_tp%':>9} {'pos_fnr%':>9} "
f"{'Null_S':>8} {'pos_J&F':>8} {'pos_mIoU':>9} {'pos_F':>7}")
log(hdr2)
log("-" * 65)
for r in sweep_rows:
flag = ""
# highlight: null_tp >= 30% AND positive_fnr <= 10%
if r["null_tp_rate"] >= 0.30 and r["positive_fnr"] <= 0.10:
flag = " ← candidate"
log(f"{r['threshold']:>7.2f} {100*r['null_tp_rate']:>8.1f}% {100*r['positive_fnr']:>8.1f}%"
f" {r['Null_S']:>8.4f} {r['pos_J&F']:>8.4f}"
f" {r['pos_mIoU']:>9.4f} {r['pos_F']:>7.4f}{flag}")
# ── Selection rule ────────────────────────────────────────────────────────
log(f"\n── Auto-selection (pos J&F drop ≀ 0.5 pt from default) ──────")
default_jf = (jf_s * len(res_s["p_exist"]) + jf_u * len(res_u["p_exist"])) / (
len(res_s["p_exist"]) + len(res_u["p_exist"]))
candidates = [r for r in sweep_rows
if default_jf - r["pos_J&F"] <= 0.005] # ≀ 0.5 pt
if candidates:
best = min(candidates, key=lambda r: r["Null_S"])
log(f" Best threshold = {best['threshold']:.2f}"
f" Null_S={best['Null_S']:.4f}"
f" null_tp={100*best['null_tp_rate']:.1f}%"
f" pos_fnr={100*best['positive_fnr']:.1f}%"
f" pos_J&F={best['pos_J&F']:.4f}")
else:
log(" No threshold meets J&F constraint β€” sweep shows extreme trade-off.")
# ── Save report ───────────────────────────────────────────────────────────
with open(out_path, "w") as f:
f.write("\n".join(lines))
print(f"\nReport saved: {out_path}")
if __name__ == "__main__":
try:
import torch.multiprocessing as mp
mp.set_start_method("spawn")
except RuntimeError:
pass
main()