| from typing import List, Tuple |
|
|
|
|
| def _get_transformers(): |
| import transformers |
| return transformers |
|
|
|
|
| def resolve_torch_dtype(dtype_name): |
| import torch |
|
|
| if dtype_name == "auto": |
| return "auto" |
| if not hasattr(torch, dtype_name): |
| raise ValueError(f"Unsupported torch dtype: {dtype_name}") |
| return getattr(torch, dtype_name) |
|
|
|
|
| def infer_model_backend(model_path, backend="auto", trust_remote_code=True): |
| if backend != "auto": |
| return backend |
|
|
| transformers = _get_transformers() |
| config = transformers.AutoConfig.from_pretrained( |
| model_path, |
| trust_remote_code=trust_remote_code |
| ) |
| architectures = [arch.lower() for arch in (getattr(config, "architectures", None) or [])] |
| model_type = str(getattr(config, "model_type", "")).lower() |
| arch_text = " ".join(architectures) |
|
|
| if "qwen3vlmoe" in arch_text or ("qwen" in model_type and "moe" in arch_text): |
| return "qwen3_vl_moe" |
| if "qwen3vl" in arch_text or ("qwen" in model_type and "vl" in model_type): |
| return "qwen3_vl" |
| if "llava" in arch_text or "llava" in model_type: |
| return "llava" |
| if "deepseek" in arch_text or "deepseek" in model_type or "janus" in arch_text or "janus" in model_type: |
| return "deepseek_vl" |
| return "hf_vision2seq" |
|
|
|
|
| def load_model_and_processor( |
| model_path, |
| backend="auto", |
| torch_dtype="bfloat16", |
| trust_remote_code=True, |
| ): |
| transformers = _get_transformers() |
| backend = infer_model_backend( |
| model_path=model_path, |
| backend=backend, |
| trust_remote_code=trust_remote_code, |
| ) |
| dtype = resolve_torch_dtype(torch_dtype) |
|
|
| if backend == "qwen3_vl": |
| model_cls = transformers.Qwen3VLForConditionalGeneration |
| elif backend == "qwen3_vl_moe": |
| model_cls = transformers.Qwen3VLMoeForConditionalGeneration |
| elif backend == "llava": |
| model_cls = getattr(transformers, "LlavaForConditionalGeneration", None) |
| if model_cls is None: |
| model_cls = transformers.AutoModelForVision2Seq |
| elif backend == "deepseek_vl": |
| |
| |
| model_cls = transformers.AutoModelForCausalLM |
| elif backend == "hf_vision2seq": |
| model_cls = transformers.AutoModelForVision2Seq |
| elif backend == "hf_causal_vlm": |
| model_cls = transformers.AutoModelForCausalLM |
| else: |
| raise ValueError(f"Unsupported model backend: {backend}") |
|
|
| model = model_cls.from_pretrained( |
| model_path, |
| torch_dtype=dtype, |
| trust_remote_code=trust_remote_code, |
| ) |
| processor = transformers.AutoProcessor.from_pretrained( |
| model_path, |
| trust_remote_code=trust_remote_code, |
| ) |
| _configure_processor(processor) |
| return backend, model, processor |
|
|
|
|
| def _configure_processor(processor): |
| tokenizer = getattr(processor, "tokenizer", None) |
| if tokenizer is None: |
| return |
| if getattr(tokenizer, "padding_side", None) is not None: |
| tokenizer.padding_side = "left" |
| if getattr(tokenizer, "pad_token", None) is None and getattr(tokenizer, "eos_token", None) is not None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
| def build_batch_tensors(processor, prompts: List[str], images, system_prompt=""): |
| messages = [] |
| for prompt in prompts: |
| messages.append([ |
| { |
| "role": "system", |
| "content": [ |
| {"type": "text", "text": system_prompt}, |
| ], |
| }, |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image"}, |
| {"type": "text", "text": prompt}, |
| ], |
| }, |
| ]) |
|
|
| rendered_prompts = [] |
| if hasattr(processor, "apply_chat_template"): |
| rendered_prompts = [ |
| processor.apply_chat_template( |
| message, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| for message in messages |
| ] |
| else: |
| tokenizer = getattr(processor, "tokenizer", None) |
| if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"): |
| rendered_prompts = [ |
| tokenizer.apply_chat_template( |
| message, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| for message in messages |
| ] |
| else: |
| rendered_prompts = prompts |
|
|
| try: |
| return processor( |
| text=rendered_prompts, |
| images=images, |
| return_tensors="pt", |
| padding=True, |
| ) |
| except TypeError: |
| return processor( |
| text=rendered_prompts, |
| images=images, |
| return_tensors="pt", |
| ) |
|
|
|
|
| def decode_generated_text(processor, output_ids, prompt_input_ids): |
| tokenizer = getattr(processor, "tokenizer", processor) |
| input_token_len = prompt_input_ids.shape[0] |
| return tokenizer.batch_decode( |
| output_ids[input_token_len:].unsqueeze(0), |
| skip_special_tokens=True |
| )[0] |
|
|