"""Inference engine using Hugging Face API or local transformers.""" import logging from typing import Any from huggingface_hub import InferenceClient from .config import Settings logger = logging.getLogger(__name__) class InferenceEngine: """Handles model loading and inference.""" def __init__(self, settings: Settings) -> None: """Initialize the inference engine.""" self.settings = settings self.client: InferenceClient | None = None self.pipeline = None self.model_loaded = False self.use_api = settings.use_api def load_model(self) -> None: """Load the model (API client or local pipeline).""" if self.use_api: self._init_api_client() else: self._init_local_pipeline() def _init_api_client(self) -> None: """Initialize the HF Inference API client.""" logger.info( "Initializing HF Inference API client for model: %s", self.settings.model_name, ) self.client = InferenceClient( model=self.settings.model_name, token=self.settings.api_token, ) self.model_loaded = True logger.info("HF Inference API client ready") def _init_local_pipeline(self) -> None: """Load the model locally using transformers.""" try: from transformers import pipeline except ImportError: raise ImportError( "transformers and torch are required for local inference. " "Install them with: pip install transformers torch" ) logger.info( "Loading local model: %s for task: %s", self.settings.model_name, self.settings.task, ) self.pipeline = pipeline( task=self.settings.task, model=self.settings.model_name, device=self.settings.device if self.settings.device != "cpu" else -1, ) self.model_loaded = True logger.info("Local model loaded successfully") def predict( self, inputs: str | list[str], parameters: dict[str, Any] | None = None ) -> list[Any]: """Run inference on the input(s).""" if not self.model_loaded: raise RuntimeError("Model not loaded") if self.use_api: return self._predict_api(inputs, parameters) else: return self._predict_local(inputs, parameters) def _predict_api( self, inputs: str | list[str], parameters: dict[str, Any] | None = None ) -> list[Any]: """Run inference using HF Inference API.""" params = parameters or {} task = self.settings.task if isinstance(inputs, str): inputs_list = [inputs] else: inputs_list = inputs results = [] for text in inputs_list: result = self._call_api(task, text, params) results.append(result) return results def _call_api(self, task: str, text: str, params: dict[str, Any]) -> Any: """Call the appropriate API method based on task.""" if task in ("text-classification", "sentiment-analysis"): return self.client.text_classification(text, **params) elif task == "text-generation": return self.client.text_generation(text, **params) elif task == "summarization": return self.client.summarization(text, **params) elif task == "translation": return self.client.translation(text, **params) elif task == "fill-mask": return self.client.fill_mask(text, **params) elif task == "question-answering": context = params.pop("context", "") return self.client.question_answering(question=text, context=context) elif task == "feature-extraction": return self.client.feature_extraction(text, **params) else: # Generic post for unsupported tasks return self.client.post(json={"inputs": text, **params}) def _predict_local( self, inputs: str | list[str], parameters: dict[str, Any] | None = None ) -> list[Any]: """Run inference using local transformers pipeline.""" params = parameters or {} results = self.pipeline(inputs, **params) if isinstance(inputs, str): return [results] return results