Spaces:
Sleeping
Sleeping
| """Inference engine using Hugging Face API or local transformers.""" | |
| import logging | |
| from typing import Any | |
| from huggingface_hub import InferenceClient | |
| from .config import Settings | |
| logger = logging.getLogger(__name__) | |
| class InferenceEngine: | |
| """Handles model loading and inference.""" | |
| def __init__(self, settings: Settings) -> None: | |
| """Initialize the inference engine.""" | |
| self.settings = settings | |
| self.client: InferenceClient | None = None | |
| self.pipeline = None | |
| self.model_loaded = False | |
| self.use_api = settings.use_api | |
| def load_model(self) -> None: | |
| """Load the model (API client or local pipeline).""" | |
| if self.use_api: | |
| self._init_api_client() | |
| else: | |
| self._init_local_pipeline() | |
| def _init_api_client(self) -> None: | |
| """Initialize the HF Inference API client.""" | |
| logger.info( | |
| "Initializing HF Inference API client for model: %s", | |
| self.settings.model_name, | |
| ) | |
| self.client = InferenceClient( | |
| model=self.settings.model_name, | |
| token=self.settings.api_token, | |
| ) | |
| self.model_loaded = True | |
| logger.info("HF Inference API client ready") | |
| def _init_local_pipeline(self) -> None: | |
| """Load the model locally using transformers.""" | |
| try: | |
| from transformers import pipeline | |
| except ImportError: | |
| raise ImportError( | |
| "transformers and torch are required for local inference. " | |
| "Install them with: pip install transformers torch" | |
| ) | |
| logger.info( | |
| "Loading local model: %s for task: %s", | |
| self.settings.model_name, | |
| self.settings.task, | |
| ) | |
| self.pipeline = pipeline( | |
| task=self.settings.task, | |
| model=self.settings.model_name, | |
| device=self.settings.device if self.settings.device != "cpu" else -1, | |
| ) | |
| self.model_loaded = True | |
| logger.info("Local model loaded successfully") | |
| def predict( | |
| self, inputs: str | list[str], parameters: dict[str, Any] | None = None | |
| ) -> list[Any]: | |
| """Run inference on the input(s).""" | |
| if not self.model_loaded: | |
| raise RuntimeError("Model not loaded") | |
| if self.use_api: | |
| return self._predict_api(inputs, parameters) | |
| else: | |
| return self._predict_local(inputs, parameters) | |
| def _predict_api( | |
| self, inputs: str | list[str], parameters: dict[str, Any] | None = None | |
| ) -> list[Any]: | |
| """Run inference using HF Inference API.""" | |
| params = parameters or {} | |
| task = self.settings.task | |
| if isinstance(inputs, str): | |
| inputs_list = [inputs] | |
| else: | |
| inputs_list = inputs | |
| results = [] | |
| for text in inputs_list: | |
| result = self._call_api(task, text, params) | |
| results.append(result) | |
| return results | |
| def _call_api(self, task: str, text: str, params: dict[str, Any]) -> Any: | |
| """Call the appropriate API method based on task.""" | |
| if task in ("text-classification", "sentiment-analysis"): | |
| return self.client.text_classification(text, **params) | |
| elif task == "text-generation": | |
| return self.client.text_generation(text, **params) | |
| elif task == "summarization": | |
| return self.client.summarization(text, **params) | |
| elif task == "translation": | |
| return self.client.translation(text, **params) | |
| elif task == "fill-mask": | |
| return self.client.fill_mask(text, **params) | |
| elif task == "question-answering": | |
| context = params.pop("context", "") | |
| return self.client.question_answering(question=text, context=context) | |
| elif task == "feature-extraction": | |
| return self.client.feature_extraction(text, **params) | |
| else: | |
| # Generic post for unsupported tasks | |
| return self.client.post(json={"inputs": text, **params}) | |
| def _predict_local( | |
| self, inputs: str | list[str], parameters: dict[str, Any] | None = None | |
| ) -> list[Any]: | |
| """Run inference using local transformers pipeline.""" | |
| params = parameters or {} | |
| results = self.pipeline(inputs, **params) | |
| if isinstance(inputs, str): | |
| return [results] | |
| return results | |