| |
| """EC-SimToken standalone evaluation: score distribution + threshold sweep. |
| |
| Loads a saved checkpoint and reports: |
| 1. p_exist distribution per split (mean/median/p10/p25/p75/p90) |
| 2. AUC-ROC (test_n as null class vs test_s+test_u as positive class) |
| 3. Threshold sweep 0.05β0.95: J&F, Null_S, null_tp_rate, positive_fnr |
| |
| Usage: |
| cd /workspace/SimToken |
| python tools/ec_simtoken_eval.py \ |
| --checkpoint checkpoints/ec_simtoken/ec_simtoken_v1_ep2.pth \ |
| --out_dir runs/ec_simtoken/eval_ep2 |
| """ |
|
|
| from __future__ import annotations |
| import argparse, os, sys |
| from functools import partial |
|
|
| import numpy as np |
| import torch |
| import transformers |
| from peft import LoraConfig, get_peft_model |
| from torch.utils.data import DataLoader |
| from transformers import AutoConfig |
| from tqdm import tqdm |
|
|
| ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| sys.path.insert(0, ROOT) |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
| from datasets.dataset_refavs import REFAVS |
| from models.ec_simtoken_model import ECSimtoken_ForCausalLM |
| from utils import utility |
|
|
| |
| MLLM = "/workspace/hf_models/Chat-UniVi-7B-v1.5" |
| SAM_CKPT = "/workspace/SimToken/models/segment_anything/sam_vit_h_4b8939.pth" |
| VISION_TOWER = "/workspace/hf_models/clip-vit-large-patch14" |
| DATA_DIR = "data" |
|
|
| IGNORE_INDEX = -100 |
| IMAGE_TOKEN_INDEX = -200 |
| AUDIO_TOKEN_INDEX = -300 |
|
|
| import re |
|
|
| def tokenizer_image_audio_token(prompt, tokenizer, |
| image_token_index=IMAGE_TOKEN_INDEX, |
| audio_token_index=AUDIO_TOKEN_INDEX, |
| num_frames=10, return_tensors=None): |
| prompt_chunks = re.split(r'(<image>|<audio>|<video>)', prompt) |
| prompt_chunks = [c for c in prompt_chunks if c] |
| text_chunks, token_types = [], [] |
| for chunk in prompt_chunks: |
| if chunk == "<image>": token_types.append("image") |
| elif chunk == "<audio>": token_types.append("audio") |
| elif chunk == "<video>": token_types.append("video") |
| else: text_chunks.append(chunk) |
| tokenized_chunks = [tokenizer(c).input_ids for c in text_chunks] |
| input_ids, offset = [], 0 |
| if tokenized_chunks and tokenized_chunks[0] and tokenized_chunks[0][0] == tokenizer.bos_token_id: |
| offset = 1 |
| input_ids.append(tokenized_chunks[0][0]) |
| min_len = min(len(text_chunks), len(token_types)) |
| for i in range(min_len): |
| input_ids.extend(tokenized_chunks[i][offset:]) |
| if token_types[i] == "image": input_ids.append(image_token_index) |
| elif token_types[i] == "audio": input_ids.append(audio_token_index) |
| elif token_types[i] == "video": input_ids.extend([image_token_index] * num_frames) |
| if len(text_chunks) > min_len: |
| input_ids.extend(tokenized_chunks[min_len][offset:]) |
| if return_tensors == "pt": |
| return torch.tensor(input_ids, dtype=torch.long) |
| return input_ids |
|
|
|
|
| def collate_fn(batch, tokenizer=None): |
| vids, images, image_clips, masks, conversations = [], [], [], [], [] |
| audio_feats, image_feats, resizes, orgsizes = [], [], [], [] |
| refs, refs_num, fids = [], [], [] |
| for data in batch: |
| vids.append(data["vid"]); images.append(data["image"]) |
| image_clips.append(data["img_clip"]); masks.append(data["mask"]) |
| conversations.append(data["conversation"]) |
| audio_feats.append(data["feat_aud"]); resizes.append(data["resize"]) |
| orgsizes.append(data["orgsize"]); image_feats.append(data["feat_sam"]) |
| refs_num.append(len(data["ref"])); fids.append(data["fids"]) |
| refs.append(data["ref"][0]) |
| input_ids = [tokenizer_image_audio_token(c, tokenizer, return_tensors="pt") |
| for c in conversations] |
| input_ids = torch.nn.utils.rnn.pad_sequence( |
| input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) |
| attention_masks = input_ids.ne(tokenizer.pad_token_id) |
| ref_ids = [tokenizer_image_audio_token(r, tokenizer, return_tensors="pt") |
| for r in refs] |
| labels = input_ids.clone() |
| sep = "Sure, it is [SEG]" |
| for conversation, target in zip(conversations, labels): |
| parts = conversation.split(sep) |
| cur_len = 1; target[:cur_len] = IGNORE_INDEX |
| sep_len = len(tokenizer_image_audio_token(sep, tokenizer)) - 1 |
| for i in range(len(parts) - 1): |
| part_len = len(tokenizer_image_audio_token(parts[i], tokenizer)) - 2 |
| target[cur_len: cur_len + part_len] = IGNORE_INDEX |
| cur_len += part_len + sep_len |
| target[cur_len:] = IGNORE_INDEX |
| return {"vids": vids, "images": images, "images_clip": image_clips, |
| "masks": masks, "convs": conversations, "input_ids": input_ids, |
| "attention_masks": attention_masks, "labels": labels, |
| "audio_feats": audio_feats, "resizes": resizes, "orgsizes": orgsizes, |
| "image_feats": image_feats, "ref_ids": ref_ids, |
| "refs_num": refs_num, "fids": fids} |
|
|
|
|
| def dict_to_cuda(d): |
| for k, v in d.items(): |
| if isinstance(v, torch.Tensor): |
| d[k] = v.cuda(non_blocking=True) |
| elif isinstance(v, list) and v and isinstance(v[0], torch.Tensor): |
| d[k] = [x.cuda(non_blocking=True) for x in v] |
| return d |
|
|
|
|
| def build_model(args, tokenizer, seg_token_idx): |
| model_args = { |
| "train_mask_decoder": True, "out_dim": 256, |
| "ce_loss_weight": 1.0, "dice_loss_weight": 0.5, "bce_loss_weight": 2.0, |
| "seg_token_idx": seg_token_idx, |
| "vision_pretrained": args.vision_pretrained, |
| "vision_tower": args.vision_tower, |
| "use_im_start_end": False, "compress": True, "start": 0, |
| "exist_loss_weight": 1.0, |
| } |
| model = ECSimtoken_ForCausalLM.from_pretrained( |
| args.mllm, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, **model_args) |
| model.config.eos_token_id = tokenizer.eos_token_id |
| model.config.bos_token_id = tokenizer.bos_token_id |
| model.config.pad_token_id = tokenizer.pad_token_id |
|
|
| model.get_model().initialize_vision_modules(model.get_model().config) |
| vision_tower = model.get_model().get_vision_tower() |
| vision_tower.to(dtype=torch.bfloat16, device="cuda") |
|
|
| cfg_pt = AutoConfig.from_pretrained(args.mllm) |
| cfg_pt.use_cluster = True; cfg_pt.freeze = False; cfg_pt.mm_tune = True |
| cfg_pt.spatial_cluster_rate0 = 64; cfg_pt.spatial_cluster_rate1 = 32 |
| cfg_pt.spatial_cluster_rate2 = 16; cfg_pt.temporal_cluster_rate = 0.0625 |
| cfg_pt.vision_tune = False |
| model.get_model().initialize_cluster_modules(cfg_pt) |
| model.get_model().initialize_lisa_modules(model.get_model().config) |
|
|
| def find_linear_layers(m, targets): |
| names = set() |
| skip = {"visual_model", "vision_tower", "mm_projector", |
| "text_hidden_fcs", "audio_feature_layer", "existence_head"} |
| for name, mod in m.named_modules(): |
| if (isinstance(mod, torch.nn.Linear) |
| and not any(s in name for s in skip) |
| and any(t in name for t in targets)): |
| names.add(name) |
| return sorted(names) |
|
|
| lora_config = LoraConfig( |
| r=8, lora_alpha=16, |
| target_modules=find_linear_layers(model, ["q_proj", "v_proj"]), |
| lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", |
| ) |
| model = get_peft_model(model, lora_config) |
| model = model.to("cuda").to(torch.bfloat16) |
| model.resize_token_embeddings(len(tokenizer)) |
| return model |
|
|
|
|
| |
|
|
| @torch.no_grad() |
| def collect(model, dataloader, split_name: str): |
| """Single inference pass: returns p_exist array, aggregate metrics, and |
| per-sample (pred_mask, gt_mask) lists for the threshold sweep.""" |
| model.eval() |
| all_p_exist = [] |
| all_pred_masks = [] |
| all_gt_masks = [] |
| total_iou = total_f = count = 0.0 |
| total_null_s = null_count = 0.0 |
|
|
| for batch in tqdm(dataloader, desc=split_name, leave=False): |
| batch = dict_to_cuda(batch) |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| out = model.forward( |
| images=batch["images"], images_clip=batch["images_clip"], |
| audio_features=batch["audio_feats"], image_features=batch["image_feats"], |
| input_ids=batch["input_ids"], labels=batch["labels"], |
| attention_masks=batch["attention_masks"], masks_list=batch["masks"], |
| resize_list=batch["resizes"], orgsize_list=batch["orgsizes"], |
| conversation_list=batch["convs"], refs_num=batch["refs_num"], |
| fids=batch["fids"], vids=batch["vids"], ref_ids=batch["ref_ids"], |
| inference=True, |
| ) |
| p_exist = torch.sigmoid(out["exist_logit"]).squeeze(-1).cpu().float() |
| all_p_exist.extend(p_exist.tolist()) |
|
|
| pred_masks = out["pred_masks"] |
| gt_masks = out["gt_masks"] |
| for i in range(len(pred_masks)): |
| pred_i = pred_masks[i].cpu() |
| gt_i = gt_masks[i].cpu() |
| all_pred_masks.append(pred_i) |
| all_gt_masks.append(gt_i) |
| n = pred_i.shape[0] * pred_i.shape[1] |
| if split_name == "test_n": |
| s = utility.metric_s_for_null(pred_i) |
| total_null_s += s * n; null_count += n |
| else: |
| iou = utility.mask_iou(pred_i, gt_i) |
| f = utility.Eval_Fmeasure(pred_i, gt_i, None) |
| total_iou += iou * n; total_f += f * n; count += n |
|
|
| result = { |
| "p_exist": np.array(all_p_exist, dtype=np.float32), |
| "pred_masks": all_pred_masks, |
| "gt_masks": all_gt_masks, |
| "split": split_name, |
| } |
| if split_name == "test_n": |
| result["null_s_default"] = total_null_s / (null_count + 1e-8) |
| else: |
| result["miou"] = total_iou / (count + 1e-8) |
| result["fscore"] = total_f / (count + 1e-8) |
| return result |
|
|
|
|
| |
|
|
| def dist_stats(arr: np.ndarray) -> dict: |
| return { |
| "n": len(arr), "mean": arr.mean(), "median": np.median(arr), |
| "p10": np.percentile(arr, 10), "p25": np.percentile(arr, 25), |
| "p75": np.percentile(arr, 75), "p90": np.percentile(arr, 90), |
| "min": arr.min(), "max": arr.max(), |
| } |
|
|
|
|
| def auc_roc(null_scores: np.ndarray, pos_scores: np.ndarray) -> float: |
| """AUC: P(null_score < pos_score). Lower p_exist = more null-like.""" |
| try: |
| from sklearn.metrics import roc_auc_score |
| y = np.concatenate([np.zeros(len(null_scores)), np.ones(len(pos_scores))]) |
| s = np.concatenate([null_scores, pos_scores]) |
| return float(roc_auc_score(y, s)) |
| except ImportError: |
| |
| null_sorted = np.sort(null_scores) |
| auc = 0.0 |
| for ps in pos_scores: |
| auc += np.searchsorted(null_sorted, ps, side="right") |
| return float(auc) / (len(null_scores) * len(pos_scores)) |
|
|
|
|
| |
|
|
| def threshold_sweep(null_p: np.ndarray, pos_p: np.ndarray, |
| pos_pred_masks, pos_gt_masks, |
| null_pred_masks): |
| """ |
| At each threshold t: |
| - null_tp_rate = # nulls with p_exist < t / len(null) |
| - positive_fnr = # pos with p_exist < t / len(pos) |
| - null_s(t) = metric_s over null samples (zero mask if detected null) |
| - pos_j_and_f = J&F over pos samples (zero mask if falsely detected null) |
| """ |
| thresholds = np.round(np.arange(0.05, 1.00, 0.05), 2) |
| rows = [] |
| for t in thresholds: |
| null_tp = int((null_p < t).sum()) |
| null_tp_rate = null_tp / len(null_p) |
| pos_fn = int((pos_p < t).sum()) |
| pos_fnr = pos_fn / len(pos_p) |
|
|
| |
| total_ns = 0.0; ns_count = 0 |
| for i, pm in enumerate(null_pred_masks): |
| if null_p[i] < t: |
| mask = torch.zeros_like(pm) |
| else: |
| mask = pm |
| n = pm.shape[0] * pm.shape[1] |
| total_ns += utility.metric_s_for_null(mask) * n |
| ns_count += n |
| null_s_t = total_ns / (ns_count + 1e-8) |
|
|
| |
| total_iou = total_f = count = 0.0 |
| for i, (pm, gm) in enumerate(zip(pos_pred_masks, pos_gt_masks)): |
| if pos_p[i] < t: |
| pm = torch.zeros_like(pm) |
| n = pm.shape[0] * pm.shape[1] |
| total_iou += utility.mask_iou(pm, gm) * n |
| total_f += utility.Eval_Fmeasure(pm, gm, None) * n |
| count += n |
| miou_t = total_iou / (count + 1e-8) |
| f_t = total_f / (count + 1e-8) |
| jf_t = (miou_t + f_t) / 2 |
|
|
| rows.append({ |
| "threshold": t, |
| "null_tp_rate": null_tp_rate, |
| "positive_fnr": pos_fnr, |
| "Null_S": null_s_t, |
| "pos_mIoU": miou_t, |
| "pos_F": f_t, |
| "pos_J&F": jf_t, |
| }) |
| return rows |
|
|
|
|
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", required=True) |
| parser.add_argument("--mllm", default=MLLM) |
| parser.add_argument("--vision_pretrained", default=SAM_CKPT) |
| parser.add_argument("--vision_tower", default=VISION_TOWER) |
| parser.add_argument("--data_dir", default=DATA_DIR) |
| parser.add_argument("--out_dir", default="runs/ec_simtoken/eval") |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--num_workers", type=int, default=4) |
| args = parser.parse_args() |
|
|
| os.makedirs(args.out_dir, exist_ok=True) |
| ep_tag = os.path.basename(args.checkpoint).replace(".pth", "") |
| out_path = os.path.join(args.out_dir, f"{ep_tag}_report.txt") |
|
|
| |
| print("Loading tokenizer β¦") |
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| args.mllm, model_max_length=2048, padding_side="right", use_fast=False) |
| tokenizer.pad_token = tokenizer.unk_token |
| tokenizer.add_tokens("[SEG]") |
| seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] |
|
|
| |
| from argparse import Namespace |
| cfg = Namespace(data_dir=args.data_dir, frame_n=10, text_max_len=25, |
| conv_template=1, vision_tower=args.vision_tower) |
| cfn = partial(collate_fn, tokenizer=tokenizer) |
| dl_kw = dict(batch_size=args.batch_size, shuffle=False, |
| num_workers=args.num_workers, collate_fn=cfn, |
| pin_memory=True, persistent_workers=False) |
|
|
| ds_s = REFAVS("test_s", cfg, tokenizer, input_type="refer") |
| ds_u = REFAVS("test_u", cfg, tokenizer, input_type="refer") |
| ds_n = REFAVS("test_n", cfg, tokenizer, input_type="refer") |
| loader_s = DataLoader(ds_s, **dl_kw) |
| loader_u = DataLoader(ds_u, **dl_kw) |
| loader_n = DataLoader(ds_n, **dl_kw) |
|
|
| |
| print("Building model β¦") |
| model = build_model(args, tokenizer, seg_token_idx) |
| ckpt = torch.load(args.checkpoint, map_location="cuda") |
| state = ckpt.get("model", ckpt) |
| missing, unexpected = model.load_state_dict(state, strict=False) |
| print(f"Loaded {args.checkpoint} missing={len(missing)} unexpected={len(unexpected)}") |
| model.eval() |
|
|
| |
| print("Collecting test_s β¦") |
| res_s = collect(model, loader_s, "test_s") |
| print("Collecting test_u β¦") |
| res_u = collect(model, loader_u, "test_u") |
| print("Collecting test_n β¦") |
| res_n = collect(model, loader_n, "test_n") |
|
|
| lines = [] |
| def log(s=""): |
| print(s); lines.append(s) |
|
|
| |
| log(f"\n{'='*64}") |
| log(f"EC-SimToken Eval | {ep_tag}") |
| log(f"{'='*64}") |
|
|
| log("\nββ p_exist distribution βββββββββββββββββββββββββββββββββββββ") |
| hdr = f"{'split':<10} {'n':>6} {'mean':>6} {'med':>6} {'p10':>6} {'p25':>6} {'p75':>6} {'p90':>6} {'min':>6} {'max':>6}" |
| log(hdr) |
| for res, label in [(res_s, "test_s(+)"), (res_u, "test_u(+)"), (res_n, "test_n(null)")]: |
| st = dist_stats(res["p_exist"]) |
| log(f"{label:<10} {st['n']:>6} {st['mean']:>6.3f} {st['median']:>6.3f} " |
| f"{st['p10']:>6.3f} {st['p25']:>6.3f} {st['p75']:>6.3f} {st['p90']:>6.3f} " |
| f"{st['min']:>6.3f} {st['max']:>6.3f}") |
|
|
| |
| pos_p = np.concatenate([res_s["p_exist"], res_u["p_exist"]]) |
| null_p = res_n["p_exist"] |
| auc = auc_roc(null_p, pos_p) |
| log(f"\nAUC-ROC (null vs positive): {auc:.4f}") |
| log(" (0.5 = random, 1.0 = perfect separation)") |
|
|
| |
| log(f"\nββ Default threshold = 0.50 ββββββββββββββββββββββββββββββββββ") |
| jf_s = (res_s["miou"] + res_s["fscore"]) / 2 |
| jf_u = (res_u["miou"] + res_u["fscore"]) / 2 |
| log(f" test_s mIoU={res_s['miou']:.4f} F={res_s['fscore']:.4f} J&F={jf_s:.4f}") |
| log(f" test_u mIoU={res_u['miou']:.4f} F={res_u['fscore']:.4f} J&F={jf_u:.4f}") |
| null_tp_50 = int((null_p < 0.5).sum()) |
| log(f" test_n Null_S={res_n['null_s_default']:.4f} " |
| f"null_tp={null_tp_50}/{len(null_p)} ({100*null_tp_50/len(null_p):.1f}%)") |
|
|
| |
| log(f"\nββ Threshold sweep βββββββββββββββββββββββββββββββββββββββββββ") |
|
|
| |
| pos_preds = res_s["pred_masks"] + res_u["pred_masks"] |
| pos_gts = res_s["gt_masks"] + res_u["gt_masks"] |
| pos_p2 = np.concatenate([res_s["p_exist"], res_u["p_exist"]]) |
| null_preds_n = res_n["pred_masks"] |
| p_n = res_n["p_exist"] |
|
|
| sweep_rows = threshold_sweep(p_n, pos_p2, pos_preds, pos_gts, null_preds_n) |
|
|
| hdr2 = (f"{'thresh':>7} {'null_tp%':>9} {'pos_fnr%':>9} " |
| f"{'Null_S':>8} {'pos_J&F':>8} {'pos_mIoU':>9} {'pos_F':>7}") |
| log(hdr2) |
| log("-" * 65) |
| for r in sweep_rows: |
| flag = "" |
| |
| if r["null_tp_rate"] >= 0.30 and r["positive_fnr"] <= 0.10: |
| flag = " β candidate" |
| log(f"{r['threshold']:>7.2f} {100*r['null_tp_rate']:>8.1f}% {100*r['positive_fnr']:>8.1f}%" |
| f" {r['Null_S']:>8.4f} {r['pos_J&F']:>8.4f}" |
| f" {r['pos_mIoU']:>9.4f} {r['pos_F']:>7.4f}{flag}") |
|
|
| |
| log(f"\nββ Auto-selection (pos J&F drop β€ 0.5 pt from default) ββββββ") |
| default_jf = (jf_s * len(res_s["p_exist"]) + jf_u * len(res_u["p_exist"])) / ( |
| len(res_s["p_exist"]) + len(res_u["p_exist"])) |
| candidates = [r for r in sweep_rows |
| if default_jf - r["pos_J&F"] <= 0.005] |
| if candidates: |
| best = min(candidates, key=lambda r: r["Null_S"]) |
| log(f" Best threshold = {best['threshold']:.2f}" |
| f" Null_S={best['Null_S']:.4f}" |
| f" null_tp={100*best['null_tp_rate']:.1f}%" |
| f" pos_fnr={100*best['positive_fnr']:.1f}%" |
| f" pos_J&F={best['pos_J&F']:.4f}") |
| else: |
| log(" No threshold meets J&F constraint β sweep shows extreme trade-off.") |
|
|
| |
| with open(out_path, "w") as f: |
| f.write("\n".join(lines)) |
| print(f"\nReport saved: {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| try: |
| import torch.multiprocessing as mp |
| mp.set_start_method("spawn") |
| except RuntimeError: |
| pass |
| main() |
|
|