Dmitriy-Egorov commited on
Commit
654f591
·
verified ·
1 Parent(s): f5de412

Update 1_finetune_mixtral.py

Browse files
Files changed (1) hide show
  1. 1_finetune_mixtral.py +25 -54
1_finetune_mixtral.py CHANGED
@@ -3,27 +3,24 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
  from peft import LoraConfig
4
  from trl import SFTTrainer
5
  from datasets import load_dataset
6
- import config as cfg # Импортируем наш конфиг
7
 
8
  def main():
9
  # --- 1. Загрузка датасета ---
10
- # Мы сохранили данные в jsonl, так что можем загрузить их так:
11
  try:
12
- dataset = load_dataset("json", data_files={"train": cfg.OUTPUT_TRAIN_FILE}, split="train")
13
- # Если у вас есть и валидационный набор:
14
- # dataset = load_dataset("json", data_files={"train": "data/train.jsonl", "validation": "data/validation.jsonl"})
15
  except Exception as e:
16
- print(f"Ошибка загрузки датасета: {e}")
17
- print("Убедитесь, что файл data/train_dataset_llm.jsonl существует и не пуст, и что скрипт 0_prepare_data.py успешно отработал.")
18
  return
19
 
20
- if not dataset:
21
  print("Датасет не загружен или пуст. Прерывание.")
22
  return
23
 
24
  print(f"Загружен датасет с {len(dataset)} примерами.")
25
- print("Пример первого элемента датасета:", dataset[0])
26
-
27
 
28
  # --- 2. Конфигурация квантизации (BitsAndBytes) ---
29
  bnb_config = BitsAndBytesConfig(
@@ -38,25 +35,23 @@ def main():
38
  cfg.BASE_MODEL_NAME,
39
  quantization_config=bnb_config,
40
  torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
41
- device_map="auto", # Распределит модель по доступным GPU
42
  trust_remote_code=True,
43
- # attn_implementation="flash_attention_2" # Если установлено и поддерживается
44
  )
45
- model.config.use_cache = False # Важно для PEFT
46
 
47
  tokenizer = AutoTokenizer.from_pretrained(cfg.BASE_MODEL_NAME, trust_remote_code=True)
48
  tokenizer.pad_token = tokenizer.eos_token
49
  tokenizer.padding_side = "right"
50
 
51
  # --- 4. Конфигурация LoRA ---
52
- # target_modules для Mixtral могут включать 'q_proj', 'k_proj', 'v_proj', 'o_proj',
53
- # 'gate_proj', 'up_proj', 'down_proj'. Начните с основных для проекций внимания.
54
- lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
55
 
56
  peft_config = LoraConfig(
57
  lora_alpha=16,
58
- lora_dropout=0.05,
59
- r=8,
60
  bias="none",
61
  task_type="CAUSAL_LM",
62
  target_modules=lora_target_modules
@@ -67,20 +62,19 @@ def main():
67
  output_dir=cfg.OUTPUT_DIR,
68
  per_device_train_batch_size=cfg.TRAIN_BATCH_SIZE,
69
  gradient_accumulation_steps=cfg.GRAD_ACCUMULATION_STEPS,
70
- optim="paged_adamw_32bit", # Экономит память
71
  learning_rate=cfg.LEARNING_RATE,
72
  num_train_epochs=cfg.NUM_EPOCHS,
73
- # max_steps=100, # Для быстрого теста, потом установите -1 для использования num_train_epochs
74
  lr_scheduler_type="cosine",
75
  warmup_ratio=0.03,
76
  logging_steps=10,
77
  save_strategy="epoch",
78
- # evaluation_strategy="epoch", # Если есть eval_dataset
79
- # load_best_model_at_end=True, # Если есть eval_dataset
80
- fp16=not torch.cuda.is_bf16_supported(), # Используйте fp16 если bfloat16 не доступен
81
  bf16=torch.cuda.is_bf16_supported(),
82
  gradient_checkpointing=True,
83
- report_to="tensorboard", # или "wandb"
 
 
84
  )
