File size: 7,876 Bytes
a2e1879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# llm_clients/shared_models.py
"""
Shared model manager to avoid loading the same model multiple times.
This significantly improves memory usage and startup time.
"""

from typing import Optional, Dict, Any, Tuple
import threading
import os

class SharedModelManager:
    """Singleton class to manage shared model instances"""
    
    _instance = None
    _lock = threading.Lock()
    _models: Dict[str, Any] = {}
    _model_components: Dict[str, Dict[str, Any]] = {}  # Store actual model components
    
    def __new__(cls):
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    cls._instance = super().__new__(cls)
        return cls._instance
    
    def get_finetuned_model_components(self, model_name: str = "zazaman/fmb") -> Optional[Dict[str, Any]]:
        """
        Get or load shared model components (model, tokenizer, classifier).
        
        Args:
            model_name: Name of the model to load
            
        Returns:
            Dict with 'model', 'tokenizer', 'classifier' components or None if loading fails
        """
        model_key = f"finetuned_components_{model_name}"
        
        if model_key not in self._model_components:
            try:
                print(f"πŸ”„ Loading shared finetuned model components: {model_name}")
                
                # Import here to avoid circular imports
                from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
                import torch
                
                # Set up cache directory for HF Spaces compatibility
                if not os.getenv('HF_HOME'):
                    cache_dir = os.path.expanduser("~/.cache/huggingface")
                    os.environ['HF_HOME'] = cache_dir
                    os.environ['TRANSFORMERS_CACHE'] = os.path.join(cache_dir, 'transformers')
                    
                    # Create cache directories if they don't exist
                    os.makedirs(cache_dir, exist_ok=True)
                    os.makedirs(os.path.join(cache_dir, 'transformers'), exist_ok=True)
                    print(f"   πŸ“ Using cache directory: {cache_dir}")
                
                # Apply optimizations
                torch._dynamo.config.suppress_errors = True
                torch._dynamo.config.disable = True
                os.environ["TORCH_COMPILE_DISABLE"] = "1"
                os.environ["TORCHDYNAMO_DISABLE"] = "1"
                os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
                
                print(f"   πŸ“₯ Downloading model from Hugging Face: {model_name}")
                
                # Load model and tokenizer with explicit cache directory
                model = AutoModelForSequenceClassification.from_pretrained(
                    model_name,
                    torch_dtype=torch.float32,
                    device_map=None,
                    cache_dir=os.environ.get('TRANSFORMERS_CACHE'),
                    local_files_only=False,  # Allow downloading
                    trust_remote_code=False  # Security best practice
                )
                tokenizer = AutoTokenizer.from_pretrained(
                    model_name,
                    cache_dir=os.environ.get('TRANSFORMERS_CACHE'),
                    local_files_only=False,
                    trust_remote_code=False
                )
                
                # Disable compilation
                if hasattr(model, '_compiler_config'):
                    model._compiler_config = None
                
                # Move to CPU
                device = "cpu"
                model = model.to(device)
                
                print(f"   🧠 Creating classifier pipeline...")
                
                # Create classifier pipeline
                classifier = pipeline(
                    "text-classification", 
                    model=model, 
                    tokenizer=tokenizer,
                    device=device,
                    framework="pt",
                    torch_dtype=torch.float32
                )
                
                # Store components
                self._model_components[model_key] = {
                    "model": model,
                    "tokenizer": tokenizer,
                    "classifier": classifier,
                    "device": device,
                    "model_name": model_name
                }
                
                print(f"βœ… Shared finetuned model components loaded successfully: {model_name}")
                print(f"   Device: {device}")
                print(f"   Cache: {os.environ.get('TRANSFORMERS_CACHE', 'default')}")
                
            except PermissionError as e:
                print(f"❌ Permission error loading model {model_name}: {e}")
                print(f"   This might be a cache directory issue in the deployment environment.")
                print(f"   Suggestion: Check HF_HOME and cache directory permissions.")
                self._model_components[model_key] = None
                return None
            except Exception as e:
                print(f"❌ Failed to load shared finetuned model components {model_name}: {e}")
                print(f"   Error type: {type(e).__name__}")
                if "connection" in str(e).lower() or "network" in str(e).lower():
                    print(f"   This appears to be a network issue. Check internet connectivity.")
                elif "disk" in str(e).lower() or "space" in str(e).lower():
                    print(f"   This appears to be a disk space issue.")
                self._model_components[model_key] = None
                return None
        
        return self._model_components[model_key]
    
    def get_finetuned_guard_client(self, model_name: str = "zazaman/fmb") -> Optional[Any]:
        """
        Get or create a shared FinetunedGuardClient instance that uses shared model components.
        
        Args:
            model_name: Name of the model to load
            
        Returns:
            FinetunedGuardClient instance or None if loading fails
        """
        model_key = f"finetuned_guard_{model_name}"
        
        if model_key not in self._models:
            try:
                # Get shared model components
                components = self.get_finetuned_model_components(model_name)
                if not components:
                    return None
                
                from .finetuned_guard import FinetunedGuardClient
                
                print(f"   πŸ” Creating FinetunedGuardClient with shared model components: {model_name}")
                
                model_config = {
                    "model_name": model_name
                }
                
                # Create client that will use shared components
                client = FinetunedGuardClient(model_config, "", shared_components=components)
                self._models[model_key] = client
                
                print(f"βœ… Shared finetuned guard client created successfully: {model_name}")
                
            except Exception as e:
                print(f"❌ Failed to create shared finetuned guard client {model_name}: {e}")
                self._models[model_key] = None
                return None
        
        return self._models[model_key]
    
    def clear_models(self):
        """Clear all cached models (useful for testing)"""
        self._models.clear()
        self._model_components.clear()
    
    def get_model_info(self) -> Dict[str, bool]:
        """Get information about loaded models"""
        return {
            model_key: model is not None 
            for model_key, model in self._models.items()
        }

# Global singleton instance
shared_model_manager = SharedModelManager()