File size: 3,606 Bytes
dcc24f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
"""
FinEE MLX Backend - Apple Silicon optimized backend.
Uses mlx-lm for fast inference on M1/M2/M3 chips.
"""
import logging
from typing import Optional
from .base import BaseBackend, BackendLoadError
logger = logging.getLogger(__name__)
# Check for MLX availability
try:
import mlx.core as mx
from mlx_lm import load, generate
HAS_MLX = True
except ImportError:
HAS_MLX = False
class MLXBackend(BaseBackend):
"""
Apple Silicon (MLX) backend for fast local inference.
Requirements:
- Apple Silicon Mac (M1/M2/M3)
- mlx-lm package installed
"""
def __init__(self, model_id: str = "Ranjit0034/finance-entity-extractor",
adapter_path: str = "adapters"):
"""
Initialize MLX backend.
Args:
model_id: Hugging Face model ID
adapter_path: Path to LoRA adapters (relative to model)
"""
super().__init__(model_id)
self.adapter_path = adapter_path
def is_available(self) -> bool:
"""Check if MLX is available on this system."""
if not HAS_MLX:
return False
# Check if running on Apple Silicon
try:
import platform
if platform.system() != 'Darwin':
return False
if platform.processor() not in ('arm', 'arm64'):
return False
return True
except Exception:
return False
def load_model(self, model_path: Optional[str] = None) -> bool:
"""
Load model with MLX.
Args:
model_path: Optional local path (overrides model_id)
Returns:
True if successful
"""
if not HAS_MLX:
raise BackendLoadError("MLX not installed. Run: pip install mlx-lm")
path = model_path or self.model_id
try:
logger.info(f"Loading model with MLX: {path}")
# Load model with adapters
self._model, self._tokenizer = load(
path,
adapter_path=self.adapter_path
)
self._loaded = True
logger.info("MLX model loaded successfully")
return True
except Exception as e:
logger.error(f"Failed to load MLX model: {e}")
raise BackendLoadError(f"MLX model load failed: {e}")
def generate(self, prompt: str, max_tokens: int = 200,
temperature: float = 0.1, **kwargs) -> str:
"""
Generate text using MLX.
Args:
prompt: Input prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
Returns:
Generated text
"""
if not self._loaded:
self.load_model()
try:
response = generate(
self._model,
self._tokenizer,
prompt=prompt,
max_tokens=max_tokens,
temp=temperature,
verbose=False,
)
return response
except Exception as e:
logger.error(f"MLX generation failed: {e}")
return ""
def unload(self) -> None:
"""Free MLX model from memory."""
super().unload()
# Force garbage collection for MLX
try:
import gc
gc.collect()
except Exception:
pass
|