Spaces:
Sleeping
Sleeping
| """Load MOMENT-1-large once and cache as a module-level singleton. | |
| The momentfm wrapper hard-pins old transformers/numpy in its package metadata, | |
| but the actual code works on modern stacks — install with ``--no-deps``. | |
| Public API: | |
| get_model() -> MomentBundle | |
| Returns the cached (model, device_info) bundle, loading on first call. | |
| reset_model() | |
| Drops the cached model — useful in tests. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Any | |
| import structlog | |
| import torch | |
| from src.inference.rocm_check import DeviceInfo, detect_device | |
| log = structlog.get_logger() | |
| MODEL_NAME = os.getenv("MOMENT_MODEL", "AutonLab/MOMENT-1-large") | |
| TASK_NAME = "reconstruction" # MOMENT anomaly detection uses reconstruction error | |
| class MomentBundle: | |
| model: Any | |
| device: DeviceInfo | |
| dtype: torch.dtype | |
| _bundle: MomentBundle | None = None | |
| def _select_dtype(device: DeviceInfo) -> torch.dtype: | |
| # fp16 only pays off on real GPUs; CPU/MPS prefer fp32 for stability. | |
| if device.backend in ("rocm", "cuda"): | |
| return torch.float16 | |
| return torch.float32 | |
| def _load() -> MomentBundle: | |
| from momentfm import MOMENTPipeline # imported lazily; heavy dep | |
| device = detect_device() | |
| dtype = _select_dtype(device) | |
| log.info( | |
| "model_load_start", | |
| component="inference.model_loader", | |
| model=MODEL_NAME, | |
| device=device.torch_device, | |
| dtype=str(dtype), | |
| ) | |
| started = time.perf_counter() | |
| try: | |
| model = MOMENTPipeline.from_pretrained( | |
| MODEL_NAME, | |
| model_kwargs={"task_name": TASK_NAME}, | |
| ) | |
| model.init() | |
| model.to(device.torch_device).to(dtype) | |
| model.eval() | |
| except Exception as exc: | |
| log.error( | |
| "model_load_failed", | |
| component="inference.model_loader", | |
| model=MODEL_NAME, | |
| error=str(exc), | |
| ) | |
| raise | |
| elapsed_ms = (time.perf_counter() - started) * 1000.0 | |
| log.info( | |
| "model_load_complete", | |
| component="inference.model_loader", | |
| model=MODEL_NAME, | |
| device=device.torch_device, | |
| load_ms=round(elapsed_ms, 1), | |
| ) | |
| return MomentBundle(model=model, device=device, dtype=dtype) | |
| def get_model() -> MomentBundle: | |
| global _bundle | |
| if _bundle is None: | |
| _bundle = _load() | |
| return _bundle | |
| def reset_model() -> None: | |
| global _bundle | |
| _bundle = None | |