roboalign_contextvla_oxe_sft_2epoch / modeling_contextvla.py
huiwon's picture
Upload folder using huggingface_hub
e6ab8f6 verified
import os
from torch import nn
import torch
from huggingface_hub import snapshot_download
from transformers.trainer_utils import load_sharded_checkpoint
from transformers import AutoConfig, AutoProcessor
from qwenvl.model.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
from qwenvl.model.contextvla import LayerWrapper
ACTION_START_TOKEN = "<|action_start|>"
ACTION_END_TOKEN = "<|action_end|>"
ACTION_PLACEHOLDER_TOKEN = "<|action_placeholder|>"
def add_action_to_processor(processor):
custom_tokens = [ACTION_START_TOKEN, ACTION_END_TOKEN, ACTION_PLACEHOLDER_TOKEN]
for i in range(2048):
custom_tokens.append(f"<|action_{i}|>")
num_added = processor.tokenizer.add_tokens(custom_tokens, special_tokens=True)
print(f"Added {num_added} custom tokens")
return processor
class ContextVLA_Qwen3VL(Qwen3VLForConditionalGeneration):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
base_config = AutoConfig.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
model = Qwen3VLForConditionalGeneration._from_config(base_config, **kwargs)
for layer_idx in range(len(model.model.language_model.layers)):
model.model.language_model.layers[layer_idx] = LayerWrapper(
model.model.language_model.layers[layer_idx],
layer_idx=layer_idx,
internal_projection=4,
img_pattern=[151652],
motion_token=1
)
processor = AutoProcessor.from_pretrained(
"Qwen/Qwen3-VL-8B-Instruct",
)
processor = add_action_to_processor(processor)
model.resize_token_embeddings(len(processor.tokenizer))
if os.path.isdir(pretrained_model_name_or_path):
local_dir = pretrained_model_name_or_path
else:
local_dir = snapshot_download(pretrained_model_name_or_path)
load_sharded_checkpoint(model, local_dir)
print(f"[ContextVLA] weights loaded from {local_dir}")
return model