|
|
|
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from torch.utils.data import DataLoader, DistributedSampler |
|
|
from transformers import ( |
|
|
AutoTokenizer, AutoModelForCausalLM, |
|
|
TrainingArguments, Trainer, DataCollatorForLanguageModeling, |
|
|
BitsAndBytesConfig |
|
|
) |
|
|
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model |
|
|
from distributed_utils import RankZeroOnly |
|
|
from typing import Dict, List, Tuple, Optional |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
class DistributedProductionTrainer: |
|
|
"""Production trainer with proper distributed training support""" |
|
|
|
|
|
def __init__(self, config: Dict): |
|
|
self.config = config |
|
|
self.output_dir = Path("runs/humigence") |
|
|
self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self.ddp = config.get("ddp", False) |
|
|
self.rank = config.get("rank", 0) |
|
|
self.world_size = config.get("world_size", 1) |
|
|
self.is_main = config.get("is_main", True) |
|
|
self.device = torch.device(config.get("device", "cuda:0")) |
|
|
|
|
|
|
|
|
self.base_model = config["base_model"] |
|
|
self.training_recipe = config["training_recipe"] |
|
|
self.learning_rate = float(config.get("learning_rate", "2e-4")) |
|
|
self.num_epochs = int(config.get("num_train_epochs", "1")) |
|
|
self.batch_size = int(config.get("per_device_train_batch_size", "2")) |
|
|
self.gradient_accumulation = int(config.get("gradient_accumulation_steps", "4")) |
|
|
|
|
|
|
|
|
self.tokenizer = None |
|
|
self.model = None |
|
|
self.trainer = None |
|
|
|
|
|
def load_model_and_tokenizer(self): |
|
|
"""Load model and tokenizer with proper device placement""" |
|
|
with RankZeroOnly(self.is_main) as rank_zero: |
|
|
rank_zero.print(f"[blue]🤖 Loading model: {self.base_model}[/blue]") |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model, trust_remote_code=True) |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
|
|
|
if self.ddp: |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
self.base_model, |
|
|
device_map=None, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.bfloat16 if "BF16" in self.training_recipe else torch.float16 |
|
|
) |
|
|
|
|
|
self.model = self.model.to(self.device) |
|
|
else: |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
self.base_model, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.bfloat16 if "BF16" in self.training_recipe else torch.float16 |
|
|
) |
|
|
|
|
|
|
|
|
if "LoRA" in self.training_recipe: |
|
|
self._apply_lora() |
|
|
|
|
|
|
|
|
if self.ddp: |
|
|
self.model = torch.nn.parallel.DistributedDataParallel( |
|
|
self.model, |
|
|
device_ids=[self.device.index], |
|
|
output_device=self.device.index |
|
|
) |
|
|
with RankZeroOnly(self.is_main) as rank_zero: |
|
|
rank_zero.print(f"[blue]✅ Model wrapped with DDP (rank {self.rank})[/blue]") |
|
|
|
|
|
def _apply_lora(self): |
|
|
"""Apply LoRA configuration to the model""" |
|
|
lora_config = LoraConfig( |
|
|
r=16, |
|
|
lora_alpha=32, |
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], |
|
|
lora_dropout=0.1, |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM" |
|
|
) |
|
|
|
|
|
self.model = get_peft_model(self.model, lora_config) |
|
|
|
|
|
with RankZeroOnly(self.is_main) as rank_zero: |
|
|
rank_zero.print("[blue]✅ LoRA configuration applied[/blue]") |
|
|
|
|
|
def prepare_datasets(self, train_data: List[Dict], val_data: List[Dict], test_data: List[Dict]): |
|
|
"""Prepare datasets with distributed sampling""" |
|
|
|
|
|
from datasets import Dataset |
|
|
|
|
|
train_dataset = Dataset.from_list(train_data) |
|
|
val_dataset = Dataset.from_list(val_data) |
|
|
test_dataset = Dataset.from_list(test_data) |
|
|
|
|
|
|
|
|
def tokenize_function(examples): |
|
|
return self.tokenizer( |
|
|
examples["text"], |
|
|
truncation=True, |
|
|
padding=False, |
|
|
max_length=512 |
|
|
) |
|
|
|
|
|
train_dataset = train_dataset.map(tokenize_function, batched=True) |
|
|
val_dataset = val_dataset.map(tokenize_function, batched=True) |
|
|
test_dataset = test_dataset.map(tokenize_function, batched=True) |
|
|
|
|
|
|
|
|
self.data_collator = DataCollatorForLanguageModeling( |
|
|
tokenizer=self.tokenizer, |
|
|
mlm=False |
|
|
) |
|
|
|
|
|
|
|
|
if self.ddp: |
|
|
self.train_sampler = DistributedSampler( |
|
|
train_dataset, |
|
|
num_replicas=self.world_size, |
|
|
rank=self.rank, |
|
|
shuffle=True |
|
|
) |
|
|
self.val_sampler = DistributedSampler( |
|
|
val_dataset, |
|
|
num_replicas=self.world_size, |
|
|
rank=self.rank, |
|
|
shuffle=False |
|
|
) |
|
|
else: |
|
|
self.train_sampler = None |
|
|
self.val_sampler = None |
|
|
|
|
|
self.train_dataset = train_dataset |
|
|
self.val_dataset = val_dataset |
|
|
self.test_dataset = test_dataset |
|
|
|
|
|
def setup_training(self): |
|
|
"""Setup training arguments and trainer""" |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=str(self.output_dir), |
|
|
per_device_train_batch_size=self.batch_size, |
|
|
per_device_eval_batch_size=self.batch_size, |
|
|
gradient_accumulation_steps=self.gradient_accumulation, |
|
|
num_train_epochs=self.num_epochs, |
|
|
learning_rate=self.learning_rate, |
|
|
logging_steps=10, |
|
|
save_steps=100, |
|
|
eval_steps=100, |
|
|
evaluation_strategy="steps", |
|
|
save_strategy="steps", |
|
|
load_best_model_at_end=True, |
|
|
metric_for_best_model="eval_loss", |
|
|
greater_is_better=False, |
|
|
ddp_find_unused_parameters=False, |
|
|
remove_unused_columns=False, |
|
|
dataloader_pin_memory=True, |
|
|
dataloader_num_workers=4, |
|
|
) |
|
|
|
|
|
|
|
|
self.trainer = Trainer( |
|
|
model=self.model, |
|
|
args=training_args, |
|
|
train_dataset=self.train_dataset, |
|
|
eval_dataset=self.val_dataset, |
|
|
data_collator=self.data_collator, |
|
|
tokenizer=self.tokenizer, |
|
|
) |
|
|
|
|
|
|
|
|
if self.ddp: |
|
|
self.trainer.train_dataloader.sampler = self.train_sampler |
|
|
self.trainer.eval_dataloader.sampler = self.val_sampler |
|
|
|
|
|
def train(self): |
|
|
"""Run training with proper distributed handling""" |
|
|
with RankZeroOnly(self.is_main) as rank_zero: |
|
|
rank_zero.print("[blue]🚀 Starting training...[/blue]") |
|
|
|
|
|
|
|
|
self.trainer.train() |
|
|
|
|
|
|
|
|
if self.is_main: |
|
|
self.trainer.save_model() |
|
|
with RankZeroOnly(self.is_main) as rank_zero: |
|
|
rank_zero.print("[blue]💾 Model saved[/blue]") |
|
|
|
|
|
|
|
|
if self.ddp: |
|
|
dist.barrier() |
|
|
|
|
|
return {"status": "success", "message": "Training completed"} |
|
|
|
|
|
def evaluate(self): |
|
|
"""Run evaluation with proper distributed handling""" |
|
|
with RankZeroOnly(self.is_main) as rank_zero: |
|
|
rank_zero.print("[blue]🧪 Running evaluation...[/blue]") |
|
|
|
|
|
|
|
|
eval_results = self.trainer.evaluate() |
|
|
|
|
|
|
|
|
if self.ddp: |
|
|
|
|
|
gathered_results = [None] * self.world_size |
|
|
dist.all_gather_object(gathered_results, eval_results) |
|
|
|
|
|
|
|
|
if self.is_main: |
|
|
avg_results = {} |
|
|
for key in eval_results.keys(): |
|
|
if isinstance(eval_results[key], (int, float)): |
|
|
values = [r[key] for r in gathered_results if r is not None] |
|
|
avg_results[key] = sum(values) / len(values) |
|
|
else: |
|
|
avg_results[key] = eval_results[key] |
|
|
eval_results = avg_results |
|
|
|
|
|
return eval_results |
|
|
|