|
|
import argparse |
|
|
import copy |
|
|
import os.path as osp |
|
|
|
|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from mmengine import Config |
|
|
from mmengine.dist import init_dist, get_dist_info |
|
|
from mmengine.utils.dl_utils import set_multi_processing |
|
|
from transformers import GenerationConfig |
|
|
from xtuner.configs import cfgs_name_path |
|
|
from xtuner.registry import BUILDER |
|
|
from xtuner.tools.chat import TORCH_DTYPE_MAP |
|
|
from xtuner.tools.utils import get_stop_criteria |
|
|
from xtuner.utils import PROMPT_TEMPLATE |
|
|
|
|
|
|
|
|
from projects.llava_sam2.datasets import video_lisa_collate_fn |
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description='RefCocoSeg') |
|
|
parser.add_argument('config', help='config file name or path.') |
|
|
parser.add_argument('--pth_model', default=None, help='pth model file') |
|
|
parser.add_argument( |
|
|
'--split', |
|
|
default='val', |
|
|
help='Specify a split') |
|
|
parser.add_argument( |
|
|
'--prompt-template', |
|
|
choices=PROMPT_TEMPLATE.keys(), |
|
|
default='internlm2_chat', |
|
|
help='Specify a prompt template') |
|
|
parser.add_argument( |
|
|
'--stop-words', nargs='+', type=str, default=[], help='Stop words') |
|
|
parser.add_argument( |
|
|
'--torch-dtype', |
|
|
default='fp16', |
|
|
choices=TORCH_DTYPE_MAP.keys(), |
|
|
help='Override the default `torch.dtype` and load the model under ' |
|
|
'a specific `dtype`.') |
|
|
parser.add_argument( |
|
|
'--bits', |
|
|
type=int, |
|
|
choices=[4, 8, None], |
|
|
default=None, |
|
|
help='LLM bits') |
|
|
parser.add_argument( |
|
|
'--bot-name', type=str, default='BOT', help='Name for Bot') |
|
|
parser.add_argument( |
|
|
'--offload-folder', |
|
|
default=None, |
|
|
help='The folder in which to offload the model weights (or where the ' |
|
|
'model weights are already offloaded).') |
|
|
parser.add_argument( |
|
|
'--max-new-tokens', |
|
|
type=int, |
|
|
default=100, |
|
|
help='Maximum number of new tokens allowed in generated text') |
|
|
parser.add_argument( |
|
|
'--seed', |
|
|
type=int, |
|
|
default=0, |
|
|
help='Random seed for reproducible text generation') |
|
|
parser.add_argument( |
|
|
'--launcher', |
|
|
choices=['none', 'pytorch', 'slurm', 'mpi'], |
|
|
default='none', |
|
|
help='job launcher') |
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
if args.launcher != 'none': |
|
|
set_multi_processing(distributed=True) |
|
|
init_dist(args.launcher) |
|
|
|
|
|
rank, world_size = get_dist_info() |
|
|
torch.cuda.set_device(rank) |
|
|
else: |
|
|
rank = 0 |
|
|
world_size = 1 |
|
|
print(f'Rank: {rank} / World size: {world_size}') |
|
|
|
|
|
if not osp.isfile(args.config): |
|
|
try: |
|
|
args.config = cfgs_name_path[args.config] |
|
|
except KeyError: |
|
|
raise FileNotFoundError(f'Cannot find {args.config}') |
|
|
|
|
|
cfg = Config.fromfile(args.config) |
|
|
model_name = cfg.model.type if isinstance(cfg.model.type, str) else cfg.model.type.__name__ |
|
|
assert model_name in ('VideoLLaVASAMModel',) |
|
|
|
|
|
model = BUILDER.build(cfg.model) |
|
|
|
|
|
if args.pth_model is not None: |
|
|
state_dict_pth = args.pth_model |
|
|
state_dict = torch.load(state_dict_pth, map_location='cpu') |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
model.to('cuda:0') |
|
|
model.eval() |
|
|
|
|
|
|
|
|
tokenizer = model.tokenizer |
|
|
|
|
|
|
|
|
gen_config = GenerationConfig( |
|
|
max_new_tokens=args.max_new_tokens, |
|
|
do_sample=False, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
pad_token_id=tokenizer.pad_token_id |
|
|
if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, |
|
|
) |
|
|
stop_words = args.stop_words |
|
|
if args.prompt_template: |
|
|
template = PROMPT_TEMPLATE[args.prompt_template] |
|
|
stop_words += template.get('STOP_WORDS', []) |
|
|
stop_criteria = get_stop_criteria( |
|
|
tokenizer=tokenizer, stop_words=stop_words) |
|
|
|
|
|
data_cfg = copy.deepcopy(cfg.video_revos_dataset) |
|
|
data_cfg.update(expression_file=data_cfg.expression_file.replace('train', 'val')) |
|
|
dataset = BUILDER.build(cfg.video_revos_dataset) |
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=1, |
|
|
num_workers=1, |
|
|
shuffle=False, |
|
|
collate_fn=video_lisa_collate_fn, |
|
|
) |
|
|
|
|
|
for data_item in dataloader: |
|
|
data_item = model.data_preprocessor(data_item) |
|
|
inputs, data_samples = data_item['data'], data_item['data_samples'] |
|
|
g_pixel_values = inputs.pop('g_pixel_values', None) |
|
|
gt_masks = inputs.pop('masks', None) |
|
|
|
|
|
output = model.mllm.generate( |
|
|
pixel_values=inputs['pixel_values'], |
|
|
input_ids=inputs['input_ids'], |
|
|
attention_mask=inputs['attention_mask'], |
|
|
visual_features=None, |
|
|
generation_config=gen_config, |
|
|
streamer=None, |
|
|
bos_token_id=tokenizer.bos_token_id, |
|
|
stopping_criteria=stop_criteria, |
|
|
output_hidden_states=True, |
|
|
return_dict_in_generate=True |
|
|
) |
|
|
print(1) |
|
|
print(1) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|