import os from torch import nn import torch from huggingface_hub import snapshot_download from transformers.trainer_utils import load_sharded_checkpoint from transformers import AutoConfig, AutoProcessor from qwenvl.model.modeling_qwen3_vl import Qwen3VLForConditionalGeneration from qwenvl.model.contextvla import LayerWrapper ACTION_START_TOKEN = "<|action_start|>" ACTION_END_TOKEN = "<|action_end|>" ACTION_PLACEHOLDER_TOKEN = "<|action_placeholder|>" def add_action_to_processor(processor): custom_tokens = [ACTION_START_TOKEN, ACTION_END_TOKEN, ACTION_PLACEHOLDER_TOKEN] for i in range(2048): custom_tokens.append(f"<|action_{i}|>") num_added = processor.tokenizer.add_tokens(custom_tokens, special_tokens=True) print(f"Added {num_added} custom tokens") return processor class ContextVLA_Qwen3VL(Qwen3VLForConditionalGeneration): @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): base_config = AutoConfig.from_pretrained("Qwen/Qwen3-VL-8B-Instruct") model = Qwen3VLForConditionalGeneration._from_config(base_config, **kwargs) for layer_idx in range(len(model.model.language_model.layers)): model.model.language_model.layers[layer_idx] = LayerWrapper( model.model.language_model.layers[layer_idx], layer_idx=layer_idx, internal_projection=4, img_pattern=[151652], motion_token=1 ) processor = AutoProcessor.from_pretrained( "Qwen/Qwen3-VL-8B-Instruct", ) processor = add_action_to_processor(processor) model.resize_token_embeddings(len(processor.tokenizer)) if os.path.isdir(pretrained_model_name_or_path): local_dir = pretrained_model_name_or_path else: local_dir = snapshot_download(pretrained_model_name_or_path) load_sharded_checkpoint(model, local_dir) print(f"[ContextVLA] weights loaded from {local_dir}") return model