| | from unittest import mock |
| | from unittest.mock import MagicMock, patch |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import pytest |
| |
|
| | from llm_studio.app_utils.default_datasets import ( |
| | prepare_default_dataset_causal_language_modeling, |
| | ) |
| | from llm_studio.python_configs.text_causal_language_modeling_config import ( |
| | ConfigNLPCausalLMDataset, |
| | ConfigNLPCausalLMTokenizer, |
| | ConfigProblemBase, |
| | ) |
| | from llm_studio.src.datasets.text_causal_language_modeling_ds import CustomDataset |
| |
|
| |
|
| | def test_prepare_default_dataset(tmp_path): |
| | df = prepare_default_dataset_causal_language_modeling(tmp_path) |
| | assert isinstance(df, pd.DataFrame) |
| | assert set(df.keys()) == set( |
| | ["instruction", "output", "id", "parent_id", "lang", "rank"] |
| | ) |
| | assert df.shape == (13026, 6) |
| |
|
| |
|
| | def test_clean_output(): |
| | output = { |
| | "predicted_text": np.array( |
| | [ |
| | "This is a test", |
| | "This is a test <stop> This is a test", |
| | "This is a test <stop2> This is a test", |
| | "This is a test <stop3> <stop> This is a test", |
| | "<stop2> <stop> This is a test", |
| | "This is a test <stop>", |
| | ] |
| | ) |
| | } |
| |
|
| | cfg = mock.MagicMock() |
| | cfg.tokenizer._stop_words = ["<stop>", "<stop2>", "<stop3>"] |
| |
|
| | predicted_text_clean = CustomDataset.clean_output(output=output, cfg=cfg)[ |
| | "predicted_text" |
| | ] |
| | assert predicted_text_clean == [ |
| | "This is a test", |
| | "This is a test", |
| | "This is a test", |
| | "This is a test", |
| | "", |
| | "This is a test", |
| | ] |
| |
|
| |
|
| | def test_sanity_check_raises_error(): |
| | mock_config = MagicMock() |
| | mock_config.dataset.parent_id_column = "parent_id" |
| | mock_config.dataset.answer_column = "answer" |
| |
|
| | df_1 = pd.DataFrame( |
| | { |
| | "id": [1, 2, 3, 4], |
| | "parent_id": [2, None, 4, 1], |
| | "answer": ["a", "b", "c", "d"], |
| | "other_data": ["a", "b", "c", "d"], |
| | } |
| | ) |
| | CustomDataset.sanity_check(df_1, mock_config) |
| |
|
| | df_2 = pd.DataFrame( |
| | { |
| | "id": [1, 2, 3, 4], |
| | "parent_id": [None, None, None, None], |
| | "answer": ["a", "b", "c", "d"], |
| | "other_data": ["a", "b", "c", "d"], |
| | } |
| | ) |
| | CustomDataset.sanity_check(df_2, mock_config) |
| |
|
| | invalid_df_1 = pd.DataFrame( |
| | { |
| | "id": [1, 2, 3, 4], |
| | "parent_id": [1, 2, 3, 4], |
| | "answer": ["a", "b", "c", "d"], |
| | "other_data": ["a", "b", "c", "d"], |
| | } |
| | ) |
| | with pytest.raises( |
| | AssertionError, match="Parent id column is the same as id column for some rows" |
| | ): |
| | CustomDataset.sanity_check(invalid_df_1, mock_config) |
| |
|
| | invalid_df_2 = pd.DataFrame( |
| | { |
| | "id": [1, 2, 3, 4], |
| | "parent_id": [2, 3, 4, 1], |
| | "other_data": ["a", "b", "c", "d"], |
| | } |
| | ) |
| | with pytest.raises( |
| | AssertionError, |
| | match="Did not find any conversation start. " |
| | "Please ensure that some parent ids are empty.", |
| | ): |
| | CustomDataset.sanity_check(invalid_df_2, mock_config) |
| |
|
| |
|
| | @pytest.fixture |
| | def mock_auto_tokenizer(): |
| | |
| | |
| | with patch( |
| | "transformers.AutoTokenizer.from_pretrained", autospec=True |
| | ) as mock_from_pretrained: |
| | yield mock_from_pretrained |
| |
|
| |
|
| | def test_init(mock_auto_tokenizer): |
| | df = pd.DataFrame( |
| | { |
| | "col_A": [1, 2, 3], |
| | "col_B": [4, 5, 6], |
| | } |
| | ) |
| | cfg = mock.MagicMock() |
| | cfg.dataset.prompt_column = "col_A" |
| | cfg.dataset.answer_column = "col_B" |
| | cfg.dataset.parent_id_column = "None" |
| | cfg.dataset.system_column = "None" |
| |
|
| | cfg.dataset.text_system_start = "" |
| | cfg.dataset.text_prompt_start = "" |
| | cfg.dataset.text_answer_separator = "" |
| |
|
| | dataset = CustomDataset(df, cfg) |
| |
|
| | assert dataset.df.equals(df) |
| | assert dataset.mode == "train" |
| |
|
| |
|
| | def test_getitem(): |
| | df = pd.DataFrame( |
| | { |
| | "prompt": ["prompt 1", "prompt 2", "prompt 3"], |
| | "answer": ["answer 1", "answer 2", "answer 3"], |
| | "parent_id": [None, 0, 1], |
| | "system": ["system 1", "system 2", "system 3"], |
| | "id": [0, 1, 2], |
| | } |
| | ) |
| |
|
| | cfg = ConfigProblemBase( |
| | dataset=ConfigNLPCausalLMDataset( |
| | prompt_column=("prompt",), |
| | answer_column="answer", |
| | parent_id_column="parent_id", |
| | system_column="system", |
| | text_system_start="System:", |
| | text_prompt_start="Prompt:", |
| | text_answer_separator="Answer:", |
| | add_eos_token_to_answer=True, |
| | limit_chained_samples=True, |
| | ), |
| | tokenizer=ConfigNLPCausalLMTokenizer(max_length=513), |
| | ) |
| |
|
| | cfg.llm_backbone = "EleutherAI/pythia-2.8b-deduped" |
| |
|
| | dataset = CustomDataset(df, cfg) |
| | assert len(dataset) == 1 |
| |
|
| | result = dataset[0] |
| | assert isinstance(result, dict) |
| | assert set(result.keys()) == { |
| | "labels", |
| | "input_ids", |
| | "attention_mask", |
| | "prompt_input_ids", |
| | "prompt_attention_mask", |
| | "answer_input_ids", |
| | "answer_attention_mask", |
| | } |
| |
|
| | assert ( |
| | dataset.tokenizer.decode(result["input_ids"], skip_special_tokens=True) |
| | == "System:system 1" |
| | "Prompt:prompt 1" |
| | "Answer:answer 1" |
| | "Prompt:prompt 2" |
| | "Answer:answer 2" |
| | "Prompt:prompt 3" |
| | "Answer:answer 3" |
| | ) |
| |
|
| | assert ( |
| | dataset.tokenizer.decode(result["prompt_input_ids"], skip_special_tokens=True) |
| | == "System:system 1" |
| | "Prompt:prompt 1" |
| | "Answer:answer 1" |
| | "Prompt:prompt 2" |
| | "Answer:answer 2" |
| | "Prompt:prompt 3" |
| | "Answer:" |
| | ) |
| |
|
| | assert ( |
| | dataset.tokenizer.decode(result["input_ids"], skip_special_tokens=False) |
| | == "<|endoftext|>" * 475 + "System:system 1" |
| | "<|endoftext|>" |
| | "Prompt:prompt 1" |
| | "<|endoftext|>" |
| | "Answer:answer 1" |
| | "<|endoftext|>" |
| | "Prompt:prompt 2" |
| | "<|endoftext|>" |
| | "Answer:answer 2" |
| | "<|endoftext|>" |
| | "Prompt:prompt 3" |
| | "<|endoftext|>" |
| | "Answer:answer 3" |
| | "<|endoftext|>" |
| | ) |
| |
|
| | assert result["input_ids"].shape == (513,) |
| | assert result["prompt_input_ids"].shape == (513,) |
| |
|
| |
|
| | def test_getitem_no_chaining(): |
| | df = pd.DataFrame( |
| | { |
| | "prompt": ["prompt 1", "prompt 2", "prompt 3"], |
| | "answer": ["answer 1", "answer 2", "answer 3"], |
| | "parent_id": [None, 0, 1], |
| | "system": ["system 1", "system 2", "system 3"], |
| | "id": [0, 1, 2], |
| | } |
| | ) |
| |
|
| | cfg = ConfigProblemBase( |
| | dataset=ConfigNLPCausalLMDataset( |
| | prompt_column=("prompt",), |
| | answer_column="answer", |
| | parent_id_column="None", |
| | system_column="system", |
| | text_system_start="System:", |
| | text_prompt_start="Prompt:", |
| | text_answer_separator="Answer:", |
| | add_eos_token_to_answer=True, |
| | ), |
| | tokenizer=ConfigNLPCausalLMTokenizer(max_length=513), |
| | ) |
| |
|
| | cfg.llm_backbone = "EleutherAI/pythia-2.8b-deduped" |
| |
|
| | dataset = CustomDataset(df, cfg) |
| | assert len(dataset) == 3 |
| |
|
| | for i in range(3): |
| | result = dataset[i] |
| | assert isinstance(result, dict) |
| | assert set(result.keys()) == { |
| | "labels", |
| | "input_ids", |
| | "attention_mask", |
| | "prompt_input_ids", |
| | "prompt_attention_mask", |
| | "answer_input_ids", |
| | "answer_attention_mask", |
| | } |
| |
|
| | assert ( |
| | dataset.tokenizer.decode(result["input_ids"], skip_special_tokens=True) |
| | == f"System:system {i+1}" |
| | f"Prompt:prompt {i+1}" |
| | f"Answer:answer {i+1}" |
| | ) |
| |
|
| | assert ( |
| | dataset.tokenizer.decode( |
| | result["prompt_input_ids"], skip_special_tokens=True |
| | ) |
| | == f"System:system {i+1}" |
| | f"Prompt:prompt {i+1}" |
| | "Answer:" |
| | ) |
| |
|