manikumargouni commited on
Commit
6c658b1
·
verified ·
1 Parent(s): 3557a12

Upload pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline.py +15 -4
pipeline.py CHANGED
@@ -42,6 +42,7 @@ Supported HF deployment surfaces
42
 
43
  from __future__ import annotations
44
 
 
45
  import sys
46
  from pathlib import Path
47
  from typing import Union
@@ -230,8 +231,14 @@ class AdmeshIntentPipeline(_HFPipeline):
230
  stacklevel=2,
231
  )
232
  else:
233
- from multitask_runtime import get_multitask_runtime # noqa: PLC0415
234
- from model_runtime import get_head # noqa: PLC0415
 
 
 
 
 
 
235
 
236
  rt = get_multitask_runtime()
237
  if rt._model is not None:
@@ -296,8 +303,12 @@ class AdmeshIntentPipeline(_HFPipeline):
296
 
297
  def _ensure_loaded(self) -> None:
298
  if self._classify_fn is None:
299
- from combined_inference import classify_query # noqa: PLC0415
300
- self._classify_fn = classify_query
 
 
 
 
301
 
302
  def __repr__(self) -> str:
303
  state = "loaded" if self._classify_fn is not None else "not yet loaded"
 
42
 
43
  from __future__ import annotations
44
 
45
+ import importlib
46
  import sys
47
  from pathlib import Path
48
  from typing import Union
 
231
  stacklevel=2,
232
  )
233
  else:
234
+ if __package__:
235
+ get_multitask_runtime = importlib.import_module(
236
+ f"{__package__}.multitask_runtime"
237
+ ).get_multitask_runtime
238
+ get_head = importlib.import_module(f"{__package__}.model_runtime").get_head
239
+ else:
240
+ get_multitask_runtime = importlib.import_module("multitask_runtime").get_multitask_runtime
241
+ get_head = importlib.import_module("model_runtime").get_head
242
 
243
  rt = get_multitask_runtime()
244
  if rt._model is not None:
 
303
 
304
  def _ensure_loaded(self) -> None:
305
  if self._classify_fn is None:
306
+ if __package__:
307
+ self._classify_fn = importlib.import_module(
308
+ f"{__package__}.combined_inference"
309
+ ).classify_query
310
+ else:
311
+ self._classify_fn = importlib.import_module("combined_inference").classify_query
312
 
313
  def __repr__(self) -> str:
314
  state = "loaded" if self._classify_fn is not None else "not yet loaded"