|
|
import os |
|
|
import json |
|
|
import argparse |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from torch.utils.data import DataLoader |
|
|
from torchvision import transforms as T |
|
|
|
|
|
from data.pose_hicodet import PoseHICODetDataset |
|
|
from data.convsersation import Conversation |
|
|
|
|
|
import re |
|
|
from dataclasses import dataclass |
|
|
|
|
|
from transformers import Qwen3VLForConditionalGeneration |
|
|
from transformers import AutoTokenizer, AutoConfig, AutoProcessor |
|
|
|
|
|
def disable_torch_init(): |
|
|
""" |
|
|
Disable the redundant torch default initialization to accelerate model creation. |
|
|
""" |
|
|
setattr(torch.nn.Linear, "reset_parameters", lambda self: None) |
|
|
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) |
|
|
|
|
|
import os, json |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
|
|
|
def gather_labels_and_save(labels, output_path): |
|
|
|
|
|
world_size = dist.get_world_size() |
|
|
rank = dist.get_rank() |
|
|
|
|
|
gathered = [None for _ in range(world_size)] |
|
|
dist.all_gather_object(gathered, labels) |
|
|
|
|
|
if rank == 0: |
|
|
merged = [] |
|
|
for part in gathered: |
|
|
merged.extend(part) |
|
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f: |
|
|
json.dump(merged, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
@dataclass |
|
|
class DataCollatorForSupervisedDataset(object): |
|
|
def __init__(self, processor, data_path): |
|
|
self.processor = processor |
|
|
self.conv = Conversation( |
|
|
system='', |
|
|
data_path=data_path |
|
|
) |
|
|
|
|
|
def __call__(self, data_dicts): |
|
|
"""Collate examples for supervised fine-tuning.""" |
|
|
batch_prompts = [] |
|
|
batch_images = [] |
|
|
result_meta = [] |
|
|
|
|
|
for i, data_dict in enumerate(data_dicts): |
|
|
batch_images.append(data_dict['image']) |
|
|
batch_prompts.append(self.conv.get_prompt(data_dict['meta'])) |
|
|
result_meta.append(data_dict['meta']) |
|
|
|
|
|
messages = [] |
|
|
for prompt in zip(batch_prompts): |
|
|
messages.append([ |
|
|
{"role": "system", |
|
|
"content":[ |
|
|
{"type": "text", |
|
|
"text": self.conv.system},]}, |
|
|
{"role": "user", |
|
|
"content":[ |
|
|
{"type": "image"}, |
|
|
{"type": "text", |
|
|
"text": prompt},]}, |
|
|
]) |
|
|
|
|
|
prompts = [self.processor.apply_chat_template(m, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True) |
|
|
for m in messages] |
|
|
batch_tensors = self.processor( |
|
|
text=prompts, |
|
|
images=batch_images, |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
) |
|
|
return batch_tensors, result_meta |
|
|
|
|
|
@torch.no_grad() |
|
|
def worker(model, processor, dataset, args, output_dir): |
|
|
|
|
|
rank = int(os.environ["LOCAL_RANK"]) |
|
|
world_size = int(os.environ["WORLD_SIZE"]) |
|
|
indices = list(range(rank, len(dataset), world_size)) |
|
|
print("==>" + " Worker {} Started, responsible for {} images".format(rank, len(indices))) |
|
|
|
|
|
sub_dataset = torch.utils.data.Subset(dataset, indices) |
|
|
batch_size = 1 |
|
|
data_loader = DataLoader(sub_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=DataCollatorForSupervisedDataset(processor, args.data_path)) |
|
|
labels = [] |
|
|
|
|
|
for batch_tensors, result_meta in tqdm(data_loader): |
|
|
|
|
|
input_ids = batch_tensors['input_ids'].cuda() |
|
|
batch_tensors = {k: v.cuda() for k, v in batch_tensors.items() if isinstance(v, torch.Tensor)} |
|
|
with torch.inference_mode(): |
|
|
output_dict = model.generate(do_sample=False, |
|
|
output_scores=True, |
|
|
return_dict_in_generate=True, |
|
|
max_new_tokens=1600, |
|
|
output_logits=True, |
|
|
**batch_tensors,) |
|
|
|
|
|
output_ids = output_dict['sequences'] |
|
|
|
|
|
for input_id, output_id, meta in zip(input_ids, output_ids, result_meta): |
|
|
input_token_len = input_id.shape[0] |
|
|
n_diff_input_output = (input_id != output_id[:input_token_len]).sum().item() |
|
|
if n_diff_input_output > 0: |
|
|
print(f'[Warning] Sample: {n_diff_input_output} output_ids are not the same as the input_ids') |
|
|
output = processor.tokenizer.batch_decode(output_id[input_token_len:].unsqueeze(0), skip_special_tokens=True)[0] |
|
|
|
|
|
labels.append({ |
|
|
'file_name': meta['file_name'], |
|
|
'image_id': meta['image_id'], |
|
|
'instance_id': meta['instance_id'], |
|
|
'keypoints': meta['joints_3d'].reshape(-1).tolist(), |
|
|
'vis': meta['joints_3d_vis'].reshape(-1).tolist(), |
|
|
'im_height': meta['hoi_obj']['height'], |
|
|
'im_width': meta['hoi_obj']['width'], |
|
|
'human_bbox': meta['hoi_obj']['human_bbox'], |
|
|
'object_bbox': meta['hoi_obj']['object_bbox'], |
|
|
'action_labels': meta['hoi_obj']['action_labels'], |
|
|
'description': output, |
|
|
}) |
|
|
|
|
|
|
|
|
local_rank = int(os.environ.get("LOCAL_RANK", "0")) |
|
|
output_path = os.path.join(args.output_dir, f'labels_{local_rank}.json') |
|
|
with open(output_path, "w", encoding="utf-8") as f: |
|
|
json.dump(labels, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
def eval_model(args): |
|
|
torch.distributed.init_process_group(backend='nccl') |
|
|
rank = int(os.environ["LOCAL_RANK"]) |
|
|
world_size = int(os.environ["WORLD_SIZE"]) |
|
|
|
|
|
print('Init process group: world_size: {}, rank: {}'.format(world_size, rank)) |
|
|
torch.cuda.set_device(rank) |
|
|
|
|
|
disable_torch_init() |
|
|
model = Qwen3VLForConditionalGeneration.from_pretrained( |
|
|
args.model_path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
trust_remote_code=True |
|
|
) |
|
|
model = model.cuda() |
|
|
model.eval() |
|
|
|
|
|
processor = AutoProcessor.from_pretrained( |
|
|
args.model_path, |
|
|
trust_remote_code=True) |
|
|
processor.tokenizer.padding_side = "left" |
|
|
processor.tokenizer.pad_token = processor.tokenizer.eos_token |
|
|
|
|
|
dataset = PoseHICODetDataset( |
|
|
data_path=args.data_path, |
|
|
multimodal_cfg=dict(image_folder=os.path.join(args.data_path, 'Images/images/train2015'), |
|
|
data_augmentation=False, |
|
|
image_size=336,),) |
|
|
worker(model, processor, dataset, args, args.output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--model-path", type=str, default="facebook/opt-350m") |
|
|
parser.add_argument("--data-path", type=str, default="") |
|
|
parser.add_argument("--output-dir", type=str, default="") |
|
|
args = parser.parse_args() |
|
|
|
|
|
eval_model(args) |
|
|
|