File size: 3,781 Bytes
dcc24f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
"""
FinEE Backends - Abstract interface for LLM backends.
All LLM backends must implement this interface.
"""
from abc import ABC, abstractmethod
from typing import Optional, List, Dict, Any
import logging
logger = logging.getLogger(__name__)
class BaseBackend(ABC):
"""
Abstract base class for LLM backends.
Any backend (MLX, Transformers, llama.cpp) must implement these methods.
"""
def __init__(self, model_id: str = "Ranjit0034/finance-entity-extractor"):
"""
Initialize backend.
Args:
model_id: Hugging Face model ID or local path
"""
self.model_id = model_id
self._model = None
self._tokenizer = None
self._loaded = False
@property
def name(self) -> str:
"""Return backend name."""
return self.__class__.__name__
@abstractmethod
def is_available(self) -> bool:
"""
Check if this backend can be used on the current system.
Returns:
True if all dependencies are installed and hardware is compatible
"""
raise NotImplementedError
@abstractmethod
def load_model(self, model_path: Optional[str] = None) -> bool:
"""
Load the model into memory.
Args:
model_path: Optional local path (overrides model_id)
Returns:
True if model loaded successfully
"""
raise NotImplementedError
@abstractmethod
def generate(self, prompt: str, max_tokens: int = 200,
temperature: float = 0.1, **kwargs) -> str:
"""
Generate text from prompt.
Args:
prompt: Input prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
**kwargs: Additional generation parameters
Returns:
Generated text
"""
raise NotImplementedError
def generate_batch(self, prompts: List[str], max_tokens: int = 200,
temperature: float = 0.1, **kwargs) -> List[str]:
"""
Generate text for multiple prompts.
Default implementation calls generate() in a loop.
Backends may override for batch optimization.
Args:
prompts: List of input prompts
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
**kwargs: Additional generation parameters
Returns:
List of generated texts
"""
return [self.generate(p, max_tokens, temperature, **kwargs) for p in prompts]
def unload(self) -> None:
"""
Free model from memory.
Call this when done with the model to free GPU/system memory.
"""
self._model = None
self._tokenizer = None
self._loaded = False
logger.info(f"{self.name}: Model unloaded")
@property
def is_loaded(self) -> bool:
"""Check if model is currently loaded."""
return self._loaded
def get_info(self) -> Dict[str, Any]:
"""Get backend information."""
return {
'name': self.name,
'model_id': self.model_id,
'available': self.is_available(),
'loaded': self.is_loaded,
}
def __repr__(self) -> str:
status = "loaded" if self.is_loaded else "not loaded"
return f"{self.name}(model={self.model_id}, {status})"
class NoBackendError(Exception):
"""Raised when no LLM backend is available."""
pass
class BackendLoadError(Exception):
"""Raised when backend fails to load model."""
pass
|