Spaces:
Runtime error
Runtime error
| from transformers import AutoTokenizer | |
| import torch | |
| from vlm_fo1.model import * | |
| from safetensors.torch import load_file | |
| import os | |
| def load_pretrained_model(model_path, load_8bit=False, load_4bit=False, device="cuda"): | |
| """ | |
| Loads a pretrained model along with its vision towers (and associated image processors). | |
| This function supports loading in 8bit/4bit precision and explicit device placement. | |
| Args: | |
| model_path (str): Path to the pretrained model directory. | |
| load_8bit (bool): Whether to load the model in 8bit mode. | |
| load_4bit (bool): Whether to load the model in 4bit mode. | |
| device (str): Device to load model onto, e.g., "cuda" or "cpu". | |
| Returns: | |
| tuple: (tokenizer, model, image_processor) | |
| """ | |
| kwargs = {"device_map": device} | |
| # Set model loading parameters for quantization or floating point | |
| if load_8bit: | |
| kwargs['load_in_8bit'] = True | |
| elif load_4bit: | |
| kwargs['load_in_4bit'] = True | |
| else: | |
| kwargs['torch_dtype'] = torch.bfloat16 | |
| # print(model_path) | |
| # Only proceed for vlm-fo1 models | |
| if 'vlm-fo1' in model_path.lower(): | |
| # Load tokenizer (slow tokenizer enforced) | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) | |
| # If this is the Qwen2.5-VL variant, load with additional kwargs | |
| if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower(): | |
| model, loading_info = OmChatQwen25VLForCausalLM.from_pretrained( | |
| model_path, | |
| low_cpu_mem_usage=True, | |
| output_loading_info=True, | |
| attn_implementation="flash_attention_2", | |
| **kwargs | |
| ) | |
| # print(f'OmChatQwen25VLForCausalLM loading_info: {loading_info}') | |
| # (For other variants of vlm-fo1, model loading detail may need additional condition.) | |
| if 'vlm-fo1' in model_path.lower(): | |
| # --- Vision Tower Loading --- | |
| # Load the main vision tower weights from model_path if it is not yet loaded | |
| primary_vision_tower = model.get_vision_tower() | |
| if primary_vision_tower and not primary_vision_tower.is_loaded: | |
| primary_vision_tower.load_model(model_path=model_path, is_train=False) | |
| primary_vision_tower.to(device=device, dtype=torch.bfloat16) # Move to correct device/dtype | |
| # Grab primary image processor from vision tower, if present | |
| if primary_vision_tower: | |
| primary_image_processor = primary_vision_tower.image_processor | |
| # --- Auxiliary Vision Tower Handling (Qwen2.5-VL case only) --- | |
| if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower(): | |
| try: | |
| aux_image_size = model.config.aux_image_size | |
| except Exception: | |
| # If aux_image_size is missing from config fallback to 768 | |
| aux_image_size = 768 | |
| aux_image_aspect_ratio = model.config.aux_image_aspect_ratio | |
| aux_vision_tower = model.get_vision_tower_aux() | |
| # Only load if not already loaded | |
| if aux_vision_tower and not aux_vision_tower.is_loaded: | |
| aux_vision_tower.load_model(image_size=aux_image_size, is_train=False, aspect_ratio=aux_image_aspect_ratio) | |
| aux_vision_tower.to(device=device, dtype=torch.bfloat16) | |
| # Get auxiliary image processor if there is an aux vision tower | |
| if aux_vision_tower: | |
| aux_image_processor = aux_vision_tower.image_processor | |
| else: | |
| image_processor = None # Set to None if there is no auxiliary vision tower | |
| # image_processor returned as a tuple of (primary, aux) | |
| image_processor = (primary_image_processor, aux_image_processor) | |
| # Set model to eval mode and move to correct device before returning | |
| model.eval() | |
| model.to(device=device, dtype=torch.bfloat16) | |
| return tokenizer, model, image_processor | |