|
|
from tasks.train.instruction_data import * |
|
|
|
|
|
|
|
|
|
|
|
train_corpus = "videochat2_instruction_full" |
|
|
|
|
|
train_file = "${available_corpus[${train_corpus}]}" |
|
|
test_file = dict() |
|
|
test_types = [] |
|
|
num_workers = 8 |
|
|
save_steps=10000 |
|
|
ckpt_steps=1000 |
|
|
stop_key = None |
|
|
deepspeed=False |
|
|
|
|
|
num_frames = 16 |
|
|
num_frames_test = 1 |
|
|
batch_size = 1 |
|
|
gradient_accumulation_steps=16 |
|
|
max_txt_l = 512 |
|
|
max_train_steps=None |
|
|
pre_text = False |
|
|
gradient_checkpointing=False |
|
|
inputs = dict( |
|
|
image_res=336, |
|
|
video_input=dict( |
|
|
num_frames="${num_frames}", |
|
|
sample_type="rand", |
|
|
num_frames_test="${num_frames_test}", |
|
|
sample_type_test="middle", |
|
|
random_aug=False, |
|
|
), |
|
|
max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"), |
|
|
batch_size=dict(image="${batch_size}", video="${batch_size}"), |
|
|
batch_size_test=dict(image="${batch_size}", video="${batch_size}"), |
|
|
) |
|
|
|
|
|
|
|
|
model = dict( |
|
|
repo_id="llava-hf/llava-v1.6-vicuna-7b-hf", |
|
|
pretrained_path=None, |
|
|
load_from_origin=False, |
|
|
origin_vision="", |
|
|
origin_llm="", |
|
|
vision_encoder=dict( |
|
|
name="vit_l14", |
|
|
), |
|
|
torch_dtype='bfloat16', |
|
|
freeze_projector=False, |
|
|
freeze_lm=True, |
|
|
freeze_vision_tower=True, |
|
|
lora_target_modules=["q_proj", "v_proj"], |
|
|
use_lora=True, |
|
|
lora_r=128, |
|
|
lora_alpha=32, |
|
|
lora_dropout=0.05, |
|
|
num_frames="${num_frames}", |
|
|
pooling_method='avg', |
|
|
use_pooling=True, |
|
|
frame_shape=(24,24), |
|
|
pooling_shape=(16,8,8), |
|
|
) |
|
|
preprocess = dict( |
|
|
system="", |
|
|
mm_alone=True, |
|
|
random_shuffle=True, |
|
|
add_second_msg=True, |
|
|
roles=['USER:', 'ASSISTANT:'], |
|
|
end_signal=(' ', '</s>'), |
|
|
begin_signal='', |
|
|
dataset_image_placeholder='<Image></Image>', |
|
|
dataset_video_placeholder='<Video></Video>', |
|
|
image_token_index=32000, |
|
|
max_txt_l = "${max_txt_l}", |
|
|
ignore_index=-100, |
|
|
center_pad=False, |
|
|
longest_edge=762, |
|
|
shortest_edge=336, |
|
|
clip_transform=False, |
|
|
num_frames="${num_frames}", |
|
|
) |
|
|
|
|
|
|
|
|
optimizer = dict( |
|
|
opt="adamW", |
|
|
lr=2e-5, |
|
|
opt_betas=[0.9, 0.999], |
|
|
weight_decay=0.02, |
|
|
max_grad_norm=-1, |
|
|
|
|
|
different_lr=dict(enable=False, module_names=[], lr=1e-3), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
scheduler = dict( |
|
|
is_videochat2_custom=False, |
|
|
sched="cosine", |
|
|
epochs=2, |
|
|
warmup_ratio=0.2, |
|
|
min_lr_multi=0.25) |
|
|
|
|
|
evaluate = False |
|
|
deep_fusion = False |
|
|
evaluation = dict( |
|
|
eval_frame_ensemble="concat", |
|
|
eval_x_only=False, |
|
|
k_test=128, |
|
|
eval_offload=True, |
|
|
) |
|
|
|
|
|
fp16 = True |
|
|
gradient_checkpointing = True |
|
|
|
|
|
|
|
|
wandb = dict( |
|
|
enable=False, |
|
|
entity="user", |
|
|
project="videochat2", |
|
|
) |
|
|
dist_url = "env://" |
|
|
device = "cuda" |
|
|
mode = "it" |
|
|
|
|
|
|
|
|
output_dir = None |
|
|
resume = False |
|
|
debug = False |
|
|
log_freq = 5 |
|
|
metric_window_size=10 |
|
|
seed = 42 |
|
|
report_to='tensorboard' |
|
|
save_latest = True |
|
|
auto_resume = True |
|
|
pretrained_path = "" |
|
|
|