File size: 7,605 Bytes
f92dacd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# handler.py
# Hugging Face Inference Endpoint custom handler for Mongolian GPT-2 summarization
# Input JSON:
#   {
#     "inputs": "ARTICLE TEXT ...",
#     "parameters": {
#        "max_new_tokens": 160,
#        "num_beams": 4,
#        "do_sample": false,
#        "no_repeat_ngram_size": 3,
#        "length_penalty": 1.0,
#        "temperature": 1.0,
#        "top_p": 1.0,
#        "top_k": 50,
#        "return_full_text": false
#     }
#   }
# Output JSON:
#   { "summary_text": "...", "used_new_tokens": 152, "requested_new_tokens": 160 }

from typing import Any, Dict, List, Union
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Mongolian instruction + prompt template used during training
INSTRUCTION = "Дараах бичвэрийг хураангуйлж бич."
PROMPT_TEMPLATE = (
    "### Даалгавар:\n"
    f"{INSTRUCTION}\n\n"
    "### Бичвэр:\n{article}\n\n"
    "### Хураангуй:\n"
)

def _select_dtype() -> torch.dtype:
    if torch.cuda.is_available():
        # Prefer bf16 if supported; otherwise use fp16
        return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    return torch.float32

class EndpointHandler:
    """
    Custom handler for HF Inference Endpoints:
    - __init__(path): loads model assets from `path`
    - __call__(data): performs generation given {"inputs": ..., "parameters": {...}}
    """
    def __init__(self, path: str = ""):
        # Device & dtype
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = _select_dtype()

        # Load tokenizer/model from the repository directory
        self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True)
        # Decoder-only model requires left padding for correct generation
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=self.dtype,
        ).to(self.device)
        # Safer attention path on many endpoint stacks
        self.model.config.attn_implementation = "eager"
        self.model.config.pad_token_id = self.tokenizer.pad_token_id
        self.model.config.eos_token_id = self.tokenizer.eos_token_id
        self.model.eval()

        # Read max context from config (GPT-2 default is 1024)
        self.max_context = getattr(self.model.config, "max_position_embeddings", 1024)

    def _build_prompt(self, article: str) -> str:
        return PROMPT_TEMPLATE.format(article=article.strip())

    def _prepare_inputs(
        self,
        articles: List[str],
        requested_new: int
    ):
        """
        Tokenize prompts so that prompt_len + max_new_tokens <= max_context.
        We first clamp requested_new, then tokenize with truncation=max_context - requested_new.
        """
        # Basic safety clamps
        requested_new = int(max(1, min(requested_new, 512)))
        max_len_for_prompt = max(1, self.max_context - requested_new)

        prompts = [self._build_prompt(a) for a in articles]
        enc = self.tokenizer(
            prompts,
            add_special_tokens=False,
            truncation=True,
            max_length=max_len_for_prompt,
            return_tensors="pt",
            padding=True,  # uses left padding because tokenizer.padding_side="left"
        )
        enc = {k: v.to(self.device) for k, v in enc.items()}

        # Compute per-example available space and adjust new tokens if needed
        input_lens = enc["attention_mask"].sum(dim=1).tolist()
        per_example_new = []
        for L in input_lens:
            available = max(0, self.max_context - int(L))
            per_example_new.append(max(1, min(requested_new, available)))

        return enc, per_example_new, prompts

    @torch.no_grad()
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        # Accept either {"inputs": "..."} or {"inputs": ["...", "..."]}
        raw_inputs: Union[str, List[str], Dict[str, Any]] = data.get("inputs", "")
        params: Dict[str, Any] = data.get("parameters", {}) or {}

        # Default generation hyperparameters (aligned with training)
        req_new = int(params.get("max_new_tokens", 160))
        num_beams = int(params.get("num_beams", 4))
        do_sample = bool(params.get("do_sample", False))
        no_repeat = int(params.get("no_repeat_ngram_size", 3))
        length_penalty = float(params.get("length_penalty", 1.0))
        temperature = float(params.get("temperature", 1.0))
        top_p = float(params.get("top_p", 1.0))
        top_k = int(params.get("top_k", 50))
        return_full_text = bool(params.get("return_full_text", False))

        # Normalize inputs to a list of strings
        if isinstance(raw_inputs, str):
            articles = [raw_inputs]
        elif isinstance(raw_inputs, list):
            if not all(isinstance(x, str) for x in raw_inputs):
                raise ValueError("All elements of 'inputs' must be strings.")
            articles = raw_inputs
        else:
            # Accept {"article": "..."} as a courtesy
            maybe_article = data.get("article")
            if isinstance(maybe_article, str):
                articles = [maybe_article]
            else:
                raise ValueError("Expect 'inputs' as a string or list of strings.")

        # Tokenize prompts and cap new tokens per example
        enc, per_example_new, prompts = self._prepare_inputs(articles, req_new)

        # Generate (batched)
        gen_out = self.model.generate(
            **enc,
            max_new_tokens=max(per_example_new),  # upper bound; actual stopping still respects EOS
            num_beams=num_beams,
            do_sample=do_sample,
            no_repeat_ngram_size=no_repeat,
            length_penalty=length_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            early_stopping=True,
        )

        # Decode and postprocess per-item (cut after the prompt if needed)
        decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True)

        results = []
        for i, text in enumerate(decoded):
            if return_full_text:
                full = text.strip()
                # Try to extract summary part for convenience too
                split_key = "### Хураангуй:\n"
                summary = full.split(split_key, 1)[-1].strip() if split_key in full else full
            else:
                # Remove the prompt prefix, return only the generated summary
                prefix = prompts[i]
                if text.startswith(prefix):
                    summary = text[len(prefix):].strip()
                else:
                    # Fallback split on the marker
                    split_key = "### Хураангуй:\n"
                    summary = text.split(split_key, 1)[-1].strip() if split_key in text else text.strip()
                full = None

            results.append({
                "summary_text": summary,
                "used_new_tokens": per_example_new[i],
                "requested_new_tokens": req_new,
                **({"full_text": full} if return_full_text else {})
            })

        # If the input was a single string, return a single object
        if isinstance(raw_inputs, str):
            return results[0]
        return {"results": results}