File size: 2,891 Bytes
7a117b6
 
48c3c18
 
7a117b6
 
 
 
48c3c18
 
7a117b6
 
 
 
 
 
 
 
 
 
 
 
802eb23
 
7a117b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48c3c18
7a117b6
 
 
 
 
 
 
 
 
 
48c3c18
 
 
 
 
 
 
 
 
7a117b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48c3c18
7a117b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
from typing import Any, Dict

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"


class EndpointHandler:
    def __init__(self, path: str = "") -> None:
        token = (
            os.environ.get("HF_TOKEN")
            or os.environ.get("HUGGING_FACE_HUB_TOKEN")
            or os.environ.get("HUGGINGFACE_HUB_TOKEN")
        )
        if not token:
            raise RuntimeError(
                "HF_TOKEN is not set. Add it as a secret on the Inference Endpoint "
                "so the handler can download the gated meta-llama/Meta-Llama-3.1-8B-Instruct weights."
            )

        tokenizer_source = path or BASE_MODEL
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_source)
        self.model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            token=token,
            device_map="auto",
            torch_dtype=torch.bfloat16,
        )
        self.model.eval()

        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        inputs_payload = data.get("inputs", data)
        messages = (
            inputs_payload.get("messages")
            if isinstance(inputs_payload, dict)
            else None
        ) or data.get("messages")

        if not messages:
            raise ValueError(
                "Request payload must include a 'messages' list, e.g. "
                '{"inputs": {"messages": [{"role": "user", "content": "hi"}]}}.'
            )

        parameters: Dict[str, Any] = data.get("parameters") or {}
        max_new_tokens = int(parameters.get("max_new_tokens", 256))
        do_sample = bool(parameters.get("do_sample", False))
        temperature = float(parameters.get("temperature", 0.7))
        top_p = float(parameters.get("top_p", 0.9))

        inputs = self.tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        ).to(self.model.device)

        generate_kwargs: Dict[str, Any] = {
            "max_new_tokens": max_new_tokens,
            "do_sample": do_sample,
            "pad_token_id": self.tokenizer.pad_token_id,
            "eos_token_id": self.tokenizer.eos_token_id,
        }
        if do_sample:
            generate_kwargs["temperature"] = temperature
            generate_kwargs["top_p"] = top_p

        with torch.inference_mode():
            outputs = self.model.generate(**inputs, **generate_kwargs)

        prompt_len = inputs["input_ids"].shape[-1]
        decoded = self.tokenizer.decode(
            outputs[0][prompt_len:],
            skip_special_tokens=True,
        )

        return {"generated_text": decoded}