85
 
86
  # --- 6. Инициализация SFTTrainer ---
@@ -88,12 +82,11 @@ def main():
88
  model=model,
89
  tokenizer=tokenizer,
90
  args=training_args,
91
- train_dataset=dataset, # dataset.select(range(100)) для теста на малом подмножестве
92
- # eval_dataset=dataset["validation"], # Если есть
93
  peft_config=peft_config,
94
- dataset_text_field="text", # Название колонки с текстом в вашем датасете
95
  max_seq_length=cfg.MAX_SEQ_LENGTH,
96
- packing=False, # Упаковка может быть полезна для коротких последовательностей, но начните без нее
97
  )
98
 
99
  # --- 7. Запуск обучения ---
@@ -102,40 +95,18 @@ def main():
102
  trainer.train()
103
  except Exception as e:
104
  print(f"Ошибка во время обучения: {e}")
105
- print("Возможные причины: нехватка VRAM (уменьшите batch_size, max_seq_length, LoRA r), проблемы с данными.")
106
  return
107
 
108
  # --- 8. Сохранение адаптера LoRA ---
109
- trainer.save_model(cfg.FINETUNED_ADAPTER_PATH) # Сохраняем адаптер
110
- # tokenizer.save_pretrained(cfg.FINETUNED_ADAPTER_PATH) # Токенизатор тоже можно сохранить рядом
111
  print(f"Обучение завершено. Адаптер LoRA сохранен в: {cfg.FINETUNED_ADAPTER_PATH}")
112
 
113
- # (Опционально) Слияние и сохранение полной модели
114
- # Это потребует больше RAM/VRAM
115
- # print("Слияние модели...")
116
- # merged_model = model.merge_and_unload() # Если использовали get_peft_model
117
- # Если SFTTrainer сам создал PeftModel, то нужно сначала получить базовую модель и PeftModel
118
- # base_model_for_merge = AutoModelForCausalLM.from_pretrained(
119
- # cfg.BASE_MODEL_NAME,
120
- # torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
121
- # device_map="cpu", # Сливаем на CPU, если VRAM мало
122
- # trust_remote_code=True
123
- # )
124
- # merged_model = PeftModel.from_pretrained(base_model_for_merge, cfg.FINETUNED_ADAPTER_PATH)
125
- # merged_model = merged_model.merge_and_unload()
126
- # merged_model.save_pretrained(cfg.MERGED_MODEL_PATH)
127
- # tokenizer.save_pretrained(cfg.MERGED_MODEL_PATH)
128
- # print(f"Смерженная модель сохранена в: {cfg.MERGED_MODEL_PATH}")
129
-
130
-
131
  if __name__ == "__main__":
132
  if not torch.cuda.is_available():
133
- print("CUDA недоступна. Обучение на CPU будет невозможным или крайне медленным для Mixtral.")
134
  else:
135
  print(f"Доступно CUDA устройств: {torch.cuda.device_count()}")
136
  print(f"Текущее устройство CUDA: {torch.cuda.current_device()} ({torch.cuda.get_device_name(torch.cuda.current_device())})")
137
- if torch.cuda.is_bf16_supported():
138
- print("BF16 поддерживается.")
139
- else:
140
- print("BF16 НЕ поддерживается. Будет использоваться FP16 (если включено) или FP32.")
141
  main()
 
3
  from peft import LoraConfig
4
  from trl import SFTTrainer
5
  from datasets import load_dataset
6
+ import config as cfg
7
 
8
  def main():
9
  # --- 1. Загрузка датасета ---
 
10
  try:
11
+ dataset = load_dataset("json", data_files={"train": cfg.TRAINING_DATA_JSONL}, split="train")
 
 
12
  except Exception as e:
13
+ print(f"Ошибка загрузки датасета из {cfg.TRAINING_DATA_JSONL}: {e}")
14
+ print("Убедитесь, что файл существует и не пуст, и что скрипт 0_prepare_data.py успешно отработал.")
15
  return
