| | import os |
| |
|
| | from torch import nn |
| | import torch |
| |
|
| | from huggingface_hub import snapshot_download |
| | from transformers.trainer_utils import load_sharded_checkpoint |
| | from transformers import AutoConfig, AutoProcessor |
| |
|
| | from qwenvl.model.modeling_qwen3_vl import Qwen3VLForConditionalGeneration |
| | from qwenvl.model.contextvla import LayerWrapper |
| |
|
| |
|
| | ACTION_START_TOKEN = "<|action_start|>" |
| | ACTION_END_TOKEN = "<|action_end|>" |
| | ACTION_PLACEHOLDER_TOKEN = "<|action_placeholder|>" |
| |
|
| |
|
| | def add_action_to_processor(processor): |
| | custom_tokens = [ACTION_START_TOKEN, ACTION_END_TOKEN, ACTION_PLACEHOLDER_TOKEN] |
| | for i in range(2048): |
| | custom_tokens.append(f"<|action_{i}|>") |
| |
|
| | num_added = processor.tokenizer.add_tokens(custom_tokens, special_tokens=True) |
| | print(f"Added {num_added} custom tokens") |
| |
|
| | return processor |
| |
|
| |
|
| | class ContextVLA_Qwen3VL(Qwen3VLForConditionalGeneration): |
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| | base_config = AutoConfig.from_pretrained("Qwen/Qwen3-VL-8B-Instruct") |
| | model = Qwen3VLForConditionalGeneration._from_config(base_config, **kwargs) |
| | for layer_idx in range(len(model.model.language_model.layers)): |
| | model.model.language_model.layers[layer_idx] = LayerWrapper( |
| | model.model.language_model.layers[layer_idx], |
| | layer_idx=layer_idx, |
| | internal_projection=4, |
| | img_pattern=[151652], |
| | motion_token=1 |
| | ) |
| | |
| | processor = AutoProcessor.from_pretrained( |
| | "Qwen/Qwen3-VL-8B-Instruct", |
| | ) |
| | processor = add_action_to_processor(processor) |
| | model.resize_token_embeddings(len(processor.tokenizer)) |
| |
|
| | if os.path.isdir(pretrained_model_name_or_path): |
| | local_dir = pretrained_model_name_or_path |
| | else: |
| | local_dir = snapshot_download(pretrained_model_name_or_path) |
| |
|
| | load_sharded_checkpoint(model, local_dir) |
| | print(f"[ContextVLA] weights loaded from {local_dir}") |
| | |
| | return model |