File size: 5,435 Bytes
77f021b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from typing import Union, List, Dict

import torch
from datasets import Dataset
from datasets.formatting.formatting import LazyRow
from transformers import (
    pipeline,
)

from src.language_model.language_model_abstraction import LanguageModel
from src.language_model.huggingface_language_model_factory import (
    hugging_face_language_model_tokenizer_factory,
)
from src.task.task import TaskType, Task


class HFLLMModel(LanguageModel):
    """
    LLM Model based on Hugging Face Transformers and pipeline mechanism, loads pretrained LLM models and uses
    it for inference.
    """

    def __init__(
        self,
        model_name: str,
        token: Union[str, None] = None,
        batch_size: int = 8,
    ):
        super().__init__(model_name)
        self._model_name = model_name
        self._token = token

        self.model, self.tokenizer = hugging_face_language_model_tokenizer_factory(
            model_name=self._model_name,
            huggingface_token=self._token,
        )

        num_params = self.model.num_parameters()

        # To handle max batch size for these models.
        if num_params >= 70000000000:  # 70B
            batch_size = 2
        elif num_params >= 32000000000:  # 32B
            batch_size = 8
        elif num_params >= 27000000000:  # 27B
            batch_size = 16
        elif "gpt-oss" in self._model_name:
            batch_size = 8  # Otherwise a lot of OOM
            self._batch_size_sts22 = (
                1  # For sts22, GPT-oss get OOM for batch size higher than 1.
            )
        self._batch_size = batch_size

    def predict(self, evaluation_dataset: Dataset, task: Task) -> List:
        if task.task_name == "sts22" and "gpt-oss" in self._model_name:
            # For sts22, GPT-oss get OOM for batch size higher than 1.
            batch_size = self._batch_size_sts22
        else:
            batch_size = self._batch_size

        if task.task_type == TaskType.INFERENCE:
            labels = task.dataset.possible_ground_truths
            self.pipeline = pipeline(
                task="zero-shot-classification",
                model=self.model,
                tokenizer=self.tokenizer,
                batch_size=batch_size,
                dtype="float16",
                return_full_text=False,
                max_new_tokens=16,
                padding=True,
                truncation=True,
                max_length=4096,
                candidate_labels=labels,
            )
            if len(labels) == 2:
                inference_fn = self.infer_binary
            else:
                inference_fn = self.infer
        else:
            self.pipeline = pipeline(
                task="text-generation",
                model=self.model,
                tokenizer=self.tokenizer,
                batch_size=batch_size,
                dtype="float16",
                return_full_text=False,
                max_new_tokens=64,
                padding=True,
                truncation=True,
                max_length=4096,
            )
            inference_fn = self.generate

        process_dataset = evaluation_dataset.map(
            inference_fn,
            batched=True,
            batch_size=self._batch_size,
            desc=f"Running evaluation for task: {task.task_name}",
            remove_columns="text",
        )

        return list(process_dataset["prediction"])

    def generate(self, rows: LazyRow) -> Dict:
        """
        Do a generation over a set of rows and extract the generated text and apply string post-processing.
        """
        with torch.no_grad():
            text = rows["text"]

            if self._model_name.lower() == "chocolatine":
                # Problem with Phi-4 generation:
                # https://github.com/huggingface/transformers/issues/36071#issuecomment-3109331152
                generation_args = {"use_cache": False}
                outputs = self.pipeline(text, **generation_args)
            else:
                outputs = self.pipeline(text)

            generated_texts = [
                output[0]["generated_text"].strip() for output in outputs
            ]

        return {"prediction": generated_texts}

    def infer(self, rows: LazyRow) -> Dict:
        """
        Do a zero-shot classification and extract the label using a per-element generation.

        For a fucking strange reason, the pipeline does not work in this case:
        1. Batched generation of more than one element
        2. More than 2 labels.

        Thus, we need to loop over the element. Painful I know.
        """

        with torch.no_grad():
            texts = rows["text"]

            classifications = []
            for text in texts:
                output = self.pipeline(text)
                classifications.append(
                    output["labels"][0]
                )  # Labels are sorted in likelihood order.

        return {"prediction": classifications}

    def infer_binary(self, rows: LazyRow) -> Dict:
        """
        Do a binary zero-shot classification and extract the label using a per-element generation.
        """
        with torch.no_grad():
            texts = rows["text"]

            outputs = self.pipeline(texts)
            classifications = [
                output["labels"][0] for output in outputs
            ]  # Labels are sorted in likelihood order.

        return {"prediction": classifications}