Update modeling_internvideo2.py
Browse files- modeling_internvideo2.py +4 -17
modeling_internvideo2.py
CHANGED
|
@@ -1056,11 +1056,6 @@ def pretrain_internvideo2_1b_patch14_224(config):
|
|
| 1056 |
clip_return_layer=config.vision_encoder.clip_return_layer,
|
| 1057 |
clip_student_return_interval=config.vision_encoder.clip_student_return_interval,
|
| 1058 |
)
|
| 1059 |
-
|
| 1060 |
-
# if config.vision_encoder.pretrained is not None:
|
| 1061 |
-
# state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu')
|
| 1062 |
-
# interpolate_pos_embed_internvideo2(state_dict, model, orig_t_size=8)
|
| 1063 |
-
# message = model.load_state_dict(state_dict, strict=False)
|
| 1064 |
|
| 1065 |
return model
|
| 1066 |
|
|
@@ -1071,8 +1066,10 @@ def pretrain_internvideo2_6b_patch14_224(config):
|
|
| 1071 |
embed_dim=3200, depth=48, num_heads=25, mlp_ratio=4,
|
| 1072 |
clip_embed_dim=config.vision_encoder.clip_embed_dim,
|
| 1073 |
attn_pool_num_heads=16, qkv_bias=False,
|
| 1074 |
-
drop_path_rate=0.3,
|
| 1075 |
-
init_values=0.00001,
|
|
|
|
|
|
|
| 1076 |
qk_normalization=True,
|
| 1077 |
use_flash_attn=config.vision_encoder.use_flash_attn,
|
| 1078 |
use_fused_rmsnorm=config.vision_encoder.use_fused_rmsnorm,
|
|
@@ -1091,12 +1088,6 @@ def pretrain_internvideo2_6b_patch14_224(config):
|
|
| 1091 |
clip_return_layer=config.vision_encoder.clip_return_layer,
|
| 1092 |
clip_student_return_interval=config.vision_encoder.clip_student_return_interval,
|
| 1093 |
)
|
| 1094 |
-
|
| 1095 |
-
# if config.vision_encoder.pretrained is not None:
|
| 1096 |
-
|
| 1097 |
-
# state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu')
|
| 1098 |
-
# interpolate_pos_embed_internvideo2(state_dict, model, orig_t_size=8)
|
| 1099 |
-
# msg = model.load_state_dict(state_dict, strict=False)
|
| 1100 |
|
| 1101 |
return model
|
| 1102 |
|
|
@@ -3155,7 +3146,6 @@ class InternVideo2_Stage2(
|
|
| 3155 |
|
| 3156 |
def __init__(self,
|
| 3157 |
config: InternVideo2_Stage2_Config,
|
| 3158 |
-
# tokenizer,
|
| 3159 |
is_pretrain: bool=True):
|
| 3160 |
|
| 3161 |
super(InternVideo2_Stage2, self).__init__(config)
|
|
@@ -3172,10 +3162,7 @@ class InternVideo2_Stage2(
|
|
| 3172 |
|
| 3173 |
# create modules.
|
| 3174 |
self.vision_encoder = self.build_vision_encoder()
|
| 3175 |
-
self.freeze_vision()
|
| 3176 |
-
|
| 3177 |
self.text_encoder = self.build_text_encoder()
|
| 3178 |
-
self.freeze_text()
|
| 3179 |
|
| 3180 |
self.vision_proj = nn.Linear(self.vision_width, self.embed_dim)
|
| 3181 |
self.text_proj = nn.Linear(self.text_width, self.embed_dim)
|
|
|
|
| 1056 |
clip_return_layer=config.vision_encoder.clip_return_layer,
|
| 1057 |
clip_student_return_interval=config.vision_encoder.clip_student_return_interval,
|
| 1058 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1059 |
|
| 1060 |
return model
|
| 1061 |
|
|
|
|
| 1066 |
embed_dim=3200, depth=48, num_heads=25, mlp_ratio=4,
|
| 1067 |
clip_embed_dim=config.vision_encoder.clip_embed_dim,
|
| 1068 |
attn_pool_num_heads=16, qkv_bias=False,
|
| 1069 |
+
# drop_path_rate=0.3,
|
| 1070 |
+
# init_values=0.00001,
|
| 1071 |
+
drop_path_rate=0,
|
| 1072 |
+
init_values=None,
|
| 1073 |
qk_normalization=True,
|
| 1074 |
use_flash_attn=config.vision_encoder.use_flash_attn,
|
| 1075 |
use_fused_rmsnorm=config.vision_encoder.use_fused_rmsnorm,
|
|
|
|
| 1088 |
clip_return_layer=config.vision_encoder.clip_return_layer,
|
| 1089 |
clip_student_return_interval=config.vision_encoder.clip_student_return_interval,
|
| 1090 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1091 |
|
| 1092 |
return model
|
| 1093 |
|
|
|
|
| 3146 |
|
| 3147 |
def __init__(self,
|
| 3148 |
config: InternVideo2_Stage2_Config,
|
|
|
|
| 3149 |
is_pretrain: bool=True):
|
| 3150 |
|
| 3151 |
super(InternVideo2_Stage2, self).__init__(config)
|
|
|
|
| 3162 |
|
| 3163 |
# create modules.
|
| 3164 |
self.vision_encoder = self.build_vision_encoder()
|
|
|
|
|
|
|
| 3165 |
self.text_encoder = self.build_text_encoder()
|
|
|
|
| 3166 |
|
| 3167 |
self.vision_proj = nn.Linear(self.vision_width, self.embed_dim)
|
| 3168 |
self.text_proj = nn.Linear(self.text_width, self.embed_dim)
|