16
 
17
+ if not dataset or len(dataset) == 0:
18
  print("Датасет не загружен или пуст. Прерывание.")
19
  return
20
 
21
  print(f"Загружен датасет с {len(dataset)} примерами.")
22
+ if len(dataset) > 0:
23
+ print("Пример первого элемента датасета:", dataset[0])
24
 
25
  # --- 2. Конфигурация квантизации (BitsAndBytes) ---
26
  bnb_config = BitsAndBytesConfig(
 
35
  cfg.BASE_MODEL_NAME,
36
  quantization_config=bnb_config,
37
  torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
38
+ device_map="auto",
39
  trust_remote_code=True,
40
+ # attn_implementation="flash_attention_2" # Если flash-attn установлен
41
  )
42
+ model.config.use_cache = False
43
 
44
  tokenizer = AutoTokenizer.from_pretrained(cfg.BASE_MODEL_NAME, trust_remote_code=True)
45
  tokenizer.pad_token = tokenizer.eos_token
46
  tokenizer.padding_side = "right"
47
 
48
  # --- 4. Конфигурация LoRA ---
49
+ lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] #, "gate_proj", "up_proj", "down_proj"]
 
 
50
 
51
  peft_config = LoraConfig(
52
  lora_alpha=16,
53
+ lora_dropout=0.05, # Было 0.1, уменьшил для возможной борьбы с переобучением на малых данных
54
+ r=8, # или 16
55
  bias="none",
56
  task_type="CAUSAL_LM",
57
  target_modules=lora_target_modules
 
62
  output_dir=cfg.OUTPUT_DIR,
63
  per_device_train_batch_size=cfg.TRAIN_BATCH_SIZE,
64
  gradient_accumulation_steps=cfg.GRAD_ACCUMULATION_STEPS,
65
+ optim="paged_adamw_32bit",
66
  learning_rate=cfg.LEARNING_RATE,
67
  num_train_epochs=cfg.NUM_EPOCHS,
 
68
  lr_scheduler_type="cosine",
69
  warmup_ratio=0.03,
70
  logging_steps=10,
71
  save_strategy="epoch",
72
+ fp16=not torch.cuda.is_bf16_supported(),
 
 
73
  bf16=torch.cuda.is_bf16_supported(),
74
  gradient_checkpointing=True,
75
+ report_to="tensorboard",
76
+ # evaluation_strategy="epoch", # Если есть eval_dataset
77
+ # load_best_model_at_end=True, # Если есть eval_dataset
78
  )
79
 
80
  # --- 6. Инициализация SFTTrainer ---
 
82
  model=model,
83
  tokenizer=tokenizer,
84
  args=training_args,
85
+ train_dataset=dataset,
 
86
  peft_config=peft_config,
87
+ dataset_text_field="text",
88
  max_seq_length=cfg.MAX_SEQ_LENGTH,
89
+ packing=False, # С вашими данными, вероятно, лучше False
90
  )
91
 
92
  # --- 7. Запуск обучения ---
 
95
  trainer.train()
96
  except Exception as e:
97
  print(f"Ошибка во время обучения: {e}")
 
98
  return
99
 
100
  # --- 8. Сохранение адаптера LoRA ---
101
+ trainer.save_model(cfg.FINETUNED_ADAPTER_PATH)
 
102
  print(f"Обучение завершено. Адаптер LoRA сохранен в: {cfg.FINETUNED_ADAPTER_PATH}")
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  if __name__ == "__main__":
105
  if not torch.cuda.is_available():
106
+ print("CUDA недоступна.")
107
  else:
108
  print(f"Доступно CUDA устройств: {torch.cuda.device_count()}")
109
  print(f"Текущее устройство CUDA: {torch.cuda.current_device()} ({torch.cuda.get_device_name(torch.cuda.current_device())})")
110
+ if torch.cuda.is_bf16_supported(): print("BF16 поддерживается.")
111
+ else: print("BF16 НЕ поддерживается.")
 
 
112
  main()