File size: 3,960 Bytes
1314bf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
import torch
from transformers import AutoModel, AutoTokenizer

class BaseModel(ABC):
    """Abstract base class for all vision-language models."""
    
    def __init__(self, model_name: str, model_config: Dict[str, Any]):
        """
        Initialize the base model.
        
        Args:
            model_name: Name of the model
            model_config: Configuration dictionary for the model
        """
        self.model_name = model_name
        self.model_config = model_config
        self.model_id = model_config['model_id']
        self.model = None
        self.tokenizer = None
        self.current_quantization = None
        self.is_loaded = False
    
    @abstractmethod
    def load_model(self, quantization_type: str, **kwargs) -> bool:
        """
        Load the model with specified quantization.
        
        Args:
            quantization_type: Type of quantization to use
            **kwargs: Additional arguments for model loading
            
        Returns:
            True if successful, False otherwise
        """
        pass
    
    @abstractmethod
    def unload_model(self) -> None:
        """Unload the model from memory."""
        pass
    
    @abstractmethod
    def inference(self, image_path: str, prompt: str, **kwargs) -> str:
        """
        Perform inference on an image with a text prompt.
        
        Args:
            image_path: Path to the image file
            prompt: Text prompt for the model
            **kwargs: Additional inference parameters
            
        Returns:
            Model's text response
        """
        pass
    
    def is_model_loaded(self) -> bool:
        """Check if model is currently loaded."""
        return self.is_loaded
    
    def get_model_info(self) -> Dict[str, Any]:
        """Get information about the model."""
        return {
            'name': self.model_name,
            'model_id': self.model_id,
            'description': self.model_config.get('description', ''),
            'min_gpu_memory_gb': self.model_config.get('min_gpu_memory_gb', 0),
            'recommended_gpu_memory_gb': self.model_config.get('recommended_gpu_memory_gb', 0),
            'supported_quantizations': self.model_config.get('supported_quantizations', []),
            'default_quantization': self.model_config.get('default_quantization', ''),
            'is_loaded': self.is_loaded,
            'current_quantization': self.current_quantization
        }
    
    def get_supported_quantizations(self) -> List[str]:
        """Get list of supported quantization methods."""
        return self.model_config.get('supported_quantizations', [])
    
    def get_memory_requirements(self) -> Dict[str, int]:
        """Get memory requirements for the model."""
        return {
            'min_gpu_memory_gb': self.model_config.get('min_gpu_memory_gb', 0),
            'recommended_gpu_memory_gb': self.model_config.get('recommended_gpu_memory_gb', 0)
        }
    
    def validate_quantization(self, quantization_type: str) -> bool:
        """
        Validate if the quantization type is supported.
        
        Args:
            quantization_type: Quantization type to validate
            
        Returns:
            True if supported, False otherwise
        """
        supported = self.get_supported_quantizations()
        return quantization_type in supported
    
    def __str__(self) -> str:
        """String representation of the model."""
        status = "loaded" if self.is_loaded else "not loaded"
        quant = f" ({self.current_quantization})" if self.current_quantization else ""
        return f"{self.model_name}{quant} - {status}"
    
    def __repr__(self) -> str:
        """Detailed string representation."""
        return f"BaseModel(name={self.model_name}, loaded={self.is_loaded}, quantization={self.current_quantization})"