from typing import List, Tuple def _get_transformers(): import transformers return transformers def resolve_torch_dtype(dtype_name): import torch if dtype_name == "auto": return "auto" if not hasattr(torch, dtype_name): raise ValueError(f"Unsupported torch dtype: {dtype_name}") return getattr(torch, dtype_name) def infer_model_backend(model_path, backend="auto", trust_remote_code=True): if backend != "auto": return backend transformers = _get_transformers() config = transformers.AutoConfig.from_pretrained( model_path, trust_remote_code=trust_remote_code ) architectures = [arch.lower() for arch in (getattr(config, "architectures", None) or [])] model_type = str(getattr(config, "model_type", "")).lower() arch_text = " ".join(architectures) if "qwen3vlmoe" in arch_text or ("qwen" in model_type and "moe" in arch_text): return "qwen3_vl_moe" if "qwen3vl" in arch_text or ("qwen" in model_type and "vl" in model_type): return "qwen3_vl" if "llava" in arch_text or "llava" in model_type: return "llava" if "deepseek" in arch_text or "deepseek" in model_type or "janus" in arch_text or "janus" in model_type: return "deepseek_vl" return "hf_vision2seq" def load_model_and_processor( model_path, backend="auto", torch_dtype="bfloat16", trust_remote_code=True, ): transformers = _get_transformers() backend = infer_model_backend( model_path=model_path, backend=backend, trust_remote_code=trust_remote_code, ) dtype = resolve_torch_dtype(torch_dtype) if backend == "qwen3_vl": model_cls = transformers.Qwen3VLForConditionalGeneration elif backend == "qwen3_vl_moe": model_cls = transformers.Qwen3VLMoeForConditionalGeneration elif backend == "llava": model_cls = getattr(transformers, "LlavaForConditionalGeneration", None) if model_cls is None: model_cls = transformers.AutoModelForVision2Seq elif backend == "deepseek_vl": # DeepSeek multimodal checkpoints often rely on trust_remote_code and may expose # custom causal-LM style classes instead of Vision2Seq classes. model_cls = transformers.AutoModelForCausalLM elif backend == "hf_vision2seq": model_cls = transformers.AutoModelForVision2Seq elif backend == "hf_causal_vlm": model_cls = transformers.AutoModelForCausalLM else: raise ValueError(f"Unsupported model backend: {backend}") model = model_cls.from_pretrained( model_path, torch_dtype=dtype, trust_remote_code=trust_remote_code, ) processor = transformers.AutoProcessor.from_pretrained( model_path, trust_remote_code=trust_remote_code, ) _configure_processor(processor) return backend, model, processor def _configure_processor(processor): tokenizer = getattr(processor, "tokenizer", None) if tokenizer is None: return if getattr(tokenizer, "padding_side", None) is not None: tokenizer.padding_side = "left" if getattr(tokenizer, "pad_token", None) is None and getattr(tokenizer, "eos_token", None) is not None: tokenizer.pad_token = tokenizer.eos_token def build_batch_tensors(processor, prompts: List[str], images, system_prompt=""): messages = [] for prompt in prompts: messages.append([ { "role": "system", "content": [ {"type": "text", "text": system_prompt}, ], }, { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt}, ], }, ]) rendered_prompts = [] if hasattr(processor, "apply_chat_template"): rendered_prompts = [ processor.apply_chat_template( message, tokenize=False, add_generation_prompt=True, ) for message in messages ] else: tokenizer = getattr(processor, "tokenizer", None) if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"): rendered_prompts = [ tokenizer.apply_chat_template( message, tokenize=False, add_generation_prompt=True, ) for message in messages ] else: rendered_prompts = prompts try: return processor( text=rendered_prompts, images=images, return_tensors="pt", padding=True, ) except TypeError: return processor( text=rendered_prompts, images=images, return_tensors="pt", ) def decode_generated_text(processor, output_ids, prompt_input_ids): tokenizer = getattr(processor, "tokenizer", processor) input_token_len = prompt_input_ids.shape[0] return tokenizer.batch_decode( output_ids[input_token_len:].unsqueeze(0), skip_special_tokens=True )[0]