| | import logging |
| | from typing import Any, Dict |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| |
|
| | from llm_studio.src.datasets.text_causal_language_modeling_ds import ( |
| | CustomDataset as TextCausalLanguageModelingCustomDataset, |
| | ) |
| | from llm_studio.src.utils.exceptions import LLMDataException |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class CustomDataset(TextCausalLanguageModelingCustomDataset): |
| | def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"): |
| | super().__init__(df=df, cfg=cfg, mode=mode) |
| | check_for_non_int_answers(cfg, df) |
| | self.answers_int = df[cfg.dataset.answer_column].astype(int).values.tolist() |
| |
|
| | if 1 < cfg.dataset.num_classes <= max(self.answers_int): |
| | raise LLMDataException( |
| | "Number of classes is smaller than max label " |
| | f"{max(self.answers_int)}. Please increase the setting accordingly." |
| | ) |
| | elif cfg.dataset.num_classes == 1 and max(self.answers_int) > 1: |
| | raise LLMDataException( |
| | "For binary classification, max label should be 1 but is " |
| | f"{max(self.answers_int)}." |
| | ) |
| | if min(self.answers_int) < 0: |
| | raise LLMDataException( |
| | "Labels should be non-negative but min label is " |
| | f"{min(self.answers_int)}." |
| | ) |
| | if ( |
| | min(self.answers_int) != 0 |
| | or max(self.answers_int) != len(set(self.answers_int)) - 1 |
| | ): |
| | logger.warning( |
| | "Labels should start at 0 and be continuous but are " |
| | f"{sorted(set(self.answers_int))}." |
| | ) |
| |
|
| | if cfg.dataset.parent_id_column != "None": |
| | raise LLMDataException( |
| | "Parent ID column is not supported for classification datasets." |
| | ) |
| |
|
| | def __getitem__(self, idx: int) -> Dict: |
| | sample = super().__getitem__(idx) |
| | sample["class_label"] = self.answers_int[idx] |
| | return sample |
| |
|
| | def postprocess_output(self, cfg, df: pd.DataFrame, output: Dict) -> Dict: |
| | output["logits"] = output["logits"].float() |
| | if cfg.dataset.num_classes == 1: |
| | preds = output["logits"] |
| | preds = np.array((preds > 0.0)).astype(int).astype(str).reshape(-1) |
| | else: |
| | preds = output["logits"] |
| | preds = ( |
| | np.array(torch.argmax(preds, dim=1)) |
| | .astype(str) |
| | .reshape(-1) |
| | ) |
| | output["predicted_text"] = preds |
| | return super().postprocess_output(cfg, df, output) |
| |
|
| | def clean_output(self, output, cfg): |
| | return output |
| |
|
| | @classmethod |
| | def sanity_check(cls, df: pd.DataFrame, cfg: Any, mode: str = "train"): |
| | |
| | check_for_non_int_answers(cfg, df) |
| |
|
| |
|
| | def check_for_non_int_answers(cfg, df): |
| | answers_non_int = [ |
| | x for x in df[cfg.dataset.answer_column].values if not is_castable_to_int(x) |
| | ] |
| | if len(answers_non_int) > 0: |
| | raise LLMDataException( |
| | f"Column {cfg.dataset.answer_column} contains non int items. " |
| | f"Sample values: {answers_non_int[:5]}." |
| | ) |
| |
|
| |
|
| | def is_castable_to_int(s): |
| | try: |
| | int(s) |
| | return True |
| | except ValueError: |
| | return False |
| |
|