| | import random |
| | from contextlib import contextmanager |
| | from dataclasses import dataclass |
| | from unittest.mock import patch |
| |
|
| | import pandas as pd |
| | import pytest |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from llm_studio.python_configs.text_causal_language_modeling_config import ( |
| | ConfigNLPCausalLMPrediction, |
| | ConfigNLPCausalLMTokenizer, |
| | ) |
| | from llm_studio.python_configs.text_dpo_modeling_config import ( |
| | ConfigDPODataset, |
| | ConfigProblemBase, |
| | ) |
| | from llm_studio.src.datasets.text_dpo_modeling_ds import CustomDataset |
| | from llm_studio.src.models.text_dpo_modeling_model import Model |
| | from llm_studio.src.utils.data_utils import batch_padding |
| | from train import run_eval |
| |
|
| |
|
| | @pytest.fixture |
| | def df(): |
| | prompt = """when ordering your sandstones, you select which colour scale you would want. |
| | it could be e.g. a 100% from grey/sand mix, or 80% fra beige/yellow mixed with 20% from black/brown. |
| | This is all lower case. Can you fix that?""" |
| | system = """You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. |
| | While performing the task think step-by-step and justify your steps.""" |
| | answer = """When ordering your sandstones, you select which color scale you would want. It could be, for example, a 100% from grey/sand mix, or 80% from beige/yellow mixed with 20% from black/brown. |
| | |
| | Step 1: Capitalize the first letter of the sentence. |
| | |
| | Step 2: Correct the spelling of "color" (assuming American English usage). |
| | |
| | Step 3: Replace ", e.g." with "for example" to clarify the sentence. |
| | |
| | Step 4: Capitalize "a" in "100% from a grey/sand mix" |
| | |
| | Step 5: Ensure the proper usage of words and punctuation throughout the revised sentence.""" |
| | return pd.DataFrame( |
| | { |
| | "prompt": [prompt], |
| | "system": [system], |
| | "answer": [answer], |
| | "rejected_answer": ["I cannot do that."], |
| | } |
| | ) |
| |
|
| |
|
| | def generate_causal_lm_model_text(df): |
| | from llm_studio.python_configs.text_causal_language_modeling_config import ( |
| | ConfigNLPCausalLMDataset, |
| | ) |
| | from llm_studio.python_configs.text_causal_language_modeling_config import ( |
| | ConfigProblemBase as ConfigCausalLMProblemBase, |
| | ) |
| | from llm_studio.src.datasets.text_causal_language_modeling_ds import ( |
| | CustomDataset as CausalLMCustomDataset, |
| | ) |
| | from llm_studio.src.models.text_causal_language_modeling_model import ( |
| | Model as CausalLMModel, |
| | ) |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | cfg = ConfigCausalLMProblemBase( |
| | llm_backbone="h2oai/llama2-0b-unit-test", |
| | dataset=ConfigNLPCausalLMDataset( |
| | system_column="system", |
| | prompt_column=("prompt",), |
| | answer_column="answer_column", |
| | ), |
| | tokenizer=ConfigNLPCausalLMTokenizer( |
| | max_length_prompt=256, max_length_answer=256, max_length=512 |
| | ), |
| | ) |
| | cfg.architecture.backbone_dtype = "float32" |
| |
|
| | dataset = CausalLMCustomDataset(df, cfg, mode="train") |
| | model = CausalLMModel(cfg).to(device).eval() |
| | dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True) |
| |
|
| | batch = next(iter(dataloader)) |
| | batch = {k: v.to(device) for k, v in batch.items()} |
| | batch_padding( |
| | cfg, |
| | batch, |
| | mask_key="prompt_attention_mask", |
| | pad_keys=[ |
| | "prompt_input_ids", |
| | "prompt_attention_mask", |
| | "prompt_special_tokens_mask", |
| | ], |
| | ) |
| | with torch.no_grad(): |
| | generated_text = dataset.tokenizer.decode(model.generate(batch, cfg)[0]) |
| |
|
| | return generated_text |
| |
|
| |
|
| | def test_generation_is_the_same_as_for_causal_language_modeling(df): |
| | """ |
| | DPO model should generate the same output text as causal language modeling |
| | """ |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | generated_text_causal_lm = generate_causal_lm_model_text(df) |
| |
|
| | cfg = ConfigProblemBase( |
| | llm_backbone="h2oai/llama2-0b-unit-test", |
| | dataset=ConfigDPODataset( |
| | system_column="system", |
| | prompt_column=("prompt",), |
| | answer_column="answer_column", |
| | rejected_answer_column="rejected_answer", |
| | ), |
| | tokenizer=ConfigNLPCausalLMTokenizer( |
| | max_length_prompt=256, max_length_answer=256, max_length=512 |
| | ), |
| | ) |
| | cfg.architecture.backbone_dtype = "float32" |
| |
|
| | dataset = CustomDataset(df, cfg, mode="train") |
| | model = Model(cfg).eval().to(device) |
| | dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True) |
| |
|
| | batch = next(iter(dataloader)) |
| | batch = {k: v.to(device) for k, v in batch.items()} |
| | batch_padding( |
| | cfg, |
| | batch, |
| | mask_key="prompt_attention_mask", |
| | pad_keys=[ |
| | "prompt_input_ids", |
| | "prompt_attention_mask", |
| | "prompt_special_tokens_mask", |
| | ], |
| | ) |
| | with torch.no_grad(): |
| | generated_text = dataset.tokenizer.decode(model.generate(batch, cfg)[0]) |
| |
|
| | assert ( |
| | generated_text == generated_text_causal_lm |
| | ), "Generated text is not the same as from causal LM model:" "{}\n{}".format( |
| | generated_text, generated_text_causal_lm |
| | ) |
| |
|
| |
|
| | @pytest.fixture |
| | def df2(): |
| | |
| | alphabet = [chr(i) for i in range(97, 123)] |
| |
|
| | |
| | prompts = ["".join(random.choice(alphabet) for _ in range(10)) for _ in range(10)] |
| | systems = ["".join(random.choice(alphabet) for _ in range(10)) for _ in range(10)] |
| | answers = ["".join(random.choice(alphabet) for _ in range(10)) for _ in range(10)] |
| | rejected_answers = [ |
| | "".join(random.choice(alphabet) for _ in range(10)) for _ in range(10) |
| | ] |
| |
|
| | return pd.DataFrame( |
| | { |
| | "prompt": prompts, |
| | "system": systems, |
| | "answer": answers, |
| | "rejected_answer": rejected_answers, |
| | } |
| | ) |
| |
|
| |
|
| | def test_dpo_perplexity_metric(tmp_path, df2): |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | cfg = ConfigProblemBase( |
| | output_directory=str(tmp_path), |
| | llm_backbone="MaxJeblick/llama2-0b-unit-test", |
| | dataset=ConfigDPODataset( |
| | system_column="system", |
| | prompt_column=("prompt",), |
| | answer_column="answer_column", |
| | rejected_answer_column="answer_column", |
| | ), |
| | tokenizer=ConfigNLPCausalLMTokenizer( |
| | max_length_prompt=256, max_length_answer=256, max_length=512 |
| | ), |
| | prediction=ConfigNLPCausalLMPrediction(metric="Perplexity"), |
| | ) |
| | cfg.architecture.gradient_checkpointing = False |
| | cfg.environment._device = device |
| |
|
| | |
| | cfg.environment.mixed_precision_dtype = "float16" |
| |
|
| | dataset = CustomDataset(df2, cfg, mode="train") |
| | model = Model(cfg).eval().to(device) |
| | vocab_size = model.backbone.config.vocab_size |
| |
|
| | class MockBackbone(nn.Module): |
| | """ |
| | Chosen and rejected logits are the same |
| | Chosen reference and rejected reference logits are the same, |
| | but different from chosen and rejected logits. |
| | As answer_column and rejected_answer_column are the same, |
| | |
| | -> perplexity and rejection_perplexity should be the same |
| | -> chosen_rewards and rejected_rewards should be the same |
| | -> chosen_cross_entropy and rejected_cross_entropy should be the same |
| | -> reward margin should be 0 |
| | """ |
| |
|
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.seed = 0 |
| |
|
| | def disable_adapter(self): |
| | |
| | @contextmanager |
| | def flip_seed(): |
| | self.seed = 1 |
| | yield None |
| | self.seed = 0 |
| |
|
| | return flip_seed() |
| |
|
| | def forward(self, input_ids, attention_mask): |
| | @dataclass |
| | class Result: |
| | bs, seq_len = input_ids.shape |
| | torch.manual_seed(self.seed) |
| | logits = torch.rand((bs, seq_len, vocab_size)).to(input_ids.device) |
| |
|
| | result = Result() |
| | return result |
| |
|
| | class ListLogger: |
| | def __init__(self): |
| | self.logs = {} |
| |
|
| | def log(self, subset: str, name: str, value: str | float, step: float = None): |
| | self.logs[name] = self.logs.get(name, []) + [value] |
| |
|
| | with patch.object(target=model, attribute="backbone", new_callable=MockBackbone): |
| | dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True) |
| |
|
| | |
| | cfg.logging._logger = ListLogger() |
| |
|
| | run_eval( |
| | cfg, |
| | model=model, |
| | val_dataloader=dataloader, |
| | val_df=df2, |
| | mode="validation", |
| | ) |
| |
|
| | log_dict = cfg.logging._logger.logs |
| | assert log_dict["Perplexity"] == log_dict["rejected_perplexity"] |
| | assert log_dict["chosen_rewards"] == log_dict["rejected_rewards"] |
| | assert ( |
| | log_dict["chosen_cross_entropy_loss"] == log_dict["rejected_cross_entropy_loss"] |
| | ) |
| | assert log_dict["reward_margin"] == [0] * len(log_dict["reward_margin"]) |
| |
|