# handler.py - Add to Jingzong/APAN5560 repository on HuggingFace import re import torch from transformers import AutoTokenizer, AutoModelForCausalLM class EndpointHandler: """ Custom handler for Jingzong/APAN5560 fine-tuned GPT-2 model. Matches the training/inference format from GPT2RoleplayModel. """ def __init__(self, path=""): # Load tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained(path) self.model.eval() # Ensure pad token exists (same as training code) if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token_id = self.tokenizer.eos_token_id # ---------- Cleaning helpers (from training code) ---------- @staticmethod def _strip_special_tokens(text: str) -> str: bad_tokens = [ "", "", "<|user|>", "<|assistant|>", "", "", "", "", "", "", "<|endoftext|>", ] for t in bad_tokens: text = text.replace(t, "") return text @staticmethod def _shorten(text: str, max_chars: int = 220) -> str: """Keep at most 1-2 sentences and hard-limit character length.""" text = text.replace("\r", " ").replace("\n", " ") text = re.sub(r"\s+", " ", text).strip() sentences = re.split(r"(?<=[.!?])\s+", text) if not sentences: return text[:max_chars] short = " ".join(sentences[:2]) if len(short) > max_chars: short = short[:max_chars].rsplit(" ", 1)[0] + "..." return short def _clean_answer(self, raw_answer: str) -> str: text = self._strip_special_tokens(raw_answer) text = text.strip().strip('"').strip("'") text = self._shorten(text) return text # ---------- Main handler ---------- def __call__(self, data): """ Process inference request. Expected input format: { "inputs": "Hello, how are you?", "parameters": { "max_new_tokens": 40, "temperature": 0.8, "top_p": 0.9 } } """ inputs = data.get("inputs", "") parameters = data.get("parameters", {}) # Default parameters matching training code max_new_tokens = parameters.get("max_new_tokens", 40) temperature = parameters.get("temperature", 0.8) top_p = parameters.get("top_p", 0.9) repetition_penalty = parameters.get("repetition_penalty", 1.1) # Build prompt in the exact format used during training prompt = f"User: {inputs}\nAssistant:" # Tokenize (add_special_tokens=False as in training) encoded = self.tokenizer( prompt, return_tensors="pt", add_special_tokens=False, ) # Generate with torch.no_grad(): outputs = self.model.generate( **encoded, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, ) # Decode decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=False) # Extract answer (everything after prompt) raw_answer = decoded[len(prompt):] clean_answer = self._clean_answer(raw_answer) return [{"generated_text": clean_answer}]