File size: 3,870 Bytes
8c69988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# handler.py - Add to Jingzong/APAN5560 repository on HuggingFace
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


class EndpointHandler:
    """
    Custom handler for Jingzong/APAN5560 fine-tuned GPT-2 model.
    Matches the training/inference format from GPT2RoleplayModel.
    """
    
    def __init__(self, path=""):
        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForCausalLM.from_pretrained(path)
        self.model.eval()
        
        # Ensure pad token exists (same as training code)
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
    
    # ---------- Cleaning helpers (from training code) ----------
    
    @staticmethod
    def _strip_special_tokens(text: str) -> str:
        bad_tokens = [
            "<s>", "</s>",
            "<|user|>", "<|assistant|>",
            "<user>", "</user>",
            "<assistant>", "</assistant>",
            "<sub>", "</sub>",
            "<|endoftext|>",
        ]
        for t in bad_tokens:
            text = text.replace(t, "")
        return text
    
    @staticmethod
    def _shorten(text: str, max_chars: int = 220) -> str:
        """Keep at most 1-2 sentences and hard-limit character length."""
        text = text.replace("\r", " ").replace("\n", " ")
        text = re.sub(r"\s+", " ", text).strip()
        
        sentences = re.split(r"(?<=[.!?])\s+", text)
        if not sentences:
            return text[:max_chars]
        
        short = " ".join(sentences[:2])
        
        if len(short) > max_chars:
            short = short[:max_chars].rsplit(" ", 1)[0] + "..."
        
        return short
    
    def _clean_answer(self, raw_answer: str) -> str:
        text = self._strip_special_tokens(raw_answer)
        text = text.strip().strip('"').strip("'")
        text = self._shorten(text)
        return text
    
    # ---------- Main handler ----------
    
    def __call__(self, data):
        """
        Process inference request.
        
        Expected input format:
        {
            "inputs": "Hello, how are you?",
            "parameters": {
                "max_new_tokens": 40,
                "temperature": 0.8,
                "top_p": 0.9
            }
        }
        """
        inputs = data.get("inputs", "")
        parameters = data.get("parameters", {})
        
        # Default parameters matching training code
        max_new_tokens = parameters.get("max_new_tokens", 40)
        temperature = parameters.get("temperature", 0.8)
        top_p = parameters.get("top_p", 0.9)
        repetition_penalty = parameters.get("repetition_penalty", 1.1)
        
        # Build prompt in the exact format used during training
        prompt = f"User: {inputs}\nAssistant:"
        
        # Tokenize (add_special_tokens=False as in training)
        encoded = self.tokenizer(
            prompt,
            return_tensors="pt",
            add_special_tokens=False,
        )
        
        # Generate
        with torch.no_grad():
            outputs = self.model.generate(
                **encoded,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                eos_token_id=self.tokenizer.eos_token_id,
                pad_token_id=self.tokenizer.pad_token_id,
            )
        
        # Decode
        decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
        
        # Extract answer (everything after prompt)
        raw_answer = decoded[len(prompt):]
        clean_answer = self._clean_answer(raw_answer)
        
        return [{"generated_text": clean_answer}]