Spaces:
Sleeping
Sleeping
File size: 1,931 Bytes
ed85fe4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | import os
import glob
import importlib.util
from typing import Optional
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
_MODEL = None
def _import_model_definition(path: str):
try:
spec = importlib.util.spec_from_file_location('model_definition', path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
except Exception:
return None
def _find_weights() -> Optional[str]:
candidates = glob.glob(os.path.join(PROJECT_ROOT, 'models', '**', '*.h5'), recursive=True)
candidates += glob.glob(os.path.join(PROJECT_ROOT, 'models', '**', '*.hdf5'), recursive=True)
return candidates[0] if candidates else None
def _try_load_sciann_model():
try:
import sciann as sn
except Exception:
return None
model_def_path = os.path.join(PROJECT_ROOT, 'model_definition.py')
if os.path.exists(model_def_path):
mod = _import_model_definition(model_def_path)
if mod and hasattr(mod, 'create_model'):
try:
return mod.create_model()
except Exception:
pass
loader_path = os.path.join(PROJECT_ROOT, 'models', 'load_model.py')
if os.path.exists(loader_path):
mod = _import_model_definition(loader_path)
if mod and hasattr(mod, 'load_model'):
try:
return mod.load_model(PROJECT_ROOT)
except Exception:
pass
return None
def get_model():
global _MODEL
if _MODEL is not None:
return _MODEL
model = _try_load_sciann_model()
if model is None:
_MODEL = None
return None
weights = _find_weights()
if weights:
try:
if hasattr(model, 'load_weights'):
model.load_weights(weights)
except Exception:
pass
_MODEL = model
return _MODEL
|