File size: 4,442 Bytes
b98ed7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""Inference engine using Hugging Face API or local transformers."""

import logging
from typing import Any

from huggingface_hub import InferenceClient

from .config import Settings

logger = logging.getLogger(__name__)


class InferenceEngine:
    """Handles model loading and inference."""

    def __init__(self, settings: Settings) -> None:
        """Initialize the inference engine."""
        self.settings = settings
        self.client: InferenceClient | None = None
        self.pipeline = None
        self.model_loaded = False
        self.use_api = settings.use_api

    def load_model(self) -> None:
        """Load the model (API client or local pipeline)."""
        if self.use_api:
            self._init_api_client()
        else:
            self._init_local_pipeline()

    def _init_api_client(self) -> None:
        """Initialize the HF Inference API client."""
        logger.info(
            "Initializing HF Inference API client for model: %s",
            self.settings.model_name,
        )
        self.client = InferenceClient(
            model=self.settings.model_name,
            token=self.settings.api_token,
        )
        self.model_loaded = True
        logger.info("HF Inference API client ready")

    def _init_local_pipeline(self) -> None:
        """Load the model locally using transformers."""
        try:
            from transformers import pipeline
        except ImportError:
            raise ImportError(
                "transformers and torch are required for local inference. "
                "Install them with: pip install transformers torch"
            )

        logger.info(
            "Loading local model: %s for task: %s",
            self.settings.model_name,
            self.settings.task,
        )
        self.pipeline = pipeline(
            task=self.settings.task,
            model=self.settings.model_name,
            device=self.settings.device if self.settings.device != "cpu" else -1,
        )
        self.model_loaded = True
        logger.info("Local model loaded successfully")

    def predict(
        self, inputs: str | list[str], parameters: dict[str, Any] | None = None
    ) -> list[Any]:
        """Run inference on the input(s)."""
        if not self.model_loaded:
            raise RuntimeError("Model not loaded")

        if self.use_api:
            return self._predict_api(inputs, parameters)
        else:
            return self._predict_local(inputs, parameters)

    def _predict_api(
        self, inputs: str | list[str], parameters: dict[str, Any] | None = None
    ) -> list[Any]:
        """Run inference using HF Inference API."""
        params = parameters or {}
        task = self.settings.task

        if isinstance(inputs, str):
            inputs_list = [inputs]
        else:
            inputs_list = inputs

        results = []
        for text in inputs_list:
            result = self._call_api(task, text, params)
            results.append(result)

        return results

    def _call_api(self, task: str, text: str, params: dict[str, Any]) -> Any:
        """Call the appropriate API method based on task."""
        if task in ("text-classification", "sentiment-analysis"):
            return self.client.text_classification(text, **params)
        elif task == "text-generation":
            return self.client.text_generation(text, **params)
        elif task == "summarization":
            return self.client.summarization(text, **params)
        elif task == "translation":
            return self.client.translation(text, **params)
        elif task == "fill-mask":
            return self.client.fill_mask(text, **params)
        elif task == "question-answering":
            context = params.pop("context", "")
            return self.client.question_answering(question=text, context=context)
        elif task == "feature-extraction":
            return self.client.feature_extraction(text, **params)
        else:
            # Generic post for unsupported tasks
            return self.client.post(json={"inputs": text, **params})

    def _predict_local(
        self, inputs: str | list[str], parameters: dict[str, Any] | None = None
    ) -> list[Any]:
        """Run inference using local transformers pipeline."""
        params = parameters or {}
        results = self.pipeline(inputs, **params)

        if isinstance(inputs, str):
            return [results]
        return results