File size: 4,407 Bytes
b56b8d4
a72fa33
b56b8d4
b5b1746
a72fa33
 
 
 
 
 
 
b56b8d4
 
b5b1746
 
 
 
 
 
b56b8d4
 
b5b1746
 
b56b8d4
a72fa33
b56b8d4
a72fa33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6923ac9
a72fa33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6923ac9
b56b8d4
 
b5b1746
 
a72fa33
b5b1746
 
a72fa33
6923ac9
 
a72fa33
b5b1746
a72fa33
 
 
 
 
 
 
 
 
 
 
 
 
 
b56b8d4
a72fa33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b56b8d4
b5b1746
a72fa33
b56b8d4
b5b1746
a72fa33
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import AutoPeftModelForCausalLM

DEFAULT_SYSTEM_PROMPT = (
    "You are a QA assistant. "
    "Use only the provided context. "
    "If the answer is not present in the context, say so clearly."
)

class EndpointHandler:
    def __init__(self, path: str = ""):
        model_dir = path or "/repository"

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_dir,
            trust_remote_code=True,
        )

        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        dtype = torch.float16 if torch.cuda.is_available() else torch.float32

        adapter_config_path = os.path.join(model_dir, "adapter_config.json")
        if os.path.exists(adapter_config_path):
            self.model = AutoPeftModelForCausalLM.from_pretrained(
                model_dir,
                trust_remote_code=True,
                torch_dtype=dtype,
                low_cpu_mem_usage=True,
                device_map="auto" if torch.cuda.is_available() else None,
            )
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_dir,
                trust_remote_code=True,
                torch_dtype=dtype,
                low_cpu_mem_usage=True,
                device_map="auto" if torch.cuda.is_available() else None,
            )

        self.model.eval()

    def _build_messages(self, inputs):
        if isinstance(inputs, list):
            messages = inputs
        elif isinstance(inputs, dict) and "context" in inputs and "question" in inputs:
            messages = [
                {"role": "system", "content": DEFAULT_SYSTEM_PROMPT},
                {
                    "role": "user",
                    "content": f"Context:\n{inputs['context']}\n\nQuestion: {inputs['question']}",
                },
            ]
        else:
            messages = [
                {"role": "system", "content": DEFAULT_SYSTEM_PROMPT},
                {"role": "user", "content": str(inputs)},
            ]

        has_system = any(message.get("role") == "system" for message in messages)
        if not has_system:
            messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}] + messages

        return messages

    def __call__(self, data):
        inputs = data.get("inputs", "")
        params = data.get("parameters", {}) or {}

        max_new_tokens = min(int(params.get("max_new_tokens", 128)), 512)
        temperature = float(params.get("temperature", 0.0))
        top_p = float(params.get("top_p", 1.0))
        do_sample = bool(params.get("do_sample", False))
        repetition_penalty = float(params.get("repetition_penalty", 1.0))
        no_repeat_ngram_size = int(params.get("no_repeat_ngram_size", 0))
        debug = bool(params.get("debug", False))

        messages = self._build_messages(inputs)

        prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )

        enc = self.tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=min(getattr(self.tokenizer, "model_max_length", 4096), 4096),
        )

        if torch.cuda.is_available():
            enc = {key: value.to(self.model.device) for key, value in enc.items()}

        generate_kwargs = dict(
            **enc,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            repetition_penalty=repetition_penalty,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        )

        if do_sample:
            generate_kwargs["temperature"] = max(temperature, 1e-5)
            generate_kwargs["top_p"] = top_p

        if no_repeat_ngram_size > 0:
            generate_kwargs["no_repeat_ngram_size"] = no_repeat_ngram_size

        with torch.no_grad():
            out = self.model.generate(**generate_kwargs)

        generated_ids = out[0][enc["input_ids"].shape[-1]:]
        text = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

        response = {"generated_text": text}
        if debug:
            response["prompt"] = prompt
            response["messages"] = messages
        return response