File size: 3,824 Bytes
20efac6
5c4dcc9
 
da0780a
5c4dcc9
 
 
 
 
 
 
e229a67
5c4dcc9
da0780a
5c4dcc9
20efac6
 
 
5c4dcc9
da0780a
5c4dcc9
 
da0780a
 
 
5c4dcc9
 
da0780a
5c4dcc9
da0780a
 
 
 
5c4dcc9
da0780a
c584727
e229a67
c584727
 
5c4dcc9
 
3cc9791
 
 
 
 
 
 
 
 
 
 
 
 
d6978ab
 
 
 
 
 
 
da0780a
3cc9791
 
 
 
 
 
 
 
 
5c4dcc9
20efac6
5c4dcc9
 
 
 
da0780a
5c4dcc9
 
20efac6
5c4dcc9
20efac6
5c4dcc9
20efac6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from typing import Dict, Any, List
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from torch.cuda.amp import autocast


class EndpointHandler:
    def __init__(self, path="chentong00/propositionizer-wiki-flan-t5-large"):
        """
        Initialize the handler by loading the model, tokenizer, and setting the device.
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(self.device).half()

    def process_chunks(
        self, chunks: List[str], titles: List[str], dates: List[str]
    ) -> List[str]:
        """
        Process multiple text chunks with the model.

        Args:
            chunks (list): List of text content to process.
            titles (list): List of document titles corresponding to the chunks.
            dates (list): List of document dates corresponding to the chunks.

        Returns:
            list: List of generated output texts.
        """
        input_texts = [
            f"Title: {t}. Date: {d}. Content: {c}"
            for c, t, d in zip(chunks, titles, dates)
        ]
        input_ids = self.tokenizer(
            input_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=1024,
        ).input_ids.to(self.device)

        try:
            with torch.no_grad():
                # Use autocast for mixed precision on CUDA devices
                if self.device.type == "cuda":
                    with autocast():
                        outputs = self.model.generate(
                            input_ids,
                            max_new_tokens=512,
                            no_repeat_ngram_size=5,
                            length_penalty=1.2,
                            num_beams=5,
                        )
                else:
                    outputs = self.model.generate(
                        input_ids,
                        max_new_tokens=512,
                        no_repeat_ngram_size=5,
                        length_penalty=1.2,
                        num_beams=5,
                    )

            predictions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        finally:
            # Explicit memory cleanup
            del input_ids, outputs
            torch.cuda.empty_cache()
            if self.device.type == "cuda":
                torch.cuda.synchronize()

        return predictions

    def __call__(self, data: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]:
        """
        Handle the inference request.

        Args:
            data (dict): The payload with text inputs.

        Returns:
            dict: The processed outputs containing the generated text for each input along with their IDs.
        """
        inputs = data.get("inputs", [])

        # Ensure inputs is a list of dictionaries
        if not isinstance(inputs, list) or not all(isinstance(i, dict) for i in inputs):
            raise ValueError("The inputs must be a list of dictionaries.")

        chunks, titles, dates, ids = [], [], [], []
        for item in inputs:
            for key in ["id", "chunk", "title", "date"]:
                if key not in item:
                    raise ValueError(f"Each input must contain the key: {key}.")
            ids.append(item["id"])
            chunks.append(item["chunk"])
            titles.append(item["title"])
            dates.append(item["date"])

        predictions = self.process_chunks(chunks, titles, dates)
        result = [
            {"id": id_, "generated_text": prediction}
            for id_, prediction in zip(ids, predictions)
        ]
        return {"results": result}