lora / 1_finetune_mixtral.py
Dmitriy-Egorov's picture
Update 1_finetune_mixtral.py
654f591 verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig
from trl import SFTTrainer
from datasets import load_dataset
import config as cfg
def main():
# --- 1. Загрузка датасета ---
try:
dataset = load_dataset("json", data_files={"train": cfg.TRAINING_DATA_JSONL}, split="train")
except Exception as e:
print(f"Ошибка загрузки датасета из {cfg.TRAINING_DATA_JSONL}: {e}")
print("Убедитесь, что файл существует и не пуст, и что скрипт 0_prepare_data.py успешно отработал.")
return
if not dataset or len(dataset) == 0:
print("Датасет не загружен или пуст. Прерывание.")
return
print(f"Загружен датасет с {len(dataset)} примерами.")
if len(dataset) > 0:
print("Пример первого элемента датасета:", dataset[0])
# --- 2. Конфигурация квантизации (BitsAndBytes) ---
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
bnb_4bit_use_double_quant=True,
)
# --- 3. Загрузка модели и токенизатора ---
model = AutoModelForCausalLM.from_pretrained(
cfg.BASE_MODEL_NAME,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
device_map="auto",
trust_remote_code=True,
# attn_implementation="flash_attention_2" # Если flash-attn установлен
)
model.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(cfg.BASE_MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# --- 4. Конфигурация LoRA ---
lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] #, "gate_proj", "up_proj", "down_proj"]
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.05, # Было 0.1, уменьшил для возможной борьбы с переобучением на малых данных
r=8, # или 16
bias="none",
task_type="CAUSAL_LM",
target_modules=lora_target_modules
)
# --- 5. Настройка аргументов обучения ---
training_args = TrainingArguments(
output_dir=cfg.OUTPUT_DIR,
per_device_train_batch_size=cfg.TRAIN_BATCH_SIZE,
gradient_accumulation_steps=cfg.GRAD_ACCUMULATION_STEPS,
optim="paged_adamw_32bit",
learning_rate=cfg.LEARNING_RATE,
num_train_epochs=cfg.NUM_EPOCHS,
lr_scheduler_type="cosine",
warmup_ratio=0.03,
logging_steps=10,
save_strategy="epoch",
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
gradient_checkpointing=True,
report_to="tensorboard",
# evaluation_strategy="epoch", # Если есть eval_dataset
# load_best_model_at_end=True, # Если есть eval_dataset
)
# --- 6. Инициализация SFTTrainer ---
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=dataset,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=cfg.MAX_SEQ_LENGTH,
packing=False, # С вашими данными, вероятно, лучше False
)
# --- 7. Запуск обучения ---
print("Начало обучения...")
try:
trainer.train()
except Exception as e:
print(f"Ошибка во время обучения: {e}")
return
# --- 8. Сохранение адаптера LoRA ---
trainer.save_model(cfg.FINETUNED_ADAPTER_PATH)
print(f"Обучение завершено. Адаптер LoRA сохранен в: {cfg.FINETUNED_ADAPTER_PATH}")
if __name__ == "__main__":
if not torch.cuda.is_available():
print("CUDA недоступна.")
else:
print(f"Доступно CUDA устройств: {torch.cuda.device_count()}")
print(f"Текущее устройство CUDA: {torch.cuda.current_device()} ({torch.cuda.get_device_name(torch.cuda.current_device())})")
if torch.cuda.is_bf16_supported(): print("BF16 поддерживается.")
else: print("BF16 НЕ поддерживается.")
main()