from transformers import AutoTokenizer import torch from vlm_fo1.model import * from safetensors.torch import load_file import os def load_pretrained_model(model_path, load_8bit=False, load_4bit=False, device="cuda"): """ Loads a pretrained model along with its vision towers (and associated image processors). This function supports loading in 8bit/4bit precision and explicit device placement. Args: model_path (str): Path to the pretrained model directory. load_8bit (bool): Whether to load the model in 8bit mode. load_4bit (bool): Whether to load the model in 4bit mode. device (str): Device to load model onto, e.g., "cuda" or "cpu". Returns: tuple: (tokenizer, model, image_processor) """ kwargs = {"device_map": device} # Set model loading parameters for quantization or floating point if load_8bit: kwargs['load_in_8bit'] = True elif load_4bit: kwargs['load_in_4bit'] = True else: kwargs['torch_dtype'] = torch.bfloat16 # print(model_path) # Only proceed for vlm-fo1 models if 'vlm-fo1' in model_path.lower(): # Load tokenizer (slow tokenizer enforced) tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) # If this is the Qwen2.5-VL variant, load with additional kwargs if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower(): model, loading_info = OmChatQwen25VLForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, output_loading_info=True, attn_implementation="flash_attention_2", **kwargs, cache_dir='./resources', ) # print(f'OmChatQwen25VLForCausalLM loading_info: {loading_info}') # (For other variants of vlm-fo1, model loading detail may need additional condition.) if 'vlm-fo1' in model_path.lower(): # --- Vision Tower Loading --- # Load the main vision tower weights from model_path if it is not yet loaded primary_vision_tower = model.get_vision_tower() if primary_vision_tower and not primary_vision_tower.is_loaded: primary_vision_tower.load_model(model_path=model_path, is_train=False) primary_vision_tower.to(device=device, dtype=torch.bfloat16) # Move to correct device/dtype # Grab primary image processor from vision tower, if present if primary_vision_tower: primary_image_processor = primary_vision_tower.image_processor # --- Auxiliary Vision Tower Handling (Qwen2.5-VL case only) --- if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower(): try: aux_image_size = model.config.aux_image_size except Exception: # If aux_image_size is missing from config fallback to 768 aux_image_size = 768 aux_image_aspect_ratio = model.config.aux_image_aspect_ratio aux_vision_tower = model.get_vision_tower_aux() # Only load if not already loaded if aux_vision_tower and not aux_vision_tower.is_loaded: aux_vision_tower.load_model(image_size=aux_image_size, is_train=False, aspect_ratio=aux_image_aspect_ratio) aux_vision_tower.to(device=device, dtype=torch.bfloat16) # Get auxiliary image processor if there is an aux vision tower if aux_vision_tower: aux_image_processor = aux_vision_tower.image_processor else: image_processor = None # Set to None if there is no auxiliary vision tower # image_processor returned as a tuple of (primary, aux) image_processor = (primary_image_processor, aux_image_processor) # --- Ensure vision_tower and vision_tower_aux are loaded with weights from model_path --- # if 'vlm-fo1' in model_path.lower(): # print(f"Loading weights from {model_path} to ensure vision_tower uses the correct weights...") # Inform user we are loading vision weights # # --- Gather all safetensors files in the model path (for sharded checkpoints) --- # state_dict = {} # safetensor_files = [f for f in os.listdir(model_path) if f.endswith('.safetensors')] # if safetensor_files: # for safetensor_file in safetensor_files: # file_path = os.path.join(model_path, safetensor_file) # shard_state_dict = load_file(file_path, device="cpu") # state_dict.update(shard_state_dict) # else: # # Fallback to legacy .bin checkpoint if no safetensors found # state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location="cpu") # # --- Filter out only vision_tower and vision_tower_aux related weights --- # vision_tower_keys = [k for k in state_dict.keys() if "vision_tower." in k] # vision_tower_state_dict = {k: state_dict[k] for k in vision_tower_keys if k in state_dict} # if vision_tower_keys: # # print(f"Found {len(vision_tower_keys)} vision_tower weights") # # Load weights into main vision tower # if primary_vision_tower and primary_vision_tower.is_loaded: # # Strips the prefix "model.vision_tower." before loading (for compatibility with submodules) # missing_keys, unexpected_keys = primary_vision_tower.load_state_dict( # {k.replace("model.vision_tower.", ""): v for k, v in vision_tower_state_dict.items() # if k.startswith("model.vision_tower.")}, # strict=True # ) # print(f"vision_tower weights loaded, missing keys: {missing_keys}, unexpected keys: {unexpected_keys}") # # If there is an aux vision tower (Qwen2.5-VL) load its weights as well # if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower(): # if aux_vision_tower and aux_vision_tower.is_loaded: # vision_tower_aux_keys = [k for k in state_dict.keys() if "vision_tower_aux." in k] # if vision_tower_aux_keys: # # print(f"Found {len(vision_tower_aux_keys)} vision_tower_aux weights") # vision_tower_aux_state_dict = {k: state_dict[k] for k in vision_tower_aux_keys if k in state_dict} # # Strip "model.vision_tower_aux." prefix before loading for compatibility # missing_keys, unexpected_keys = aux_vision_tower.load_state_dict( # {k.replace("model.vision_tower_aux.", ""): v for k, v in vision_tower_aux_state_dict.items() # if k.startswith("model.vision_tower_aux.")}, # strict=True # ) # print(f"vision_tower_aux weights loaded, missing keys: {missing_keys}, unexpected keys: {unexpected_keys}") # else: # # If no vision tower weights found, raise an error # print("No vision_tower weights found") # raise Exception("No vision_tower weights found") # Set model to eval mode and move to correct device before returning model.eval() model.to(device=device, dtype=torch.bfloat16) return tokenizer, model, image_processor