|
|
""" |
|
|
FinEE MLX Backend - Apple Silicon optimized backend. |
|
|
|
|
|
Uses mlx-lm for fast inference on M1/M2/M3 chips. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from typing import Optional |
|
|
|
|
|
from .base import BaseBackend, BackendLoadError |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
try: |
|
|
import mlx.core as mx |
|
|
from mlx_lm import load, generate |
|
|
HAS_MLX = True |
|
|
except ImportError: |
|
|
HAS_MLX = False |
|
|
|
|
|
|
|
|
class MLXBackend(BaseBackend): |
|
|
""" |
|
|
Apple Silicon (MLX) backend for fast local inference. |
|
|
|
|
|
Requirements: |
|
|
- Apple Silicon Mac (M1/M2/M3) |
|
|
- mlx-lm package installed |
|
|
""" |
|
|
|
|
|
def __init__(self, model_id: str = "Ranjit0034/finance-entity-extractor", |
|
|
adapter_path: str = "adapters"): |
|
|
""" |
|
|
Initialize MLX backend. |
|
|
|
|
|
Args: |
|
|
model_id: Hugging Face model ID |
|
|
adapter_path: Path to LoRA adapters (relative to model) |
|
|
""" |
|
|
super().__init__(model_id) |
|
|
self.adapter_path = adapter_path |
|
|
|
|
|
def is_available(self) -> bool: |
|
|
"""Check if MLX is available on this system.""" |
|
|
if not HAS_MLX: |
|
|
return False |
|
|
|
|
|
|
|
|
try: |
|
|
import platform |
|
|
if platform.system() != 'Darwin': |
|
|
return False |
|
|
if platform.processor() not in ('arm', 'arm64'): |
|
|
return False |
|
|
return True |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
def load_model(self, model_path: Optional[str] = None) -> bool: |
|
|
""" |
|
|
Load model with MLX. |
|
|
|
|
|
Args: |
|
|
model_path: Optional local path (overrides model_id) |
|
|
|
|
|
Returns: |
|
|
True if successful |
|
|
""" |
|
|
if not HAS_MLX: |
|
|
raise BackendLoadError("MLX not installed. Run: pip install mlx-lm") |
|
|
|
|
|
path = model_path or self.model_id |
|
|
|
|
|
try: |
|
|
logger.info(f"Loading model with MLX: {path}") |
|
|
|
|
|
|
|
|
self._model, self._tokenizer = load( |
|
|
path, |
|
|
adapter_path=self.adapter_path |
|
|
) |
|
|
|
|
|
self._loaded = True |
|
|
logger.info("MLX model loaded successfully") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load MLX model: {e}") |
|
|
raise BackendLoadError(f"MLX model load failed: {e}") |
|
|
|
|
|
def generate(self, prompt: str, max_tokens: int = 200, |
|
|
temperature: float = 0.1, **kwargs) -> str: |
|
|
""" |
|
|
Generate text using MLX. |
|
|
|
|
|
Args: |
|
|
prompt: Input prompt |
|
|
max_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
|
|
|
Returns: |
|
|
Generated text |
|
|
""" |
|
|
if not self._loaded: |
|
|
self.load_model() |
|
|
|
|
|
try: |
|
|
response = generate( |
|
|
self._model, |
|
|
self._tokenizer, |
|
|
prompt=prompt, |
|
|
max_tokens=max_tokens, |
|
|
temp=temperature, |
|
|
verbose=False, |
|
|
) |
|
|
|
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"MLX generation failed: {e}") |
|
|
return "" |
|
|
|
|
|
def unload(self) -> None: |
|
|
"""Free MLX model from memory.""" |
|
|
super().unload() |
|
|
|
|
|
|
|
|
try: |
|
|
import gc |
|
|
gc.collect() |
|
|
except Exception: |
|
|
pass |
|
|
|