Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments | |
| from peft import LoraConfig | |
| from trl import SFTTrainer | |
| from datasets import load_dataset | |
| import config as cfg | |
| def main(): | |
| # --- 1. Загрузка датасета --- | |
| try: | |
| dataset = load_dataset("json", data_files={"train": cfg.TRAINING_DATA_JSONL}, split="train") | |
| except Exception as e: | |
| print(f"Ошибка загрузки датасета из {cfg.TRAINING_DATA_JSONL}: {e}") | |
| print("Убедитесь, что файл существует и не пуст, и что скрипт 0_prepare_data.py успешно отработал.") | |
| return | |
| if not dataset or len(dataset) == 0: | |
| print("Датасет не загружен или пуст. Прерывание.") | |
| return | |
| print(f"Загружен датасет с {len(dataset)} примерами.") | |
| if len(dataset) > 0: | |
| print("Пример первого элемента датасета:", dataset[0]) | |
| # --- 2. Конфигурация квантизации (BitsAndBytes) --- | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| # --- 3. Загрузка модели и токенизатора --- | |
| model = AutoModelForCausalLM.from_pretrained( | |
| cfg.BASE_MODEL_NAME, | |
| quantization_config=bnb_config, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| # attn_implementation="flash_attention_2" # Если flash-attn установлен | |
| ) | |
| model.config.use_cache = False | |
| tokenizer = AutoTokenizer.from_pretrained(cfg.BASE_MODEL_NAME, trust_remote_code=True) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "right" | |
| # --- 4. Конфигурация LoRA --- | |
| lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] #, "gate_proj", "up_proj", "down_proj"] | |
| peft_config = LoraConfig( | |
| lora_alpha=16, | |
| lora_dropout=0.05, # Было 0.1, уменьшил для возможной борьбы с переобучением на малых данных | |
| r=8, # или 16 | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| target_modules=lora_target_modules | |
| ) | |
| # --- 5. Настройка аргументов обучения --- | |
| training_args = TrainingArguments( | |
| output_dir=cfg.OUTPUT_DIR, | |
| per_device_train_batch_size=cfg.TRAIN_BATCH_SIZE, | |
| gradient_accumulation_steps=cfg.GRAD_ACCUMULATION_STEPS, | |
| optim="paged_adamw_32bit", | |
| learning_rate=cfg.LEARNING_RATE, | |
| num_train_epochs=cfg.NUM_EPOCHS, | |
| lr_scheduler_type="cosine", | |
| warmup_ratio=0.03, | |
| logging_steps=10, | |
| save_strategy="epoch", | |
| fp16=not torch.cuda.is_bf16_supported(), | |
| bf16=torch.cuda.is_bf16_supported(), | |
| gradient_checkpointing=True, | |
| report_to="tensorboard", | |
| # evaluation_strategy="epoch", # Если есть eval_dataset | |
| # load_best_model_at_end=True, # Если есть eval_dataset | |
| ) | |
| # --- 6. Инициализация SFTTrainer --- | |
| trainer = SFTTrainer( | |
| model=model, | |
| tokenizer=tokenizer, | |
| args=training_args, | |
| train_dataset=dataset, | |
| peft_config=peft_config, | |
| dataset_text_field="text", | |
| max_seq_length=cfg.MAX_SEQ_LENGTH, | |
| packing=False, # С вашими данными, вероятно, лучше False | |
| ) | |
| # --- 7. Запуск обучения --- | |
| print("Начало обучения...") | |
| try: | |
| trainer.train() | |
| except Exception as e: | |
| print(f"Ошибка во время обучения: {e}") | |
| return | |
| # --- 8. Сохранение адаптера LoRA --- | |
| trainer.save_model(cfg.FINETUNED_ADAPTER_PATH) | |
| print(f"Обучение завершено. Адаптер LoRA сохранен в: {cfg.FINETUNED_ADAPTER_PATH}") | |
| if __name__ == "__main__": | |
| if not torch.cuda.is_available(): | |
| print("CUDA недоступна.") | |
| else: | |
| print(f"Доступно CUDA устройств: {torch.cuda.device_count()}") | |
| print(f"Текущее устройство CUDA: {torch.cuda.current_device()} ({torch.cuda.get_device_name(torch.cuda.current_device())})") | |
| if torch.cuda.is_bf16_supported(): print("BF16 поддерживается.") | |
| else: print("BF16 НЕ поддерживается.") | |
| main() |