File size: 7,820 Bytes
101858b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""

Model Cache Module

Shared model caching system with shared Qwen model integration.

"""
import torch
import logging
from typing import Dict, Any, Optional
import sys
import os

logger = logging.getLogger(__name__)

# Try to import shared model
SHARED_MODEL_AVAILABLE = False
get_shared_model_func = None
get_shared_tokenizer_func = None

try:
    import importlib.util
    shared_model_path = os.path.join(os.path.dirname(__file__), '..', 'Shared Model', 'shared_model.py')
    if os.path.exists(shared_model_path):
        spec = importlib.util.spec_from_file_location("shared_model", shared_model_path)
        if spec and spec.loader:
            shared_model_module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(shared_model_module)
            SharedModel = shared_model_module.SharedModel
            SharedModelConfig = shared_model_module.SharedModelConfig
            get_shared_model_func = shared_model_module.get_shared_model
            get_shared_tokenizer_func = shared_model_module.get_shared_tokenizer
            SHARED_MODEL_AVAILABLE = True
except Exception as e:
    logger.debug(f"Shared model not available: {e}")

class ModelCache:
    """

    Shared model cache to prevent memory duplication.

    Integrates with shared Qwen model for zero memory overhead.

    """
    def __init__(self, use_shared_model: bool = True, shared_model_name: str = "Qwen/Qwen3-0.6B"):
        self.use_shared_model = use_shared_model
        self.shared_model_name = shared_model_name
        self.shared_models = {}
        self.shared_tokenizers = {}
        
        logger.debug("ModelCache initialized")
    
    def get_shared_model(self, model_name: str, model_type: str = "transformer",

                        device: Optional[str] = None, **kwargs) -> Any:
        """

        Get or create a shared model instance.

        Uses shared Qwen model if available for zero memory overhead.

        

        Args:

            model_name: Name of the model to load

            model_type: Type of model (transformer, tokenizer, etc.)

            device: Device to load model on

            **kwargs: Additional model loading parameters

            

        Returns:

            Shared model instance

        """
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Try to use shared Qwen model first
        if (self.use_shared_model and model_type == "transformer" and 
            SHARED_MODEL_AVAILABLE and get_shared_model_func is not None):
            try:
                shared_model = get_shared_model_func()
                if shared_model is not None:
                    logger.info(f"[CACHE] Using shared Qwen model (zero memory overhead)")
                    return shared_model
            except Exception as e:
                logger.debug(f"[CACHE] Shared model not available: {e}")
        
        # Try to use shared tokenizer
        if (self.use_shared_model and model_type == "tokenizer" and 
            SHARED_MODEL_AVAILABLE and get_shared_tokenizer_func is not None):
            try:
                shared_tokenizer = get_shared_tokenizer_func()
                if shared_tokenizer is not None:
                    logger.info(f"[CACHE] Using shared Qwen tokenizer (zero memory overhead)")
                    return shared_tokenizer
            except Exception as e:
                logger.debug(f"[CACHE] Shared tokenizer not available: {e}")
        
        # Fallback to cached models
        cache_key = f"{model_name}_{model_type}_{device}_{hash(str(sorted(kwargs.items())))}"
        
        if model_type == "tokenizer":
            cache_dict = self.shared_tokenizers
        else:
            cache_dict = self.shared_models
        
        if cache_key not in cache_dict:
            logger.info(f"[CACHE] Loading {model_type} model: {model_name}")
            
            try:
                if model_type == "transformer":
                    model = self._load_transformer_model(model_name, device, **kwargs)
                elif model_type == "tokenizer":
                    model = self._load_tokenizer_model(model_name, device, **kwargs)
                else:
                    raise ValueError(f"Unknown model type: {model_type}")
                
                cache_dict[cache_key] = model
                logger.info(f"[CACHE] {model_type} model cached: {cache_key}")
                
            except Exception as e:
                logger.error(f"[CACHE] Failed to load {model_type} model {model_name}: {e}")
                raise
        else:
            logger.debug(f"[CACHE] Using cached {model_type} model: {cache_key}")
        
        return cache_dict[cache_key]
    
    def _load_transformer_model(self, model_name: str, device: str, **kwargs) -> Any:
        """Load transformer model with memory optimizations."""
        from transformers import AutoModelForCausalLM, BitsAndBytesConfig
        
        # Memory-optimized loading configuration
        load_config = {
            "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
            "device_map": "auto" if torch.cuda.is_available() else None,
            "trust_remote_code": True,
            "attn_implementation": "eager",  # More memory efficient
        }
        
        # Add quantization if available and beneficial
        if kwargs.get('use_4bit_quantization', True) and torch.cuda.is_available():
            try:
                # Check if bitsandbytes is properly installed with CUDA support
                import bitsandbytes as bnb
                if hasattr(bnb, 'libbitsandbytes_cuda'):
                    quantization_config = BitsAndBytesConfig(
                        load_in_4bit=True,
                        bnb_4bit_compute_dtype=torch.float16,
                        bnb_4bit_use_double_quant=True,
                        bnb_4bit_quant_type="nf4"
                    )
                    load_config["quantization_config"] = quantization_config
                else:
                    logger.debug("[CACHE] BitsAndBytes CUDA support not available, skipping quantization")
            except (ImportError, AttributeError, Exception) as e:
                logger.debug(f"[CACHE] 4-bit quantization not available: {e}")
        
        # Remove problematic kwargs
        filtered_kwargs = {k: v for k, v in kwargs.items() 
                          if k not in ['use_4bit', 'dtype', 'use_4bit_quantization']}
        
        # Merge with user-provided kwargs
        load_config.update(filtered_kwargs)
        
        return AutoModelForCausalLM.from_pretrained(model_name, **load_config)
    
    def _load_tokenizer_model(self, model_name: str, device: str, **kwargs) -> Any:
        """Load tokenizer with memory optimizations."""
        from transformers import AutoTokenizer
        
        load_config = {
            "trust_remote_code": True,
        }
        load_config.update({k: v for k, v in kwargs.items() if k != 'use_4bit_quantization'})
        
        return AutoTokenizer.from_pretrained(model_name, **load_config)
    
    def clear_cache(self) -> None:
        """Clear all cached models."""
        self.shared_models.clear()
        self.shared_tokenizers.clear()
        logger.info("[CACHE] Model cache cleared")
    
    def get_stats(self) -> Dict:
        """Get cache statistics."""
        return {
            'shared_models': list(self.shared_models.keys()),
            'shared_tokenizers': list(self.shared_tokenizers.keys()),
            'use_shared_model': self.use_shared_model
        }