AutoLLMAnnotation / tools /examine_hico.py
ayh015's picture
Update modifed code
73df34b
import os
import json
import argparse
import re
from tqdm import tqdm
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from data.dataset_for_clean_descrip import PoseHICODetDataset
from data.convsersation import Conversation_examiner as Conversation
from dataclasses import dataclass
from tools.vlm_backend import build_batch_tensors, decode_generated_text, load_model_and_processor
class StreamingJsonArrayWriter:
def __init__(self, output_path):
self.output_path = output_path
self.file = None
self.is_first = True
def __enter__(self):
self.file = open(self.output_path, "w", encoding="utf-8")
self.file.write("[\n")
self.file.flush()
return self
def write(self, item):
if not self.is_first:
self.file.write(",\n")
json.dump(item, self.file, ensure_ascii=False, indent=2)
self.file.flush()
self.is_first = False
def __exit__(self, exc_type, exc_val, exc_tb):
if self.file is not None:
self.file.write("\n]\n")
self.file.close()
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def extract_checked_description(text):
match = re.search(
r"Checked description:\s*(.*?)\s*Issues:\s*",
text,
flags=re.DOTALL
)
if match:
return match.group(1).strip()
return ""
@dataclass
class DataCollatorForSupervisedDataset(object):
def __init__(self, processor, data_path):
self.processor = processor
self.conv = Conversation(
system='',
data_path=data_path
)
def __call__(self, data_dicts):
batch_prompts = []
batch_images = []
result_meta = []
for data_dict in data_dicts:
batch_images.append(data_dict['image'])
batch_prompts.append(self.conv.get_prompt(data_dict['meta']))
result_meta.append(data_dict['meta'])
batch_tensors = build_batch_tensors(
processor=self.processor,
prompts=batch_prompts,
images=batch_images,
system_prompt=self.conv.system,
)
return batch_tensors, result_meta
@torch.no_grad()
def worker(model, processor, dataset, args):
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
indices = list(range(rank, len(dataset), world_size))
print("==>" + " Worker {} Started, responsible for {} images".format(rank, len(indices)))
sub_dataset = torch.utils.data.Subset(dataset, indices)
batch_size = args.batch_size
data_loader = DataLoader(
sub_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
collate_fn=DataCollatorForSupervisedDataset(processor, args.data_path)
)
output_path = os.path.join(args.output_dir, f'examiner_labels_{rank}.json')
with StreamingJsonArrayWriter(output_path) as writer:
for batch_tensors, result_meta in tqdm(data_loader):
input_ids = batch_tensors['input_ids'].cuda()
batch_tensors = {k: v.cuda() for k, v in batch_tensors.items() if isinstance(v, torch.Tensor)}
with torch.inference_mode():
output_dict = model.generate(
do_sample=False,
output_scores=True,
return_dict_in_generate=True,
max_new_tokens=args.max_new_tokens,
output_logits=True,
**batch_tensors,
)
output_ids = output_dict['sequences']
for input_id, output_id, meta in zip(input_ids, output_ids, result_meta):
input_token_len = input_id.shape[0]
n_diff_input_output = (input_id != output_id[:input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] Sample: {n_diff_input_output} output_ids are not the same as the input_ids')
output = decode_generated_text(processor, output_id, input_id)
meta['examiner_result'] = output
meta['final_description'] = extract_checked_description(output)
writer.write(meta)
def eval_model(args):
dist.init_process_group(backend='nccl')
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
print('Init process group: world_size: {}, rank: {}'.format(world_size, rank))
torch.cuda.set_device(rank)
disable_torch_init()
backend_name, model, processor = load_model_and_processor(
model_path=args.model_path,
backend=args.model_backend,
torch_dtype=args.torch_dtype,
trust_remote_code=True,
)
print(f'Using model backend: {backend_name}')
model = model.cuda()
model.eval()
dataset = PoseHICODetDataset(
data_path=args.data_path,
multimodal_cfg=dict(
image_folder=os.path.join(args.data_path, 'Images/images/train2015'),
data_augmentation=False,
image_size=336,
),
annotation_path=args.annotation_path,
max_samples=args.max_samples,
)
worker(model, processor, dataset, args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--data-path", type=str, default="")
parser.add_argument("--annotation-path", type=str, default="./outputs/merged_labels.json")
parser.add_argument("--output-dir", type=str, default="")
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--max-samples", type=int, default=0)
parser.add_argument("--model-backend", type=str, default="auto")
parser.add_argument("--torch-dtype", type=str, default="bfloat16")
args = parser.parse_args()
eval_model(args)