Spaces:
Runtime error
Runtime error
File size: 7,463 Bytes
56ef371 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from transformers import AutoTokenizer
import torch
from vlm_fo1.model import *
from safetensors.torch import load_file
import os
def load_pretrained_model(model_path, load_8bit=False, load_4bit=False, device="cuda"):
"""
Loads a pretrained model along with its vision towers (and associated image processors).
This function supports loading in 8bit/4bit precision and explicit device placement.
Args:
model_path (str): Path to the pretrained model directory.
load_8bit (bool): Whether to load the model in 8bit mode.
load_4bit (bool): Whether to load the model in 4bit mode.
device (str): Device to load model onto, e.g., "cuda" or "cpu".
Returns:
tuple: (tokenizer, model, image_processor)
"""
kwargs = {"device_map": device}
# Set model loading parameters for quantization or floating point
if load_8bit:
kwargs['load_in_8bit'] = True
elif load_4bit:
kwargs['load_in_4bit'] = True
else:
kwargs['torch_dtype'] = torch.bfloat16
# print(model_path)
# Only proceed for vlm-fo1 models
if 'vlm-fo1' in model_path.lower():
# Load tokenizer (slow tokenizer enforced)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
# If this is the Qwen2.5-VL variant, load with additional kwargs
if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower():
model, loading_info = OmChatQwen25VLForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
output_loading_info=True,
attn_implementation="flash_attention_2",
**kwargs,
cache_dir='./resources',
)
# print(f'OmChatQwen25VLForCausalLM loading_info: {loading_info}')
# (For other variants of vlm-fo1, model loading detail may need additional condition.)
if 'vlm-fo1' in model_path.lower():
# --- Vision Tower Loading ---
# Load the main vision tower weights from model_path if it is not yet loaded
primary_vision_tower = model.get_vision_tower()
if primary_vision_tower and not primary_vision_tower.is_loaded:
primary_vision_tower.load_model(model_path=model_path, is_train=False)
primary_vision_tower.to(device=device, dtype=torch.bfloat16) # Move to correct device/dtype
# Grab primary image processor from vision tower, if present
if primary_vision_tower:
primary_image_processor = primary_vision_tower.image_processor
# --- Auxiliary Vision Tower Handling (Qwen2.5-VL case only) ---
if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower():
try:
aux_image_size = model.config.aux_image_size
except Exception:
# If aux_image_size is missing from config fallback to 768
aux_image_size = 768
aux_image_aspect_ratio = model.config.aux_image_aspect_ratio
aux_vision_tower = model.get_vision_tower_aux()
# Only load if not already loaded
if aux_vision_tower and not aux_vision_tower.is_loaded:
aux_vision_tower.load_model(image_size=aux_image_size, is_train=False, aspect_ratio=aux_image_aspect_ratio)
aux_vision_tower.to(device=device, dtype=torch.bfloat16)
# Get auxiliary image processor if there is an aux vision tower
if aux_vision_tower:
aux_image_processor = aux_vision_tower.image_processor
else:
image_processor = None # Set to None if there is no auxiliary vision tower
# image_processor returned as a tuple of (primary, aux)
image_processor = (primary_image_processor, aux_image_processor)
# --- Ensure vision_tower and vision_tower_aux are loaded with weights from model_path ---
# if 'vlm-fo1' in model_path.lower():
# print(f"Loading weights from {model_path} to ensure vision_tower uses the correct weights...") # Inform user we are loading vision weights
# # --- Gather all safetensors files in the model path (for sharded checkpoints) ---
# state_dict = {}
# safetensor_files = [f for f in os.listdir(model_path) if f.endswith('.safetensors')]
# if safetensor_files:
# for safetensor_file in safetensor_files:
# file_path = os.path.join(model_path, safetensor_file)
# shard_state_dict = load_file(file_path, device="cpu")
# state_dict.update(shard_state_dict)
# else:
# # Fallback to legacy .bin checkpoint if no safetensors found
# state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location="cpu")
# # --- Filter out only vision_tower and vision_tower_aux related weights ---
# vision_tower_keys = [k for k in state_dict.keys() if "vision_tower." in k]
# vision_tower_state_dict = {k: state_dict[k] for k in vision_tower_keys if k in state_dict}
# if vision_tower_keys:
# # print(f"Found {len(vision_tower_keys)} vision_tower weights")
# # Load weights into main vision tower
# if primary_vision_tower and primary_vision_tower.is_loaded:
# # Strips the prefix "model.vision_tower." before loading (for compatibility with submodules)
# missing_keys, unexpected_keys = primary_vision_tower.load_state_dict(
# {k.replace("model.vision_tower.", ""): v for k, v in vision_tower_state_dict.items()
# if k.startswith("model.vision_tower.")},
# strict=True
# )
# print(f"vision_tower weights loaded, missing keys: {missing_keys}, unexpected keys: {unexpected_keys}")
# # If there is an aux vision tower (Qwen2.5-VL) load its weights as well
# if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower():
# if aux_vision_tower and aux_vision_tower.is_loaded:
# vision_tower_aux_keys = [k for k in state_dict.keys() if "vision_tower_aux." in k]
# if vision_tower_aux_keys:
# # print(f"Found {len(vision_tower_aux_keys)} vision_tower_aux weights")
# vision_tower_aux_state_dict = {k: state_dict[k] for k in vision_tower_aux_keys if k in state_dict}
# # Strip "model.vision_tower_aux." prefix before loading for compatibility
# missing_keys, unexpected_keys = aux_vision_tower.load_state_dict(
# {k.replace("model.vision_tower_aux.", ""): v for k, v in vision_tower_aux_state_dict.items()
# if k.startswith("model.vision_tower_aux.")},
# strict=True
# )
# print(f"vision_tower_aux weights loaded, missing keys: {missing_keys}, unexpected keys: {unexpected_keys}")
# else:
# # If no vision tower weights found, raise an error
# print("No vision_tower weights found")
# raise Exception("No vision_tower weights found")
# Set model to eval mode and move to correct device before returning
model.eval()
model.to(device=device, dtype=torch.bfloat16)
return tokenizer, model, image_processor
|