Test-Prompt / backend /models /internvl /internvl_model.py
abhiman181025's picture
First commit
1314bf5
import torch
import gc
import os
from typing import Dict, Any, Optional, Callable
from transformers import AutoModel, AutoTokenizer, AutoConfig
from ..base_model import BaseModel
from ...utils.image_processing import load_image
from ...config.config_manager import ConfigManager
class InternVLModel(BaseModel):
"""InternVL3 model implementation."""
def __init__(self, model_name: str, model_config: Dict[str, Any], config_manager: ConfigManager):
"""
Initialize the InternVL model.
Args:
model_name: Name of the model
model_config: Configuration dictionary for the model
config_manager: Configuration manager instance
"""
super().__init__(model_name, model_config)
self.config_manager = config_manager
# Set environment variable for CUDA memory allocation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
def check_model_exists_locally(self) -> bool:
"""Check if model exists locally in Hugging Face cache."""
try:
from transformers.utils import cached_file
cached_file(self.model_id, "config.json", local_files_only=True)
return True
except:
return False
def download_model_with_progress(self, progress_callback: Optional[Callable] = None) -> bool:
"""
Download model with progress tracking.
Args:
progress_callback: Callback function for progress updates
Returns:
True if successful, False otherwise
"""
try:
if progress_callback:
progress_callback("πŸ“₯ Downloading tokenizer...")
# Download tokenizer first (smaller)
tokenizer = AutoTokenizer.from_pretrained(
self.model_id,
trust_remote_code=True,
use_fast=False
)
if progress_callback:
progress_callback("πŸ“₯ Downloading model weights... This may take several minutes...")
# Download model config and weights
config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
if progress_callback:
progress_callback("βœ… Model downloaded successfully!")
return True
except Exception as e:
if progress_callback:
progress_callback(f"❌ Download failed: {str(e)}")
return False
def split_model(self) -> Dict[str, int]:
"""
Distribute LLM layers across GPUs, keeping vision encoder on GPU 0.
Returns:
Device map dictionary
"""
device_map = {}
world_size = torch.cuda.device_count()
if world_size < 2:
return "auto" # let transformers decide
cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
num_layers = cfg.llm_config.num_hidden_layers # type: ignore[attr-defined]
# More aggressive distribution - treat GPU 0 as 0.3 GPU capacity due to vision model
effective_gpus = world_size - 0.7 # More conservative for GPU 0
layers_per_gpu = num_layers / effective_gpus
# Calculate layer distribution
gpu_layers = []
for i in range(world_size):
if i == 0:
# GPU 0 gets fewer layers due to vision model
gpu_layers.append(max(1, int(layers_per_gpu * 0.3)))
else:
gpu_layers.append(int(layers_per_gpu))
# Adjust if total doesn't match num_layers
total_assigned = sum(gpu_layers)
diff = num_layers - total_assigned
if diff > 0:
# Add remaining layers to non-zero GPUs
for i in range(1, min(world_size, diff + 1)):
gpu_layers[i] += 1
elif diff < 0:
# Remove excess layers from GPU 0
gpu_layers[0] = max(1, gpu_layers[0] + diff)
# Assign layers to devices
layer_cnt = 0
for gpu_id, num_layers_on_gpu in enumerate(gpu_layers):
for _ in range(num_layers_on_gpu):
if layer_cnt < num_layers:
device_map[f'language_model.model.layers.{layer_cnt}'] = gpu_id
layer_cnt += 1
# Distribute other components more evenly across GPUs
last_gpu = world_size - 1
# Vision model must stay on GPU 0
device_map['vision_model'] = 0
device_map['mlp1'] = 0
# Distribute language model components across GPUs
device_map['language_model.model.tok_embeddings'] = 0
device_map['language_model.model.embed_tokens'] = 0
device_map['language_model.model.norm'] = last_gpu # Move to last GPU
device_map['language_model.model.rotary_emb'] = 1 if world_size > 1 else 0 # Move to GPU 1
device_map['language_model.output'] = last_gpu # Move to last GPU
device_map['language_model.lm_head'] = last_gpu # Move to last GPU
# Keep the last layer on the same GPU as output layers for compatibility
device_map[f'language_model.model.layers.{num_layers - 1}'] = last_gpu
print(f"Layer distribution: {gpu_layers}")
print(f"Total layers: {num_layers}, Assigned: {sum(gpu_layers)}")
return device_map
def load_model(self, quantization_type: str, progress_callback: Optional[Callable] = None) -> bool:
"""
Load the model with specified quantization.
Args:
quantization_type: Type of quantization to use
progress_callback: Callback function for progress updates
Returns:
True if successful, False otherwise
"""
if not self.validate_quantization(quantization_type):
raise ValueError(f"Quantization type '{quantization_type}' not supported for {self.model_name}")
# If model is already loaded with the same quantization, return
if (self.model is not None and self.tokenizer is not None and
self.current_quantization == quantization_type):
if progress_callback:
progress_callback(f"βœ… {self.model_name} already loaded!")
return True
print(f"Loading {self.model_name} with {quantization_type} quantization...")
if progress_callback:
progress_callback(f"πŸ”„ Loading {self.model_name} with {quantization_type} quantization...")
try:
# Check if model exists locally
model_exists = self.check_model_exists_locally()
if not model_exists:
if progress_callback:
progress_callback(f"πŸ“₯ {self.model_name} not found locally. Starting download...")
print(f"Model {self.model_name} not found locally. Starting download...")
success = self.download_model_with_progress(progress_callback)
if not success:
raise Exception(f"Failed to download {self.model_name}")
else:
if progress_callback:
progress_callback(f"βœ… {self.model_name} found locally.")
# Clear existing model if any
if self.model is not None:
self.unload_model()
# Print memory before loading
self._print_gpu_memory("before loading")
if progress_callback:
progress_callback(f"πŸš€ Loading {self.model_name} tokenizer...")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_id,
trust_remote_code=True,
use_fast=False
)
# Load model based on quantization type
if "non-quantized" in quantization_type:
if progress_callback:
progress_callback(f"πŸš€ Loading {self.model_name} model in 16-bit precision...")
device_map = self.split_model()
print(f"Device map for multi-GPU: {device_map}")
# Try loading with custom device_map, fallback to "auto" if it fails
# Some InternVL models (e.g., InternVL3_5) don't support custom device_map
# due to missing 'all_tied_weights_keys' attribute
try:
self.model = AutoModel.from_pretrained(
self.model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_flash_attn=True,
trust_remote_code=True,
device_map=device_map,
).eval()
except (AttributeError, TypeError, RuntimeError, ValueError) as e:
error_str = str(e).lower()
# Check for device_map related errors, especially all_tied_weights_keys
# This is a known issue with some InternVL models that don't expose
# the all_tied_weights_keys attribute required for custom device_map
if ("all_tied_weights_keys" in error_str or
"tied_weights" in error_str or
("device_map" in error_str and "attribute" in error_str)):
print(f"⚠️ Custom device_map failed ({str(e)}), falling back to 'auto' device_map...")
if progress_callback:
progress_callback(f"⚠️ Using automatic device mapping...")
self.model = AutoModel.from_pretrained(
self.model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_flash_attn=True,
trust_remote_code=True,
device_map="auto",
).eval()
else:
# Re-raise if it's a different error
raise
else: # quantized (8bit)
if progress_callback:
progress_callback(f"πŸš€ Loading {self.model_name} model with 8-bit quantization...")
print("Loading with 8-bit quantization to reduce memory usage...")
self.model = AutoModel.from_pretrained(
self.model_id,
torch_dtype=torch.bfloat16,
load_in_8bit=True,
low_cpu_mem_usage=True,
use_flash_attn=True,
trust_remote_code=True,
device_map="auto" # Let transformers handle device mapping for quantized model
).eval()
# Verify model and tokenizer are properly loaded
if self.model is None:
raise Exception(f"Model failed to load for {self.model_name}")
if self.tokenizer is None:
raise Exception(f"Tokenizer failed to load for {self.model_name}")
self.current_quantization = quantization_type
self.is_loaded = True
success_msg = f"βœ… {self.model_name} loaded successfully with {quantization_type} quantization!"
print(success_msg)
if progress_callback:
progress_callback(success_msg)
# Print GPU memory usage after loading
self._print_gpu_memory("after loading")
return True
except Exception as e:
error_msg = f"Failed to load model {self.model_name}: {str(e)}"
print(error_msg)
if progress_callback:
progress_callback(f"❌ {error_msg}")
# Reset on failure
self.unload_model()
raise Exception(error_msg)
def unload_model(self) -> None:
"""Unload the model from memory."""
if self.model is not None:
print("🧹 Clearing model from memory...")
del self.model
self.model = None
if self.tokenizer is not None:
del self.tokenizer
self.tokenizer = None
self.current_quantization = None
self.is_loaded = False
# Clear GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Force garbage collection
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache() # Clear again after gc
print("βœ… Model unloaded successfully")
def inference(self, image_path: str, prompt: str, **kwargs) -> str:
"""
Perform inference on an image with a text prompt.
Args:
image_path: Path to the image file
prompt: Text prompt for the model
**kwargs: Additional inference parameters
Returns:
Model's text response
"""
if not self.is_loaded:
raise RuntimeError(f"Model {self.model_name} is not loaded. Call load_model() first.")
try:
# Load and preprocess image using default settings from original app.py
pixel_values = load_image(image_path, input_size=448, max_num=12).to(torch.bfloat16)
# Move pixel_values to the same device as the model
if torch.cuda.is_available():
# Get the device of the first model parameter
model_device = next(self.model.parameters()).device
pixel_values = pixel_values.to(model_device)
else:
# Fallback to CPU if no CUDA available
pixel_values = pixel_values.cpu()
# Prepare prompt
formatted_prompt = f"<image>\n{prompt}" if prompt else "<image>\n"
# Generation configuration - using same settings as original app.py
gen_cfg = dict(max_new_tokens=1024, do_sample=True)
# Perform inference
response = self.model.chat(self.tokenizer, pixel_values, formatted_prompt, gen_cfg)
return response
except Exception as e:
error_msg = f"Error processing image: {str(e)}"
print(error_msg)
return error_msg
def _print_gpu_memory(self, stage: str) -> None:
"""Print GPU memory usage for debugging."""
if torch.cuda.is_available():
print(f"Memory {stage}:")
for i in range(torch.cuda.device_count()):
allocated = torch.cuda.memory_allocated(i) / 1024**3
reserved = torch.cuda.memory_reserved(i) / 1024**3
print(f"GPU {i}: Allocated {allocated:.2f} GB, Reserved {reserved:.2f} GB")