File size: 2,063 Bytes
e6ab8f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os

from torch import nn
import torch

from huggingface_hub import snapshot_download
from transformers.trainer_utils import load_sharded_checkpoint
from transformers import AutoConfig, AutoProcessor

from qwenvl.model.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
from qwenvl.model.contextvla import LayerWrapper


ACTION_START_TOKEN = "<|action_start|>"
ACTION_END_TOKEN = "<|action_end|>"
ACTION_PLACEHOLDER_TOKEN = "<|action_placeholder|>"


def add_action_to_processor(processor):
    custom_tokens = [ACTION_START_TOKEN, ACTION_END_TOKEN, ACTION_PLACEHOLDER_TOKEN]
    for i in range(2048):
        custom_tokens.append(f"<|action_{i}|>")

    num_added = processor.tokenizer.add_tokens(custom_tokens, special_tokens=True)
    print(f"Added {num_added} custom tokens")

    return processor


class ContextVLA_Qwen3VL(Qwen3VLForConditionalGeneration):
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        base_config = AutoConfig.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
        model = Qwen3VLForConditionalGeneration._from_config(base_config, **kwargs)
        for layer_idx in range(len(model.model.language_model.layers)):
            model.model.language_model.layers[layer_idx] = LayerWrapper(
                model.model.language_model.layers[layer_idx],
                layer_idx=layer_idx,
                internal_projection=4,
                img_pattern=[151652],
                motion_token=1
            )
        
        processor = AutoProcessor.from_pretrained(
            "Qwen/Qwen3-VL-8B-Instruct",
        )
        processor = add_action_to_processor(processor)
        model.resize_token_embeddings(len(processor.tokenizer))

        if os.path.isdir(pretrained_model_name_or_path):
            local_dir = pretrained_model_name_or_path
        else:
            local_dir = snapshot_download(pretrained_model_name_or_path)

        load_sharded_checkpoint(model, local_dir)
        print(f"[ContextVLA] weights loaded from {local_dir}")
        
        return model