huiwon commited on
Commit
d4c783c
·
verified ·
1 Parent(s): ff268db

Upload modeling_contextvla.py

Browse files
Files changed (1) hide show
  1. modeling_contextvla.py +64 -0
modeling_contextvla.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # contextvla_model.py
2
+ import os
3
+
4
+ from huggingface_hub import snapshot_download
5
+
6
+ from transformers.modeling_utils import load_sharded_checkpoint
7
+ from transformers import AutoConfig
8
+
9
+ from src.models.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
10
+
11
+ from torch import nn
12
+ import torch
13
+ from src.models.layer_wrapper import LayerWrapper
14
+
15
+
16
+ class IndexContext:
17
+ batch_indices: int
18
+ gather_indices: int
19
+
20
+
21
+ class ContextVLA_Qwen2_5_VL(Qwen2_5_VLForConditionalGeneration):
22
+ @classmethod
23
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
24
+
25
+ base_config = AutoConfig.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
26
+ model = Qwen2_5_VLForConditionalGeneration._from_config(base_config, **kwargs)
27
+
28
+ index_context = IndexContext()
29
+ for layer_idx in range(len(model.model.layers)):
30
+ model.model.layers[layer_idx] = LayerWrapper(
31
+ model.model.layers[layer_idx],
32
+ layer_idx=layer_idx,
33
+ internal_projection=2,
34
+ num_frames=8,
35
+ num_views=3,
36
+ index_context=index_context,
37
+ img_pattern=[151652],
38
+ motion_token=1,
39
+ )
40
+
41
+ # expand vocab
42
+ old_weight = model.model.embed_tokens.weight.data
43
+ new_embedding = nn.Embedding(153713, old_weight.shape[1])
44
+ with torch.no_grad():
45
+ new_embedding.weight[:151664].copy_(old_weight[:151664])
46
+ model.model.embed_tokens = new_embedding
47
+
48
+ old_head = model.lm_head
49
+ new_head = nn.Linear(old_head.weight.data.shape[1], 153713, bias=False)
50
+ with torch.no_grad():
51
+ new_head.weight[:151664].copy_(old_head.weight[:151664])
52
+ model.lm_head = new_head
53
+ model.vocab_size = model.config.vocab_size = 153713
54
+
55
+ if os.path.isdir(pretrained_model_name_or_path):
56
+ local_dir = pretrained_model_name_or_path
57
+ else:
58
+ local_dir = snapshot_download(pretrained_model_name_or_path)
59
+
60
+ load_sharded_checkpoint(model, local_dir)
61
+ print(f"[ContextVLA] weights loaded from {local_dir}")
62
+
63
+ return model
64
+