|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Union |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
INSTRUCTION = "Дараах бичвэрийг хураангуйлж бич." |
|
|
PROMPT_TEMPLATE = ( |
|
|
"### Даалгавар:\n" |
|
|
f"{INSTRUCTION}\n\n" |
|
|
"### Бичвэр:\n{article}\n\n" |
|
|
"### Хураангуй:\n" |
|
|
) |
|
|
|
|
|
def _select_dtype() -> torch.dtype: |
|
|
if torch.cuda.is_available(): |
|
|
|
|
|
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 |
|
|
return torch.float32 |
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
Custom handler for HF Inference Endpoints: |
|
|
- __init__(path): loads model assets from `path` |
|
|
- __call__(data): performs generation given {"inputs": ..., "parameters": {...}} |
|
|
""" |
|
|
def __init__(self, path: str = ""): |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.dtype = _select_dtype() |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True) |
|
|
|
|
|
self.tokenizer.padding_side = "left" |
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
path, |
|
|
torch_dtype=self.dtype, |
|
|
).to(self.device) |
|
|
|
|
|
self.model.config.attn_implementation = "eager" |
|
|
self.model.config.pad_token_id = self.tokenizer.pad_token_id |
|
|
self.model.config.eos_token_id = self.tokenizer.eos_token_id |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.max_context = getattr(self.model.config, "max_position_embeddings", 1024) |
|
|
|
|
|
def _build_prompt(self, article: str) -> str: |
|
|
return PROMPT_TEMPLATE.format(article=article.strip()) |
|
|
|
|
|
def _prepare_inputs( |
|
|
self, |
|
|
articles: List[str], |
|
|
requested_new: int |
|
|
): |
|
|
""" |
|
|
Tokenize prompts so that prompt_len + max_new_tokens <= max_context. |
|
|
We first clamp requested_new, then tokenize with truncation=max_context - requested_new. |
|
|
""" |
|
|
|
|
|
requested_new = int(max(1, min(requested_new, 512))) |
|
|
max_len_for_prompt = max(1, self.max_context - requested_new) |
|
|
|
|
|
prompts = [self._build_prompt(a) for a in articles] |
|
|
enc = self.tokenizer( |
|
|
prompts, |
|
|
add_special_tokens=False, |
|
|
truncation=True, |
|
|
max_length=max_len_for_prompt, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
) |
|
|
enc = {k: v.to(self.device) for k, v in enc.items()} |
|
|
|
|
|
|
|
|
input_lens = enc["attention_mask"].sum(dim=1).tolist() |
|
|
per_example_new = [] |
|
|
for L in input_lens: |
|
|
available = max(0, self.max_context - int(L)) |
|
|
per_example_new.append(max(1, min(requested_new, available))) |
|
|
|
|
|
return enc, per_example_new, prompts |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
|
|
raw_inputs: Union[str, List[str], Dict[str, Any]] = data.get("inputs", "") |
|
|
params: Dict[str, Any] = data.get("parameters", {}) or {} |
|
|
|
|
|
|
|
|
req_new = int(params.get("max_new_tokens", 160)) |
|
|
num_beams = int(params.get("num_beams", 4)) |
|
|
do_sample = bool(params.get("do_sample", False)) |
|
|
no_repeat = int(params.get("no_repeat_ngram_size", 3)) |
|
|
length_penalty = float(params.get("length_penalty", 1.0)) |
|
|
temperature = float(params.get("temperature", 1.0)) |
|
|
top_p = float(params.get("top_p", 1.0)) |
|
|
top_k = int(params.get("top_k", 50)) |
|
|
return_full_text = bool(params.get("return_full_text", False)) |
|
|
|
|
|
|
|
|
if isinstance(raw_inputs, str): |
|
|
articles = [raw_inputs] |
|
|
elif isinstance(raw_inputs, list): |
|
|
if not all(isinstance(x, str) for x in raw_inputs): |
|
|
raise ValueError("All elements of 'inputs' must be strings.") |
|
|
articles = raw_inputs |
|
|
else: |
|
|
|
|
|
maybe_article = data.get("article") |
|
|
if isinstance(maybe_article, str): |
|
|
articles = [maybe_article] |
|
|
else: |
|
|
raise ValueError("Expect 'inputs' as a string or list of strings.") |
|
|
|
|
|
|
|
|
enc, per_example_new, prompts = self._prepare_inputs(articles, req_new) |
|
|
|
|
|
|
|
|
gen_out = self.model.generate( |
|
|
**enc, |
|
|
max_new_tokens=max(per_example_new), |
|
|
num_beams=num_beams, |
|
|
do_sample=do_sample, |
|
|
no_repeat_ngram_size=no_repeat, |
|
|
length_penalty=length_penalty, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
early_stopping=True, |
|
|
) |
|
|
|
|
|
|
|
|
decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True) |
|
|
|
|
|
results = [] |
|
|
for i, text in enumerate(decoded): |
|
|
if return_full_text: |
|
|
full = text.strip() |
|
|
|
|
|
split_key = "### Хураангуй:\n" |
|
|
summary = full.split(split_key, 1)[-1].strip() if split_key in full else full |
|
|
else: |
|
|
|
|
|
prefix = prompts[i] |
|
|
if text.startswith(prefix): |
|
|
summary = text[len(prefix):].strip() |
|
|
else: |
|
|
|
|
|
split_key = "### Хураангуй:\n" |
|
|
summary = text.split(split_key, 1)[-1].strip() if split_key in text else text.strip() |
|
|
full = None |
|
|
|
|
|
results.append({ |
|
|
"summary_text": summary, |
|
|
"used_new_tokens": per_example_new[i], |
|
|
"requested_new_tokens": req_new, |
|
|
**({"full_text": full} if return_full_text else {}) |
|
|
}) |
|
|
|
|
|
|
|
|
if isinstance(raw_inputs, str): |
|
|
return results[0] |
|
|
return {"results": results} |
|
|
|