File size: 1,985 Bytes
863713b
 
b347258
863713b
 
 
 
 
b347258
 
 
fa67548
 
 
 
863713b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Custom handler for HuggingFace Inference Endpoints — TextSight T5 Humanizer"""
from typing import Dict, Any
from transformers import T5ForConditionalGeneration, AutoTokenizer
import torch


class EndpointHandler:
    def __init__(self, path: str = ""):
        # Load tokenizer from HF hub (avoids local spiece.model path issues)
        self.tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-large")
        # Load model weights from the local repo path
        self.model = T5ForConditionalGeneration.from_pretrained(
            path,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        )
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        inputs = data.get("inputs", "")
        params = data.get("parameters", {})

        if not inputs:
            return {"error": "No input text provided"}

        # Prefix for T5
        input_text = f"humanize: {inputs}"

        tokens = self.tokenizer(
            input_text,
            return_tensors="pt",
            max_length=512,
            truncation=True,
            padding=True,
        ).to(self.device)

        with torch.no_grad():
            output_ids = self.model.generate(
                **tokens,
                max_new_tokens=params.get("max_new_tokens", 512),
                num_beams=params.get("num_beams", 4),
                temperature=params.get("temperature", 1.1),
                do_sample=True,
                top_p=params.get("top_p", 0.92),
                top_k=params.get("top_k", 50),
                repetition_penalty=params.get("repetition_penalty", 2.5),
                no_repeat_ngram_size=3,
                early_stopping=True,
            )

        result = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return {"generated_text": result}