PengLiu
push inference code
56ef371
from transformers import AutoTokenizer
import torch
from vlm_fo1.model import *
from safetensors.torch import load_file
import os
def load_pretrained_model(model_path, load_8bit=False, load_4bit=False, device="cuda"):
"""
Loads a pretrained model along with its vision towers (and associated image processors).
This function supports loading in 8bit/4bit precision and explicit device placement.
Args:
model_path (str): Path to the pretrained model directory.
load_8bit (bool): Whether to load the model in 8bit mode.
load_4bit (bool): Whether to load the model in 4bit mode.
device (str): Device to load model onto, e.g., "cuda" or "cpu".
Returns:
tuple: (tokenizer, model, image_processor)
"""
kwargs = {"device_map": device}
# Set model loading parameters for quantization or floating point
if load_8bit:
kwargs['load_in_8bit'] = True
elif load_4bit:
kwargs['load_in_4bit'] = True
else:
kwargs['torch_dtype'] = torch.bfloat16
# print(model_path)
# Only proceed for vlm-fo1 models
if 'vlm-fo1' in model_path.lower():
# Load tokenizer (slow tokenizer enforced)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
# If this is the Qwen2.5-VL variant, load with additional kwargs
if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower():
model, loading_info = OmChatQwen25VLForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
output_loading_info=True,
attn_implementation="flash_attention_2",
**kwargs,
cache_dir='./resources',
)
# print(f'OmChatQwen25VLForCausalLM loading_info: {loading_info}')
# (For other variants of vlm-fo1, model loading detail may need additional condition.)
if 'vlm-fo1' in model_path.lower():
# --- Vision Tower Loading ---
# Load the main vision tower weights from model_path if it is not yet loaded
primary_vision_tower = model.get_vision_tower()
if primary_vision_tower and not primary_vision_tower.is_loaded:
primary_vision_tower.load_model(model_path=model_path, is_train=False)
primary_vision_tower.to(device=device, dtype=torch.bfloat16) # Move to correct device/dtype
# Grab primary image processor from vision tower, if present
if primary_vision_tower:
primary_image_processor = primary_vision_tower.image_processor
# --- Auxiliary Vision Tower Handling (Qwen2.5-VL case only) ---
if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower():
try:
aux_image_size = model.config.aux_image_size
except Exception:
# If aux_image_size is missing from config fallback to 768
aux_image_size = 768
aux_image_aspect_ratio = model.config.aux_image_aspect_ratio
aux_vision_tower = model.get_vision_tower_aux()
# Only load if not already loaded
if aux_vision_tower and not aux_vision_tower.is_loaded:
aux_vision_tower.load_model(image_size=aux_image_size, is_train=False, aspect_ratio=aux_image_aspect_ratio)
aux_vision_tower.to(device=device, dtype=torch.bfloat16)
# Get auxiliary image processor if there is an aux vision tower
if aux_vision_tower:
aux_image_processor = aux_vision_tower.image_processor
else:
image_processor = None # Set to None if there is no auxiliary vision tower
# image_processor returned as a tuple of (primary, aux)
image_processor = (primary_image_processor, aux_image_processor)
# --- Ensure vision_tower and vision_tower_aux are loaded with weights from model_path ---
# if 'vlm-fo1' in model_path.lower():
# print(f"Loading weights from {model_path} to ensure vision_tower uses the correct weights...") # Inform user we are loading vision weights
# # --- Gather all safetensors files in the model path (for sharded checkpoints) ---
# state_dict = {}
# safetensor_files = [f for f in os.listdir(model_path) if f.endswith('.safetensors')]
# if safetensor_files:
# for safetensor_file in safetensor_files:
# file_path = os.path.join(model_path, safetensor_file)
# shard_state_dict = load_file(file_path, device="cpu")
# state_dict.update(shard_state_dict)
# else:
# # Fallback to legacy .bin checkpoint if no safetensors found
# state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location="cpu")
# # --- Filter out only vision_tower and vision_tower_aux related weights ---
# vision_tower_keys = [k for k in state_dict.keys() if "vision_tower." in k]
# vision_tower_state_dict = {k: state_dict[k] for k in vision_tower_keys if k in state_dict}
# if vision_tower_keys:
# # print(f"Found {len(vision_tower_keys)} vision_tower weights")
# # Load weights into main vision tower
# if primary_vision_tower and primary_vision_tower.is_loaded:
# # Strips the prefix "model.vision_tower." before loading (for compatibility with submodules)
# missing_keys, unexpected_keys = primary_vision_tower.load_state_dict(
# {k.replace("model.vision_tower.", ""): v for k, v in vision_tower_state_dict.items()
# if k.startswith("model.vision_tower.")},
# strict=True
# )
# print(f"vision_tower weights loaded, missing keys: {missing_keys}, unexpected keys: {unexpected_keys}")
# # If there is an aux vision tower (Qwen2.5-VL) load its weights as well
# if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower():
# if aux_vision_tower and aux_vision_tower.is_loaded:
# vision_tower_aux_keys = [k for k in state_dict.keys() if "vision_tower_aux." in k]
# if vision_tower_aux_keys:
# # print(f"Found {len(vision_tower_aux_keys)} vision_tower_aux weights")
# vision_tower_aux_state_dict = {k: state_dict[k] for k in vision_tower_aux_keys if k in state_dict}
# # Strip "model.vision_tower_aux." prefix before loading for compatibility
# missing_keys, unexpected_keys = aux_vision_tower.load_state_dict(
# {k.replace("model.vision_tower_aux.", ""): v for k, v in vision_tower_aux_state_dict.items()
# if k.startswith("model.vision_tower_aux.")},
# strict=True
# )
# print(f"vision_tower_aux weights loaded, missing keys: {missing_keys}, unexpected keys: {unexpected_keys}")
# else:
# # If no vision tower weights found, raise an error
# print("No vision_tower weights found")
# raise Exception("No vision_tower weights found")
# Set model to eval mode and move to correct device before returning
model.eval()
model.to(device=device, dtype=torch.bfloat16)
return tokenizer, model, image_processor