Image Segmentation
Transformers
PyTorch
pixdlm
cvpr-2026
compute-transparency
reasoning-segmentation
uav
remote-sensing
vision-language
Instructions to use WhynotHug/PixDLM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use WhynotHug/PixDLM with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-segmentation", model="WhynotHug/PixDLM")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("WhynotHug/PixDLM", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import argparse | |
| import os | |
| import shutil | |
| import sys | |
| import time | |
| from functools import partial | |
| import logging | |
| import deepspeed | |
| import numpy as np | |
| import torch | |
| import tqdm | |
| import transformers | |
| import copy | |
| from peft import LoraConfig, get_peft_model | |
| from torch.utils.tensorboard import SummaryWriter | |
| import torch.distributed as dist | |
| from model.PixDLM import PixDLMForCausalLM | |
| from model.llava import conversation as conversation_lib | |
| from utils.dataset import HybridDataset, ValDataset, collate_fn | |
| from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, | |
| AverageMeter, ProgressMeter, Summary, dict_to_cuda, | |
| intersectionAndUnionGPU) | |
| from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, | |
| DEFAULT_IMAGE_PATCH_TOKEN) | |
| from utils.matcher import match_pred | |
| from utils.multi_reason_seg_val_dataset import MultiReasonSegValDataset | |
| from model.llava.mm_utils import tokenizer_image_token | |
| import requests | |
| import json | |
| import base64 | |
| import cv2 | |
| def parse_args(args): | |
| parser = argparse.ArgumentParser(description="PixDLM Model Training") | |
| parser.add_argument("--local_rank", default=0, type=int, help="node rank") | |
| parser.add_argument( | |
| "--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview" | |
| ) | |
| parser.add_argument("--vis_save_path", default="./vis_output", type=str) | |
| parser.add_argument( | |
| "--precision", | |
| default="bf16", | |
| type=str, | |
| choices=["fp32", "bf16", "fp16"], | |
| help="precision for inference", | |
| ) | |
| parser.add_argument("--image_size", default=1024, type=int, help="image size") | |
| parser.add_argument("--model_max_length", default=512, type=int) | |
| parser.add_argument("--lora_r", default=8, type=int) | |
| parser.add_argument( | |
| "--vision-tower", default="openai/clip-vit-large-patch14", type=str | |
| ) | |
| parser.add_argument("--load_in_8bit", action="store_true", default=False) | |
| parser.add_argument("--load_in_4bit", action="store_true", default=False) | |
| parser.add_argument( | |
| "--dataset", default="sem_seg||refer_seg||vqa||reason_seg", type=str | |
| ) | |
| parser.add_argument("--sample_rates", default="9,3,3,1", type=str) | |
| parser.add_argument( | |
| "--sem_seg_data", | |
| default="ade20k||cocostuff||pascal_part||paco_lvis||mapillary", | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str | |
| ) | |
| parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str) | |
| parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str) | |
| parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str) | |
| parser.add_argument("--dataset_dir", default="./dataset", type=str) | |
| parser.add_argument("--log_base_dir", default="./runs", type=str) | |
| parser.add_argument("--exp_name", default="pixdlm", type=str) | |
| parser.add_argument("--epochs", default=5, type=int) | |
| parser.add_argument("--steps_per_epoch", default=200, type=int) | |
| parser.add_argument( | |
| "--batch_size", default=2, type=int, help="batch size per device per step" | |
| ) | |
| parser.add_argument( | |
| "--grad_accumulation_steps", | |
| default=10, | |
| type=int, | |
| ) | |
| parser.add_argument("--val_batch_size", default=1, type=int) | |
| parser.add_argument("--workers", default=4, type=int) | |
| parser.add_argument("--lr", default=0.0003, type=float) | |
| parser.add_argument("--ce_loss_weight", default=1.0, type=float) | |
| parser.add_argument("--dice_loss_weight", default=0.5, type=float) | |
| parser.add_argument("--bce_loss_weight", default=2.0, type=float) | |
| parser.add_argument("--lora_alpha", default=16, type=int) | |
| parser.add_argument("--lora_dropout", default=0.05, type=float) | |
| parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str) | |
| parser.add_argument("--explanatory", default=0.1, type=float) | |
| parser.add_argument("--beta1", default=0.9, type=float) | |
| parser.add_argument("--beta2", default=0.95, type=float) | |
| parser.add_argument("--num_classes_per_sample", default=3, type=int) | |
| parser.add_argument("--exclude_val", action="store_true", default=False) | |
| parser.add_argument("--no_eval", action="store_true", default=False) | |
| parser.add_argument("--eval_only", action="store_true", default=False) | |
| parser.add_argument("--vision_pretrained", default="", type=str) | |
| parser.add_argument("--out_dim", default=256, type=int) | |
| parser.add_argument("--resume", default="", type=str) | |
| parser.add_argument("--print_freq", default=1, type=int) | |
| parser.add_argument("--start_epoch", default=0, type=int) | |
| parser.add_argument("--gradient_checkpointing", action="store_true", default=True) | |
| parser.add_argument("--train_mask_decoder", action="store_true", default=True) | |
| parser.add_argument("--use_mm_start_end", action="store_true", default=True) | |
| parser.add_argument("--auto_resume", action="store_true", default=True) | |
| parser.add_argument("--seg_token_num", default=1, type=int) | |
| parser.add_argument("--num_classes_per_question", default=1, type=int) | |
| parser.add_argument("--pad_train_clip_images", action="store_true", default=False) | |
| parser.add_argument("--masks_process_with_clip", default=False, action="store_true") | |
| parser.add_argument("--preprocessor_config", default='', type=str) | |
| parser.add_argument("--resize_vision_tower", action="store_true", default=False) | |
| parser.add_argument("--resize_vision_tower_size", default=224, type=int) | |
| parser.add_argument("--vision_tower_for_mask", action="store_true", default=False) | |
| parser.add_argument("--weight", default="", type=str) | |
| parser.add_argument("--use_expand_question_list", action="store_true", default=False) | |
| parser.add_argument("--separate_mm_projector", action="store_true", default=False) | |
| parser.add_argument("--image_feature_scale_num", default=1, type=int) | |
| parser.add_argument("--Three_Level_Multi_Scale_Decoder", action="store_true", default=False) | |
| parser.add_argument( | |
| "--conv_type", | |
| default="llava_v1", | |
| type=str, | |
| choices=["llava_v1", "llava_llama_2"], | |
| ) | |
| parser.add_argument("--is_multipath_encoder", action="store_true", default=False) | |
| parser.add_argument("--sam2_config", default='./sam2/configs/sam2.1/sam2.1_hiera_l.yaml', type=str) | |
| parser.add_argument("--freeze_vision", action="store_true", default=False) | |
| return parser.parse_args(args) | |
| def get_language_backbone(model): | |
| module = getattr(model, "module", model) | |
| candidate = getattr(module, "model", module) | |
| return getattr(candidate, "model", candidate) | |
| def _safe_name(name): | |
| return "".join(c if c.isalnum() or c in "._-" else "_" for c in name) | |
| def _first_text(value): | |
| if value is None: | |
| return None | |
| if isinstance(value, (list, tuple)): | |
| return _first_text(value[0]) if value else None | |
| return str(value) | |
| def _mask_union(mask_tensor): | |
| arr = mask_tensor.detach().float().cpu().numpy() | |
| if arr.ndim == 0: | |
| arr = arr.reshape(1, 1) | |
| if arr.ndim == 3: | |
| arr = arr.max(axis=0) | |
| elif arr.ndim > 3: | |
| arr = arr.max(axis=tuple(range(arr.ndim - 2))) | |
| return (arr > 0).astype(np.uint8) | |
| def save_eval_artifacts(args, input_dict, dataset_name, cot_type, output_list, masks_list, | |
| question_text, condition_text, answer_text, per_image_ciou, | |
| per_image_giou): | |
| if getattr(args, "local_rank", 0) != 0: | |
| return | |
| root = args.vis_save_path | |
| if not os.path.isabs(root): | |
| root = os.path.join(args.log_dir, root) | |
| save_dir = os.path.join(root, _safe_name(dataset_name), cot_type) | |
| os.makedirs(save_dir, exist_ok=True) | |
| image_path = input_dict["image_paths"][0] | |
| image = cv2.imread(image_path) | |
| if image is None: | |
| return | |
| base = _safe_name(os.path.splitext(os.path.basename(image_path))[0]) | |
| input_path = os.path.join(save_dir, base + "_input.jpg") | |
| pred_path = os.path.join(save_dir, base + "_pred_mask.png") | |
| gt_path = os.path.join(save_dir, base + "_gt_mask.png") | |
| overlay_path = os.path.join(save_dir, base + "_overlay_pred_red_gt_green.jpg") | |
| result_path = os.path.join(save_dir, base + "_result.json") | |
| pred_mask = _mask_union(output_list) | |
| gt_mask = _mask_union(masks_list) | |
| height, width = image.shape[:2] | |
| if pred_mask.shape[:2] != (height, width): | |
| pred_mask = cv2.resize(pred_mask, (width, height), interpolation=cv2.INTER_NEAREST) | |
| if gt_mask.shape[:2] != (height, width): | |
| gt_mask = cv2.resize(gt_mask, (width, height), interpolation=cv2.INTER_NEAREST) | |
| overlay = image.copy() | |
| gt_pixels = gt_mask > 0 | |
| pred_pixels = pred_mask > 0 | |
| overlay[gt_pixels] = (0.55 * overlay[gt_pixels] + 0.45 * np.array([0, 255, 0])).astype(np.uint8) | |
| overlay[pred_pixels] = (0.55 * overlay[pred_pixels] + 0.45 * np.array([0, 0, 255])).astype(np.uint8) | |
| overlap = gt_pixels & pred_pixels | |
| overlay[overlap] = (0.35 * overlay[overlap] + 0.65 * np.array([0, 255, 255])).astype(np.uint8) | |
| cv2.imwrite(input_path, image) | |
| cv2.imwrite(pred_path, pred_mask * 255) | |
| cv2.imwrite(gt_path, gt_mask * 255) | |
| cv2.imwrite(overlay_path, overlay) | |
| result = { | |
| "dataset": dataset_name, | |
| "cot_type": cot_type, | |
| "image": image_path, | |
| "question": _first_text(question_text), | |
| "answer": _first_text(answer_text), | |
| "conditioning_text": _first_text(condition_text), | |
| "metrics": { | |
| "cIoU": float(per_image_ciou), | |
| "gIoU": float(per_image_giou), | |
| }, | |
| "artifacts": { | |
| "input": input_path, | |
| "pred_mask": pred_path, | |
| "gt_mask": gt_path, | |
| "overlay": overlay_path, | |
| }, | |
| } | |
| with open(result_path, "w", encoding="utf-8") as f: | |
| json.dump(result, f, ensure_ascii=False, indent=2) | |
| print("Saved eval artifact:", result_path) | |
| def main(args): | |
| args = parse_args(args) | |
| args.log_dir = os.path.join(args.log_base_dir, args.exp_name) | |
| if args.local_rank == 0: | |
| os.makedirs(args.log_dir, exist_ok=True) | |
| writer = SummaryWriter(args.log_dir) | |
| log_filename = os.path.join(args.log_dir, 'meta.log') | |
| i = 1 | |
| while os.path.exists(log_filename): | |
| log_filename = os.path.join(args.log_dir, 'meta_{}.log'.format(str(i))) | |
| i += 1 | |
| logger = logging.getLogger('pixdlm_logger') | |
| logger.setLevel(logging.INFO) | |
| file_handler = logging.FileHandler(log_filename) | |
| formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
| file_handler.setFormatter(formatter) | |
| logger.addHandler(file_handler) | |
| logger.info(args) | |
| else: | |
| writer = None | |
| logger = None | |
| tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| args.version, | |
| cache_dir=None, | |
| model_max_length=args.model_max_length, | |
| padding_side="right", | |
| use_fast=False, | |
| legacy=True | |
| ) | |
| tokenizer.pad_token = tokenizer.unk_token | |
| if args.seg_token_num*args.image_feature_scale_num == 1: | |
| num_added_tokens = tokenizer.add_tokens("[SEG]") | |
| args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] | |
| else: | |
| new_tokens = ["[SEG{}]".format(i) for i in range(args.seg_token_num*args.image_feature_scale_num)] | |
| num_added_tokens = tokenizer.add_tokens(new_tokens) | |
| args.seg_token_idx = [tokenizer(token, add_special_tokens=False).input_ids[0] for token in new_tokens] | |
| num_added_tokens_think = tokenizer.add_tokens(["<think>", "</think>", "<answer>", "</answer>"]) | |
| if args.use_mm_start_end: | |
| tokenizer.add_tokens( | |
| [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True | |
| ) | |
| model_args = { | |
| "train_mask_decoder": args.train_mask_decoder, | |
| "out_dim": args.out_dim, | |
| "ce_loss_weight": args.ce_loss_weight, | |
| "dice_loss_weight": args.dice_loss_weight, | |
| "bce_loss_weight": args.bce_loss_weight, | |
| "seg_token_idx": args.seg_token_idx, | |
| "vision_pretrained": args.vision_pretrained, | |
| "vision_tower": args.vision_tower, | |
| "use_mm_start_end": args.use_mm_start_end, | |
| "seg_token_num": args.seg_token_num, | |
| "logger": logger, | |
| "tokenizer": tokenizer, | |
| "local_rank": args.local_rank, | |
| "pad_train_clip_images": args.pad_train_clip_images, | |
| "resize_vision_tower": args.resize_vision_tower, | |
| "resize_vision_tower_size": args.resize_vision_tower_size, | |
| "vision_tower_for_mask": args.vision_tower_for_mask, | |
| "separate_mm_projector": args.separate_mm_projector, | |
| "masks_process_with_clip": args.masks_process_with_clip, | |
| "image_feature_scale_num": args.image_feature_scale_num, | |
| "three_level_multi_scale_decoder": args.Three_Level_Multi_Scale_Decoder, | |
| "is_multipath_encoder": args.is_multipath_encoder, | |
| "sam2_config": args.sam2_config, | |
| "freeze_vision":args.freeze_vision | |
| } | |
| torch_dtype = torch.float32 | |
| if args.precision == "bf16": | |
| torch_dtype = torch.bfloat16 | |
| elif args.precision == "fp16": | |
| torch_dtype = torch.half | |
| ignore_mismatched_sizes = args.separate_mm_projector | |
| model = PixDLMForCausalLM.from_pretrained( | |
| args.version, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=True, | |
| ignore_mismatched_sizes=ignore_mismatched_sizes, | |
| **model_args, | |
| ) | |
| model.config.eos_token_id = tokenizer.eos_token_id | |
| model.config.bos_token_id = tokenizer.bos_token_id | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| model.enable_input_require_grads() | |
| model.gradient_checkpointing_enable() | |
| model.get_model().initialize_vision_modules(model.get_model().config) | |
| vision_tower = model.get_model().get_vision_tower() | |
| vision_tower.to(dtype=torch_dtype, device=args.local_rank) | |
| for p in vision_tower.parameters(): | |
| p.requires_grad = False | |
| if args.resize_vision_tower_size == 224: | |
| for p in model.get_model().mm_projector.parameters(): | |
| p.requires_grad = False | |
| conversation_lib.default_conversation = conversation_lib.conv_templates[ | |
| args.conv_type | |
| ] | |
| lora_r = args.lora_r | |
| if lora_r > 0: | |
| def find_linear_layers(model, lora_target_modules): | |
| cls = torch.nn.Linear | |
| lora_module_names = set() | |
| for name, module in model.named_modules(): | |
| if ( | |
| isinstance(module, cls) | |
| and all( | |
| [ | |
| x not in name | |
| for x in [ | |
| "visual_model", | |
| "vision_tower", | |
| "mm_projector", | |
| "text_hidden_fcs", | |
| "mask_decoder", | |
| "image_feature_neck", | |
| "prompt_encoder", | |
| ] | |
| ] | |
| ) | |
| and any([x in name for x in lora_target_modules]) | |
| ): | |
| lora_module_names.add(name) | |
| return sorted(list(lora_module_names)) | |
| lora_alpha = args.lora_alpha | |
| lora_dropout = args.lora_dropout | |
| lora_target_modules = find_linear_layers( | |
| model, args.lora_target_modules.split(",") | |
| ) | |
| lora_config = LoraConfig( | |
| r=lora_r, | |
| lora_alpha=lora_alpha, | |
| target_modules=lora_target_modules, | |
| lora_dropout=lora_dropout, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| model.print_trainable_parameters() | |
| model.resize_token_embeddings(len(tokenizer)) | |
| if args.weight: | |
| state_dict = torch.load(args.weight, map_location="cpu") | |
| model.load_state_dict(state_dict, strict=False) | |
| trainable_list = ["lm_head", "embed_tokens", "mask_decoder", "text_hidden_fcs", "sam_to_embed_conv", "prompt_encoder", "image_feature_neck"] | |
| if args.resize_vision_tower_size != 224: | |
| trainable_list.append('mm_projector') | |
| for n, p in model.named_parameters(): | |
| if any( | |
| [ | |
| x in n | |
| for x in trainable_list | |
| ] | |
| ): | |
| p.requires_grad = True | |
| world_size = torch.cuda.device_count() | |
| args.distributed = world_size > 1 | |
| train_dataset = HybridDataset( | |
| args.dataset_dir, | |
| tokenizer, | |
| args.vision_tower, | |
| samples_per_epoch=args.batch_size | |
| * args.grad_accumulation_steps | |
| * args.steps_per_epoch | |
| * world_size, | |
| precision=args.precision, | |
| image_size=args.image_size, | |
| num_classes_per_sample=args.num_classes_per_sample, | |
| exclude_val=args.exclude_val, | |
| dataset=args.dataset, | |
| sample_rate=[float(x) for x in args.sample_rates.split(",")], | |
| sem_seg_data=args.sem_seg_data, | |
| refer_seg_data=args.refer_seg_data, | |
| vqa_data=args.vqa_data, | |
| reason_seg_data=args.reason_seg_data, | |
| explanatory=args.explanatory, | |
| seg_token_num=args.seg_token_num*args.image_feature_scale_num, | |
| num_classes_per_question=args.num_classes_per_question, | |
| pad_train_clip_images=args.pad_train_clip_images, | |
| masks_process_with_clip=args.masks_process_with_clip, | |
| preprocessor_config=args.preprocessor_config, | |
| use_expand_question_list=args.use_expand_question_list, | |
| ) | |
| print("____seg_token_num in data:________: ", args.seg_token_num*args.image_feature_scale_num) | |
| multi_val = False | |
| if args.no_eval == False: | |
| token_num = args.seg_token_num*args.image_feature_scale_num | |
| if len(args.val_dataset.split('||')) == 1: | |
| if args.val_dataset.split('|')[0] == 'MultiReasonSeg': | |
| ValDataset_type = MultiReasonSegValDataset | |
| else: | |
| ValDataset_type = ValDataset | |
| val_dataset_names = [args.val_dataset] | |
| val_dataset = ValDataset_type( | |
| args.dataset_dir, | |
| tokenizer, | |
| args.vision_tower, | |
| args.val_dataset, | |
| args.image_size, | |
| seg_token_num=token_num, | |
| pad_val_clip_images=args.pad_train_clip_images, | |
| masks_process_with_clip=args.masks_process_with_clip, | |
| preprocessor_config=args.preprocessor_config, | |
| ) | |
| print( | |
| f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples." | |
| ) | |
| else: | |
| multi_val = True | |
| val_dataset_names = args.val_dataset.split('||') | |
| val_dataset = [] | |
| for val_dataset_name in val_dataset_names: | |
| if val_dataset_name.split('|')[0] == 'MultiReasonSeg': | |
| ValDataset_type = MultiReasonSegValDataset | |
| else: | |
| ValDataset_type = ValDataset | |
| val_dataset.append( | |
| ValDataset_type( | |
| args.dataset_dir, | |
| tokenizer, | |
| args.vision_tower, | |
| val_dataset_name, | |
| args.image_size, | |
| seg_token_num=token_num, | |
| pad_val_clip_images=args.pad_train_clip_images, | |
| masks_process_with_clip=args.masks_process_with_clip, | |
| preprocessor_config=args.preprocessor_config, | |
| ) | |
| ) | |
| else: | |
| val_dataset = None | |
| print(f"Training with {len(train_dataset)} examples.") | |
| ds_config = { | |
| "train_micro_batch_size_per_gpu": args.batch_size, | |
| "gradient_accumulation_steps": args.grad_accumulation_steps, | |
| "optimizer": { | |
| "type": "AdamW", | |
| "params": { | |
| "lr": args.lr, | |
| "weight_decay": 0.0, | |
| "betas": (args.beta1, args.beta2), | |
| }, | |
| }, | |
| "scheduler": { | |
| "type": "WarmupDecayLR", | |
| "params": { | |
| "total_num_steps": args.epochs * args.steps_per_epoch, | |
| "warmup_min_lr": 0, | |
| "warmup_max_lr": args.lr, | |
| "warmup_num_steps": 100, | |
| "warmup_type": "linear", | |
| }, | |
| }, | |
| "fp16": { | |
| "enabled": args.precision == "fp16", | |
| }, | |
| "bf16": { | |
| "enabled": args.precision == "bf16", | |
| }, | |
| "gradient_clipping": 1.0, | |
| "zero_optimization": { | |
| "stage": 2, | |
| "contiguous_gradients": True, | |
| "overlap_comm": True, | |
| "reduce_scatter": True, | |
| "reduce_bucket_size": 5e8, | |
| "allgather_bucket_size": 5e8, | |
| }, | |
| } | |
| model_engine, optimizer, train_loader, scheduler = deepspeed.initialize( | |
| model=model, | |
| model_parameters=model.parameters(), | |
| training_data=train_dataset, | |
| collate_fn=partial( | |
| collate_fn, | |
| tokenizer=tokenizer, | |
| conv_type=args.conv_type, | |
| use_mm_start_end=args.use_mm_start_end, | |
| local_rank=args.local_rank, | |
| ), | |
| config=ds_config, | |
| ) | |
| if args.auto_resume and len(args.resume) == 0: | |
| resume = os.path.join(args.log_dir, "ckpt_model") | |
| if os.path.exists(resume): | |
| args.resume = resume | |
| if args.resume: | |
| load_path, client_state = model_engine.load_checkpoint(args.resume) | |
| with open(os.path.join(args.resume, "latest"), "r") as f: | |
| ckpt_dir = f.readlines()[0].strip() | |
| args.start_epoch = ( | |
| int(ckpt_dir.replace("global_step", "")) // args.steps_per_epoch | |
| ) | |
| print( | |
| "resume training from {}, start from epoch {}".format( | |
| args.resume, args.start_epoch | |
| ) | |
| ) | |
| if val_dataset is not None: | |
| assert args.val_batch_size == 1 | |
| if multi_val: | |
| val_sampler = [torch.utils.data.distributed.DistributedSampler( | |
| dataset, shuffle=False, drop_last=False | |
| ) for dataset in val_dataset] | |
| val_loader = [torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=args.val_batch_size, | |
| shuffle=False, | |
| num_workers=args.workers, | |
| pin_memory=False, | |
| sampler=sampler, | |
| collate_fn=partial( | |
| collate_fn, | |
| tokenizer=tokenizer, | |
| conv_type=args.conv_type, | |
| use_mm_start_end=args.use_mm_start_end, | |
| local_rank=args.local_rank, | |
| ), | |
| ) for dataset, sampler in zip(val_dataset, val_sampler)] | |
| else: | |
| val_sampler = torch.utils.data.distributed.DistributedSampler( | |
| val_dataset, shuffle=False, drop_last=False | |
| ) | |
| val_loader = torch.utils.data.DataLoader( | |
| val_dataset, | |
| batch_size=args.val_batch_size, | |
| shuffle=False, | |
| num_workers=args.workers, | |
| pin_memory=False, | |
| sampler=val_sampler, | |
| collate_fn=partial( | |
| collate_fn, | |
| tokenizer=tokenizer, | |
| conv_type=args.conv_type, | |
| use_mm_start_end=args.use_mm_start_end, | |
| local_rank=args.local_rank, | |
| ), | |
| ) | |
| train_iter = iter(train_loader) | |
| best_score, cur_ciou = 0.0, 0.0 | |
| if args.eval_only: | |
| if args.val_dataset.split('|')[0] == 'MultiReasonSeg': | |
| ar_validate(val_loader, model_engine, 0, writer, args, logger, val_dataset_names, tokenizer, args.seg_token_num, args.image_feature_scale_num) | |
| else: | |
| giou, ciou = validate(val_loader, model_engine, 0, writer, args, logger, val_dataset_names,tokenizer) | |
| print(giou,ciou) | |
| exit() | |
| for epoch in range(args.start_epoch, args.epochs): | |
| train_iter = train( | |
| train_loader, | |
| model_engine, | |
| epoch, | |
| scheduler, | |
| writer, | |
| train_iter, | |
| args, | |
| tokenizer, | |
| ) | |
| if args.no_eval == False: | |
| giou, ciou = validate(val_loader, model_engine, epoch, writer, args, logger, val_dataset_names,tokenizer) | |
| is_best = giou > best_score | |
| best_score = max(giou, best_score) | |
| cur_ciou = ciou if is_best else cur_ciou | |
| if args.no_eval or is_best: | |
| save_dir = os.path.join(args.log_dir, "best_ckpt_model") | |
| if args.local_rank == 0: | |
| torch.save( | |
| {"epoch": epoch}, | |
| os.path.join( | |
| args.log_dir, | |
| "meta_log_giou{:.3f}_ciou{:.3f}.pth".format( | |
| best_score, cur_ciou | |
| ), | |
| ), | |
| ) | |
| if os.path.exists(save_dir): | |
| shutil.rmtree(save_dir) | |
| torch.distributed.barrier() | |
| model_engine.save_checkpoint(save_dir) | |
| save_dir = os.path.join(args.log_dir, "ckpt_model") | |
| if args.local_rank == 0: | |
| if os.path.exists(save_dir): | |
| shutil.rmtree(save_dir) | |
| torch.distributed.barrier() | |
| model_engine.save_checkpoint(save_dir) | |
| def train( | |
| train_loader, | |
| model, | |
| epoch, | |
| scheduler, | |
| writer, | |
| train_iter, | |
| args, | |
| tokenizer, | |
| ): | |
| """Main training loop.""" | |
| batch_time = AverageMeter("Time", ":6.3f") | |
| data_time = AverageMeter("Data", ":6.3f") | |
| losses = AverageMeter("Loss", ":.4f") | |
| ce_losses = AverageMeter("CeLoss", ":.4f") | |
| mask_bce_losses = AverageMeter("MaskBCELoss", ":.4f") | |
| mask_dice_losses = AverageMeter("MaskDICELoss", ":.4f") | |
| mask_losses = AverageMeter("MaskLoss", ":.4f") | |
| progress = ProgressMeter( | |
| args.steps_per_epoch, | |
| [ | |
| batch_time, | |
| losses, | |
| ce_losses, | |
| mask_losses, | |
| mask_bce_losses, | |
| mask_dice_losses, | |
| ], | |
| prefix="Epoch: [{}]".format(epoch), | |
| ) | |
| model.train() | |
| end = time.time() | |
| for global_step in range(args.steps_per_epoch): | |
| for i in range(args.grad_accumulation_steps): | |
| try: | |
| input_dict = next(train_iter) | |
| except: | |
| train_iter = iter(train_loader) | |
| input_dict = next(train_iter) | |
| data_time.update(time.time() - end) | |
| texts = [] | |
| for cls_group in input_dict["sampled_classes_list"]: | |
| if isinstance(cls_group, list) and isinstance(cls_group[0], list): | |
| text = " ".join(cls_group[0]) | |
| elif isinstance(cls_group, list): | |
| text = " ".join(cls_group) | |
| else: | |
| text = str(cls_group) | |
| texts.append(text) | |
| input_ids_list = [] | |
| for text in texts: | |
| input_ids = tokenizer_image_token( | |
| text, | |
| tokenizer, | |
| image_token_index=-200, | |
| return_tensors="pt" | |
| ) | |
| input_ids_list.append(input_ids) | |
| input_ids = torch.nn.utils.rnn.pad_sequence( | |
| input_ids_list, | |
| batch_first=True, | |
| padding_value=tokenizer.pad_token_id | |
| ) | |
| attention_mask = input_ids.ne(tokenizer.pad_token_id) | |
| input_ids = input_ids.cuda() | |
| attention_mask = attention_mask.cuda() | |
| with torch.no_grad(): | |
| outputs = get_language_backbone(model)( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| output_hidden_states=True | |
| ) | |
| embeddings = outputs.hidden_states[-1] | |
| text_embeddings = [] | |
| for i in range(len(texts)): | |
| valid_mask = attention_mask[i] | |
| valid_embeddings = embeddings[i][valid_mask] | |
| text_embeddings.append(valid_embeddings) | |
| text_embeddings = torch.nn.utils.rnn.pad_sequence( | |
| text_embeddings, | |
| batch_first=True, | |
| padding_value=0.0 | |
| ) | |
| input_dict["txt_feat"] =text_embeddings | |
| input_dict = dict_to_cuda(input_dict) | |
| if args.precision == "fp16": | |
| input_dict["images"] = input_dict["images"].half() | |
| input_dict["images_clip"] = input_dict["images_clip"].half() | |
| elif args.precision == "bf16": | |
| input_dict["images"] = input_dict["images"].bfloat16() | |
| input_dict["images_clip"] = input_dict["images_clip"].bfloat16() | |
| else: | |
| input_dict["images"] = input_dict["images"].float() | |
| input_dict["images_clip"] = input_dict["images_clip"].float() | |
| output_dict = model(**input_dict) | |
| loss = output_dict["loss"] | |
| ce_loss = output_dict["ce_loss"] | |
| mask_bce_loss = output_dict["mask_bce_loss"] | |
| mask_dice_loss = output_dict["mask_dice_loss"] | |
| mask_loss = output_dict["mask_loss"] | |
| losses.update(loss.item(), input_dict["images"].size(0)) | |
| ce_losses.update(ce_loss.item(), input_dict["images"].size(0)) | |
| mask_bce_losses.update(mask_bce_loss.item(), input_dict["images"].size(0)) | |
| mask_dice_losses.update(mask_dice_loss.item(), input_dict["images"].size(0)) | |
| mask_losses.update(mask_loss.item(), input_dict["images"].size(0)) | |
| model.backward(loss) | |
| model.step() | |
| batch_time.update(time.time() - end) | |
| end = time.time() | |
| if global_step % args.print_freq == 0: | |
| if args.distributed: | |
| batch_time.all_reduce() | |
| data_time.all_reduce() | |
| losses.all_reduce() | |
| ce_losses.all_reduce() | |
| mask_bce_losses.all_reduce() | |
| mask_dice_losses.all_reduce() | |
| mask_losses.all_reduce() | |
| if args.local_rank == 0: | |
| progress.display(global_step + 1) | |
| writer.add_scalar("train/loss", losses.avg, global_step) | |
| writer.add_scalar("train/ce_loss", ce_losses.avg, global_step) | |
| writer.add_scalar( | |
| "train/mask_bce_loss", mask_bce_losses.avg, global_step | |
| ) | |
| writer.add_scalar( | |
| "train/mask_dice_loss", mask_dice_losses.avg, global_step | |
| ) | |
| writer.add_scalar("train/mask_loss", mask_losses.avg, global_step) | |
| writer.add_scalar( | |
| "metrics/total_secs_per_batch", batch_time.avg, global_step | |
| ) | |
| writer.add_scalar( | |
| "metrics/data_secs_per_batch", data_time.avg, global_step | |
| ) | |
| batch_time.reset() | |
| data_time.reset() | |
| losses.reset() | |
| ce_losses.reset() | |
| mask_bce_losses.reset() | |
| mask_dice_losses.reset() | |
| mask_losses.reset() | |
| if global_step != 0: | |
| curr_lr = scheduler.get_last_lr() | |
| if args.local_rank == 0: | |
| writer.add_scalar("train/lr", curr_lr[0], global_step) | |
| return train_iter | |
| def ar_validate(val_loader, model_engine, epoch, writer, args, logger, val_dataset_names, tokenizer, seg_token_num=1, image_feature_scale_num=1): | |
| pred_file = [] | |
| acc_iou_list = [] | |
| log_dir = args.log_dir | |
| out_file = os.path.join(log_dir, 'out_file_{}.json'.format(args.local_rank)) | |
| acc_iou_out_file = os.path.join(log_dir, 'acc_list_{}.json'.format(args.local_rank)) | |
| model_engine.eval() | |
| if not isinstance(val_loader, list): | |
| val_loader = [val_loader] | |
| assert len(val_dataset_names) == len(val_loader) | |
| k = 0 | |
| for loader, dataset_name in zip(val_loader, val_dataset_names): | |
| intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM) | |
| union_meter = AverageMeter("Union", ":6.3f", Summary.SUM) | |
| acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM) | |
| for input_dict in tqdm.tqdm(loader): | |
| image_pred = {} | |
| image_pred['answers'] = [] | |
| image_pred['question_gt_category_name'] = [] | |
| input_dict = dict_to_cuda(input_dict) | |
| if args.precision == "fp16": | |
| input_dict["images"] = input_dict["images"].half() | |
| input_dict["images_clip"] = input_dict["images_clip"].half() | |
| elif args.precision == "bf16": | |
| input_dict["images"] = input_dict["images"].bfloat16() | |
| input_dict["images_clip"] = input_dict["images_clip"].bfloat16() | |
| else: | |
| input_dict["images"] = input_dict["images"].float() | |
| input_dict["images_clip"] = input_dict["images_clip"].float() | |
| image_paths = input_dict['image_paths'] | |
| images = input_dict['images'] | |
| images_clip = input_dict['images_clip'] | |
| resize_list = input_dict['resize_list'] | |
| clip_resize_list = input_dict['clip_resize_list'] | |
| label_list = input_dict['label_list'] | |
| input_ids = input_dict['input_ids'] | |
| gt_masks = input_dict['masks_list'] | |
| questions_list = input_dict['questions_list'] | |
| original_size_list = [label.shape for label in label_list] | |
| if k == 0: | |
| model_engine(**input_dict) | |
| output_ids, pred_masks, batch_seg_token_counts, mask_scores = model_engine.base_model.evaluate(images_clip, images, input_ids, resize_list, clip_resize_list, original_size_list, max_new_tokens=512, tokenizer=tokenizer) | |
| text_outputs = [] | |
| for output_id in output_ids: | |
| _output_id = copy.deepcopy(output_id[0]) | |
| _output_id[_output_id==-200] = 31999 | |
| text_output = tokenizer.decode(_output_id, skip_special_tokens=False) | |
| text_output = ( | |
| text_output.replace(DEFAULT_IMAGE_PATCH_TOKEN, "") | |
| .replace("\n", "") | |
| .replace(" ", "") | |
| ) | |
| text_outputs.append(text_output) | |
| image_path = input_dict['image_paths'][0] | |
| print("idx:", k, "image_path:", input_dict['image_paths'][0], "text_output: ", text_outputs) | |
| k += 1 | |
| batch_seg_token_count = batch_seg_token_counts[0] | |
| batch_seg_token_count = batch_seg_token_count.cumsum(-1) | |
| batch_seg_token_count = torch.cat( | |
| [torch.zeros(1).long().cuda(), batch_seg_token_count], dim=0 | |
| ) | |
| pred_mask = pred_masks[0] | |
| gt_mask = gt_masks[0] | |
| mask_score = mask_scores[0] | |
| max_num = max(len(pred_masks[0]), len(gt_masks[0])) | |
| assigned_gt_masks = [] | |
| assigned_pred_masks = [] | |
| questions_list = input_dict['questions_list'] | |
| gt_target_count = questions_list[0][1] | |
| gt_category_name = questions_list[0][2] | |
| prompt_ins = questions_list[0][3] | |
| gt_target_count = torch.tensor(gt_target_count).to(batch_seg_token_count).cumsum(-1) | |
| gt_target_count = torch.cat( | |
| [torch.zeros(1).long().cuda(), gt_target_count], dim=0 | |
| ) | |
| assign_length = [] | |
| assign_indice = [] | |
| assign_acc = [] | |
| total_pred_count = [] | |
| pred_count = [] | |
| assert len(batch_seg_token_count) == len(gt_target_count) | |
| for j in range(len(batch_seg_token_count) -1): | |
| start_i = batch_seg_token_count[j] | |
| end_i = batch_seg_token_count[j+1] | |
| q_start_i = gt_target_count[j] | |
| q_end_i = gt_target_count[j+1] | |
| question_inputs = pred_mask[start_i:end_i] | |
| question_mask_scores = mask_score[start_i:end_i] | |
| question_targets = gt_mask[q_start_i:q_end_i] | |
| indice = match_pred(question_inputs.detach(), question_targets.detach()) | |
| assigned_pred_mask = pred_mask[start_i:end_i][indice[0]] | |
| assigned_pred_mask = (assigned_pred_mask > 0).int() | |
| assigned_gt_mask = gt_mask[q_start_i:q_end_i][indice[1]] | |
| unassugned_indice = [] | |
| unassugned_indice_pred = [] | |
| for i in range(len(gt_mask[q_start_i:q_end_i])): | |
| if i not in indice[1]: | |
| unassugned_indice.append(i) | |
| for i in range(len(pred_mask[start_i:end_i])): | |
| if i not in indice[0]: | |
| unassugned_indice_pred.append(i) | |
| unassugned_indice = np.array(unassugned_indice) | |
| unassugned_indice_pred = np.array(unassugned_indice_pred) | |
| unassigned_gt_mask = gt_mask[q_start_i:q_end_i][unassugned_indice] | |
| unassigned_pred = pred_mask[start_i:end_i][unassugned_indice_pred] | |
| empty_gt = torch.zeros_like(unassigned_pred) | |
| empty_pred = torch.zeros_like(unassigned_gt_mask) | |
| assigned_gt_mask = torch.cat((assigned_gt_mask, unassigned_gt_mask)) | |
| assigned_pred_mask = torch.cat((assigned_pred_mask, empty_pred)) | |
| assigned_gt_mask = torch.cat((assigned_gt_mask, empty_gt)) | |
| assigned_pred_mask = torch.cat((assigned_pred_mask, unassigned_pred)) | |
| assigned_gt_masks.append(assigned_gt_mask) | |
| assigned_pred_masks.append(assigned_pred_mask) | |
| question_gt_category_name = gt_category_name[j] | |
| text_output = text_outputs[j] | |
| sorted_id = sorted(range(len(indice[0])), key=lambda k: indice[0][k], reverse=False) | |
| sorted_gt_indice = indice[1][sorted_id] | |
| sorted_pred_indice = indice[0][sorted_id] | |
| seg_token = ' '.join(['[SEG{}]'.format(str(s)) for s in range(seg_token_num*image_feature_scale_num)]) if seg_token_num*image_feature_scale_num > 1 else '[SEG]' | |
| _text_output = text_output | |
| in_count = 0 | |
| question_gt_category_name_list = [] | |
| for count in range(text_output.count(seg_token)): | |
| if count in sorted_pred_indice: | |
| _text_output = _text_output.replace(seg_token, question_gt_category_name[sorted_gt_indice[in_count]], 1) | |
| question_gt_category_name_list.append(question_gt_category_name[sorted_gt_indice[in_count]][1:-1]) | |
| in_count += 1 | |
| else: | |
| question_gt_category_name_list.append('None []') | |
| _text_output = _text_output.replace(seg_token, '(None [])', 1) | |
| image_pred['image_path'] = input_dict['image_paths'][0] | |
| image_pred['questions'] = questions_list[0][0] | |
| answer = _text_output.split('ASSISTANT:')[-1] | |
| answer = answer.replace('<unk>', '') | |
| image_pred['answers'].append(answer) | |
| image_pred['question_gt_category_name'].append(question_gt_category_name_list) | |
| assign_length.extend([True]*len(indice[0])) | |
| assign_length.extend([False]*(len(assigned_gt_mask)-len(indice[0]))) | |
| assign_indice.append(indice[0].tolist()) | |
| total_pred_count.append(len(assigned_gt_mask)) | |
| pred_count.append(len(pred_mask[start_i:end_i])) | |
| assigned_gt_masks = torch.cat(assigned_gt_masks) | |
| output_list = torch.cat(assigned_pred_masks) | |
| intersection, union, acc_iou = 0.0, 0.0, 0.0 | |
| for mask_i, output_i, is_assign in zip(assigned_gt_masks, output_list, assign_length): | |
| intersection_i, union_i, _ = intersectionAndUnionGPU( | |
| output_i.contiguous().clone(), mask_i.contiguous(), 2, ignore_index=255 | |
| ) | |
| intersection += intersection_i | |
| union += union_i | |
| acc_iou += intersection_i / (union_i + 1e-5) | |
| acc_iou[union_i == 0] += 1.0 | |
| assign_acc.append((intersection_i.tolist(), union_i.tolist())) | |
| image_pred['assign_length'] = assign_length | |
| image_pred['assign_indice'] = assign_indice | |
| image_pred['assign_acc'] = assign_acc | |
| image_pred['total_pred_count'] = total_pred_count | |
| image_pred['pred_count'] = pred_count | |
| image_pred['prompt_ins'] = prompt_ins | |
| pred_file.append(image_pred) | |
| intersection, union = intersection.cpu().numpy(), union.cpu().numpy() | |
| acc_iou = acc_iou.cpu().numpy() / max_num | |
| intersection_meter.update(intersection), union_meter.update( | |
| union | |
| ), acc_iou_meter.update(acc_iou, n=max_num) | |
| print(acc_iou) | |
| _acc_iou = acc_iou.tolist() | |
| _acc_iou.append(max_num) | |
| _acc_iou.append(input_dict['image_paths'][0]) | |
| acc_iou_list.append(_acc_iou) | |
| intersection_meter.all_reduce() | |
| union_meter.all_reduce() | |
| acc_iou_meter.all_reduce() | |
| with open(acc_iou_out_file, 'w') as f: | |
| json.dump(acc_iou_list, f) | |
| with open(out_file, 'w') as f: | |
| json.dump(pred_file, f) | |
| iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) | |
| ciou = iou_class[1] | |
| giou = acc_iou_meter.avg[1] | |
| if args.local_rank == 0: | |
| writer.add_scalar("val/giou", giou, epoch) | |
| writer.add_scalar("val/ciou", ciou, epoch) | |
| print("{}, epoch: {}, giou: {:.4f}, ciou: {:.4f}".format(dataset_name, epoch, giou, ciou)) | |
| logger.info("{}, epoch: {}, giou: {:.4f}, ciou: {:.4f}".format(dataset_name, epoch, giou, ciou)) | |
| def validate(val_loader, model_engine, epoch, writer, args, logger, val_dataset_names,tokenizer): | |
| import time | |
| import re | |
| from collections import defaultdict | |
| model_engine.eval() | |
| if not isinstance(val_loader, list): | |
| val_loader = [val_loader] | |
| for loader, dataset_name in zip(val_loader, val_dataset_names): | |
| if 'NYU' in dataset_name: | |
| continue | |
| intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM) | |
| union_meter = AverageMeter("Union", ":6.3f", Summary.SUM) | |
| acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM) | |
| reasoning_type_meters = defaultdict(lambda: { | |
| 'intersection': AverageMeter("Intersec", ":6.3f", Summary.SUM), | |
| 'union': AverageMeter("Union", ":6.3f", Summary.SUM), | |
| 'acc_iou': AverageMeter("gIoU", ":6.3f", Summary.SUM), | |
| 'count': 0 | |
| }) | |
| correct_with_cot = 0 | |
| correct_without_cot = 0 | |
| total_samples = 0 | |
| cot_right_no_cot_wrong = 0 | |
| cot_wrong_no_cot_right = 0 | |
| total_time = 0 | |
| num_images = 0 | |
| for input_dict in tqdm.tqdm(loader): | |
| start_time = time.time() | |
| torch.cuda.empty_cache() | |
| input_dict = dict_to_cuda(input_dict) | |
| answers_list = input_dict.get("answers_list", [None] * len(input_dict["image_paths"])) | |
| answer_raw = answers_list[0] if len(answers_list) > 0 else None | |
| answer = None | |
| if answer_raw: | |
| if isinstance(answer_raw, list): | |
| answer = answer_raw[0] if len(answer_raw) > 0 else None | |
| elif isinstance(answer_raw, str): | |
| answer = answer_raw | |
| else: | |
| answer = str(answer_raw) | |
| texts_with_cot = [] | |
| texts_without_cot = [] | |
| for cls_group in input_dict["sampled_classes_list"]: | |
| if isinstance(cls_group, list) and isinstance(cls_group[0], list): | |
| text = " ".join(cls_group[0]) | |
| elif isinstance(cls_group, list): | |
| text = " ".join(cls_group) | |
| else: | |
| text = str(cls_group) | |
| text_with_cot = text | |
| if answer: | |
| text_with_cot = text + " " + str(answer) | |
| texts_with_cot.append(text_with_cot) | |
| text_without_cot = text | |
| if answer: | |
| answer_without_cot = re.sub(r'<think>.*?</think>', '', str(answer), flags=re.DOTALL) | |
| answer_without_cot = ' '.join(answer_without_cot.split()) | |
| text_without_cot = text + " " + answer_without_cot if answer_without_cot else text | |
| texts_without_cot.append(text_without_cot) | |
| miou_with_cot = None | |
| miou_without_cot = None | |
| for texts, cot_type in [(texts_with_cot, "with_cot"), (texts_without_cot, "without_cot")]: | |
| input_ids_list = [] | |
| for text in texts: | |
| input_ids = tokenizer_image_token( | |
| text, | |
| tokenizer, | |
| image_token_index=-200, | |
| return_tensors="pt" | |
| ) | |
| input_ids_list.append(input_ids) | |
| input_ids = torch.nn.utils.rnn.pad_sequence( | |
| input_ids_list, | |
| batch_first=True, | |
| padding_value=tokenizer.pad_token_id | |
| ) | |
| attention_mask = input_ids.ne(tokenizer.pad_token_id) | |
| input_ids = input_ids.cuda() | |
| attention_mask = attention_mask.cuda() | |
| with torch.no_grad(): | |
| outputs = get_language_backbone(model_engine)( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| output_hidden_states=True | |
| ) | |
| embeddings = outputs.hidden_states[-1] | |
| text_embeddings = [] | |
| for i in range(len(texts)): | |
| valid_mask = attention_mask[i] | |
| valid_embeddings = embeddings[i][valid_mask] | |
| text_embeddings.append(valid_embeddings) | |
| text_embeddings = torch.nn.utils.rnn.pad_sequence( | |
| text_embeddings, | |
| batch_first=True, | |
| padding_value=0.0 | |
| ) | |
| input_dict["txt_feat"] = text_embeddings | |
| input_dict["inference"] = True | |
| if args.precision == "fp16": | |
| input_dict["images"] = input_dict["images"].half() | |
| input_dict["images_clip"] = input_dict["images_clip"].half() | |
| elif args.precision == "bf16": | |
| input_dict["images"] = input_dict["images"].bfloat16() | |
| input_dict["images_clip"] = input_dict["images_clip"].bfloat16() | |
| else: | |
| input_dict["images"] = input_dict["images"].float() | |
| input_dict["images_clip"] = input_dict["images_clip"].float() | |
| with torch.no_grad(): | |
| output_dict = model_engine(**input_dict) | |
| pred_masks = output_dict["pred_masks"] | |
| masks_list = output_dict["gt_masks"][0].int() | |
| output_list = (pred_masks[0] > 0).int() | |
| assert len(pred_masks) == 1 | |
| intersection, union, acc_iou = 0.0, 0.0, 0.0 | |
| for mask_i, output_i in zip(masks_list, output_list): | |
| intersection_i, union_i, _ = intersectionAndUnionGPU( | |
| output_i.contiguous().clone(), mask_i.contiguous(), 2, ignore_index=255 | |
| ) | |
| intersection += intersection_i | |
| union += union_i | |
| acc_iou += intersection_i / (union_i + 1e-5) | |
| acc_iou[union_i == 0] += 1.0 | |
| intersection, union = intersection.cpu().numpy(), union.cpu().numpy() | |
| acc_iou = acc_iou.cpu().numpy() / masks_list.shape[0] | |
| per_image_miou = (intersection / (union + 1e-10)).mean() | |
| per_image_miou_percent = per_image_miou * 100 | |
| if cot_type == "with_cot": | |
| miou_with_cot = per_image_miou_percent | |
| else: | |
| miou_without_cot = per_image_miou_percent | |
| if cot_type == "with_cot": | |
| if per_image_miou_percent > 25: | |
| correct_with_cot += 1 | |
| else: | |
| if per_image_miou_percent > 25: | |
| correct_without_cot += 1 | |
| if cot_type == "with_cot": | |
| intersection_meter.update(intersection), union_meter.update( | |
| union | |
| ), acc_iou_meter.update(acc_iou, n=masks_list.shape[0]) | |
| per_image_ciou = intersection[1] / (union[1] + 1e-10) | |
| per_image_giou = acc_iou[1] | |
| per_image_acc = intersection.sum() / union.sum() | |
| categories = input_dict.get("categories", ["unknown"] * len(input_dict["image_paths"])) | |
| if isinstance(categories, list) and len(categories) > 0: | |
| reasoning_type = categories[0] if len(categories) > 0 else "unknown" | |
| reasoning_type_meters[reasoning_type]['intersection'].update(intersection) | |
| reasoning_type_meters[reasoning_type]['union'].update(union) | |
| reasoning_type_meters[reasoning_type]['acc_iou'].update(acc_iou, n=masks_list.shape[0]) | |
| reasoning_type_meters[reasoning_type]['count'] += 1 | |
| save_eval_artifacts( | |
| args=args, | |
| input_dict=input_dict, | |
| dataset_name=dataset_name, | |
| cot_type=cot_type, | |
| output_list=output_list, | |
| masks_list=masks_list, | |
| question_text=input_dict.get("questions_list", [[None]])[0], | |
| condition_text=texts[0] if len(texts) > 0 else None, | |
| answer_text=answer, | |
| per_image_ciou=per_image_ciou, | |
| per_image_giou=per_image_giou, | |
| ) | |
| if miou_with_cot is not None and miou_without_cot is not None: | |
| cot_right = miou_with_cot > 50 | |
| no_cot_right = miou_without_cot > 50 | |
| if cot_right and not no_cot_right: | |
| cot_right_no_cot_wrong += 1 | |
| if not cot_right and no_cot_right: | |
| cot_wrong_no_cot_right += 1 | |
| total_samples += 1 | |
| batch_time = time.time() - start_time | |
| total_time += batch_time | |
| num_images += 1 | |
| print(f"Image: {input_dict['image_paths'][0]}, cIoU: {per_image_ciou:.4f}, gIoU: {per_image_giou:.4f}, Time: {batch_time:.4f}s") | |
| avg_time = total_time / num_images if num_images > 0 else 0 | |
| fps = num_images / total_time if total_time > 0 else 0 | |
| print(f"\n{'='*50}") | |
| print(f"Validation Speed Statistics - {dataset_name}") | |
| print(f"{'='*50}") | |
| print(f"Total images: {num_images}") | |
| print(f"Total time: {total_time:.4f}s") | |
| print(f"Average time per image: {avg_time:.4f}s") | |
| print(f"FPS: {fps:.2f}") | |
| print(f"{'='*50}\n") | |
| intersection_meter.all_reduce() | |
| union_meter.all_reduce() | |
| acc_iou_meter.all_reduce() | |
| for reasoning_type in reasoning_type_meters.keys(): | |
| reasoning_type_meters[reasoning_type]['intersection'].all_reduce() | |
| reasoning_type_meters[reasoning_type]['union'].all_reduce() | |
| reasoning_type_meters[reasoning_type]['acc_iou'].all_reduce() | |
| if dist.is_initialized(): | |
| count_tensor = torch.tensor(reasoning_type_meters[reasoning_type]['count'], dtype=torch.long, device='cuda') | |
| dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM) | |
| reasoning_type_meters[reasoning_type]['count'] = count_tensor.item() | |
| iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) | |
| ciou = iou_class[1] | |
| giou = acc_iou_meter.avg[1] | |
| if args.local_rank == 0: | |
| writer.add_scalar("val/giou", giou, epoch) | |
| writer.add_scalar("val/ciou", ciou, epoch) | |
| logger.info("{}, epoch: {}, giou: {:.4f}, ciou: {:.4f}".format(dataset_name, epoch, giou, ciou)) | |
| print("giou: {:.4f}, ciou: {:.4f}".format(giou, ciou)) | |
| print(f"\n{'='*50}") | |
| print(f"CoT Comparison Results - {dataset_name}") | |
| print(f"{'='*50}") | |
| print(f"Total samples: {total_samples}") | |
| print(f"Correct with CoT (mIoU > 25): {correct_with_cot}/{total_samples} ({correct_with_cot/total_samples*100:.2f}%)") | |
| print(f"Correct without CoT (mIoU > 25): {correct_without_cot}/{total_samples} ({correct_without_cot/total_samples*100:.2f}%)") | |
| print(f"\nSpecial Cases (mIoU > 50 as threshold):") | |
| print(f" Cases where CoT is correct but no-CoT is wrong: {cot_right_no_cot_wrong}") | |
| print(f" Cases where CoT is wrong but no-CoT is correct: {cot_wrong_no_cot_right}") | |
| print(f"{'='*50}\n") | |
| logger.info("{}, CoT comparison: with_cot={}/{}, without_cot={}/{}".format( | |
| dataset_name, correct_with_cot, total_samples, correct_without_cot, total_samples)) | |
| logger.info("{}, Special cases (mIoU>50): cot_right_no_cot_wrong={}, cot_wrong_no_cot_right={}".format( | |
| dataset_name, cot_right_no_cot_wrong, cot_wrong_no_cot_right)) | |
| if reasoning_type_meters: | |
| print(f"\n{'='*50}") | |
| print(f"Results by Reasoning Type - {dataset_name}") | |
| print(f"{'='*50}") | |
| for reasoning_type in sorted(reasoning_type_meters.keys()): | |
| meters = reasoning_type_meters[reasoning_type] | |
| type_iou_class = meters['intersection'].sum / (meters['union'].sum + 1e-10) | |
| type_ciou = type_iou_class[1] | |
| type_giou = meters['acc_iou'].avg[1] | |
| type_count = meters['count'] | |
| print(f"{reasoning_type}: gIoU: {type_giou:.4f}, cIoU: {type_ciou:.4f}, Count: {type_count}") | |
| logger.info("{}, reasoning_type: {}, giou: {:.4f}, ciou: {:.4f}, count: {}".format( | |
| dataset_name, reasoning_type, type_giou, type_ciou, type_count)) | |
| print(f"{'='*50}\n") | |
| return giou, ciou | |
| if __name__ == "__main__": | |
| main(sys.argv[1:]) | |