Spaces:
Running
Running
| from typing import Union, List, Dict | |
| import torch | |
| from datasets import Dataset | |
| from datasets.formatting.formatting import LazyRow | |
| from transformers import ( | |
| pipeline, | |
| ) | |
| from src.language_model.language_model_abstraction import LanguageModel | |
| from src.language_model.huggingface_language_model_factory import ( | |
| hugging_face_language_model_tokenizer_factory, | |
| ) | |
| from src.task.task import TaskType, Task | |
| class HFLLMModel(LanguageModel): | |
| """ | |
| LLM Model based on Hugging Face Transformers and pipeline mechanism, loads pretrained LLM models and uses | |
| it for inference. | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str, | |
| token: Union[str, None] = None, | |
| batch_size: int = 8, | |
| ): | |
| super().__init__(model_name) | |
| self._model_name = model_name | |
| self._token = token | |
| self.model, self.tokenizer = hugging_face_language_model_tokenizer_factory( | |
| model_name=self._model_name, | |
| huggingface_token=self._token, | |
| ) | |
| num_params = self.model.num_parameters() | |
| # To handle max batch size for these models. | |
| if num_params >= 70000000000: # 70B | |
| batch_size = 2 | |
| elif num_params >= 32000000000: # 32B | |
| batch_size = 8 | |
| elif num_params >= 27000000000: # 27B | |
| batch_size = 16 | |
| elif "gpt-oss" in self._model_name: | |
| batch_size = 8 # Otherwise a lot of OOM | |
| self._batch_size_sts22 = ( | |
| 1 # For sts22, GPT-oss get OOM for batch size higher than 1. | |
| ) | |
| self._batch_size = batch_size | |
| def predict(self, evaluation_dataset: Dataset, task: Task) -> List: | |
| if task.task_name == "sts22" and "gpt-oss" in self._model_name: | |
| # For sts22, GPT-oss get OOM for batch size higher than 1. | |
| batch_size = self._batch_size_sts22 | |
| else: | |
| batch_size = self._batch_size | |
| if task.task_type == TaskType.INFERENCE: | |
| labels = task.dataset.possible_ground_truths | |
| self.pipeline = pipeline( | |
| task="zero-shot-classification", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| batch_size=batch_size, | |
| torch_dtype="float16", | |
| return_full_text=False, | |
| max_new_tokens=16, | |
| padding=True, | |
| truncation=True, | |
| max_length=4096, | |
| candidate_labels=labels, | |
| ) | |
| if len(labels) == 2: | |
| inference_fn = self.infer_binary | |
| else: | |
| inference_fn = self.infer | |
| else: | |
| self.pipeline = pipeline( | |
| task="text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| batch_size=batch_size, | |
| torch_dtype="float16", | |
| return_full_text=False, | |
| max_new_tokens=64, | |
| padding=True, | |
| truncation=True, | |
| max_length=4096, | |
| ) | |
| inference_fn = self.generate | |
| process_dataset = evaluation_dataset.map( | |
| inference_fn, | |
| batched=True, | |
| batch_size=self._batch_size, | |
| desc=f"Running evaluation for task: {task.task_name}", | |
| remove_columns="text", | |
| ) | |
| return process_dataset["prediction"] | |
| def generate(self, rows: LazyRow) -> Dict: | |
| """ | |
| Do a generation over a set of rows and extract the generated text and apply string post-processing. | |
| """ | |
| with torch.no_grad(): | |
| text = rows["text"] | |
| if self._model_name.lower() == "chocolatine": | |
| # Problem with Phi-4 generation: | |
| # https://github.com/huggingface/transformers/issues/36071#issuecomment-3109331152 | |
| generation_args = {"use_cache": False} | |
| outputs = self.pipeline(text, **generation_args) | |
| else: | |
| outputs = self.pipeline(text) | |
| generated_texts = [ | |
| output[0]["generated_text"].strip() for output in outputs | |
| ] | |
| return {"prediction": generated_texts} | |
| def infer(self, rows: LazyRow) -> Dict: | |
| """ | |
| Do a zero-shot classification and extract the label using a per-element generation. | |
| For a fucking strange reason, the pipeline does not work in this case: | |
| 1. Batched generation of more than one element | |
| 2. More than 2 labels. | |
| Thus, we need to loop over the element. Painful I know. | |
| """ | |
| with torch.no_grad(): | |
| texts = rows["text"] | |
| classifications = [] | |
| for text in texts: | |
| output = self.pipeline(text) | |
| classifications.append( | |
| output["labels"][0] | |
| ) # Labels are sorted in likelihood order. | |
| return {"prediction": classifications} | |
| def infer_binary(self, rows: LazyRow) -> Dict: | |
| """ | |
| Do a binary zero-shot classification and extract the label using a per-element generation. | |
| """ | |
| with torch.no_grad(): | |
| texts = rows["text"] | |
| outputs = self.pipeline(texts) | |
| classifications = [ | |
| output["labels"][0] for output in outputs | |
| ] # Labels are sorted in likelihood order. | |
| return {"prediction": classifications} | |