| from unsloth import FastVisionModel |
| import torch |
| from datasets import load_dataset |
| from unsloth import is_bf16_supported |
| from unsloth.trainer import UnslothVisionDataCollator |
| from trl import SFTTrainer, SFTConfig |
| from transformers import TextStreamer |
| import datetime |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
| model, tokenizer = FastVisionModel.from_pretrained( |
| model_name = "/home/rzhong/project/unsloth/model_pretrain_20250301_113944", |
| load_in_4bit = False, |
| |
| use_gradient_checkpointing = "unsloth", |
| max_seq_length = 2048, |
| dtype = torch.bfloat16, |
| ) |
|
|
| model = FastVisionModel.get_peft_model( |
| model, |
| finetune_vision_layers = True, |
| finetune_language_layers = True, |
| finetune_attention_modules = True, |
| finetune_mlp_modules = True, |
| |
| r = 16, |
| lora_alpha = 16, |
| lora_dropout = 0, |
| bias = "none", |
| random_state = 3407, |
| use_rslora = False, |
| loftq_config = None, |
| |
| target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
| ) |
|
|
| dataset = load_dataset("/home/share/rzhong/dataset/google-landmark/dataset_4/dataset_file", split = "train") |
| print(dataset) |
|
|
| instruction = "描述这张图片。" |
| |
|
|
| def convert_to_conversation(sample): |
| conversation = [ |
| { "role": "user", |
| "content" : [ |
| {"type" : "text", "text" : instruction}, |
| {"type" : "image", "image" : sample["image"]} ] |
| }, |
| { "role" : "assistant", |
| "content" : [ |
| {"type" : "text", "text" : sample["text"]} ] |
| }, |
| ] |
| return { "messages" : conversation } |
| pass |
|
|
| converted_dataset = [convert_to_conversation(sample) for sample in dataset] |
|
|
| print(converted_dataset[0]) |
|
|
| FastVisionModel.for_training(model) |
|
|
| trainer = SFTTrainer( |
| model = model, |
| tokenizer = tokenizer, |
| data_collator = UnslothVisionDataCollator(model, tokenizer), |
| train_dataset = converted_dataset, |
| args = SFTConfig( |
| per_device_train_batch_size = 2, |
| gradient_accumulation_steps = 4, |
| warmup_steps = 5, |
| |
| num_train_epochs = 10, |
| learning_rate = 5e-5, |
| fp16 = not is_bf16_supported(), |
| bf16 = is_bf16_supported(), |
| logging_steps = 1, |
| optim = "adamw_8bit", |
| weight_decay = 0.01, |
| lr_scheduler_type = "linear", |
| seed = 3407, |
| output_dir = f"outputs_pretrain_sft_{timestamp}", |
| report_to = "none", |
|
|
| |
| remove_unused_columns = False, |
| dataset_text_field = "", |
| dataset_kwargs = {"skip_prepare_dataset": True}, |
| dataset_num_proc = 4, |
| max_seq_length = 2048, |
| ), |
| ) |
|
|
| trainer_stats = trainer.train() |
|
|
| model.save_pretrained(f"lora_model_pretrain_sft_{timestamp}") |
| tokenizer.save_pretrained(f"lora_model_pretrain_sft_{timestamp}") |