import os import json import argparse import numpy as np from tqdm import tqdm import torch import torch.distributed as dist from torch.utils.data import DataLoader from torchvision import transforms as T from data.pose_hicodet import PoseHICODetDataset from data.convsersation import Conversation import re from dataclasses import dataclass from tools.vlm_backend import build_batch_tensors, decode_generated_text, load_model_and_processor def disable_torch_init(): """ Disable the redundant torch default initialization to accelerate model creation. """ setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) import os, json import torch import torch.distributed as dist class StreamingJsonArrayWriter: def __init__(self, output_path): self.output_path = output_path self.file = None self.is_first = True def __enter__(self): self.file = open(self.output_path, "w", encoding="utf-8") self.file.write("[\n") self.file.flush() return self def write(self, item): if not self.is_first: self.file.write(",\n") json.dump(item, self.file, ensure_ascii=False, indent=2) self.file.flush() self.is_first = False def __exit__(self, exc_type, exc_val, exc_tb): if self.file is not None: self.file.write("\n]\n") self.file.close() def gather_labels_and_save(labels, output_path): # Make sure dist is initialized (torchrun / deepspeed / accelerate usually does this) world_size = dist.get_world_size() rank = dist.get_rank() gathered = [None for _ in range(world_size)] dist.all_gather_object(gathered, labels) # gathered[i] is labels from rank i if rank == 0: merged = [] for part in gathered: merged.extend(part) with open(output_path, "w", encoding="utf-8") as f: json.dump(merged, f, ensure_ascii=False, indent=2) dist.barrier() # optional: ensure rank0 finished writing before others exit @dataclass class DataCollatorForSupervisedDataset(object): def __init__(self, processor, data_path): self.processor = processor self.conv = Conversation( system='', data_path=data_path ) def __call__(self, data_dicts): """Collate examples for supervised fine-tuning.""" batch_prompts = [] batch_images = [] result_meta = [] for i, data_dict in enumerate(data_dicts): batch_images.append(data_dict['image']) batch_prompts.append(self.conv.get_prompt(data_dict['meta'])) result_meta.append(data_dict['meta']) messages = [] for prompt in zip(batch_prompts): messages.append([ {"role": "system", "content":[ {"type": "text", "text": self.conv.system},]}, {"role": "user", "content":[ {"type": "image"}, {"type": "text", "text": prompt},]}, ]) prompts = [self.processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages] batch_tensors = build_batch_tensors( processor=self.processor, prompts=batch_prompts, images=batch_images, system_prompt=self.conv.system, ) return batch_tensors, result_meta @torch.no_grad() def worker(model, processor, dataset, args, output_dir): rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) indices = list(range(rank, len(dataset), world_size)) print("==>" + " Worker {} Started, responsible for {} images".format(rank, len(indices))) sub_dataset = torch.utils.data.Subset(dataset, indices) batch_size = 1 data_loader = DataLoader(sub_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=DataCollatorForSupervisedDataset(processor, args.data_path)) output_path = os.path.join(args.output_dir, f'labels_{rank}.json') with StreamingJsonArrayWriter(output_path) as writer: for batch_tensors, result_meta in tqdm(data_loader): input_ids = batch_tensors['input_ids'].cuda() batch_tensors = {k: v.cuda() for k, v in batch_tensors.items() if isinstance(v, torch.Tensor)} with torch.inference_mode(): output_dict = model.generate(do_sample=False, output_scores=True, return_dict_in_generate=True, max_new_tokens=1600, output_logits=True, **batch_tensors,) output_ids = output_dict['sequences'] for input_id, output_id, meta in zip(input_ids, output_ids, result_meta): input_token_len = input_id.shape[0] n_diff_input_output = (input_id != output_id[:input_token_len]).sum().item() if n_diff_input_output > 0: print(f'[Warning] Sample: {n_diff_input_output} output_ids are not the same as the input_ids') output = decode_generated_text(processor, output_id, input_id) writer.write({ 'file_name': meta['file_name'], 'image_id': meta['image_id'], 'instance_id': meta['instance_id'], 'keypoints': meta['joints_3d'].reshape(-1).tolist(), 'vis': meta['joints_3d_vis'].reshape(-1).tolist(), 'im_height': meta['hoi_obj']['height'], 'im_width': meta['hoi_obj']['width'], 'hoi_id': meta['hoi_obj']['hoi_id'], 'human_bbox': meta['hoi_obj']['human_bbox'], 'object_bbox': meta['hoi_obj']['object_bbox'], 'action_labels': meta['hoi_obj']['action_labels'], 'description': output, }) def eval_model(args): torch.distributed.init_process_group(backend='nccl') rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) print('Init process group: world_size: {}, rank: {}'.format(world_size, rank)) torch.cuda.set_device(rank) disable_torch_init() backend_name, model, processor = load_model_and_processor( model_path=args.model_path, backend=args.model_backend, torch_dtype=args.torch_dtype, trust_remote_code=True, ) print(f'Using model backend: {backend_name}') model = model.cuda() model.eval() dataset = PoseHICODetDataset( data_path=args.data_path, multimodal_cfg=dict(image_folder=os.path.join(args.data_path, 'Images/images/train2015'), data_augmentation=False, image_size=336,), max_samples=args.max_samples,) worker(model, processor, dataset, args, args.output_dir) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="facebook/opt-350m") parser.add_argument("--data-path", type=str, default="") parser.add_argument("--output-dir", type=str, default="") parser.add_argument("--max-samples", type=int, default=0) parser.add_argument("--model-backend", type=str, default="auto") parser.add_argument("--torch-dtype", type=str, default="bfloat16") args = parser.parse_args() eval_model(args)