huiwon commited on
Commit
c6c73d4
·
verified ·
1 Parent(s): 8ee9a0f

Add modeling_contextvla.py

Browse files
Files changed (1) hide show
  1. modeling_contextvla.py +55 -0
modeling_contextvla.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from huggingface_hub import snapshot_download
4
+ from transformers.trainer_utils import load_sharded_checkpoint
5
+ from transformers import AutoConfig, AutoProcessor
6
+
7
+ from .modeling_qwen3_vl import Qwen3VLForConditionalGeneration
8
+ from .contextvla import LayerWrapper
9
+
10
+
11
+ ACTION_START_TOKEN = "<|action_start|>"
12
+ ACTION_END_TOKEN = "<|action_end|>"
13
+ ACTION_PLACEHOLDER_TOKEN = "<|action_placeholder|>"
14
+
15
+
16
+ def add_action_to_processor(processor):
17
+ custom_tokens = [ACTION_START_TOKEN, ACTION_END_TOKEN, ACTION_PLACEHOLDER_TOKEN]
18
+ for i in range(2048):
19
+ custom_tokens.append(f"<|action_{i}|>")
20
+
21
+ num_added = processor.tokenizer.add_tokens(custom_tokens, special_tokens=True)
22
+ print(f"Added {num_added} custom tokens")
23
+
24
+ return processor
25
+
26
+
27
+ class ContextVLA_Qwen3VL(Qwen3VLForConditionalGeneration):
28
+ @classmethod
29
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
30
+ base_config = AutoConfig.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
31
+ model = Qwen3VLForConditionalGeneration._from_config(base_config, **kwargs)
32
+ for layer_idx in range(len(model.model.language_model.layers)):
33
+ model.model.language_model.layers[layer_idx] = LayerWrapper(
34
+ model.model.language_model.layers[layer_idx],
35
+ layer_idx=layer_idx,
36
+ internal_projection=4,
37
+ img_pattern=[151652],
38
+ motion_token=1
39
+ )
40
+
41
+ processor = AutoProcessor.from_pretrained(
42
+ "Qwen/Qwen3-VL-8B-Instruct",
43
+ )
44
+ processor = add_action_to_processor(processor)
45
+ model.resize_token_embeddings(len(processor.tokenizer))
46
+
47
+ if os.path.isdir(pretrained_model_name_or_path):
48
+ local_dir = pretrained_model_name_or_path
49
+ else:
50
+ local_dir = snapshot_download(pretrained_model_name_or_path)
51
+
52
+ load_sharded_checkpoint(model, local_dir)
53
+ print(f"[ContextVLA] weights loaded from {local_dir}")
54
+
55
+ return model