Spaces:
Sleeping
Sleeping
| import os | |
| from enum import Enum | |
| import packaging.version | |
| import torch | |
| import transformers | |
| from datasets import DatasetDict, load_dataset, load_from_disk | |
| from datasets.builder import DatasetGenerationError | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| ) | |
| from peft import LoraConfig | |
| DEFAULT_CHATML_CHAT_TEMPLATE = "{% for message in messages %}\n{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% if loop.last and add_generation_prompt %}{{'<|im_start|>assistant\n' }}{% endif %}{% endfor %}" | |
| DEFAULT_ZEPHYR_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" | |
| class ZephyrSpecialTokens(str, Enum): | |
| user = "<|user|>" | |
| assistant = "<|assistant|>" | |
| system = "<|system|>" | |
| eos_token = "</s>" | |
| bos_token = "<s>" | |
| pad_token = "<pad>" | |
| def list(cls): | |
| return [c.value for c in cls] | |
| class ChatmlSpecialTokens(str, Enum): | |
| user = "<|im_start|>user" | |
| assistant = "<|im_start|>assistant" | |
| system = "<|im_start|>system" | |
| eos_token = "<|im_end|>" | |
| bos_token = "<s>" | |
| pad_token = "<pad>" | |
| def list(cls): | |
| return [c.value for c in cls] | |
| def create_datasets(tokenizer, data_args, training_args, apply_chat_template=False): | |
| def preprocess(samples): | |
| batch = [] | |
| for conversation in samples["messages"]: | |
| batch.append(tokenizer.apply_chat_template(conversation, tokenize=False)) | |
| return {"content": batch} | |
| raw_datasets = DatasetDict() | |
| for split in data_args.splits.split(","): | |
| try: | |
| # Try first if dataset on a Hub repo | |
| dataset = load_dataset(data_args.dataset_name, split=split) | |
| except DatasetGenerationError: | |
| # If not, check local dataset | |
| dataset = load_from_disk(os.path.join(data_args.dataset_name, split)) | |
| if "train" in split: | |
| raw_datasets["train"] = dataset | |
| elif "test" in split: | |
| raw_datasets["test"] = dataset | |
| else: | |
| raise ValueError(f"Split type {split} not recognized as one of test or train.") | |
| if apply_chat_template: | |
| raw_datasets = raw_datasets.map( | |
| preprocess, | |
| batched=True, | |
| remove_columns=raw_datasets["train"].column_names, | |
| ) | |
| train_data = raw_datasets["train"] | |
| valid_data = raw_datasets["test"] | |
| print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}") | |
| print(f"A sample of train dataset: {train_data[0]}") | |
| return train_data, valid_data | |
| def create_and_prepare_model(args, data_args, training_args): | |
| if args.use_unsloth: | |
| from unsloth import FastLanguageModel | |
| bnb_config = None | |
| quant_storage_dtype = None | |
| if ( | |
| torch.distributed.is_available() | |
| and torch.distributed.is_initialized() | |
| and torch.distributed.get_world_size() > 1 | |
| and args.use_unsloth | |
| ): | |
| raise NotImplementedError("Unsloth is not supported in distributed training") | |
| if args.use_4bit_quantization: | |
| compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype) | |
| quant_storage_dtype = getattr(torch, args.bnb_4bit_quant_storage_dtype) | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=args.use_4bit_quantization, | |
| bnb_4bit_quant_type=args.bnb_4bit_quant_type, | |
| bnb_4bit_compute_dtype=compute_dtype, | |
| bnb_4bit_use_double_quant=args.use_nested_quant, | |
| bnb_4bit_quant_storage=quant_storage_dtype, | |
| ) | |
| if compute_dtype == torch.float16 and args.use_4bit_quantization: | |
| major, _ = torch.cuda.get_device_capability() | |
| if major >= 8: | |
| print("=" * 80) | |
| print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16") | |
| print("=" * 80) | |
| elif args.use_8bit_quantization: | |
| bnb_config = BitsAndBytesConfig(load_in_8bit=args.use_8bit_quantization) | |
| if args.use_unsloth: | |
| if torch.xpu.is_available(): | |
| raise NotImplementedError("XPU hasn't supported unsloth yet") | |
| # Load model | |
| model, _ = FastLanguageModel.from_pretrained( | |
| model_name=args.model_name_or_path, | |
| max_seq_length=training_args.max_seq_length, | |
| dtype=None, | |
| load_in_4bit=args.use_4bit_quantization, | |
| ) | |
| else: | |
| torch_dtype = ( | |
| quant_storage_dtype if quant_storage_dtype and quant_storage_dtype.is_floating_point else torch.float32 | |
| ) | |
| # Prepare model loading arguments | |
| model_kwargs = { | |
| "trust_remote_code": True, | |
| "torch_dtype": torch_dtype, | |
| } | |
| if args.use_flash_attn: | |
| if torch.xpu.is_available(): | |
| print("XPU hasn't supported flash_attn yet, use eager implementation instead.") | |
| model_kwargs["attn_implementation"] = "eager" | |
| else: | |
| model_kwargs["attn_implementation"] = "flash_attention_2" | |
| # Only add quantization_config if bnb_config is not None | |
| if bnb_config is not None: | |
| model_kwargs["quantization_config"] = bnb_config | |
| model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs) | |
| peft_config = None | |
| chat_template = None | |
| if args.use_peft_lora and not args.use_unsloth: | |
| peft_config = LoraConfig( | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=args.lora_dropout, | |
| r=args.lora_r, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| target_modules=args.lora_target_modules.split(",") | |
| if args.lora_target_modules != "all-linear" | |
| else args.lora_target_modules, | |
| ) | |
| special_tokens = None | |
| chat_template = None | |
| if args.chat_template_format == "chatml": | |
| special_tokens = ChatmlSpecialTokens | |
| chat_template = DEFAULT_CHATML_CHAT_TEMPLATE | |
| elif args.chat_template_format == "zephyr": | |
| special_tokens = ZephyrSpecialTokens | |
| chat_template = DEFAULT_ZEPHYR_CHAT_TEMPLATE | |
| if special_tokens is not None: | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| args.model_name_or_path, | |
| pad_token=special_tokens.pad_token.value, | |
| bos_token=special_tokens.bos_token.value, | |
| eos_token=special_tokens.eos_token.value, | |
| additional_special_tokens=special_tokens.list(), | |
| trust_remote_code=True, | |
| ) | |
| tokenizer.chat_template = chat_template | |
| # make embedding resizing configurable? | |
| # Transformers 4.46.0+ defaults uses mean_resizing by default, which fails with QLoRA + FSDP because the | |
| # embedding could be on meta device, therefore, we set mean_resizing=False in that case (i.e. the status quo | |
| # ante). See https://github.com/huggingface/accelerate/issues/1620. | |
| uses_transformers_4_46 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.46.0") | |
| uses_fsdp = os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" | |
| # Check if the model is quantized | |
| is_quantized = (bnb_config is not None) or (getattr(model, "hf_quantizer", None) is not None) | |
| if is_quantized and uses_fsdp and uses_transformers_4_46: | |
| model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8, mean_resizing=False) | |
| else: | |
| model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| if args.use_unsloth: | |
| # Do model patching and add fast LoRA weights | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=args.lora_dropout, | |
| r=args.lora_r, | |
| target_modules=args.lora_target_modules.split(",") | |
| if args.lora_target_modules != "all-linear" | |
| else args.lora_target_modules, | |
| use_gradient_checkpointing=training_args.gradient_checkpointing, | |
| random_state=training_args.seed, | |
| max_seq_length=training_args.max_seq_length, | |
| ) | |
| return model, peft_config, tokenizer | |