lukhsaankumar commited on
Commit
3c57c7f
·
1 Parent(s): d0d4075

Deploy DeepFake Detector API - 2026-04-20 01:46:18

Browse files
Files changed (1) hide show
  1. app/services/model_registry.py +36 -20
app/services/model_registry.py CHANGED
@@ -152,17 +152,23 @@ class ModelRegistry:
152
  details={"repo_id": fusion_repo_id}
153
  )
154
 
155
- # Load submodels concurrently with a small bound to avoid
156
- # overwhelming the container while still reducing cold-start wall time.
 
 
 
 
 
 
157
  max_concurrent_loads = 2
158
  semaphore = asyncio.Semaphore(max_concurrent_loads)
159
 
160
- async def load_with_limit(repo_id: str):
161
  async with semaphore:
162
- return await self._load_submodel(repo_id)
163
 
164
  load_results = await asyncio.gather(
165
- *(load_with_limit(submodel_repo_id) for submodel_repo_id in submodel_repos),
166
  return_exceptions=True
167
  )
168
 
@@ -194,29 +200,39 @@ class ModelRegistry:
194
  self._is_loaded = True
195
  logger.info(f"Successfully loaded {len(self._submodels)} submodels and fusion model")
196
 
197
- async def _load_submodel(self, repo_id: str) -> BaseSubmodelWrapper:
198
  """
199
- Download and load a single submodel.
200
-
201
- Uses the config to determine the correct wrapper class.
202
-
203
- Args:
204
- repo_id: Hugging Face repository ID for the submodel
205
  """
206
- logger.info(f"Loading submodel: {repo_id}")
207
-
208
- # Download the repo
209
  local_path = await asyncio.to_thread(
210
  self._hf_service.download_repo, repo_id
211
  )
212
-
213
- # Read config
214
  config = self._read_config(local_path)
215
-
216
- # Select appropriate wrapper class based on config
217
  wrapper_class = get_wrapper_class(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  logger.info(f"Using wrapper class {wrapper_class.__name__} for {repo_id}")
219
-
220
  # Create and load wrapper
221
  wrapper = wrapper_class(
222
  repo_id=repo_id,
 
152
  details={"repo_id": fusion_repo_id}
153
  )
154
 
155
+ # Prepare submodels sequentially to avoid concurrent Hugging Face
156
+ # download contention, then load the already-downloaded artifacts in parallel.
157
+ prepared_submodels = []
158
+ for submodel_repo_id in submodel_repos:
159
+ prepared_submodels.append(
160
+ await self._prepare_submodel(submodel_repo_id)
161
+ )
162
+
163
  max_concurrent_loads = 2
164
  semaphore = asyncio.Semaphore(max_concurrent_loads)
165
 
166
+ async def load_with_limit(prepared_submodel):
167
  async with semaphore:
168
+ return await self._load_prepared_submodel(prepared_submodel)
169
 
170
  load_results = await asyncio.gather(
171
+ *(load_with_limit(prepared_submodel) for prepared_submodel in prepared_submodels),
172
  return_exceptions=True
173
  )
174
 
 
200
  self._is_loaded = True
201
  logger.info(f"Successfully loaded {len(self._submodels)} submodels and fusion model")
202
 
203
+ async def _prepare_submodel(self, repo_id: str) -> Dict[str, Any]:
204
  """
205
+ Download a submodel repository and prepare metadata for loading.
206
+
207
+ This stays sequential to avoid concurrent Hugging Face download issues.
 
 
 
208
  """
209
+ logger.info(f"Preparing submodel: {repo_id}")
210
+
 
211
  local_path = await asyncio.to_thread(
212
  self._hf_service.download_repo, repo_id
213
  )
 
 
214
  config = self._read_config(local_path)
 
 
215
  wrapper_class = get_wrapper_class(config)
216
+
217
+ return {
218
+ "repo_id": repo_id,
219
+ "local_path": local_path,
220
+ "config": config,
221
+ "wrapper_class": wrapper_class,
222
+ }
223
+
224
+ async def _load_prepared_submodel(self, prepared_submodel: Dict[str, Any]) -> BaseSubmodelWrapper:
225
+ """
226
+ Load a submodel that has already been downloaded and prepared.
227
+ """
228
+ repo_id = prepared_submodel["repo_id"]
229
+ local_path = prepared_submodel["local_path"]
230
+ config = prepared_submodel["config"]
231
+ wrapper_class = prepared_submodel["wrapper_class"]
232
+
233
+ logger.info(f"Loading submodel: {repo_id}")
234
  logger.info(f"Using wrapper class {wrapper_class.__name__} for {repo_id}")
235
+
236
  # Create and load wrapper
237
  wrapper = wrapper_class(
238
  repo_id=repo_id,