File size: 5,092 Bytes
032e687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import argparse
import copy
import os.path as osp
import torch
from torch.utils.data import DataLoader
from mmengine import Config
from mmengine.dist import init_dist, get_dist_info
from mmengine.utils.dl_utils import set_multi_processing
from transformers import GenerationConfig
from xtuner.configs import cfgs_name_path
from xtuner.registry import BUILDER
from xtuner.tools.chat import TORCH_DTYPE_MAP
from xtuner.tools.utils import get_stop_criteria
from xtuner.utils import PROMPT_TEMPLATE
from projects.llava_sam2.datasets import video_lisa_collate_fn
def parse_args():
parser = argparse.ArgumentParser(description='RefCocoSeg')
parser.add_argument('config', help='config file name or path.')
parser.add_argument('--pth_model', default=None, help='pth model file')
parser.add_argument(
'--split',
default='val',
help='Specify a split')
parser.add_argument(
'--prompt-template',
choices=PROMPT_TEMPLATE.keys(),
default='internlm2_chat',
help='Specify a prompt template')
parser.add_argument(
'--stop-words', nargs='+', type=str, default=[], help='Stop words')
parser.add_argument(
'--torch-dtype',
default='fp16',
choices=TORCH_DTYPE_MAP.keys(),
help='Override the default `torch.dtype` and load the model under '
'a specific `dtype`.')
parser.add_argument(
'--bits',
type=int,
choices=[4, 8, None],
default=None,
help='LLM bits')
parser.add_argument(
'--bot-name', type=str, default='BOT', help='Name for Bot')
parser.add_argument(
'--offload-folder',
default=None,
help='The folder in which to offload the model weights (or where the '
'model weights are already offloaded).')
parser.add_argument(
'--max-new-tokens',
type=int,
default=100,
help='Maximum number of new tokens allowed in generated text')
parser.add_argument(
'--seed',
type=int,
default=0,
help='Random seed for reproducible text generation')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
args = parser.parse_args()
return args
def main():
args = parse_args()
if args.launcher != 'none':
set_multi_processing(distributed=True)
init_dist(args.launcher)
rank, world_size = get_dist_info()
torch.cuda.set_device(rank)
else:
rank = 0
world_size = 1
print(f'Rank: {rank} / World size: {world_size}')
if not osp.isfile(args.config):
try:
args.config = cfgs_name_path[args.config]
except KeyError:
raise FileNotFoundError(f'Cannot find {args.config}')
cfg = Config.fromfile(args.config)
model_name = cfg.model.type if isinstance(cfg.model.type, str) else cfg.model.type.__name__
assert model_name in ('VideoLLaVASAMModel',)
model = BUILDER.build(cfg.model)
if args.pth_model is not None:
state_dict_pth = args.pth_model
state_dict = torch.load(state_dict_pth, map_location='cpu')
model.load_state_dict(state_dict, strict=False)
model.to('cuda:0')
model.eval()
# define some pointers
tokenizer = model.tokenizer
# gen_configs
gen_config = GenerationConfig(
max_new_tokens=args.max_new_tokens,
do_sample=False,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
)
stop_words = args.stop_words
if args.prompt_template:
template = PROMPT_TEMPLATE[args.prompt_template]
stop_words += template.get('STOP_WORDS', [])
stop_criteria = get_stop_criteria(
tokenizer=tokenizer, stop_words=stop_words)
data_cfg = copy.deepcopy(cfg.video_revos_dataset)
data_cfg.update(expression_file=data_cfg.expression_file.replace('train', 'val'))
dataset = BUILDER.build(cfg.video_revos_dataset)
dataloader = DataLoader(
dataset,
batch_size=1,
num_workers=1,
shuffle=False,
collate_fn=video_lisa_collate_fn,
)
for data_item in dataloader:
data_item = model.data_preprocessor(data_item)
inputs, data_samples = data_item['data'], data_item['data_samples']
g_pixel_values = inputs.pop('g_pixel_values', None)
gt_masks = inputs.pop('masks', None)
output = model.mllm.generate(
pixel_values=inputs['pixel_values'],
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
visual_features=None,
generation_config=gen_config,
streamer=None,
bos_token_id=tokenizer.bos_token_id,
stopping_criteria=stop_criteria,
output_hidden_states=True,
return_dict_in_generate=True
)
print(1)
print(1)
if __name__ == '__main__':
main()
|