| """ |
| Custom HF Inference Endpoint handler for humarin/chatgpt_paraphraser_on_T5_base. |
| Uses explicit T5ForConditionalGeneration + diverse beam search. |
| """ |
| import torch |
| from typing import Any, Dict, List |
| from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
| MODEL_ID = "humarin/chatgpt_paraphraser_on_T5_base" |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| |
| self.tokenizer = T5Tokenizer.from_pretrained(MODEL_ID) |
| self.model = T5ForConditionalGeneration.from_pretrained(MODEL_ID) |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model.to(self.device) |
| self.model.eval() |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]: |
| """ |
| Request body shape: |
| { |
| "inputs": "string or [list of strings]", |
| "parameters": { |
| "num_beams": 5, |
| "num_beam_groups": 5, |
| "num_return_sequences": 5, |
| "diversity_penalty": 3.0, |
| "repetition_penalty": 10.0, |
| "no_repeat_ngram_size": 2, |
| "max_length": 128, |
| "temperature": 0.7, |
| "add_prefix": true # auto-prefix "paraphrase: " if not already present |
| } |
| } |
| """ |
| inputs = data.get("inputs", data.get("input", "")) |
| params = data.get("parameters", {}) or {} |
|
|
| if isinstance(inputs, str): |
| inputs = [inputs] |
|
|
| add_prefix = params.get("add_prefix", True) |
| prefix = "paraphrase: " |
| prepared = [] |
| for s in inputs: |
| s = s.strip() |
| if add_prefix and not s.lower().startswith("paraphrase:"): |
| s = prefix + s |
| prepared.append(s) |
|
|
| gen_kwargs = { |
| "num_beams": int(params.get("num_beams", 5)), |
| "num_beam_groups": int(params.get("num_beam_groups", 5)), |
| "num_return_sequences": int(params.get("num_return_sequences", 5)), |
| "diversity_penalty": float(params.get("diversity_penalty", 3.0)), |
| "repetition_penalty": float(params.get("repetition_penalty", 10.0)), |
| "no_repeat_ngram_size": int(params.get("no_repeat_ngram_size", 2)), |
| "max_length": int(params.get("max_length", 128)), |
| "early_stopping": True, |
| } |
| |
| gen_kwargs["do_sample"] = False |
|
|
| enc = self.tokenizer( |
| prepared, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=256, |
| ).to(self.device) |
|
|
| with torch.no_grad(): |
| out = self.model.generate( |
| input_ids=enc["input_ids"], |
| attention_mask=enc["attention_mask"], |
| **gen_kwargs, |
| ) |
|
|
| decoded = self.tokenizer.batch_decode(out, skip_special_tokens=True) |
|
|
| |
| n_ret = gen_kwargs["num_return_sequences"] |
| result: List[Dict[str, Any]] = [] |
| for i in range(len(prepared)): |
| candidates = decoded[i * n_ret : (i + 1) * n_ret] |
| result.append({ |
| "input": prepared[i], |
| "paraphrases": candidates, |
| }) |
|
|
| |
| if len(result) == 1: |
| return [{"generated_text": c} for c in result[0]["paraphrases"]] |
| return result |
|
|