abhiman181025's picture
First commit
1314bf5
import torch
import gc
import os
from typing import Dict, Any, Optional, Callable
from transformers import AutoModelForCausalLM, AutoTokenizer
from ..base_model import BaseModel
from ...config.config_manager import ConfigManager
class QwenModel(BaseModel):
"""Qwen2.5 model implementation."""
def __init__(self, model_name: str, model_config: Dict[str, Any], config_manager: ConfigManager):
"""
Initialize the Qwen model.
Args:
model_name: Name of the model
model_config: Configuration dictionary for the model
config_manager: Configuration manager instance
"""
super().__init__(model_name, model_config)
self.config_manager = config_manager
# Set environment variable for CUDA memory allocation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
def check_model_exists_locally(self) -> bool:
"""Check if model exists locally in Hugging Face cache."""
try:
from transformers.utils import cached_file
cached_file(self.model_id, "config.json", local_files_only=True)
return True
except:
return False
def download_model_with_progress(self, progress_callback: Optional[Callable] = None) -> bool:
"""
Download model with progress tracking.
Args:
progress_callback: Callback function for progress updates
Returns:
True if successful, False otherwise
"""
try:
if progress_callback:
progress_callback("πŸ“₯ Downloading tokenizer...")
# Download tokenizer first (smaller)
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
if progress_callback:
progress_callback("πŸ“₯ Downloading model weights... This may take several minutes...")
# Download model config and weights by trying to load config
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
torch_dtype="auto",
device_map="cpu", # Just download, don't load to GPU yet
low_cpu_mem_usage=True
)
# Clean up the test loading
del model
if progress_callback:
progress_callback("βœ… Model downloaded successfully!")
return True
except Exception as e:
if progress_callback:
progress_callback(f"❌ Download failed: {str(e)}")
return False
def load_model(self, quantization_type: str, progress_callback: Optional[Callable] = None) -> bool:
"""
Load the model with specified quantization.
Args:
quantization_type: Type of quantization to use
progress_callback: Callback function for progress updates
Returns:
True if successful, False otherwise
"""
if not self.validate_quantization(quantization_type):
raise ValueError(f"Quantization type '{quantization_type}' not supported for {self.model_name}")
# If model is already loaded with the same quantization, return
if (self.model is not None and self.tokenizer is not None and
self.current_quantization == quantization_type):
if progress_callback:
progress_callback(f"βœ… {self.model_name} already loaded!")
return True
print(f"Loading {self.model_name} with {quantization_type} quantization...")
if progress_callback:
progress_callback(f"πŸ”„ Loading {self.model_name} with {quantization_type} quantization...")
try:
# Check if model exists locally
model_exists = self.check_model_exists_locally()
if not model_exists:
if progress_callback:
progress_callback(f"πŸ“₯ {self.model_name} not found locally. Starting download...")
print(f"Model {self.model_name} not found locally. Starting download...")
success = self.download_model_with_progress(progress_callback)
if not success:
raise Exception(f"Failed to download {self.model_name}")
else:
if progress_callback:
progress_callback(f"βœ… {self.model_name} found locally.")
# Clear existing model if any
if self.model is not None:
self.unload_model()
# Print memory before loading
self._print_gpu_memory("before loading")
if progress_callback:
progress_callback(f"πŸš€ Loading {self.model_name} tokenizer...")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
# Load model based on quantization type
if progress_callback:
progress_callback(f"πŸš€ Loading {self.model_name} model...")
if "non-quantized" in quantization_type:
# Load with auto dtype and device mapping
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id,
torch_dtype="auto",
device_map="auto",
low_cpu_mem_usage=True
)
else: # quantized (8bit)
print("Loading with 8-bit quantization to reduce memory usage...")
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id,
torch_dtype="auto",
load_in_8bit=True,
device_map="auto",
low_cpu_mem_usage=True
)
# Verify model and tokenizer are properly loaded
if self.model is None:
raise Exception(f"Model failed to load for {self.model_name}")
if self.tokenizer is None:
raise Exception(f"Tokenizer failed to load for {self.model_name}")
self.current_quantization = quantization_type
self.is_loaded = True
success_msg = f"βœ… {self.model_name} loaded successfully with {quantization_type} quantization!"
print(success_msg)
if progress_callback:
progress_callback(success_msg)
# Print GPU memory usage after loading
self._print_gpu_memory("after loading")
return True
except Exception as e:
error_msg = f"Failed to load model {self.model_name}: {str(e)}"
print(error_msg)
if progress_callback:
progress_callback(f"❌ {error_msg}")
# Reset on failure
self.unload_model()
raise Exception(error_msg)
def unload_model(self) -> None:
"""Unload the model from memory."""
if self.model is not None:
print("🧹 Clearing model from memory...")
del self.model
self.model = None
if self.tokenizer is not None:
del self.tokenizer
self.tokenizer = None
self.current_quantization = None
self.is_loaded = False
# Clear GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Force garbage collection
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache() # Clear again after gc
print("βœ… Model unloaded successfully")
def inference(self, image_path: str, prompt: str, **kwargs) -> str:
"""
Perform inference with a text prompt.
Note: Qwen2.5 is a text-only model, so image_path is ignored.
Args:
image_path: Path to the image file (ignored for text-only models)
prompt: Text prompt for the model
**kwargs: Additional inference parameters
Returns:
Model's text response
"""
if not self.is_loaded:
raise RuntimeError(f"Model {self.model_name} is not loaded. Call load_model() first.")
if not prompt or not prompt.strip():
return "Error: No prompt provided"
try:
# Prepare messages for chat format
messages = [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{"role": "user", "content": prompt}
]
# Apply chat template
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize input
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
# Generate response
generated_ids = self.model.generate(
**model_inputs,
max_new_tokens=kwargs.get('max_new_tokens', 512),
do_sample=kwargs.get('do_sample', True),
temperature=kwargs.get('temperature', 0.7),
top_p=kwargs.get('top_p', 0.9),
pad_token_id=self.tokenizer.eos_token_id
)
# Extract only the generated part (remove input tokens)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
# Decode response
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
except Exception as e:
error_msg = f"Error processing prompt: {str(e)}"
print(error_msg)
return error_msg
def _print_gpu_memory(self, stage: str) -> None:
"""Print GPU memory usage for debugging."""
if torch.cuda.is_available():
print(f"Memory {stage}:")
for i in range(torch.cuda.device_count()):
allocated = torch.cuda.memory_allocated(i) / 1024**3
reserved = torch.cuda.memory_reserved(i) / 1024**3
print(f"GPU {i}: Allocated {allocated:.2f} GB, Reserved {reserved:.2f} GB")