File size: 7,605 Bytes
f92dacd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
# handler.py
# Hugging Face Inference Endpoint custom handler for Mongolian GPT-2 summarization
# Input JSON:
# {
# "inputs": "ARTICLE TEXT ...",
# "parameters": {
# "max_new_tokens": 160,
# "num_beams": 4,
# "do_sample": false,
# "no_repeat_ngram_size": 3,
# "length_penalty": 1.0,
# "temperature": 1.0,
# "top_p": 1.0,
# "top_k": 50,
# "return_full_text": false
# }
# }
# Output JSON:
# { "summary_text": "...", "used_new_tokens": 152, "requested_new_tokens": 160 }
from typing import Any, Dict, List, Union
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Mongolian instruction + prompt template used during training
INSTRUCTION = "Дараах бичвэрийг хураангуйлж бич."
PROMPT_TEMPLATE = (
"### Даалгавар:\n"
f"{INSTRUCTION}\n\n"
"### Бичвэр:\n{article}\n\n"
"### Хураангуй:\n"
)
def _select_dtype() -> torch.dtype:
if torch.cuda.is_available():
# Prefer bf16 if supported; otherwise use fp16
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
return torch.float32
class EndpointHandler:
"""
Custom handler for HF Inference Endpoints:
- __init__(path): loads model assets from `path`
- __call__(data): performs generation given {"inputs": ..., "parameters": {...}}
"""
def __init__(self, path: str = ""):
# Device & dtype
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = _select_dtype()
# Load tokenizer/model from the repository directory
self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True)
# Decoder-only model requires left padding for correct generation
self.tokenizer.padding_side = "left"
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=self.dtype,
).to(self.device)
# Safer attention path on many endpoint stacks
self.model.config.attn_implementation = "eager"
self.model.config.pad_token_id = self.tokenizer.pad_token_id
self.model.config.eos_token_id = self.tokenizer.eos_token_id
self.model.eval()
# Read max context from config (GPT-2 default is 1024)
self.max_context = getattr(self.model.config, "max_position_embeddings", 1024)
def _build_prompt(self, article: str) -> str:
return PROMPT_TEMPLATE.format(article=article.strip())
def _prepare_inputs(
self,
articles: List[str],
requested_new: int
):
"""
Tokenize prompts so that prompt_len + max_new_tokens <= max_context.
We first clamp requested_new, then tokenize with truncation=max_context - requested_new.
"""
# Basic safety clamps
requested_new = int(max(1, min(requested_new, 512)))
max_len_for_prompt = max(1, self.max_context - requested_new)
prompts = [self._build_prompt(a) for a in articles]
enc = self.tokenizer(
prompts,
add_special_tokens=False,
truncation=True,
max_length=max_len_for_prompt,
return_tensors="pt",
padding=True, # uses left padding because tokenizer.padding_side="left"
)
enc = {k: v.to(self.device) for k, v in enc.items()}
# Compute per-example available space and adjust new tokens if needed
input_lens = enc["attention_mask"].sum(dim=1).tolist()
per_example_new = []
for L in input_lens:
available = max(0, self.max_context - int(L))
per_example_new.append(max(1, min(requested_new, available)))
return enc, per_example_new, prompts
@torch.no_grad()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# Accept either {"inputs": "..."} or {"inputs": ["...", "..."]}
raw_inputs: Union[str, List[str], Dict[str, Any]] = data.get("inputs", "")
params: Dict[str, Any] = data.get("parameters", {}) or {}
# Default generation hyperparameters (aligned with training)
req_new = int(params.get("max_new_tokens", 160))
num_beams = int(params.get("num_beams", 4))
do_sample = bool(params.get("do_sample", False))
no_repeat = int(params.get("no_repeat_ngram_size", 3))
length_penalty = float(params.get("length_penalty", 1.0))
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = int(params.get("top_k", 50))
return_full_text = bool(params.get("return_full_text", False))
# Normalize inputs to a list of strings
if isinstance(raw_inputs, str):
articles = [raw_inputs]
elif isinstance(raw_inputs, list):
if not all(isinstance(x, str) for x in raw_inputs):
raise ValueError("All elements of 'inputs' must be strings.")
articles = raw_inputs
else:
# Accept {"article": "..."} as a courtesy
maybe_article = data.get("article")
if isinstance(maybe_article, str):
articles = [maybe_article]
else:
raise ValueError("Expect 'inputs' as a string or list of strings.")
# Tokenize prompts and cap new tokens per example
enc, per_example_new, prompts = self._prepare_inputs(articles, req_new)
# Generate (batched)
gen_out = self.model.generate(
**enc,
max_new_tokens=max(per_example_new), # upper bound; actual stopping still respects EOS
num_beams=num_beams,
do_sample=do_sample,
no_repeat_ngram_size=no_repeat,
length_penalty=length_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
early_stopping=True,
)
# Decode and postprocess per-item (cut after the prompt if needed)
decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True)
results = []
for i, text in enumerate(decoded):
if return_full_text:
full = text.strip()
# Try to extract summary part for convenience too
split_key = "### Хураангуй:\n"
summary = full.split(split_key, 1)[-1].strip() if split_key in full else full
else:
# Remove the prompt prefix, return only the generated summary
prefix = prompts[i]
if text.startswith(prefix):
summary = text[len(prefix):].strip()
else:
# Fallback split on the marker
split_key = "### Хураангуй:\n"
summary = text.split(split_key, 1)[-1].strip() if split_key in text else text.strip()
full = None
results.append({
"summary_text": summary,
"used_new_tokens": per_example_new[i],
"requested_new_tokens": req_new,
**({"full_text": full} if return_full_text else {})
})
# If the input was a single string, return a single object
if isinstance(raw_inputs, str):
return results[0]
return {"results": results}
|