zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
import argparse
import copy
import os.path as osp
import torch
from torch.utils.data import DataLoader
from mmengine import Config
from mmengine.dist import init_dist, get_dist_info
from mmengine.utils.dl_utils import set_multi_processing
from transformers import GenerationConfig
from xtuner.configs import cfgs_name_path
from xtuner.registry import BUILDER
from xtuner.tools.chat import TORCH_DTYPE_MAP
from xtuner.tools.utils import get_stop_criteria
from xtuner.utils import PROMPT_TEMPLATE
from projects.llava_sam2.datasets import video_lisa_collate_fn
def parse_args():
parser = argparse.ArgumentParser(description='RefCocoSeg')
parser.add_argument('config', help='config file name or path.')
parser.add_argument('--pth_model', default=None, help='pth model file')
parser.add_argument(
'--split',
default='val',
help='Specify a split')
parser.add_argument(
'--prompt-template',
choices=PROMPT_TEMPLATE.keys(),
default='internlm2_chat',
help='Specify a prompt template')
parser.add_argument(
'--stop-words', nargs='+', type=str, default=[], help='Stop words')
parser.add_argument(
'--torch-dtype',
default='fp16',
choices=TORCH_DTYPE_MAP.keys(),
help='Override the default `torch.dtype` and load the model under '
'a specific `dtype`.')
parser.add_argument(
'--bits',
type=int,
choices=[4, 8, None],
default=None,
help='LLM bits')
parser.add_argument(
'--bot-name', type=str, default='BOT', help='Name for Bot')
parser.add_argument(
'--offload-folder',
default=None,
help='The folder in which to offload the model weights (or where the '
'model weights are already offloaded).')
parser.add_argument(
'--max-new-tokens',
type=int,
default=100,
help='Maximum number of new tokens allowed in generated text')
parser.add_argument(
'--seed',
type=int,
default=0,
help='Random seed for reproducible text generation')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
args = parser.parse_args()
return args
def main():
args = parse_args()
if args.launcher != 'none':
set_multi_processing(distributed=True)
init_dist(args.launcher)
rank, world_size = get_dist_info()
torch.cuda.set_device(rank)
else:
rank = 0
world_size = 1
print(f'Rank: {rank} / World size: {world_size}')
if not osp.isfile(args.config):
try:
args.config = cfgs_name_path[args.config]
except KeyError:
raise FileNotFoundError(f'Cannot find {args.config}')
cfg = Config.fromfile(args.config)
model_name = cfg.model.type if isinstance(cfg.model.type, str) else cfg.model.type.__name__
assert model_name in ('VideoLLaVASAMModel',)
model = BUILDER.build(cfg.model)
if args.pth_model is not None:
state_dict_pth = args.pth_model
state_dict = torch.load(state_dict_pth, map_location='cpu')
model.load_state_dict(state_dict, strict=False)
model.to('cuda:0')
model.eval()
# define some pointers
tokenizer = model.tokenizer
# gen_configs
gen_config = GenerationConfig(
max_new_tokens=args.max_new_tokens,
do_sample=False,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
)
stop_words = args.stop_words
if args.prompt_template:
template = PROMPT_TEMPLATE[args.prompt_template]
stop_words += template.get('STOP_WORDS', [])
stop_criteria = get_stop_criteria(
tokenizer=tokenizer, stop_words=stop_words)
data_cfg = copy.deepcopy(cfg.video_revos_dataset)
data_cfg.update(expression_file=data_cfg.expression_file.replace('train', 'val'))
dataset = BUILDER.build(cfg.video_revos_dataset)
dataloader = DataLoader(
dataset,
batch_size=1,
num_workers=1,
shuffle=False,
collate_fn=video_lisa_collate_fn,
)
for data_item in dataloader:
data_item = model.data_preprocessor(data_item)
inputs, data_samples = data_item['data'], data_item['data_samples']
g_pixel_values = inputs.pop('g_pixel_values', None)
gt_masks = inputs.pop('masks', None)
output = model.mllm.generate(
pixel_values=inputs['pixel_values'],
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
visual_features=None,
generation_config=gen_config,
streamer=None,
bos_token_id=tokenizer.bos_token_id,
stopping_criteria=stop_criteria,
output_hidden_states=True,
return_dict_in_generate=True
)
print(1)
print(1)
if __name__ == '__main__':
main()