Patryk Studzinski commited on
Commit
b50a781
·
1 Parent(s): a7fd202

model-lazy-loading

Browse files
README.md CHANGED
@@ -11,9 +11,9 @@ This service provides an API for generating enhanced descriptions using multiple
11
  | Model | Size | Polish Support | Type |
12
  |-------|------|----------------|------|
13
  | Bielik-1.5B | 1.5B | Excellent | Local |
 
 
14
  | PLLuM-12B | 12B | Excellent | API |
15
- | Mistral-Small-3 | 24B | Good | API |
16
- | Gemma-2-9B | 9B | Medium | API |
17
 
18
  ## API Endpoints
19
 
@@ -25,6 +25,13 @@ This service provides an API for generating enhanced descriptions using multiple
25
  | `GET` | `/health` | API health check and model status |
26
  | `GET` | `/models` | List all available models |
27
 
 
 
 
 
 
 
 
28
  ### Generation
29
 
30
  | Method | Endpoint | Description |
@@ -34,18 +41,37 @@ This service provides an API for generating enhanced descriptions using multiple
34
 
35
  ---
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  ## Endpoint Details
38
 
39
  ### `GET /health`
40
 
41
- Check API status and model initialization.
42
 
43
  **Response:**
44
  ```json
45
  {
46
  "status": "ok",
47
- "local_models_initialized": true,
48
- "available_models": 4
 
49
  }
50
  ```
51
 
@@ -53,7 +79,7 @@ Check API status and model initialization.
53
 
54
  ### `GET /models`
55
 
56
- List all available models with their details.
57
 
58
  **Response:**
59
  ```json
@@ -64,13 +90,55 @@ List all available models with their details.
64
  "type": "local",
65
  "polish_support": "excellent",
66
  "size": "1.5B",
67
- "initialized": true
 
 
 
 
 
 
 
 
 
 
68
  }
69
  ]
70
  ```
71
 
72
  ---
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  ### `POST /enhance-description`
75
 
76
  Generate enhanced description using a single model.
@@ -119,7 +187,7 @@ Compare outputs from multiple models for the same input.
119
  "features": ["nawigacja", "klimatyzacja"],
120
  "condition": "bardzo dobry"
121
  },
122
- "models": ["bielik-1.5b", "pllum-12b", "gemma-2-9b"]
123
  }
124
  ```
125
 
 
11
  | Model | Size | Polish Support | Type |
12
  |-------|------|----------------|------|
13
  | Bielik-1.5B | 1.5B | Excellent | Local |
14
+ | Qwen2.5-3B | 3B | Good | Local |
15
+ | Gemma-2-2B | 2B | Medium | Local |
16
  | PLLuM-12B | 12B | Excellent | API |
 
 
17
 
18
  ## API Endpoints
19
 
 
25
  | `GET` | `/health` | API health check and model status |
26
  | `GET` | `/models` | List all available models |
27
 
28
+ ### Model Management (Lazy Loading)
29
+
30
+ | Method | Endpoint | Description |
31
+ |--------|----------|-------------|
32
+ | `POST` | `/models/{name}/load` | Load a model into memory |
33
+ | `POST` | `/models/{name}/unload` | Unload a model from memory |
34
+
35
  ### Generation
36
 
37
  | Method | Endpoint | Description |
 
41
 
42
  ---
43
 
44
+ ## Lazy Loading
45
+
46
+ Models are **not loaded at startup** to conserve memory. Instead:
47
+ - Models are loaded **on first request** (lazy loading)
48
+ - Only **one local model** is loaded at a time
49
+ - Switching to a different local model **automatically unloads** the previous one
50
+ - API models (PLLuM) don't affect local model memory
51
+
52
+ ### Example: Load/Unload Flow
53
+ ```
54
+ 1. Request with bielik-1.5b → Loads Bielik (first use)
55
+ 2. Request with qwen2.5-3b → Unloads Bielik, loads Qwen
56
+ 3. Request with pllum-12b → Qwen stays loaded (API model doesn't affect local)
57
+ 4. POST /models/qwen2.5-3b/unload → Manually free memory
58
+ ```
59
+
60
+ ---
61
+
62
  ## Endpoint Details
63
 
64
  ### `GET /health`
65
 
66
+ Check API status and loaded models.
67
 
68
  **Response:**
69
  ```json
70
  {
71
  "status": "ok",
72
+ "available_models": 4,
73
+ "loaded_models": ["bielik-1.5b"],
74
+ "active_local_model": "bielik-1.5b"
75
  }
76
  ```
77
 
 
79
 
80
  ### `GET /models`
81
 
82
+ List all available models with their load status.
83
 
84
  **Response:**
85
  ```json
 
90
  "type": "local",
91
  "polish_support": "excellent",
92
  "size": "1.5B",
93
+ "loaded": true,
94
+ "active": true
95
+ },
96
+ {
97
+ "name": "qwen2.5-3b",
98
+ "model_id": "Qwen/Qwen2.5-3B-Instruct",
99
+ "type": "local",
100
+ "polish_support": "good",
101
+ "size": "3B",
102
+ "loaded": false,
103
+ "active": false
104
  }
