APAN5560 / handler.py
Jingzong's picture
create handler.py to deploy inference endpoint
8c69988 verified
# handler.py - Add to Jingzong/APAN5560 repository on HuggingFace
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
class EndpointHandler:
"""
Custom handler for Jingzong/APAN5560 fine-tuned GPT-2 model.
Matches the training/inference format from GPT2RoleplayModel.
"""
def __init__(self, path=""):
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(path)
self.model.eval()
# Ensure pad token exists (same as training code)
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
# ---------- Cleaning helpers (from training code) ----------
@staticmethod
def _strip_special_tokens(text: str) -> str:
bad_tokens = [
"<s>", "</s>",
"<|user|>", "<|assistant|>",
"<user>", "</user>",
"<assistant>", "</assistant>",
"<sub>", "</sub>",
"<|endoftext|>",
]
for t in bad_tokens:
text = text.replace(t, "")
return text
@staticmethod
def _shorten(text: str, max_chars: int = 220) -> str:
"""Keep at most 1-2 sentences and hard-limit character length."""
text = text.replace("\r", " ").replace("\n", " ")
text = re.sub(r"\s+", " ", text).strip()
sentences = re.split(r"(?<=[.!?])\s+", text)
if not sentences:
return text[:max_chars]
short = " ".join(sentences[:2])
if len(short) > max_chars:
short = short[:max_chars].rsplit(" ", 1)[0] + "..."
return short
def _clean_answer(self, raw_answer: str) -> str:
text = self._strip_special_tokens(raw_answer)
text = text.strip().strip('"').strip("'")
text = self._shorten(text)
return text
# ---------- Main handler ----------
def __call__(self, data):
"""
Process inference request.
Expected input format:
{
"inputs": "Hello, how are you?",
"parameters": {
"max_new_tokens": 40,
"temperature": 0.8,
"top_p": 0.9
}
}
"""
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Default parameters matching training code
max_new_tokens = parameters.get("max_new_tokens", 40)
temperature = parameters.get("temperature", 0.8)
top_p = parameters.get("top_p", 0.9)
repetition_penalty = parameters.get("repetition_penalty", 1.1)
# Build prompt in the exact format used during training
prompt = f"User: {inputs}\nAssistant:"
# Tokenize (add_special_tokens=False as in training)
encoded = self.tokenizer(
prompt,
return_tensors="pt",
add_special_tokens=False,
)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**encoded,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
# Decode
decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
# Extract answer (everything after prompt)
raw_answer = decoded[len(prompt):]
clean_answer = self._clean_answer(raw_answer)
return [{"generated_text": clean_answer}]