AutoLLMAnnotation / tools /vlm_backend.py
ayh015's picture
Update modifed code
73df34b
from typing import List, Tuple
def _get_transformers():
import transformers
return transformers
def resolve_torch_dtype(dtype_name):
import torch
if dtype_name == "auto":
return "auto"
if not hasattr(torch, dtype_name):
raise ValueError(f"Unsupported torch dtype: {dtype_name}")
return getattr(torch, dtype_name)
def infer_model_backend(model_path, backend="auto", trust_remote_code=True):
if backend != "auto":
return backend
transformers = _get_transformers()
config = transformers.AutoConfig.from_pretrained(
model_path,
trust_remote_code=trust_remote_code
)
architectures = [arch.lower() for arch in (getattr(config, "architectures", None) or [])]
model_type = str(getattr(config, "model_type", "")).lower()
arch_text = " ".join(architectures)
if "qwen3vlmoe" in arch_text or ("qwen" in model_type and "moe" in arch_text):
return "qwen3_vl_moe"
if "qwen3vl" in arch_text or ("qwen" in model_type and "vl" in model_type):
return "qwen3_vl"
if "llava" in arch_text or "llava" in model_type:
return "llava"
if "deepseek" in arch_text or "deepseek" in model_type or "janus" in arch_text or "janus" in model_type:
return "deepseek_vl"
return "hf_vision2seq"
def load_model_and_processor(
model_path,
backend="auto",
torch_dtype="bfloat16",
trust_remote_code=True,
):
transformers = _get_transformers()
backend = infer_model_backend(
model_path=model_path,
backend=backend,
trust_remote_code=trust_remote_code,
)
dtype = resolve_torch_dtype(torch_dtype)
if backend == "qwen3_vl":
model_cls = transformers.Qwen3VLForConditionalGeneration
elif backend == "qwen3_vl_moe":
model_cls = transformers.Qwen3VLMoeForConditionalGeneration
elif backend == "llava":
model_cls = getattr(transformers, "LlavaForConditionalGeneration", None)
if model_cls is None:
model_cls = transformers.AutoModelForVision2Seq
elif backend == "deepseek_vl":
# DeepSeek multimodal checkpoints often rely on trust_remote_code and may expose
# custom causal-LM style classes instead of Vision2Seq classes.
model_cls = transformers.AutoModelForCausalLM
elif backend == "hf_vision2seq":
model_cls = transformers.AutoModelForVision2Seq
elif backend == "hf_causal_vlm":
model_cls = transformers.AutoModelForCausalLM
else:
raise ValueError(f"Unsupported model backend: {backend}")
model = model_cls.from_pretrained(
model_path,
torch_dtype=dtype,
trust_remote_code=trust_remote_code,
)
processor = transformers.AutoProcessor.from_pretrained(
model_path,
trust_remote_code=trust_remote_code,
)
_configure_processor(processor)
return backend, model, processor
def _configure_processor(processor):
tokenizer = getattr(processor, "tokenizer", None)
if tokenizer is None:
return
if getattr(tokenizer, "padding_side", None) is not None:
tokenizer.padding_side = "left"
if getattr(tokenizer, "pad_token", None) is None and getattr(tokenizer, "eos_token", None) is not None:
tokenizer.pad_token = tokenizer.eos_token
def build_batch_tensors(processor, prompts: List[str], images, system_prompt=""):
messages = []
for prompt in prompts:
messages.append([
{
"role": "system",
"content": [
{"type": "text", "text": system_prompt},
],
},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt},
],
},
])
rendered_prompts = []
if hasattr(processor, "apply_chat_template"):
rendered_prompts = [
processor.apply_chat_template(
message,
tokenize=False,
add_generation_prompt=True,
)
for message in messages
]
else:
tokenizer = getattr(processor, "tokenizer", None)
if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"):
rendered_prompts = [
tokenizer.apply_chat_template(
message,
tokenize=False,
add_generation_prompt=True,
)
for message in messages
]
else:
rendered_prompts = prompts
try:
return processor(
text=rendered_prompts,
images=images,
return_tensors="pt",
padding=True,
)
except TypeError:
return processor(
text=rendered_prompts,
images=images,
return_tensors="pt",
)
def decode_generated_text(processor, output_ids, prompt_input_ids):
tokenizer = getattr(processor, "tokenizer", processor)
input_token_len = prompt_input_ids.shape[0]
return tokenizer.batch_decode(
output_ids[input_token_len:].unsqueeze(0),
skip_special_tokens=True
)[0]