Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,903 Bytes
46861c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# --------------------------------------------------------
# Copyright (2025) Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License")
# Grasp Any Region Project
# Written by Haochen Wang
# --------------------------------------------------------
import argparse
import ast
import numpy as np
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor, GenerationConfig
from evaluation.eval_dataset import MultiRegionDataset
TORCH_DTYPE_MAP = dict(fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32)
def parse_args():
parser = argparse.ArgumentParser(
description="Inference of Grasp Any Region models on DLC-Bench."
)
parser.add_argument(
"--model_name_or_path",
help="HF model name or path",
default="HaochenWang/GAR-8B",
)
parser.add_argument(
"--image_path",
help="image path",
required=True,
)
parser.add_argument(
"--mask_paths",
help="mask path",
required=True,
)
parser.add_argument(
"--question_str",
help="input instructions",
required=True,
)
parser.add_argument(
"--data_type",
help="data dtype",
type=str,
choices=["fp16", "bf16", "fp32"],
default="bf16",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Random seed for reproducible text generation",
)
args = parser.parse_args()
return args
def select_ann(coco, img_id, area_min=None, area_max=None):
cat_ids = coco.getCatIds()
ann_ids = coco.getAnnIds(imgIds=[img_id], catIds=cat_ids, iscrowd=None)
if area_min is not None:
ann_ids = [
ann_id for ann_id in ann_ids if coco.anns[ann_id]["area"] >= area_min
]
if area_max is not None:
ann_ids = [
ann_id for ann_id in ann_ids if coco.anns[ann_id]["area"] <= area_max
]
return ann_ids
def main():
args = parse_args()
data_dtype = TORCH_DTYPE_MAP[args.data_type]
torch.manual_seed(args.seed)
# init ditribution for dispatch_modules in LLM
torch.cuda.set_device(0)
torch.distributed.init_process_group(backend="nccl")
# build HF model
model = AutoModel.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
torch_dtype=data_dtype,
device_map="cuda:0",
).eval()
processor = AutoProcessor.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
)
img = Image.open(args.image_path)
masks = []
for mask_path in ast.literal_eval(args.mask_paths):
mask = np.array(Image.open(mask_path).convert("L")).astype(bool)
masks.append(mask)
prompt_number = model.config.prompt_numbers
prompt_tokens = [f"<Prompt{i_p}>" for i_p in range(prompt_number)] + ["<NO_Prompt>"]
dataset = MultiRegionDataset(
image=img,
masks=masks,
question_str=args.question_str
+ "\nAnswer with the correct option's letter directly.",
processor=processor,
prompt_number=prompt_number,
visual_prompt_tokens=prompt_tokens,
data_dtype=data_dtype,
)
data_sample = dataset[0]
with torch.no_grad():
generate_ids = model.generate(
**data_sample,
generation_config=GenerationConfig(
max_new_tokens=1024,
do_sample=False,
eos_token_id=processor.tokenizer.eos_token_id,
pad_token_id=processor.tokenizer.pad_token_id,
),
return_dict=True,
)
outputs = processor.tokenizer.decode(
generate_ids.sequences[0], skip_special_tokens=True
).strip()
print(outputs) # Print model output for this image
if __name__ == "__main__":
main()
|