Spaces:
Running
Running
File size: 5,435 Bytes
77f021b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | from typing import Union, List, Dict
import torch
from datasets import Dataset
from datasets.formatting.formatting import LazyRow
from transformers import (
pipeline,
)
from src.language_model.language_model_abstraction import LanguageModel
from src.language_model.huggingface_language_model_factory import (
hugging_face_language_model_tokenizer_factory,
)
from src.task.task import TaskType, Task
class HFLLMModel(LanguageModel):
"""
LLM Model based on Hugging Face Transformers and pipeline mechanism, loads pretrained LLM models and uses
it for inference.
"""
def __init__(
self,
model_name: str,
token: Union[str, None] = None,
batch_size: int = 8,
):
super().__init__(model_name)
self._model_name = model_name
self._token = token
self.model, self.tokenizer = hugging_face_language_model_tokenizer_factory(
model_name=self._model_name,
huggingface_token=self._token,
)
num_params = self.model.num_parameters()
# To handle max batch size for these models.
if num_params >= 70000000000: # 70B
batch_size = 2
elif num_params >= 32000000000: # 32B
batch_size = 8
elif num_params >= 27000000000: # 27B
batch_size = 16
elif "gpt-oss" in self._model_name:
batch_size = 8 # Otherwise a lot of OOM
self._batch_size_sts22 = (
1 # For sts22, GPT-oss get OOM for batch size higher than 1.
)
self._batch_size = batch_size
def predict(self, evaluation_dataset: Dataset, task: Task) -> List:
if task.task_name == "sts22" and "gpt-oss" in self._model_name:
# For sts22, GPT-oss get OOM for batch size higher than 1.
batch_size = self._batch_size_sts22
else:
batch_size = self._batch_size
if task.task_type == TaskType.INFERENCE:
labels = task.dataset.possible_ground_truths
self.pipeline = pipeline(
task="zero-shot-classification",
model=self.model,
tokenizer=self.tokenizer,
batch_size=batch_size,
dtype="float16",
return_full_text=False,
max_new_tokens=16,
padding=True,
truncation=True,
max_length=4096,
candidate_labels=labels,
)
if len(labels) == 2:
inference_fn = self.infer_binary
else:
inference_fn = self.infer
else:
self.pipeline = pipeline(
task="text-generation",
model=self.model,
tokenizer=self.tokenizer,
batch_size=batch_size,
dtype="float16",
return_full_text=False,
max_new_tokens=64,
padding=True,
truncation=True,
max_length=4096,
)
inference_fn = self.generate
process_dataset = evaluation_dataset.map(
inference_fn,
batched=True,
batch_size=self._batch_size,
desc=f"Running evaluation for task: {task.task_name}",
remove_columns="text",
)
return list(process_dataset["prediction"])
def generate(self, rows: LazyRow) -> Dict:
"""
Do a generation over a set of rows and extract the generated text and apply string post-processing.
"""
with torch.no_grad():
text = rows["text"]
if self._model_name.lower() == "chocolatine":
# Problem with Phi-4 generation:
# https://github.com/huggingface/transformers/issues/36071#issuecomment-3109331152
generation_args = {"use_cache": False}
outputs = self.pipeline(text, **generation_args)
else:
outputs = self.pipeline(text)
generated_texts = [
output[0]["generated_text"].strip() for output in outputs
]
return {"prediction": generated_texts}
def infer(self, rows: LazyRow) -> Dict:
"""
Do a zero-shot classification and extract the label using a per-element generation.
For a fucking strange reason, the pipeline does not work in this case:
1. Batched generation of more than one element
2. More than 2 labels.
Thus, we need to loop over the element. Painful I know.
"""
with torch.no_grad():
texts = rows["text"]
classifications = []
for text in texts:
output = self.pipeline(text)
classifications.append(
output["labels"][0]
) # Labels are sorted in likelihood order.
return {"prediction": classifications}
def infer_binary(self, rows: LazyRow) -> Dict:
"""
Do a binary zero-shot classification and extract the label using a per-element generation.
"""
with torch.no_grad():
texts = rows["text"]
outputs = self.pipeline(texts)
classifications = [
output["labels"][0] for output in outputs
] # Labels are sorted in likelihood order.
return {"prediction": classifications}
|