File size: 12,535 Bytes
b751bb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63c7c7b
 
 
 
 
 
b751bb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63c7c7b
 
 
 
 
b751bb5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
"""
AdmeshIntentPipeline β€” transformers.Pipeline subclass for
admesh/agentic-intent-classifier.

Because config.json declares "pt": [] the transformers pipeline() loader
skips AutoModel.from_pretrained() entirely and passes model=None straight
to this class.  All model loading is handled internally via combined_inference,
which resolves paths relative to __file__ so it works wherever HF downloads
the repo (Inference Endpoints, Spaces, local snapshot_download, etc.).

Supported HF deployment surfaces
---------------------------------
1. transformers.pipeline() direct call (trust_remote_code=True):

       from transformers import pipeline
       clf = pipeline(
           "admesh-intent",
           model="admesh/agentic-intent-classifier",
           trust_remote_code=True,
       )
       result = clf("Which laptop should I buy for college?")

2. HF Inference Endpoints β€” Standard (PyTorch, trust_remote_code=True):
   Deploy from https://ui.endpoints.huggingface.co β€” no custom container
   needed; HF loads this pipeline class automatically.

3. HF Spaces (Gradio / Streamlit):

       import sys
       from huggingface_hub import snapshot_download
       local_dir = snapshot_download("admesh/agentic-intent-classifier", repo_type="model")
       sys.path.insert(0, local_dir)
       from pipeline import AdmeshIntentPipeline
       clf = AdmeshIntentPipeline()
       result = clf("I need a CRM for a 5-person startup")

4. Anywhere via from_pretrained():

       from pipeline import AdmeshIntentPipeline
       clf = AdmeshIntentPipeline.from_pretrained("admesh/agentic-intent-classifier")
"""

from __future__ import annotations

import sys
from pathlib import Path
from typing import Union

# ── try to import transformers.Pipeline; fall back gracefully if absent ───────
try:
    from transformers import Pipeline as _HFPipeline
    _TRANSFORMERS_AVAILABLE = True
except ImportError:
    _HFPipeline = object  # bare object as base when transformers is not installed
    _TRANSFORMERS_AVAILABLE = False


