Spaces:
Runtime error
Runtime error
File size: 3,960 Bytes
1314bf5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
import torch
from transformers import AutoModel, AutoTokenizer
class BaseModel(ABC):
"""Abstract base class for all vision-language models."""
def __init__(self, model_name: str, model_config: Dict[str, Any]):
"""
Initialize the base model.
Args:
model_name: Name of the model
model_config: Configuration dictionary for the model
"""
self.model_name = model_name
self.model_config = model_config
self.model_id = model_config['model_id']
self.model = None
self.tokenizer = None
self.current_quantization = None
self.is_loaded = False
@abstractmethod
def load_model(self, quantization_type: str, **kwargs) -> bool:
"""
Load the model with specified quantization.
Args:
quantization_type: Type of quantization to use
**kwargs: Additional arguments for model loading
Returns:
True if successful, False otherwise
"""
pass
@abstractmethod
def unload_model(self) -> None:
"""Unload the model from memory."""
pass
@abstractmethod
def inference(self, image_path: str, prompt: str, **kwargs) -> str:
"""
Perform inference on an image with a text prompt.
Args:
image_path: Path to the image file
prompt: Text prompt for the model
**kwargs: Additional inference parameters
Returns:
Model's text response
"""
pass
def is_model_loaded(self) -> bool:
"""Check if model is currently loaded."""
return self.is_loaded
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the model."""
return {
'name': self.model_name,
'model_id': self.model_id,
'description': self.model_config.get('description', ''),
'min_gpu_memory_gb': self.model_config.get('min_gpu_memory_gb', 0),
'recommended_gpu_memory_gb': self.model_config.get('recommended_gpu_memory_gb', 0),
'supported_quantizations': self.model_config.get('supported_quantizations', []),
'default_quantization': self.model_config.get('default_quantization', ''),
'is_loaded': self.is_loaded,
'current_quantization': self.current_quantization
}
def get_supported_quantizations(self) -> List[str]:
"""Get list of supported quantization methods."""
return self.model_config.get('supported_quantizations', [])
def get_memory_requirements(self) -> Dict[str, int]:
"""Get memory requirements for the model."""
return {
'min_gpu_memory_gb': self.model_config.get('min_gpu_memory_gb', 0),
'recommended_gpu_memory_gb': self.model_config.get('recommended_gpu_memory_gb', 0)
}
def validate_quantization(self, quantization_type: str) -> bool:
"""
Validate if the quantization type is supported.
Args:
quantization_type: Quantization type to validate
Returns:
True if supported, False otherwise
"""
supported = self.get_supported_quantizations()
return quantization_type in supported
def __str__(self) -> str:
"""String representation of the model."""
status = "loaded" if self.is_loaded else "not loaded"
quant = f" ({self.current_quantization})" if self.current_quantization else ""
return f"{self.model_name}{quant} - {status}"
def __repr__(self) -> str:
"""Detailed string representation."""
return f"BaseModel(name={self.model_name}, loaded={self.is_loaded}, quantization={self.current_quantization})" |