105
  ]
106
  ```
107
 
108
  ---
109
 
110
+ ### `POST /models/{name}/load`
111
+
112
+ Explicitly load a model. For local models, unloads the previous one first.
113
+
114
+ **Response:**
115
+ ```json
116
+ {
117
+ "status": "loaded",
118
+ "model": {
119
+ "name": "bielik-1.5b",
120
+ "loaded": true,
121
+ "active": true
122
+ }
123
+ }
124
+ ```
125
+
126
+ ---
127
+
128
+ ### `POST /models/{name}/unload`
129
+
130
+ Explicitly unload a model to free memory.
131
+
132
+ **Response:**
133
+ ```json
134
+ {
135
+ "status": "unloaded",
136
+ "model": "bielik-1.5b"
137
+ }
138
+ ```
139
+
140
+ ---
141
+
142
  ### `POST /enhance-description`
143
 
144
  Generate enhanced description using a single model.
 
187
  "features": ["nawigacja", "klimatyzacja"],
188
  "condition": "bardzo dobry"
189
  },
190
+ "models": ["bielik-1.5b", "qwen2.5-3b", "gemma-2-2b", "pllum-12b"]
191
  }
192
  ```
193
 
app/main.py CHANGED
@@ -38,14 +38,12 @@ app.add_middleware(
38
 
39
  @app.on_event("startup")
40
  async def startup_event():
41
- """Initialize local models at startup."""
42
- print("Starting up and initializing local models...")
43
- try:
44
- await registry.initialize_local_models()
45
- print("Local models initialized successfully.")
46
- except Exception as e:
47
- print(f"Error during model initialization: {e}")
48
- raise
49
 
50
  # --- Helper function to load domain logic ---
51
  def get_domain_config(domain: str):
@@ -65,18 +63,49 @@ async def read_root():
65
  async def health_check():
66
  """Check API health and model status."""
67
  models = registry.list_models()
68
- local_initialized = any(m["initialized"] for m in models if m["type"] == "local")
 
69
  return {
70
  "status": "ok",
71
- "local_models_initialized": local_initialized,
72
  "available_models": len(models),
 
 
73
  }
74
 
75
  @app.get("/models", response_model=List[ModelInfo])
76
  async def list_models():
77
- """List all available models."""
78
  return registry.list_models()
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  @app.post("/enhance-description", response_model=EnhancedDescriptionResponse)
81
  async def enhance_description(
82
  domain: str = Body(..., embed=True),
 
38
 
39
  @app.on_event("startup")
40
  async def startup_event():
41
+ """
42
+ Startup event - models are loaded lazily on first request.
43
+ No models are pre-loaded to conserve memory.
44
+ """
45
+ print("Application started. Models will be loaded lazily on first request.")
46
+ print(f"Available models: {registry.get_available_model_names()}")
 
 
47
 
48
  # --- Helper function to load domain logic ---
49
  def get_domain_config(domain: str):
 
63
  async def health_check():
64
  """Check API health and model status."""
65
  models = registry.list_models()
66
+ loaded_models = registry.get_loaded_models()
67
+ active_model = registry.get_active_model()
68
  return {
69
  "status": "ok",
 
70
  "available_models": len(models),
71
+ "loaded_models": loaded_models,
72
+ "active_local_model": active_model,
73
  }
74
 
75
  @app.get("/models", response_model=List[ModelInfo])
76
  async def list_models():
77
+ """List all available models with their load status."""
78
  return registry.list_models()
79
 
80
+ @app.post("/models/{model_name}/load")
81
+ async def load_model(model_name: str):
82
+ """
83
+ Explicitly load a model into memory.
84
+ For local models: unloads any previously loaded local model first.
85
+ """
86
+ if model_name not in registry.get_available_model_names():
87
+ raise HTTPException(status_code=404, detail=f"Unknown model: {model_name}")
88
+
89
+ try:
90
+ info = await registry.load_model(model_name)
91
+ return {"status": "loaded", "model": info}
92
+ except Exception as e:
93
+ raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
94
+
95
+ @app.post("/models/{model_name}/unload")
96
+ async def unload_model(model_name: str):
97
+ """
98
+ Explicitly unload a model from memory to free resources.
99
+ """
100
+ if model_name not in registry.get_available_model_names():
101
+ raise HTTPException(status_code=404, detail=f"Unknown model: {model_name}")
102
+
103
+ try:
104
+ result = await registry.unload_model(model_name)
105
+ return result
106
+ except Exception as e:
107
+ raise HTTPException(status_code=500, detail=f"Failed to unload model: {str(e)}")
108
+
109
  @app.post("/enhance-description", response_model=EnhancedDescriptionResponse)
110
  async def enhance_description(
111
  domain: str = Body(..., embed=True),
app/models/huggingface_local.py CHANGED
@@ -131,3 +131,19 @@ class HuggingFaceLocal(BaseLLM):
131
  "initialized": self._initialized,
132
  "device": self.device
133
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  "initialized": self._initialized,
132
  "device": self.device
133
  }
134
+
135
+ async def cleanup(self) -> None:
136
+ """Release model from memory."""
137
+ if self.pipeline is not None:
138
+ del self.pipeline
139
+ self.pipeline = None
140
+ if self.tokenizer is not None:
141
+ del self.tokenizer
142
+ self.tokenizer = None
143
+ self._initialized = False
144
+
145
+ # Force CUDA cache clear if available
146
+ if torch.cuda.is_available():
147
+ torch.cuda.empty_cache()
148
+
149
+ print(f"[{self.name}] Model unloaded from memory")
app/models/registry.py CHANGED
@@ -1,8 +1,10 @@
1
  """
2
  Model Registry - Central configuration and factory for all LLM models.
 
3
  """
4
 
5
  import os
 
6
  from typing import Dict, List, Any, Optional
7
 
8
  from app.models.base_llm import BaseLLM
@@ -10,7 +12,7 @@ from app.models.huggingface_local import HuggingFaceLocal
10
  from app.models.huggingface_inference_api import HuggingFaceInferenceAPI
11
 
12
 
13
- # Model configuration
14
  MODEL_CONFIG = {
15
  "bielik-1.5b": {
16
  "id": "speakleash/Bielik-1.5B-v3.0-Instruct",
@@ -18,24 +20,24 @@ MODEL_CONFIG = {
18
  "polish_support": "excellent",
19
  "size": "1.5B",
20
  },
 
 
 
 
 
 
 
 
 
 
 
 
21
  "pllum-12b": {
22
  "id": "CYFRAGOVPL/PLLuM-12B-instruct",
23
  "type": "inference_api",
24
  "polish_support": "excellent",
25
  "size": "12B",
26
  },
27
- "mistral-small-3": {
28
- "id": "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
29
- "type": "inference_api",
30
- "polish_support": "good",
31
- "size": "24B",
32
- },
33
- "gemma-2-9b": {
34
- "id": "google/gemma-2-9b-it",
35
- "type": "inference_api",
36
- "polish_support": "medium",
37
- "size": "9B",
38
- },
39
  }
40
 
41
  # For local model override (when model is pre-downloaded in container)
@@ -45,12 +47,14 @@ LOCAL_MODEL_PATH = os.getenv("LOCAL_MODEL_PATH", "/app/pretrain_model")
45
  class ModelRegistry:
46
  """
47
  Central registry for managing all LLM models.
48
- Handles model instantiation, initialization, and access.
 
49
  """
50
 
51
  def __init__(self):
52
  self._models: Dict[str, BaseLLM] = {}
53
  self._config = MODEL_CONFIG.copy()
 
54
 
55
  def _create_model(self, name: str) -> BaseLLM:
56
  """Factory method to create model instance."""
@@ -80,44 +84,119 @@ class ModelRegistry:
80
  else:
81
  raise ValueError(f"Unknown model type: {model_type}")
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  async def get_model(self, name: str) -> BaseLLM:
84
- """Get or create and initialize a model."""
 
 
 
 
 
 
85
 
86
- if name not in self._models:
87
- model = self._create_model(name)
88
- await model.initialize()
89
- self._models[name] = model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  return self._models[name]
92
 
93
- async def initialize_model(self, name: str) -> None:
94
- """Pre-initialize a specific model."""
 
 
 
95
  await self.get_model(name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- async def initialize_local_models(self) -> None:
98
- """Initialize all local models at startup."""
99
- for name, config in self._config.items():
100
- if config["type"] == "local":
101
- await self.initialize_model(name)
 
 
 
 
 
 
 
 
 
 
102
 
103
  def list_models(self) -> List[Dict[str, Any]]:
104
  """List all available models with their info."""
105
- models = []
106
- for name, config in self._config.items():
107
- model_info = {
108
- "name": name,
109
- "model_id": config["id"],
110
- "type": config["type"],
111
- "polish_support": config["polish_support"],
112
- "size": config["size"],
113
- "initialized": name in self._models and self._models[name].is_initialized,
114
- }
115
- models.append(model_info)
116
- return models
117
 
118
  def get_available_model_names(self) -> List[str]:
119
  """Get list of available model names."""
120
  return list(self._config.keys())
 
 
 
 
 
 
 
 
121
 
122
 
123
  # Global registry instance
 
1
  """
2
  Model Registry - Central configuration and factory for all LLM models.
3
+ Supports lazy loading and on/off mechanism for memory management.
4
  """
5
 
6
  import os
7
+ import gc
8
  from typing import Dict, List, Any, Optional
9
 
10
  from app.models.base_llm import BaseLLM
 
12
  from app.models.huggingface_inference_api import HuggingFaceInferenceAPI
13
 
14
 
15
+ # Model configuration - 3 local + 1 API for Polish language comparison
16
  MODEL_CONFIG = {
17
  "bielik-1.5b": {
18
  "id": "speakleash/Bielik-1.5B-v3.0-Instruct",
 
20
  "polish_support": "excellent",
21
  "size": "1.5B",
22
  },
23
+ "qwen2.5-3b": {
24
+ "id": "Qwen/Qwen2.5-3B-Instruct",
25
+ "type": "local",
26
+ "polish_support": "good",
27
+ "size": "3B",
28
+ },
29
+ "gemma-2-2b": {
30
+ "id": "google/gemma-2-2b-it",
31
+ "type": "local",
32
+ "polish_support": "medium",
33
+ "size": "2B",
34
+ },
35
  "pllum-12b": {
36
  "id": "CYFRAGOVPL/PLLuM-12B-instruct",
37
  "type": "inference_api",
38
  "polish_support": "excellent",
39
  "size": "12B",
40
  },
 
 
 
 
 
 
 
 
 
 
 
 
41
  }
42
 
43
  # For local model override (when model is pre-downloaded in container)
 
47
  class ModelRegistry:
48
  """
49
  Central registry for managing all LLM models.
50
+ Supports lazy loading (load on first request) and unloading for memory management.
51
+ Only one local model is loaded at a time to conserve memory.
52
  """
53
 
54
  def __init__(self):
55
  self._models: Dict[str, BaseLLM] = {}
56
  self._config = MODEL_CONFIG.copy()
57
+ self._active_local_model: Optional[str] = None
58
 
59
  def _create_model(self, name: str) -> BaseLLM:
60
  """Factory method to create model instance."""
 
84
  else:
85
  raise ValueError(f"Unknown model type: {model_type}")
86
 
87
+ async def _unload_model(self, name: str) -> None:
88
+ """Unload a model from memory."""
89
+ if name in self._models:
90
+ model = self._models[name]
91
+ # Call cleanup if available
92
+ if hasattr(model, 'cleanup'):
93
+ await model.cleanup()
94
+ del self._models[name]
95
+ gc.collect() # Force garbage collection
96
+ print(f"Model '{name}' unloaded from memory.")
97
+
98
+ async def _unload_all_local_models(self) -> None:
99
+ """Unload all local models to free memory."""
100
+ local_models = [
101
+ name for name, config in self._config.items()
102
+ if config["type"] == "local" and name in self._models
103
+ ]
104
+ for name in local_models:
105
+ await self._unload_model(name)
106
+ self._active_local_model = None
107
+
108
  async def get_model(self, name: str) -> BaseLLM:
109
+ """
110
+ Get a model (lazy loading).
111
+ For local models: unloads any previously loaded local model first.
112
+ For API models: always available without affecting local models.
113
+ """
114
+ if name not in self._config:
115
+ raise ValueError(f"Unknown model: {name}")
116
 
117
+ config = self._config[name]
118
+
119
+ # If it's a local model, ensure only one is loaded at a time
120
+ if config["type"] == "local":
121
+ # Unload current local model if different
122
+ if self._active_local_model and self._active_local_model != name:
123
+ print(f"Switching from '{self._active_local_model}' to '{name}'...")
124
+ await self._unload_model(self._active_local_model)
125
+
126
+ # Load the requested model if not already loaded
127
+ if name not in self._models:
128
+ print(f"Loading model '{name}'...")
129
+ model = self._create_model(name)
130
+ await model.initialize()
131
+ self._models[name] = model
132
+ self._active_local_model = name
133
+ print(f"Model '{name}' loaded successfully.")
134
+
135
+ # For API models, just create/return (no memory concern)
136
+ elif config["type"] == "inference_api":
137
+ if name not in self._models:
138
+ print(f"Initializing API model '{name}'...")
139
+ model = self._create_model(name)
140
+ await model.initialize()
141
+ self._models[name] = model
142
 
143
  return self._models[name]
144
 
145
+ async def load_model(self, name: str) -> Dict[str, Any]:
146
+ """
147
+ Explicitly load a model (unloads other local models first).
148
+ Returns model info.
149
+ """
150
  await self.get_model(name)
151
+ return self.get_model_info(name)
152
+
153
+ async def unload_model(self, name: str) -> Dict[str, str]:
154
+ """
155
+ Explicitly unload a model from memory.
156
+ """
157
+ if name not in self._config:
158
+ raise ValueError(f"Unknown model: {name}")
159
+
160
+ if name not in self._models:
161
+ return {"status": "not_loaded", "model": name}
162
+
163
+ await self._unload_model(name)
164
+ if self._active_local_model == name:
165
+ self._active_local_model = None
166
+
167
+ return {"status": "unloaded", "model": name}
168
 
169
+ def get_model_info(self, name: str) -> Dict[str, Any]:
170
+ """Get info about a specific model."""
171
+ if name not in self._config:
172
+ raise ValueError(f"Unknown model: {name}")
173
+
174
+ config = self._config[name]
175
+ return {
176
+ "name": name,
177
+ "model_id": config["id"],
178
+ "type": config["type"],
179
+ "polish_support": config["polish_support"],
180
+ "size": config["size"],
181
+ "loaded": name in self._models,
182
+ "active": name == self._active_local_model if config["type"] == "local" else None,
183
+ }
184
 
185
  def list_models(self) -> List[Dict[str, Any]]:
186
  """List all available models with their info."""
187
+ return [self.get_model_info(name) for name in self._config.keys()]
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  def get_available_model_names(self) -> List[str]:
190
  """Get list of available model names."""
191
  return list(self._config.keys())
192
+
193
+ def get_active_model(self) -> Optional[str]:
194
+ """Get the currently active (loaded) local model name."""
195
+ return self._active_local_model
196
+
197
+ def get_loaded_models(self) -> List[str]:
198
+ """Get list of currently loaded model names."""
199
+ return list(self._models.keys())
200
 
201
 
202
  # Global registry instance
app/schemas/schemas.py CHANGED
@@ -15,7 +15,8 @@ class ModelInfo(BaseModel):
15
  type: str
16
  polish_support: str
17
  size: str
18
- initialized: bool
 
19
 
20
 
21
  class CompareRequest(BaseModel):
 
15
  type: str
16
  polish_support: str
17
  size: str
18
+ loaded: bool
19
+ active: Optional[bool] = None # Only for local models
20
 
21
 
22
  class CompareRequest(BaseModel):