singoojiang's picture
Upload 3 files
47ed012 verified
Raw
History Blame Contribute Delete
3.58 kB
"""
Custom HF Inference Endpoint handler for humarin/chatgpt_paraphraser_on_T5_base.
Uses explicit T5ForConditionalGeneration + diverse beam search.
"""
import torch
from typing import Any, Dict, List
from transformers import T5Tokenizer, T5ForConditionalGeneration
MODEL_ID = "humarin/chatgpt_paraphraser_on_T5_base"
class EndpointHandler:
def __init__(self, path: str = ""):
# Always load humarin's weights from the Hub (we don't ship them in this repo)
self.tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
self.model = T5ForConditionalGeneration.from_pretrained(MODEL_ID)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
"""
Request body shape:
{
"inputs": "string or [list of strings]",
"parameters": {
"num_beams": 5,
"num_beam_groups": 5,
"num_return_sequences": 5,
"diversity_penalty": 3.0,
"repetition_penalty": 10.0,
"no_repeat_ngram_size": 2,
"max_length": 128,
"temperature": 0.7,
"add_prefix": true # auto-prefix "paraphrase: " if not already present
}
}
"""
inputs = data.get("inputs", data.get("input", ""))
params = data.get("parameters", {}) or {}
if isinstance(inputs, str):
inputs = [inputs]
add_prefix = params.get("add_prefix", True)
prefix = "paraphrase: "
prepared = []
for s in inputs:
s = s.strip()
if add_prefix and not s.lower().startswith("paraphrase:"):
s = prefix + s
prepared.append(s)
gen_kwargs = {
"num_beams": int(params.get("num_beams", 5)),
"num_beam_groups": int(params.get("num_beam_groups", 5)),
"num_return_sequences": int(params.get("num_return_sequences", 5)),
"diversity_penalty": float(params.get("diversity_penalty", 3.0)),
"repetition_penalty": float(params.get("repetition_penalty", 10.0)),
"no_repeat_ngram_size": int(params.get("no_repeat_ngram_size", 2)),
"max_length": int(params.get("max_length", 128)),
"early_stopping": True,
}
# Beam search is deterministic when do_sample=False
gen_kwargs["do_sample"] = False
enc = self.tokenizer(
prepared,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256,
).to(self.device)
with torch.no_grad():
out = self.model.generate(
input_ids=enc["input_ids"],
attention_mask=enc["attention_mask"],
**gen_kwargs,
)
decoded = self.tokenizer.batch_decode(out, skip_special_tokens=True)
# Result shape: for each input -> list of num_return_sequences paraphrases
n_ret = gen_kwargs["num_return_sequences"]
result: List[Dict[str, Any]] = []
for i in range(len(prepared)):
candidates = decoded[i * n_ret : (i + 1) * n_ret]
result.append({
"input": prepared[i],
"paraphrases": candidates,
})
# If single input, flatten the response a bit for convenience
if len(result) == 1:
return [{"generated_text": c} for c in result[0]["paraphrases"]]
return result