File size: 4,567 Bytes
c90fe04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# handler.py
from __future__ import annotations

from typing import Any, Dict, List, Union

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


Json = Dict[str, Any]
Messages = List[Dict[str, str]]  # [{"role":"user|assistant|system", "content":"..."}]


def _is_messages(x: Any) -> bool:
    return (
        isinstance(x, list)
        and len(x) > 0
        and all(isinstance(m, dict) and "role" in m and "content" in m for m in x)
    )


class EndpointHandler:
    """
    Hugging Face Inference Endpoints custom handler.
    Expects:
      - request body is a dict
      - always contains `inputs`
      - may contain `parameters` for generation
    """

    def __init__(self, model_dir: str):
        self.model_dir = model_dir

        # Pick dtype/device
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        if self.device == "cuda":
            # bfloat16 is usually safe on A100/H100; if your instance doesn't support bf16, change to float16
            self.dtype = torch.bfloat16
        else:
            self.dtype = torch.float32

        # IMPORTANT: trust_remote_code=True because repo contains AsteriskForCausalLM.py + auto_map
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_dir,
            trust_remote_code=True,
            use_fast=True,
        )

        # Make sure pad token exists (your config uses pad_token_id=2 which equals eos_token_id in many llama-like models)
        if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            trust_remote_code=True,
            torch_dtype=self.dtype,
            device_map="auto" if self.device == "cuda" else None,
        )

        if self.device != "cuda":
            self.model.to(self.device)

        self.model.eval()

    @torch.inference_mode()
    def __call__(self, data: Json) -> Union[Json, List[Json]]:
        inputs = data.get("inputs", "")
        params = data.get("parameters", {}) or {}

        # Generation defaults (can be overridden via `parameters`)
        max_new_tokens = int(params.get("max_new_tokens", 256))
        temperature = float(params.get("temperature", 0.7))
        top_p = float(params.get("top_p", 0.95))
        top_k = int(params.get("top_k", 0))
        repetition_penalty = float(params.get("repetition_penalty", 1.0))

        do_sample = bool(params.get("do_sample", temperature > 0))
        num_beams = int(params.get("num_beams", 1))

        def _one(item: Any) -> Json:
            # Accept:
            # 1) string prompt
            # 2) messages list: [{"role":"user","content":"..."}]
            # 3) dict {"messages":[...]} (common chat style)
            if isinstance(item, dict) and "messages" in item:
                item = item["messages"]

            if _is_messages(item):
                # Chat template path exists in repo; tokenizer.apply_chat_template will use it if configured
                input_ids = self.tokenizer.apply_chat_template(
                    item,
                    return_tensors="pt",
                    add_generation_prompt=True,
                )
            else:
                if not isinstance(item, str):
                    item = str(item)
                enc = self.tokenizer(item, return_tensors="pt")
                input_ids = enc["input_ids"]

            input_ids = input_ids.to(self.model.device)
            input_len = input_ids.shape[-1]

            gen_ids = self.model.generate(
                input_ids=input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=do_sample,
                temperature=temperature if do_sample else None,
                top_p=top_p if do_sample else None,
                top_k=top_k if do_sample and top_k > 0 else None,
                num_beams=num_beams,
                repetition_penalty=repetition_penalty,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
            )

            # Only return newly generated tokens
            new_tokens = gen_ids[0, input_len:]
            text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
            return {"generated_text": text}

        # Batch support
        if isinstance(inputs, list) and not _is_messages(inputs):
            return [_one(x) for x in inputs]
        else:
            return _one(inputs)