import argparse
import os
import shutil
import sys
import time
from functools import partial
import logging
import deepspeed
import numpy as np
import torch
import tqdm
import transformers
import copy
from peft import LoraConfig, get_peft_model
from torch.utils.tensorboard import SummaryWriter
import torch.distributed as dist
from model.PixDLM import PixDLMForCausalLM
from model.llava import conversation as conversation_lib
from utils.dataset import HybridDataset, ValDataset, collate_fn
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
AverageMeter, ProgressMeter, Summary, dict_to_cuda,
intersectionAndUnionGPU)
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_PATCH_TOKEN)
from utils.matcher import match_pred
from utils.multi_reason_seg_val_dataset import MultiReasonSegValDataset
from model.llava.mm_utils import tokenizer_image_token
import requests
import json
import base64
import cv2
def parse_args(args):
parser = argparse.ArgumentParser(description="PixDLM Model Training")
parser.add_argument("--local_rank", default=0, type=int, help="node rank")
parser.add_argument(
"--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview"
)
parser.add_argument("--vis_save_path", default="./vis_output", type=str)
parser.add_argument(
"--precision",
default="bf16",
type=str,
choices=["fp32", "bf16", "fp16"],
help="precision for inference",
)
parser.add_argument("--image_size", default=1024, type=int, help="image size")
parser.add_argument("--model_max_length", default=512, type=int)
parser.add_argument("--lora_r", default=8, type=int)
parser.add_argument(
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
)
parser.add_argument("--load_in_8bit", action="store_true", default=False)
parser.add_argument("--load_in_4bit", action="store_true", default=False)
parser.add_argument(
"--dataset", default="sem_seg||refer_seg||vqa||reason_seg", type=str
)
parser.add_argument("--sample_rates", default="9,3,3,1", type=str)
parser.add_argument(
"--sem_seg_data",
default="ade20k||cocostuff||pascal_part||paco_lvis||mapillary",
type=str,
)
parser.add_argument(
"--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str
)
parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str)
parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str)
parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str)
parser.add_argument("--dataset_dir", default="./dataset", type=str)
parser.add_argument("--log_base_dir", default="./runs", type=str)
parser.add_argument("--exp_name", default="pixdlm", type=str)
parser.add_argument("--epochs", default=5, type=int)
parser.add_argument("--steps_per_epoch", default=200, type=int)
parser.add_argument(
"--batch_size", default=2, type=int, help="batch size per device per step"
)
parser.add_argument(
"--grad_accumulation_steps",
default=10,
type=int,
)
parser.add_argument("--val_batch_size", default=1, type=int)
parser.add_argument("--workers", default=4, type=int)
parser.add_argument("--lr", default=0.0003, type=float)
parser.add_argument("--ce_loss_weight", default=1.0, type=float)
parser.add_argument("--dice_loss_weight", default=0.5, type=float)
parser.add_argument("--bce_loss_weight", default=2.0, type=float)
parser.add_argument("--lora_alpha", default=16, type=int)
parser.add_argument("--lora_dropout", default=0.05, type=float)
parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str)
parser.add_argument("--explanatory", default=0.1, type=float)
parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.95, type=float)
parser.add_argument("--num_classes_per_sample", default=3, type=int)
parser.add_argument("--exclude_val", action="store_true", default=False)
parser.add_argument("--no_eval", action="store_true", default=False)
parser.add_argument("--eval_only", action="store_true", default=False)
parser.add_argument("--vision_pretrained", default="", type=str)
parser.add_argument("--out_dim", default=256, type=int)
parser.add_argument("--resume", default="", type=str)
parser.add_argument("--print_freq", default=1, type=int)
parser.add_argument("--start_epoch", default=0, type=int)
parser.add_argument("--gradient_checkpointing", action="store_true", default=True)
parser.add_argument("--train_mask_decoder", action="store_true", default=True)
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
parser.add_argument("--auto_resume", action="store_true", default=True)
parser.add_argument("--seg_token_num", default=1, type=int)
parser.add_argument("--num_classes_per_question", default=1, type=int)
parser.add_argument("--pad_train_clip_images", action="store_true", default=False)
parser.add_argument("--masks_process_with_clip", default=False, action="store_true")
parser.add_argument("--preprocessor_config", default='', type=str)
parser.add_argument("--resize_vision_tower", action="store_true", default=False)
parser.add_argument("--resize_vision_tower_size", default=224, type=int)
parser.add_argument("--vision_tower_for_mask", action="store_true", default=False)
parser.add_argument("--weight", default="", type=str)
parser.add_argument("--use_expand_question_list", action="store_true", default=False)
parser.add_argument("--separate_mm_projector", action="store_true", default=False)
parser.add_argument("--image_feature_scale_num", default=1, type=int)
parser.add_argument("--Three_Level_Multi_Scale_Decoder", action="store_true", default=False)
parser.add_argument(
"--conv_type",
default="llava_v1",
type=str,
choices=["llava_v1", "llava_llama_2"],
)
parser.add_argument("--is_multipath_encoder", action="store_true", default=False)
parser.add_argument("--sam2_config", default='./sam2/configs/sam2.1/sam2.1_hiera_l.yaml', type=str)
parser.add_argument("--freeze_vision", action="store_true", default=False)
return parser.parse_args(args)
def get_language_backbone(model):
module = getattr(model, "module", model)
candidate = getattr(module, "model", module)
return getattr(candidate, "model", candidate)
def _safe_name(name):
return "".join(c if c.isalnum() or c in "._-" else "_" for c in name)
def _first_text(value):
if value is None:
return None
if isinstance(value, (list, tuple)):
return _first_text(value[0]) if value else None
return str(value)
def _mask_union(mask_tensor):
arr = mask_tensor.detach().float().cpu().numpy()
if arr.ndim == 0:
arr = arr.reshape(1, 1)
if arr.ndim == 3:
arr = arr.max(axis=0)
elif arr.ndim > 3:
arr = arr.max(axis=tuple(range(arr.ndim - 2)))
return (arr > 0).astype(np.uint8)
def save_eval_artifacts(args, input_dict, dataset_name, cot_type, output_list, masks_list,
question_text, condition_text, answer_text, per_image_ciou,
per_image_giou):
if getattr(args, "local_rank", 0) != 0:
return
root = args.vis_save_path
if not os.path.isabs(root):
root = os.path.join(args.log_dir, root)
save_dir = os.path.join(root, _safe_name(dataset_name), cot_type)
os.makedirs(save_dir, exist_ok=True)
image_path = input_dict["image_paths"][0]
image = cv2.imread(image_path)
if image is None:
return
base = _safe_name(os.path.splitext(os.path.basename(image_path))[0])
input_path = os.path.join(save_dir, base + "_input.jpg")
pred_path = os.path.join(save_dir, base + "_pred_mask.png")
gt_path = os.path.join(save_dir, base + "_gt_mask.png")
overlay_path = os.path.join(save_dir, base + "_overlay_pred_red_gt_green.jpg")
result_path = os.path.join(save_dir, base + "_result.json")
pred_mask = _mask_union(output_list)
gt_mask = _mask_union(masks_list)
height, width = image.shape[:2]
if pred_mask.shape[:2] != (height, width):
pred_mask = cv2.resize(pred_mask, (width, height), interpolation=cv2.INTER_NEAREST)
if gt_mask.shape[:2] != (height, width):
gt_mask = cv2.resize(gt_mask, (width, height), interpolation=cv2.INTER_NEAREST)
overlay = image.copy()
gt_pixels = gt_mask > 0
pred_pixels = pred_mask > 0
overlay[gt_pixels] = (0.55 * overlay[gt_pixels] + 0.45 * np.array([0, 255, 0])).astype(np.uint8)
overlay[pred_pixels] = (0.55 * overlay[pred_pixels] + 0.45 * np.array([0, 0, 255])).astype(np.uint8)
overlap = gt_pixels & pred_pixels
overlay[overlap] = (0.35 * overlay[overlap] + 0.65 * np.array([0, 255, 255])).astype(np.uint8)
cv2.imwrite(input_path, image)
cv2.imwrite(pred_path, pred_mask * 255)
cv2.imwrite(gt_path, gt_mask * 255)
cv2.imwrite(overlay_path, overlay)
result = {
"dataset": dataset_name,
"cot_type": cot_type,
"image": image_path,
"question": _first_text(question_text),
"answer": _first_text(answer_text),
"conditioning_text": _first_text(condition_text),
"metrics": {
"cIoU": float(per_image_ciou),
"gIoU": float(per_image_giou),
},
"artifacts": {
"input": input_path,
"pred_mask": pred_path,
"gt_mask": gt_path,
"overlay": overlay_path,
},
}
with open(result_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print("Saved eval artifact:", result_path)
def main(args):
args = parse_args(args)
args.log_dir = os.path.join(args.log_base_dir, args.exp_name)
if args.local_rank == 0:
os.makedirs(args.log_dir, exist_ok=True)
writer = SummaryWriter(args.log_dir)
log_filename = os.path.join(args.log_dir, 'meta.log')
i = 1
while os.path.exists(log_filename):
log_filename = os.path.join(args.log_dir, 'meta_{}.log'.format(str(i)))
i += 1
logger = logging.getLogger('pixdlm_logger')
logger.setLevel(logging.INFO)
file_handler = logging.FileHandler(log_filename)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.info(args)
else:
writer = None
logger = None
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.version,
cache_dir=None,
model_max_length=args.model_max_length,
padding_side="right",
use_fast=False,
legacy=True
)
tokenizer.pad_token = tokenizer.unk_token
if args.seg_token_num*args.image_feature_scale_num == 1:
num_added_tokens = tokenizer.add_tokens("[SEG]")
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
else:
new_tokens = ["[SEG{}]".format(i) for i in range(args.seg_token_num*args.image_feature_scale_num)]
num_added_tokens = tokenizer.add_tokens(new_tokens)
args.seg_token_idx = [tokenizer(token, add_special_tokens=False).input_ids[0] for token in new_tokens]
num_added_tokens_think = tokenizer.add_tokens(["", "", "", ""])
if args.use_mm_start_end:
tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
)
model_args = {
"train_mask_decoder": args.train_mask_decoder,
"out_dim": args.out_dim,
"ce_loss_weight": args.ce_loss_weight,
"dice_loss_weight": args.dice_loss_weight,
"bce_loss_weight": args.bce_loss_weight,
"seg_token_idx": args.seg_token_idx,
"vision_pretrained": args.vision_pretrained,
"vision_tower": args.vision_tower,
"use_mm_start_end": args.use_mm_start_end,
"seg_token_num": args.seg_token_num,
"logger": logger,
"tokenizer": tokenizer,
"local_rank": args.local_rank,
"pad_train_clip_images": args.pad_train_clip_images,
"resize_vision_tower": args.resize_vision_tower,
"resize_vision_tower_size": args.resize_vision_tower_size,
"vision_tower_for_mask": args.vision_tower_for_mask,
"separate_mm_projector": args.separate_mm_projector,
"masks_process_with_clip": args.masks_process_with_clip,
"image_feature_scale_num": args.image_feature_scale_num,
"three_level_multi_scale_decoder": args.Three_Level_Multi_Scale_Decoder,
"is_multipath_encoder": args.is_multipath_encoder,
"sam2_config": args.sam2_config,
"freeze_vision":args.freeze_vision
}
torch_dtype = torch.float32
if args.precision == "bf16":
torch_dtype = torch.bfloat16
elif args.precision == "fp16":
torch_dtype = torch.half
ignore_mismatched_sizes = args.separate_mm_projector
model = PixDLMForCausalLM.from_pretrained(
args.version,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
ignore_mismatched_sizes=ignore_mismatched_sizes,
**model_args,
)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model.get_model().initialize_vision_modules(model.get_model().config)
vision_tower = model.get_model().get_vision_tower()
vision_tower.to(dtype=torch_dtype, device=args.local_rank)
for p in vision_tower.parameters():
p.requires_grad = False
if args.resize_vision_tower_size == 224:
for p in model.get_model().mm_projector.parameters():
p.requires_grad = False
conversation_lib.default_conversation = conversation_lib.conv_templates[
args.conv_type
]
lora_r = args.lora_r
if lora_r > 0:
def find_linear_layers(model, lora_target_modules):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if (
isinstance(module, cls)
and all(
[
x not in name
for x in [
"visual_model",
"vision_tower",
"mm_projector",
"text_hidden_fcs",
"mask_decoder",
"image_feature_neck",
"prompt_encoder",
]
]
)
and any([x in name for x in lora_target_modules])
):
lora_module_names.add(name)
return sorted(list(lora_module_names))
lora_alpha = args.lora_alpha
lora_dropout = args.lora_dropout
lora_target_modules = find_linear_layers(
model, args.lora_target_modules.split(",")
)
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
model.resize_token_embeddings(len(tokenizer))
if args.weight:
state_dict = torch.load(args.weight, map_location="cpu")
model.load_state_dict(state_dict, strict=False)
trainable_list = ["lm_head", "embed_tokens", "mask_decoder", "text_hidden_fcs", "sam_to_embed_conv", "prompt_encoder", "image_feature_neck"]
if args.resize_vision_tower_size != 224:
trainable_list.append('mm_projector')
for n, p in model.named_parameters():
if any(
[
x in n
for x in trainable_list
]
):
p.requires_grad = True
world_size = torch.cuda.device_count()
args.distributed = world_size > 1
train_dataset = HybridDataset(
args.dataset_dir,
tokenizer,
args.vision_tower,
samples_per_epoch=args.batch_size
* args.grad_accumulation_steps
* args.steps_per_epoch
* world_size,
precision=args.precision,
image_size=args.image_size,
num_classes_per_sample=args.num_classes_per_sample,
exclude_val=args.exclude_val,
dataset=args.dataset,
sample_rate=[float(x) for x in args.sample_rates.split(",")],
sem_seg_data=args.sem_seg_data,
refer_seg_data=args.refer_seg_data,
vqa_data=args.vqa_data,
reason_seg_data=args.reason_seg_data,
explanatory=args.explanatory,
seg_token_num=args.seg_token_num*args.image_feature_scale_num,
num_classes_per_question=args.num_classes_per_question,
pad_train_clip_images=args.pad_train_clip_images,
masks_process_with_clip=args.masks_process_with_clip,
preprocessor_config=args.preprocessor_config,
use_expand_question_list=args.use_expand_question_list,
)
print("____seg_token_num in data:________: ", args.seg_token_num*args.image_feature_scale_num)
multi_val = False
if args.no_eval == False:
token_num = args.seg_token_num*args.image_feature_scale_num
if len(args.val_dataset.split('||')) == 1:
if args.val_dataset.split('|')[0] == 'MultiReasonSeg':
ValDataset_type = MultiReasonSegValDataset
else:
ValDataset_type = ValDataset
val_dataset_names = [args.val_dataset]
val_dataset = ValDataset_type(
args.dataset_dir,
tokenizer,
args.vision_tower,
args.val_dataset,
args.image_size,
seg_token_num=token_num,
pad_val_clip_images=args.pad_train_clip_images,
masks_process_with_clip=args.masks_process_with_clip,
preprocessor_config=args.preprocessor_config,
)
print(
f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples."
)
else:
multi_val = True
val_dataset_names = args.val_dataset.split('||')
val_dataset = []
for val_dataset_name in val_dataset_names:
if val_dataset_name.split('|')[0] == 'MultiReasonSeg':
ValDataset_type = MultiReasonSegValDataset
else:
ValDataset_type = ValDataset
val_dataset.append(
ValDataset_type(
args.dataset_dir,
tokenizer,
args.vision_tower,
val_dataset_name,
args.image_size,
seg_token_num=token_num,
pad_val_clip_images=args.pad_train_clip_images,
masks_process_with_clip=args.masks_process_with_clip,
preprocessor_config=args.preprocessor_config,
)
)
else:
val_dataset = None
print(f"Training with {len(train_dataset)} examples.")
ds_config = {
"train_micro_batch_size_per_gpu": args.batch_size,
"gradient_accumulation_steps": args.grad_accumulation_steps,
"optimizer": {
"type": "AdamW",
"params": {
"lr": args.lr,
"weight_decay": 0.0,
"betas": (args.beta1, args.beta2),
},
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"total_num_steps": args.epochs * args.steps_per_epoch,
"warmup_min_lr": 0,
"warmup_max_lr": args.lr,
"warmup_num_steps": 100,
"warmup_type": "linear",
},
},
"fp16": {
"enabled": args.precision == "fp16",
},
"bf16": {
"enabled": args.precision == "bf16",
},
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 2,
"contiguous_gradients": True,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8,
},
}
model_engine, optimizer, train_loader, scheduler = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
training_data=train_dataset,
collate_fn=partial(
collate_fn,
tokenizer=tokenizer,
conv_type=args.conv_type,
use_mm_start_end=args.use_mm_start_end,
local_rank=args.local_rank,
),
config=ds_config,
)
if args.auto_resume and len(args.resume) == 0:
resume = os.path.join(args.log_dir, "ckpt_model")
if os.path.exists(resume):
args.resume = resume
if args.resume:
load_path, client_state = model_engine.load_checkpoint(args.resume)
with open(os.path.join(args.resume, "latest"), "r") as f:
ckpt_dir = f.readlines()[0].strip()
args.start_epoch = (
int(ckpt_dir.replace("global_step", "")) // args.steps_per_epoch
)
print(
"resume training from {}, start from epoch {}".format(
args.resume, args.start_epoch
)
)
if val_dataset is not None:
assert args.val_batch_size == 1
if multi_val:
val_sampler = [torch.utils.data.distributed.DistributedSampler(
dataset, shuffle=False, drop_last=False
) for dataset in val_dataset]
val_loader = [torch.utils.data.DataLoader(
dataset,
batch_size=args.val_batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=False,
sampler=sampler,
collate_fn=partial(
collate_fn,
tokenizer=tokenizer,
conv_type=args.conv_type,
use_mm_start_end=args.use_mm_start_end,
local_rank=args.local_rank,
),
) for dataset, sampler in zip(val_dataset, val_sampler)]
else:
val_sampler = torch.utils.data.distributed.DistributedSampler(
val_dataset, shuffle=False, drop_last=False
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.val_batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=False,
sampler=val_sampler,
collate_fn=partial(
collate_fn,
tokenizer=tokenizer,
conv_type=args.conv_type,
use_mm_start_end=args.use_mm_start_end,
local_rank=args.local_rank,
),
)
train_iter = iter(train_loader)
best_score, cur_ciou = 0.0, 0.0
if args.eval_only:
if args.val_dataset.split('|')[0] == 'MultiReasonSeg':
ar_validate(val_loader, model_engine, 0, writer, args, logger, val_dataset_names, tokenizer, args.seg_token_num, args.image_feature_scale_num)
else:
giou, ciou = validate(val_loader, model_engine, 0, writer, args, logger, val_dataset_names,tokenizer)
print(giou,ciou)
exit()
for epoch in range(args.start_epoch, args.epochs):
train_iter = train(
train_loader,
model_engine,
epoch,
scheduler,
writer,
train_iter,
args,
tokenizer,
)
if args.no_eval == False:
giou, ciou = validate(val_loader, model_engine, epoch, writer, args, logger, val_dataset_names,tokenizer)
is_best = giou > best_score
best_score = max(giou, best_score)
cur_ciou = ciou if is_best else cur_ciou
if args.no_eval or is_best:
save_dir = os.path.join(args.log_dir, "best_ckpt_model")
if args.local_rank == 0:
torch.save(
{"epoch": epoch},
os.path.join(
args.log_dir,
"meta_log_giou{:.3f}_ciou{:.3f}.pth".format(
best_score, cur_ciou
),
),
)
if os.path.exists(save_dir):
shutil.rmtree(save_dir)
torch.distributed.barrier()
model_engine.save_checkpoint(save_dir)
save_dir = os.path.join(args.log_dir, "ckpt_model")
if args.local_rank == 0:
if os.path.exists(save_dir):
shutil.rmtree(save_dir)
torch.distributed.barrier()
model_engine.save_checkpoint(save_dir)
def train(
train_loader,
model,
epoch,
scheduler,
writer,
train_iter,
args,
tokenizer,
):
"""Main training loop."""
batch_time = AverageMeter("Time", ":6.3f")
data_time = AverageMeter("Data", ":6.3f")
losses = AverageMeter("Loss", ":.4f")
ce_losses = AverageMeter("CeLoss", ":.4f")
mask_bce_losses = AverageMeter("MaskBCELoss", ":.4f")
mask_dice_losses = AverageMeter("MaskDICELoss", ":.4f")
mask_losses = AverageMeter("MaskLoss", ":.4f")
progress = ProgressMeter(
args.steps_per_epoch,
[
batch_time,
losses,
ce_losses,
mask_losses,
mask_bce_losses,
mask_dice_losses,
],
prefix="Epoch: [{}]".format(epoch),
)
model.train()
end = time.time()
for global_step in range(args.steps_per_epoch):
for i in range(args.grad_accumulation_steps):
try:
input_dict = next(train_iter)
except:
train_iter = iter(train_loader)
input_dict = next(train_iter)
data_time.update(time.time() - end)
texts = []
for cls_group in input_dict["sampled_classes_list"]:
if isinstance(cls_group, list) and isinstance(cls_group[0], list):
text = " ".join(cls_group[0])
elif isinstance(cls_group, list):
text = " ".join(cls_group)
else:
text = str(cls_group)
texts.append(text)
input_ids_list = []
for text in texts:
input_ids = tokenizer_image_token(
text,
tokenizer,
image_token_index=-200,
return_tensors="pt"
)
input_ids_list.append(input_ids)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids_list,
batch_first=True,
padding_value=tokenizer.pad_token_id
)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
with torch.no_grad():
outputs = get_language_backbone(model)(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True
)
embeddings = outputs.hidden_states[-1]
text_embeddings = []
for i in range(len(texts)):
valid_mask = attention_mask[i]
valid_embeddings = embeddings[i][valid_mask]
text_embeddings.append(valid_embeddings)
text_embeddings = torch.nn.utils.rnn.pad_sequence(
text_embeddings,
batch_first=True,
padding_value=0.0
)
input_dict["txt_feat"] =text_embeddings
input_dict = dict_to_cuda(input_dict)
if args.precision == "fp16":
input_dict["images"] = input_dict["images"].half()
input_dict["images_clip"] = input_dict["images_clip"].half()
elif args.precision == "bf16":
input_dict["images"] = input_dict["images"].bfloat16()
input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
else:
input_dict["images"] = input_dict["images"].float()
input_dict["images_clip"] = input_dict["images_clip"].float()
output_dict = model(**input_dict)
loss = output_dict["loss"]
ce_loss = output_dict["ce_loss"]
mask_bce_loss = output_dict["mask_bce_loss"]
mask_dice_loss = output_dict["mask_dice_loss"]
mask_loss = output_dict["mask_loss"]
losses.update(loss.item(), input_dict["images"].size(0))
ce_losses.update(ce_loss.item(), input_dict["images"].size(0))
mask_bce_losses.update(mask_bce_loss.item(), input_dict["images"].size(0))
mask_dice_losses.update(mask_dice_loss.item(), input_dict["images"].size(0))
mask_losses.update(mask_loss.item(), input_dict["images"].size(0))
model.backward(loss)
model.step()
batch_time.update(time.time() - end)
end = time.time()
if global_step % args.print_freq == 0:
if args.distributed:
batch_time.all_reduce()
data_time.all_reduce()
losses.all_reduce()
ce_losses.all_reduce()
mask_bce_losses.all_reduce()
mask_dice_losses.all_reduce()
mask_losses.all_reduce()
if args.local_rank == 0:
progress.display(global_step + 1)
writer.add_scalar("train/loss", losses.avg, global_step)
writer.add_scalar("train/ce_loss", ce_losses.avg, global_step)
writer.add_scalar(
"train/mask_bce_loss", mask_bce_losses.avg, global_step
)
writer.add_scalar(
"train/mask_dice_loss", mask_dice_losses.avg, global_step
)
writer.add_scalar("train/mask_loss", mask_losses.avg, global_step)
writer.add_scalar(
"metrics/total_secs_per_batch", batch_time.avg, global_step
)
writer.add_scalar(
"metrics/data_secs_per_batch", data_time.avg, global_step
)
batch_time.reset()
data_time.reset()
losses.reset()
ce_losses.reset()
mask_bce_losses.reset()
mask_dice_losses.reset()
mask_losses.reset()
if global_step != 0:
curr_lr = scheduler.get_last_lr()
if args.local_rank == 0:
writer.add_scalar("train/lr", curr_lr[0], global_step)
return train_iter
def ar_validate(val_loader, model_engine, epoch, writer, args, logger, val_dataset_names, tokenizer, seg_token_num=1, image_feature_scale_num=1):
pred_file = []
acc_iou_list = []
log_dir = args.log_dir
out_file = os.path.join(log_dir, 'out_file_{}.json'.format(args.local_rank))
acc_iou_out_file = os.path.join(log_dir, 'acc_list_{}.json'.format(args.local_rank))
model_engine.eval()
if not isinstance(val_loader, list):
val_loader = [val_loader]
assert len(val_dataset_names) == len(val_loader)
k = 0
for loader, dataset_name in zip(val_loader, val_dataset_names):
intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM)
union_meter = AverageMeter("Union", ":6.3f", Summary.SUM)
acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM)
for input_dict in tqdm.tqdm(loader):
image_pred = {}
image_pred['answers'] = []
image_pred['question_gt_category_name'] = []
input_dict = dict_to_cuda(input_dict)
if args.precision == "fp16":
input_dict["images"] = input_dict["images"].half()
input_dict["images_clip"] = input_dict["images_clip"].half()
elif args.precision == "bf16":
input_dict["images"] = input_dict["images"].bfloat16()
input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
else:
input_dict["images"] = input_dict["images"].float()
input_dict["images_clip"] = input_dict["images_clip"].float()
image_paths = input_dict['image_paths']
images = input_dict['images']
images_clip = input_dict['images_clip']
resize_list = input_dict['resize_list']
clip_resize_list = input_dict['clip_resize_list']
label_list = input_dict['label_list']
input_ids = input_dict['input_ids']
gt_masks = input_dict['masks_list']
questions_list = input_dict['questions_list']
original_size_list = [label.shape for label in label_list]
if k == 0:
model_engine(**input_dict)
output_ids, pred_masks, batch_seg_token_counts, mask_scores = model_engine.base_model.evaluate(images_clip, images, input_ids, resize_list, clip_resize_list, original_size_list, max_new_tokens=512, tokenizer=tokenizer)
text_outputs = []
for output_id in output_ids:
_output_id = copy.deepcopy(output_id[0])
_output_id[_output_id==-200] = 31999
text_output = tokenizer.decode(_output_id, skip_special_tokens=False)
text_output = (
text_output.replace(DEFAULT_IMAGE_PATCH_TOKEN, "")
.replace("\n", "")
.replace(" ", "")
)
text_outputs.append(text_output)
image_path = input_dict['image_paths'][0]
print("idx:", k, "image_path:", input_dict['image_paths'][0], "text_output: ", text_outputs)
k += 1
batch_seg_token_count = batch_seg_token_counts[0]
batch_seg_token_count = batch_seg_token_count.cumsum(-1)
batch_seg_token_count = torch.cat(
[torch.zeros(1).long().cuda(), batch_seg_token_count], dim=0
)
pred_mask = pred_masks[0]
gt_mask = gt_masks[0]
mask_score = mask_scores[0]
max_num = max(len(pred_masks[0]), len(gt_masks[0]))
assigned_gt_masks = []
assigned_pred_masks = []
questions_list = input_dict['questions_list']
gt_target_count = questions_list[0][1]
gt_category_name = questions_list[0][2]
prompt_ins = questions_list[0][3]
gt_target_count = torch.tensor(gt_target_count).to(batch_seg_token_count).cumsum(-1)
gt_target_count = torch.cat(
[torch.zeros(1).long().cuda(), gt_target_count], dim=0
)
assign_length = []
assign_indice = []
assign_acc = []
total_pred_count = []
pred_count = []
assert len(batch_seg_token_count) == len(gt_target_count)
for j in range(len(batch_seg_token_count) -1):
start_i = batch_seg_token_count[j]
end_i = batch_seg_token_count[j+1]
q_start_i = gt_target_count[j]
q_end_i = gt_target_count[j+1]
question_inputs = pred_mask[start_i:end_i]
question_mask_scores = mask_score[start_i:end_i]
question_targets = gt_mask[q_start_i:q_end_i]
indice = match_pred(question_inputs.detach(), question_targets.detach())
assigned_pred_mask = pred_mask[start_i:end_i][indice[0]]
assigned_pred_mask = (assigned_pred_mask > 0).int()
assigned_gt_mask = gt_mask[q_start_i:q_end_i][indice[1]]
unassugned_indice = []
unassugned_indice_pred = []
for i in range(len(gt_mask[q_start_i:q_end_i])):
if i not in indice[1]:
unassugned_indice.append(i)
for i in range(len(pred_mask[start_i:end_i])):
if i not in indice[0]:
unassugned_indice_pred.append(i)
unassugned_indice = np.array(unassugned_indice)
unassugned_indice_pred = np.array(unassugned_indice_pred)
unassigned_gt_mask = gt_mask[q_start_i:q_end_i][unassugned_indice]
unassigned_pred = pred_mask[start_i:end_i][unassugned_indice_pred]
empty_gt = torch.zeros_like(unassigned_pred)
empty_pred = torch.zeros_like(unassigned_gt_mask)
assigned_gt_mask = torch.cat((assigned_gt_mask, unassigned_gt_mask))
assigned_pred_mask = torch.cat((assigned_pred_mask, empty_pred))
assigned_gt_mask = torch.cat((assigned_gt_mask, empty_gt))
assigned_pred_mask = torch.cat((assigned_pred_mask, unassigned_pred))
assigned_gt_masks.append(assigned_gt_mask)
assigned_pred_masks.append(assigned_pred_mask)
question_gt_category_name = gt_category_name[j]
text_output = text_outputs[j]
sorted_id = sorted(range(len(indice[0])), key=lambda k: indice[0][k], reverse=False)
sorted_gt_indice = indice[1][sorted_id]
sorted_pred_indice = indice[0][sorted_id]
seg_token = ' '.join(['[SEG{}]'.format(str(s)) for s in range(seg_token_num*image_feature_scale_num)]) if seg_token_num*image_feature_scale_num > 1 else '[SEG]'
_text_output = text_output
in_count = 0
question_gt_category_name_list = []
for count in range(text_output.count(seg_token)):
if count in sorted_pred_indice:
_text_output = _text_output.replace(seg_token, question_gt_category_name[sorted_gt_indice[in_count]], 1)
question_gt_category_name_list.append(question_gt_category_name[sorted_gt_indice[in_count]][1:-1])
in_count += 1
else:
question_gt_category_name_list.append('None []')
_text_output = _text_output.replace(seg_token, '(None [])', 1)
image_pred['image_path'] = input_dict['image_paths'][0]
image_pred['questions'] = questions_list[0][0]
answer = _text_output.split('ASSISTANT:')[-1]
answer = answer.replace('', '')
image_pred['answers'].append(answer)
image_pred['question_gt_category_name'].append(question_gt_category_name_list)
assign_length.extend([True]*len(indice[0]))
assign_length.extend([False]*(len(assigned_gt_mask)-len(indice[0])))
assign_indice.append(indice[0].tolist())
total_pred_count.append(len(assigned_gt_mask))
pred_count.append(len(pred_mask[start_i:end_i]))
assigned_gt_masks = torch.cat(assigned_gt_masks)
output_list = torch.cat(assigned_pred_masks)
intersection, union, acc_iou = 0.0, 0.0, 0.0
for mask_i, output_i, is_assign in zip(assigned_gt_masks, output_list, assign_length):
intersection_i, union_i, _ = intersectionAndUnionGPU(
output_i.contiguous().clone(), mask_i.contiguous(), 2, ignore_index=255
)
intersection += intersection_i
union += union_i
acc_iou += intersection_i / (union_i + 1e-5)
acc_iou[union_i == 0] += 1.0
assign_acc.append((intersection_i.tolist(), union_i.tolist()))
image_pred['assign_length'] = assign_length
image_pred['assign_indice'] = assign_indice
image_pred['assign_acc'] = assign_acc
image_pred['total_pred_count'] = total_pred_count
image_pred['pred_count'] = pred_count
image_pred['prompt_ins'] = prompt_ins
pred_file.append(image_pred)
intersection, union = intersection.cpu().numpy(), union.cpu().numpy()
acc_iou = acc_iou.cpu().numpy() / max_num
intersection_meter.update(intersection), union_meter.update(
union
), acc_iou_meter.update(acc_iou, n=max_num)
print(acc_iou)
_acc_iou = acc_iou.tolist()
_acc_iou.append(max_num)
_acc_iou.append(input_dict['image_paths'][0])
acc_iou_list.append(_acc_iou)
intersection_meter.all_reduce()
union_meter.all_reduce()
acc_iou_meter.all_reduce()
with open(acc_iou_out_file, 'w') as f:
json.dump(acc_iou_list, f)
with open(out_file, 'w') as f:
json.dump(pred_file, f)
iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
ciou = iou_class[1]
giou = acc_iou_meter.avg[1]
if args.local_rank == 0:
writer.add_scalar("val/giou", giou, epoch)
writer.add_scalar("val/ciou", ciou, epoch)
print("{}, epoch: {}, giou: {:.4f}, ciou: {:.4f}".format(dataset_name, epoch, giou, ciou))
logger.info("{}, epoch: {}, giou: {:.4f}, ciou: {:.4f}".format(dataset_name, epoch, giou, ciou))
def validate(val_loader, model_engine, epoch, writer, args, logger, val_dataset_names,tokenizer):
import time
import re
from collections import defaultdict
model_engine.eval()
if not isinstance(val_loader, list):
val_loader = [val_loader]
for loader, dataset_name in zip(val_loader, val_dataset_names):
if 'NYU' in dataset_name:
continue
intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM)
union_meter = AverageMeter("Union", ":6.3f", Summary.SUM)
acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM)
reasoning_type_meters = defaultdict(lambda: {
'intersection': AverageMeter("Intersec", ":6.3f", Summary.SUM),
'union': AverageMeter("Union", ":6.3f", Summary.SUM),
'acc_iou': AverageMeter("gIoU", ":6.3f", Summary.SUM),
'count': 0
})
correct_with_cot = 0
correct_without_cot = 0
total_samples = 0
cot_right_no_cot_wrong = 0
cot_wrong_no_cot_right = 0
total_time = 0
num_images = 0
for input_dict in tqdm.tqdm(loader):
start_time = time.time()
torch.cuda.empty_cache()
input_dict = dict_to_cuda(input_dict)
answers_list = input_dict.get("answers_list", [None] * len(input_dict["image_paths"]))
answer_raw = answers_list[0] if len(answers_list) > 0 else None
answer = None
if answer_raw:
if isinstance(answer_raw, list):
answer = answer_raw[0] if len(answer_raw) > 0 else None
elif isinstance(answer_raw, str):
answer = answer_raw
else:
answer = str(answer_raw)
texts_with_cot = []
texts_without_cot = []
for cls_group in input_dict["sampled_classes_list"]:
if isinstance(cls_group, list) and isinstance(cls_group[0], list):
text = " ".join(cls_group[0])
elif isinstance(cls_group, list):
text = " ".join(cls_group)
else:
text = str(cls_group)
text_with_cot = text
if answer:
text_with_cot = text + " " + str(answer)
texts_with_cot.append(text_with_cot)
text_without_cot = text
if answer:
answer_without_cot = re.sub(r'.*?', '', str(answer), flags=re.DOTALL)
answer_without_cot = ' '.join(answer_without_cot.split())
text_without_cot = text + " " + answer_without_cot if answer_without_cot else text
texts_without_cot.append(text_without_cot)
miou_with_cot = None
miou_without_cot = None
for texts, cot_type in [(texts_with_cot, "with_cot"), (texts_without_cot, "without_cot")]:
input_ids_list = []
for text in texts:
input_ids = tokenizer_image_token(
text,
tokenizer,
image_token_index=-200,
return_tensors="pt"
)
input_ids_list.append(input_ids)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids_list,
batch_first=True,
padding_value=tokenizer.pad_token_id
)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
with torch.no_grad():
outputs = get_language_backbone(model_engine)(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True
)
embeddings = outputs.hidden_states[-1]
text_embeddings = []
for i in range(len(texts)):
valid_mask = attention_mask[i]
valid_embeddings = embeddings[i][valid_mask]
text_embeddings.append(valid_embeddings)
text_embeddings = torch.nn.utils.rnn.pad_sequence(
text_embeddings,
batch_first=True,
padding_value=0.0
)
input_dict["txt_feat"] = text_embeddings
input_dict["inference"] = True
if args.precision == "fp16":
input_dict["images"] = input_dict["images"].half()
input_dict["images_clip"] = input_dict["images_clip"].half()
elif args.precision == "bf16":
input_dict["images"] = input_dict["images"].bfloat16()
input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
else:
input_dict["images"] = input_dict["images"].float()
input_dict["images_clip"] = input_dict["images_clip"].float()
with torch.no_grad():
output_dict = model_engine(**input_dict)
pred_masks = output_dict["pred_masks"]
masks_list = output_dict["gt_masks"][0].int()
output_list = (pred_masks[0] > 0).int()
assert len(pred_masks) == 1
intersection, union, acc_iou = 0.0, 0.0, 0.0
for mask_i, output_i in zip(masks_list, output_list):
intersection_i, union_i, _ = intersectionAndUnionGPU(
output_i.contiguous().clone(), mask_i.contiguous(), 2, ignore_index=255
)
intersection += intersection_i
union += union_i
acc_iou += intersection_i / (union_i + 1e-5)
acc_iou[union_i == 0] += 1.0
intersection, union = intersection.cpu().numpy(), union.cpu().numpy()
acc_iou = acc_iou.cpu().numpy() / masks_list.shape[0]
per_image_miou = (intersection / (union + 1e-10)).mean()
per_image_miou_percent = per_image_miou * 100
if cot_type == "with_cot":
miou_with_cot = per_image_miou_percent
else:
miou_without_cot = per_image_miou_percent
if cot_type == "with_cot":
if per_image_miou_percent > 25:
correct_with_cot += 1
else:
if per_image_miou_percent > 25:
correct_without_cot += 1
if cot_type == "with_cot":
intersection_meter.update(intersection), union_meter.update(
union
), acc_iou_meter.update(acc_iou, n=masks_list.shape[0])
per_image_ciou = intersection[1] / (union[1] + 1e-10)
per_image_giou = acc_iou[1]
per_image_acc = intersection.sum() / union.sum()
categories = input_dict.get("categories", ["unknown"] * len(input_dict["image_paths"]))
if isinstance(categories, list) and len(categories) > 0:
reasoning_type = categories[0] if len(categories) > 0 else "unknown"
reasoning_type_meters[reasoning_type]['intersection'].update(intersection)
reasoning_type_meters[reasoning_type]['union'].update(union)
reasoning_type_meters[reasoning_type]['acc_iou'].update(acc_iou, n=masks_list.shape[0])
reasoning_type_meters[reasoning_type]['count'] += 1
save_eval_artifacts(
args=args,
input_dict=input_dict,
dataset_name=dataset_name,
cot_type=cot_type,
output_list=output_list,
masks_list=masks_list,
question_text=input_dict.get("questions_list", [[None]])[0],
condition_text=texts[0] if len(texts) > 0 else None,
answer_text=answer,
per_image_ciou=per_image_ciou,
per_image_giou=per_image_giou,
)
if miou_with_cot is not None and miou_without_cot is not None:
cot_right = miou_with_cot > 50
no_cot_right = miou_without_cot > 50
if cot_right and not no_cot_right:
cot_right_no_cot_wrong += 1
if not cot_right and no_cot_right:
cot_wrong_no_cot_right += 1
total_samples += 1
batch_time = time.time() - start_time
total_time += batch_time
num_images += 1
print(f"Image: {input_dict['image_paths'][0]}, cIoU: {per_image_ciou:.4f}, gIoU: {per_image_giou:.4f}, Time: {batch_time:.4f}s")
avg_time = total_time / num_images if num_images > 0 else 0
fps = num_images / total_time if total_time > 0 else 0
print(f"\n{'='*50}")
print(f"Validation Speed Statistics - {dataset_name}")
print(f"{'='*50}")
print(f"Total images: {num_images}")
print(f"Total time: {total_time:.4f}s")
print(f"Average time per image: {avg_time:.4f}s")
print(f"FPS: {fps:.2f}")
print(f"{'='*50}\n")
intersection_meter.all_reduce()
union_meter.all_reduce()
acc_iou_meter.all_reduce()
for reasoning_type in reasoning_type_meters.keys():
reasoning_type_meters[reasoning_type]['intersection'].all_reduce()
reasoning_type_meters[reasoning_type]['union'].all_reduce()
reasoning_type_meters[reasoning_type]['acc_iou'].all_reduce()
if dist.is_initialized():
count_tensor = torch.tensor(reasoning_type_meters[reasoning_type]['count'], dtype=torch.long, device='cuda')
dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM)
reasoning_type_meters[reasoning_type]['count'] = count_tensor.item()
iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
ciou = iou_class[1]
giou = acc_iou_meter.avg[1]
if args.local_rank == 0:
writer.add_scalar("val/giou", giou, epoch)
writer.add_scalar("val/ciou", ciou, epoch)
logger.info("{}, epoch: {}, giou: {:.4f}, ciou: {:.4f}".format(dataset_name, epoch, giou, ciou))
print("giou: {:.4f}, ciou: {:.4f}".format(giou, ciou))
print(f"\n{'='*50}")
print(f"CoT Comparison Results - {dataset_name}")
print(f"{'='*50}")
print(f"Total samples: {total_samples}")
print(f"Correct with CoT (mIoU > 25): {correct_with_cot}/{total_samples} ({correct_with_cot/total_samples*100:.2f}%)")
print(f"Correct without CoT (mIoU > 25): {correct_without_cot}/{total_samples} ({correct_without_cot/total_samples*100:.2f}%)")
print(f"\nSpecial Cases (mIoU > 50 as threshold):")
print(f" Cases where CoT is correct but no-CoT is wrong: {cot_right_no_cot_wrong}")
print(f" Cases where CoT is wrong but no-CoT is correct: {cot_wrong_no_cot_right}")
print(f"{'='*50}\n")
logger.info("{}, CoT comparison: with_cot={}/{}, without_cot={}/{}".format(
dataset_name, correct_with_cot, total_samples, correct_without_cot, total_samples))
logger.info("{}, Special cases (mIoU>50): cot_right_no_cot_wrong={}, cot_wrong_no_cot_right={}".format(
dataset_name, cot_right_no_cot_wrong, cot_wrong_no_cot_right))
if reasoning_type_meters:
print(f"\n{'='*50}")
print(f"Results by Reasoning Type - {dataset_name}")
print(f"{'='*50}")
for reasoning_type in sorted(reasoning_type_meters.keys()):
meters = reasoning_type_meters[reasoning_type]
type_iou_class = meters['intersection'].sum / (meters['union'].sum + 1e-10)
type_ciou = type_iou_class[1]
type_giou = meters['acc_iou'].avg[1]
type_count = meters['count']
print(f"{reasoning_type}: gIoU: {type_giou:.4f}, cIoU: {type_ciou:.4f}, Count: {type_count}")
logger.info("{}, reasoning_type: {}, giou: {:.4f}, ciou: {:.4f}, count: {}".format(
dataset_name, reasoning_type, type_giou, type_ciou, type_count))
print(f"{'='*50}\n")
return giou, ciou
if __name__ == "__main__":
main(sys.argv[1:])