File size: 5,249 Bytes
73df34b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | from typing import List, Tuple
def _get_transformers():
import transformers
return transformers
def resolve_torch_dtype(dtype_name):
import torch
if dtype_name == "auto":
return "auto"
if not hasattr(torch, dtype_name):
raise ValueError(f"Unsupported torch dtype: {dtype_name}")
return getattr(torch, dtype_name)
def infer_model_backend(model_path, backend="auto", trust_remote_code=True):
if backend != "auto":
return backend
transformers = _get_transformers()
config = transformers.AutoConfig.from_pretrained(
model_path,
trust_remote_code=trust_remote_code
)
architectures = [arch.lower() for arch in (getattr(config, "architectures", None) or [])]
model_type = str(getattr(config, "model_type", "")).lower()
arch_text = " ".join(architectures)
if "qwen3vlmoe" in arch_text or ("qwen" in model_type and "moe" in arch_text):
return "qwen3_vl_moe"
if "qwen3vl" in arch_text or ("qwen" in model_type and "vl" in model_type):
return "qwen3_vl"
if "llava" in arch_text or "llava" in model_type:
return "llava"
if "deepseek" in arch_text or "deepseek" in model_type or "janus" in arch_text or "janus" in model_type:
return "deepseek_vl"
return "hf_vision2seq"
def load_model_and_processor(
model_path,
backend="auto",
torch_dtype="bfloat16",
trust_remote_code=True,
):
transformers = _get_transformers()
backend = infer_model_backend(
model_path=model_path,
backend=backend,
trust_remote_code=trust_remote_code,
)
dtype = resolve_torch_dtype(torch_dtype)
if backend == "qwen3_vl":
model_cls = transformers.Qwen3VLForConditionalGeneration
elif backend == "qwen3_vl_moe":
model_cls = transformers.Qwen3VLMoeForConditionalGeneration
elif backend == "llava":
model_cls = getattr(transformers, "LlavaForConditionalGeneration", None)
if model_cls is None:
model_cls = transformers.AutoModelForVision2Seq
elif backend == "deepseek_vl":
# DeepSeek multimodal checkpoints often rely on trust_remote_code and may expose
# custom causal-LM style classes instead of Vision2Seq classes.
model_cls = transformers.AutoModelForCausalLM
elif backend == "hf_vision2seq":
model_cls = transformers.AutoModelForVision2Seq
elif backend == "hf_causal_vlm":
model_cls = transformers.AutoModelForCausalLM
else:
raise ValueError(f"Unsupported model backend: {backend}")
model = model_cls.from_pretrained(
model_path,
torch_dtype=dtype,
trust_remote_code=trust_remote_code,
)
processor = transformers.AutoProcessor.from_pretrained(
model_path,
trust_remote_code=trust_remote_code,
)
_configure_processor(processor)
return backend, model, processor
def _configure_processor(processor):
tokenizer = getattr(processor, "tokenizer", None)
if tokenizer is None:
return
if getattr(tokenizer, "padding_side", None) is not None:
tokenizer.padding_side = "left"
if getattr(tokenizer, "pad_token", None) is None and getattr(tokenizer, "eos_token", None) is not None:
tokenizer.pad_token = tokenizer.eos_token
def build_batch_tensors(processor, prompts: List[str], images, system_prompt=""):
messages = []
for prompt in prompts:
messages.append([
{
"role": "system",
"content": [
{"type": "text", "text": system_prompt},
],
},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt},
],
},
])
rendered_prompts = []
if hasattr(processor, "apply_chat_template"):
rendered_prompts = [
processor.apply_chat_template(
message,
tokenize=False,
add_generation_prompt=True,
)
for message in messages
]
else:
tokenizer = getattr(processor, "tokenizer", None)
if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"):
rendered_prompts = [
tokenizer.apply_chat_template(
message,
tokenize=False,
add_generation_prompt=True,
)
for message in messages
]
else:
rendered_prompts = prompts
try:
return processor(
text=rendered_prompts,
images=images,
return_tensors="pt",
padding=True,
)
except TypeError:
return processor(
text=rendered_prompts,
images=images,
return_tensors="pt",
)
def decode_generated_text(processor, output_ids, prompt_input_ids):
tokenizer = getattr(processor, "tokenizer", processor)
input_token_len = prompt_input_ids.shape[0]
return tokenizer.batch_decode(
output_ids[input_token_len:].unsqueeze(0),
skip_special_tokens=True
)[0]
|