File size: 3,582 Bytes
47ed012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
Custom HF Inference Endpoint handler for humarin/chatgpt_paraphraser_on_T5_base.
Uses explicit T5ForConditionalGeneration + diverse beam search.
"""
import torch
from typing import Any, Dict, List
from transformers import T5Tokenizer, T5ForConditionalGeneration

MODEL_ID = "humarin/chatgpt_paraphraser_on_T5_base"


class EndpointHandler:
    def __init__(self, path: str = ""):
        # Always load humarin's weights from the Hub (we don't ship them in this repo)
        self.tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
        self.model = T5ForConditionalGeneration.from_pretrained(MODEL_ID)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
        """
        Request body shape:
        {
          "inputs": "string or [list of strings]",
          "parameters": {
            "num_beams": 5,
            "num_beam_groups": 5,
            "num_return_sequences": 5,
            "diversity_penalty": 3.0,
            "repetition_penalty": 10.0,
            "no_repeat_ngram_size": 2,
            "max_length": 128,
            "temperature": 0.7,
            "add_prefix": true   # auto-prefix "paraphrase: " if not already present
          }
        }
        """
        inputs = data.get("inputs", data.get("input", ""))
        params = data.get("parameters", {}) or {}

        if isinstance(inputs, str):
            inputs = [inputs]

        add_prefix = params.get("add_prefix", True)
        prefix = "paraphrase: "
        prepared = []
        for s in inputs:
            s = s.strip()
            if add_prefix and not s.lower().startswith("paraphrase:"):
                s = prefix + s
            prepared.append(s)

        gen_kwargs = {
            "num_beams": int(params.get("num_beams", 5)),
            "num_beam_groups": int(params.get("num_beam_groups", 5)),
            "num_return_sequences": int(params.get("num_return_sequences", 5)),
            "diversity_penalty": float(params.get("diversity_penalty", 3.0)),
            "repetition_penalty": float(params.get("repetition_penalty", 10.0)),
            "no_repeat_ngram_size": int(params.get("no_repeat_ngram_size", 2)),
            "max_length": int(params.get("max_length", 128)),
            "early_stopping": True,
        }
        # Beam search is deterministic when do_sample=False
        gen_kwargs["do_sample"] = False

        enc = self.tokenizer(
            prepared,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=256,
        ).to(self.device)

        with torch.no_grad():
            out = self.model.generate(
                input_ids=enc["input_ids"],
                attention_mask=enc["attention_mask"],
                **gen_kwargs,
            )

        decoded = self.tokenizer.batch_decode(out, skip_special_tokens=True)

        # Result shape: for each input -> list of num_return_sequences paraphrases
        n_ret = gen_kwargs["num_return_sequences"]
        result: List[Dict[str, Any]] = []
        for i in range(len(prepared)):
            candidates = decoded[i * n_ret : (i + 1) * n_ret]
            result.append({
                "input": prepared[i],
                "paraphrases": candidates,
            })

        # If single input, flatten the response a bit for convenience
        if len(result) == 1:
            return [{"generated_text": c} for c in result[0]["paraphrases"]]
        return result