class AdmeshIntentPipeline(_HFPipeline):
    """
    Full intent + IAB classification pipeline.

    Inherits from ``transformers.Pipeline`` so it works natively with
    ``pipeline()``, HF Inference Endpoints (standard mode), and HF Spaces.

    When ``transformers`` is not installed it falls back to a plain callable
    class so the same code works in minimal environments too.

    Parameters
    ----------
    model:
        Ignored β€” we load all models internally.  Present only to satisfy
        the ``transformers.Pipeline`` interface when HF calls
        ``PipelineClass(model=None, ...)``.
    **kwargs:
        Forwarded to ``transformers.Pipeline.__init__`` if transformers is
        available, otherwise ignored.
    """

    # ── init ──────────────────────────────────────────────────────────────────

    def __init__(self, model=None, tokenizer=None, **kwargs):
        # Ensure this repo's directory is on sys.path so all relative imports
        # in combined_inference / config / model_runtime resolve correctly.
        # Path(__file__) points to wherever HF cached the repo snapshot.
        _repo_dir = Path(__file__).resolve().parent
        if str(_repo_dir) not in sys.path:
            sys.path.insert(0, str(_repo_dir))

        if _TRANSFORMERS_AVAILABLE:
            import torch

            # transformers.Pipeline requires certain attributes to be set.
            # Because config.json has "pt": [] HF passes model=None here β€”
            # we satisfy the interface by setting the minimum required attrs
            # manually instead of calling super().__init__(model=None, ...)
            # which would raise inside infer_framework_load_model().
            self.task = kwargs.pop("task", "admesh-intent")
            self.model = model          # None β€” unused, kept for interface compat
            self.tokenizer = tokenizer  # None β€” unused
            self.feature_extractor = None
            self.image_processor = None
            self.modelcard = None
            self.framework = "pt"
            self.device = torch.device(kwargs.pop("device", "cpu"))
            self.binary_output = kwargs.pop("binary_output", False)
            self.call_count = 0
            self._batch_size = kwargs.pop("batch_size", 1)
            self._num_workers = kwargs.pop("num_workers", 0)
            self._preprocess_params: dict = {}
            self._forward_params: dict = {}
            self._postprocess_params: dict = {}
        # else: plain object, no init needed

        self._classify_fn = None  # lazy-loaded on first __call__

    # ── transformers.Pipeline abstract methods ────────────────────────────────
    # These are required by the ABC but our __call__ override bypasses them.
    # They are still implemented in case a caller invokes them directly.

    def _sanitize_parameters(self, **kwargs):
        forward_kwargs = {}
        if "threshold_overrides" in kwargs:
            forward_kwargs["threshold_overrides"] = kwargs["threshold_overrides"]
        if "force_iab_placeholder" in kwargs:
            forward_kwargs["force_iab_placeholder"] = kwargs["force_iab_placeholder"]
        return {}, forward_kwargs, {}

    def preprocess(self, inputs):
        return {"text": inputs if isinstance(inputs, str) else str(inputs)}

    def _forward(self, model_inputs, threshold_overrides=None, force_iab_placeholder=False):
        self._ensure_loaded()
        return self._classify_fn(
            model_inputs["text"],
            threshold_overrides=threshold_overrides,
            force_iab_placeholder=force_iab_placeholder,
        )

    def postprocess(self, model_outputs):
        return model_outputs

    # ── __call__ override ─────────────────────────────────────────────────────
    # We bypass Pipeline's preprocess→_forward→postprocess chain entirely so
    # we never touch self.model and keep full control over batching logic.

    def __call__(
        self,
        inputs: Union[str, list[str]],
        *,
        threshold_overrides: dict[str, float] | None = None,
        force_iab_placeholder: bool = False,
    ) -> Union[dict, list[dict]]:
        """
        Classify one or more query strings.

        Parameters
        ----------
        inputs:
            A single query string or a list of query strings.
        threshold_overrides:
            Optional per-head confidence threshold overrides, e.g.
            ``{"intent_type": 0.5, "iab_content": 0.3}``.
        force_iab_placeholder:
            Skip IAB classifier and return placeholder values (faster,
            no IAB accuracy).

        Returns
        -------
        dict or list[dict]:
            Full classification payload matching the combined_inference schema.
            Returns a single dict for a string input, list of dicts for a list.

        Examples
        --------
        ::

            clf = pipeline("admesh-intent", model="admesh/agentic-intent-classifier",
                           trust_remote_code=True)

            # single
            result = clf("Which laptop should I buy for college?")

            # batch
            results = clf(["Best running shoes", "How does TCP work?"])

            # custom thresholds
            result = clf("Buy headphones", threshold_overrides={"intent_type": 0.6})
        """
        self._ensure_loaded()

        single = isinstance(inputs, str)
        texts: list[str] = [inputs] if single else list(inputs)

        results = [
            self._classify_fn(
                text,
                threshold_overrides=threshold_overrides,
                force_iab_placeholder=force_iab_placeholder,
            )
            for text in texts
        ]
        return results[0] if single else results

    # ── warm-up / compile ─────────────────────────────────────────────────────

    def warm_up(self, compile: bool = False) -> "AdmeshIntentPipeline":
        """
        Pre-load all models and optionally compile them with torch.compile().

        Call once after instantiation so the first real request pays no
        model-load cost.  HF Inference Endpoints automatically sends a
        warm-up probe before routing live traffic, so this is optional there.

        Parameters
        ----------
        compile:
            If ``True``, call ``torch.compile()`` on the DistilBERT encoder
            and IAB classifier (requires PyTorch >= 2.0).  Gives ~15-30 %
            CPU speedup after the first traced call.
        """
        self._ensure_loaded()

        if compile:
            import torch  # noqa: PLC0415
            if not hasattr(torch, "compile"):
                import warnings
                warnings.warn(
                    "torch.compile() is not available (PyTorch >= 2.0 required). "
                    "Skipping.",
                    stacklevel=2,
                )
            else:
                try:
                    from .multitask_runtime import get_multitask_runtime  # type: ignore
                    from .model_runtime import get_head  # type: ignore
                except ImportError:
                    from multitask_runtime import get_multitask_runtime
                    from model_runtime import get_head

                rt = get_multitask_runtime()
                if rt._model is not None:
                    rt._model = torch.compile(rt._model)
                iab_head = get_head("iab_content")
                if iab_head._model is not None:
                    iab_head._model = torch.compile(iab_head._model)

        # Dry run β€” triggers any remaining lazy init (calibration JSON reads, etc.)
        self("warm up query for intent classification", force_iab_placeholder=True)
        return self

    # ── factory ───────────────────────────────────────────────────────────────

    @classmethod
    def from_pretrained(
        cls,
        repo_id: str = "admesh/agentic-intent-classifier",
        *,
        revision: str | None = None,
        token: str | None = None,
    ) -> "AdmeshIntentPipeline":
        """
        Download the model bundle from HF Hub and return a ready-to-use instance.

        Parameters
        ----------
        repo_id:
            HF Hub model id.
        revision:
            Optional git commit hash to pin a specific release.
        token:
            Optional HF auth token for private repos.

        Example
        -------
        ::

            from pipeline import AdmeshIntentPipeline
            clf = AdmeshIntentPipeline.from_pretrained("admesh/agentic-intent-classifier")
            print(clf("I need a CRM for a 5-person startup"))
        """
        try:
            from huggingface_hub import snapshot_download  # noqa: PLC0415
        except ImportError as exc:
            raise ImportError(
                "huggingface_hub is required. Install: pip install huggingface_hub"
            ) from exc

        kwargs: dict = {"repo_type": "model"}
        if revision:
            kwargs["revision"] = revision
        if token:
            kwargs["token"] = token

        local_dir = snapshot_download(repo_id=repo_id, **kwargs)
        if str(local_dir) not in sys.path:
            sys.path.insert(0, str(local_dir))
        return cls()

    # ── internal ──────────────────────────────────────────────────────────────

    def _ensure_loaded(self) -> None:
        if self._classify_fn is None:
            try:
                from .combined_inference import classify_query  # type: ignore
            except ImportError:
                from combined_inference import classify_query
            self._classify_fn = classify_query

    def __repr__(self) -> str:
        state = "loaded" if self._classify_fn is not None else "not yet loaded"
        return f"AdmeshIntentPipeline(classify_fn={state})"