File size: 2,928 Bytes
a47e733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac63a19
a47e733
 
 
 
 
 
 
 
 
 
 
 
ac63a19
a47e733
ac63a19
a47e733
 
 
 
 
 
c4c6335
a47e733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a95e79a
a47e733
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from email.policy import default
import os

import sys
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(BASE_DIR)

import cv2  # type: ignore

import argparse
import json
import os
from typing import Any, Dict, List

# 数据集结构
file_arch = """
./REFAVS/data
    - /media
    - /gt_mask
    - /metadata.csv
    - /audio_embed
    - /image_embed
"""
# print(f">>> File arch: {file_arch}")

parser = argparse.ArgumentParser(
    description=(
        "SimToken"
    )
)



parser.add_argument("--vision_pretrained",type=str,default='path/to/segment_anything/sam_vit_h_4b8939.pth')
parser.add_argument("--vision_tower",type=str,default='openai/clip-vit-large-patch14')
parser.add_argument("--mllm",type=str,default='Chat-UniVi/Chat-UniVi-7B-v1.5')

parser.add_argument("--conv_template",type=int,default=1)
parser.add_argument("--ct_weight",type=float,default=0.1)
parser.add_argument("--input_type",type=str,default='refer')
parser.add_argument("--compress",action='store_false',default=True)
parser.add_argument("--start",type=int,default=0)


parser.add_argument("--name",type=str,default='testrun')
# path to ref-avs dataset
parser.add_argument("--data_dir",type=str,default='data',help=f"The data paranet dir. File arch should be: {file_arch}")
# path to pretrained checkpoints
parser.add_argument("--saved_model",type=str,default='trained_simtoken.pth', help="the pretrained simtoken pth")


parser.add_argument("--log_root",type=str,default='log', help="where to save log during training")
parser.add_argument("--checkpoint_root",type=str,default='checkpoints', help="where to save trained checkpoints during training")

parser.add_argument("--visualization_root",type=str,default='visualization', help="where to save visualization result during test")
parser.add_argument("--eval_splits",type=str,default='test_s,test_u,test_n', help="comma-separated eval splits for load_model.py: test_s,test_u,test_n")




# parser.add_argument("--show_params", action='store_true', help=f"Show params names with Requires_grad==True.")

# learning rate
parser.add_argument("--lr", type=float, default=5e-5, help='lr to fine tuning adapters.')
# epochs
parser.add_argument("--epochs", type=int, default=10, help='epochs to fine tuning adapters.')
parser.add_argument("--batch_size", type=int, default=8)


parser.add_argument("--gpu_id", type=str, default="0", help="The GPU device to run generation on.")

parser.add_argument("--run", type=str, default='train', help="train, test")

parser.add_argument("--frame_n", type=int, default=10, help="Frame num of each video. Fixed to 10.")
parser.add_argument("--text_max_len", type=int, default=25, help="Maximum textual reference length.")



args, _ = parser.parse_known_args()

# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
# print(f'>>> Sys: set "CUDA_VISIBLE_DEVICES" - GPU: {args.gpu_id}')