| import os |
| import re |
| from typing import cast, Any |
| from datasets import load_dataset, Dataset as HfDataset |
| from unfat.extract import Extractor |
| from unfat.client import OpenAiCompatClient |
| from unfat.datasets import Dataset, Prompts, hub_prompts, HubSplit |
| from unfat.together import llama_3_1_70b_together |
| from unfat.lora import LoraSettings |
|
|
| def gen_prompts( |
| ds_name: str, |
| text_field: str, |
| start_regex: re.Pattern | None = None, |
| end_regex: re.Pattern | None = None, |
| ): |
| ds = cast(HfDataset, load_dataset(ds_name, split="train")) |
| def items(): |
| for row in ds: |
| casted = cast(dict[Any, Any], row) |
| text = casted[text_field] |
| if start_regex and end_regex: |
| yield end_regex.sub("", start_regex.sub("", text)) |
| elif start_regex: |
| yield start_regex.sub("", text) |
| elif end_regex: |
| yield end_regex.sub("", text) |
| else: |
| yield text |
|
|
| return Prompts( |
| output_path=f"hub/{ds_name}.jsonl", |
| count=lambda: len(ds), |
| items=items, |
| ) |
|
|
| def extract_prompts_from_convos( |
| ds_name: str, |
| messages_field: str, |
| role_field: str, |
| content_field: str, |
| user_role: str, |
| ): |
| ds = cast(HfDataset, load_dataset(ds_name, split="train")) |
| def items(): |
| for row in ds: |
| casted = cast(dict[Any, Any], row) |
| for message in casted[messages_field]: |
| if message[role_field] == user_role: |
| yield message[content_field] |
| break |
| return Prompts( |
| output_path=f"hub/{ds_name}.jsonl", |
| count=lambda: len(ds), |
| items=items, |
| ) |
|
|
| def main(): |
| output_dir = "output" |
| rp_english = extract_prompts_from_convos( |
| ds_name="OdiaGenAI/roleplay_english", |
| messages_field="conversations", |
| role_field="from", |
| content_field="value", |
| user_role="user", |
| ) |
| bluemoon = extract_prompts_from_convos( |
| ds_name="xDAN2099/RolePlay-Mixed-Bluemoon-Limarp", |
| messages_field="conversations", |
| role_field="from", |
| content_field="value", |
| user_role="human", |
| ) |
| roleplay_prompts = gen_prompts( |
| ds_name="AlekseyKorshuk/roleplay-io", |
| text_field="input_text", |
| start_regex=re.compile(r'^User: '), |
| end_regex=re.compile(r'Bot:\s*$'), |
| ) |
| roleplay_instr_prompts = gen_prompts( |
| ds_name="iamketan25/roleplay-instructions-dataset", |
| text_field="prompt", |
| start_regex=re.compile(r'^Human: '), |
| end_regex=re.compile(r'Assistant:\s*$'), |
| ) |
|
|
| extractor = Extractor( |
| max_concurrent=50, |
| output_dir=output_dir, |
| client=OpenAiCompatClient( |
| base_url="https://glhf.chat/api/openai/v1", |
| api_key=os.environ["GLHF_API_KEY"], |
| model="hf:TheDrummer/Behemoth-123B-v1.2", |
| retries=20, |
| ), |
| dataset=Dataset( |
| train=[ |
| hub_prompts( |
| name="mlabonne/harmful_behaviors", |
| text_field="text", |
| split="train", |
| ), |
| roleplay_instr_prompts, |
| roleplay_prompts, |
| rp_english, |
| bluemoon, |
| hub_prompts( |
| name="TheDrummer/AmoralQA-v2", |
| text_field="prompt", |
| split="train", |
| ), |
| hub_prompts( |
| name="vicgalle/OpenHermesPreferences-roleplay", |
| text_field="prompt", |
| split="train", |
| ), |
| hub_prompts( |
| name="mrcuddle/DPO_Pairs_Roleplay-Alpaca", |
| text_field="prompt", |
| split="train", |
| ), |
| hub_prompts( |
| name="ResplendentAI/theory_of_mind_fixed_output", |
| text_field="instruction", |
| split="train", |
| ), |
| hub_prompts( |
| name="mlabonne/harmless_alpaca", |
| text_field="text", |
| split=HubSplit(name="train", max_rows=1000), |
| ), |
| ], |
| ), |
| ) |
| extractor.run() |
| dataset = extractor.output_dataset() |
| together_config = llama_3_1_70b_together( |
| output_dir=output_dir, |
| dataset=dataset, |
| api_key=os.environ["TOGETHER_API_KEY"], |
| settings=LoraSettings( |
| rank=32, |
| alpha=16, |
| dropout=0.01, |
| num_epochs=2, |
| learning_rate=4e-4, |
| evals_per_epoch=0, |
| wandb_project="behemoth-distill", |
| wandb_api_key=os.environ["WANDB_API_KEY"], |
| ) |
| ) |
| files = together_config.upload_files() |
| together_config.finetune(files) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|