| import torch |
| import torch.nn as nn |
| from torch.optim import AdamW |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| import os |
| import logging |
| from .model import CodeEmbedder |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class CodeTrainer: |
| def __init__( |
| self, |
| model: CodeEmbedder, |
| train_loader: DataLoader, |
| val_loader: DataLoader = None, |
| epochs: int = 3, |
| learning_rate: float = 2e-5, |
| accumulation_steps: int = 1, |
| mixed_precision: bool = True, |
| output_dir: str = "./output", |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" |
| ): |
| self.model = model.to(device) |
| self.train_loader = train_loader |
| self.val_loader = val_loader |
| self.epochs = epochs |
| self.lr = learning_rate |
| self.accumulation_steps = accumulation_steps |
| self.mixed_precision = mixed_precision |
| self.output_dir = output_dir |
| self.device = device |
| |
| |
| self.optimizer = AdamW(self.model.parameters(), lr=self.lr) |
| |
| |
| |
| |
| |
| self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision) |
| |
| |
| |
| self.criterion = nn.TripletMarginLoss(margin=1.0, p=2) |
|
|
| def train_step(self, batch): |
| """ |
| Runs one training step. Returns loss. |
| """ |
| |
| |
| |
| |
| to_device = lambda x: x.to(self.device) |
| |
| |
| with torch.cuda.amp.autocast(enabled=self.mixed_precision): |
| |
| anchor_emb = self.model(to_device(batch['anchor_input_ids']), to_device(batch['anchor_attention_mask'])) |
| positive_emb = self.model(to_device(batch['positive_input_ids']), to_device(batch['positive_attention_mask'])) |
| negative_emb = self.model(to_device(batch['negative_input_ids']), to_device(batch['negative_attention_mask'])) |
| |
| |
| loss = self.criterion(anchor_emb, positive_emb, negative_emb) |
| |
| return loss |
|
|
| def train(self): |
| logger.info(f"Starting training on {self.device}...") |
| logger.info(f"Batch Size: {self.train_loader.batch_size}, Accumulation Steps: {self.accumulation_steps}") |
| logger.info(f"Effective Batch Size: {self.train_loader.batch_size * self.accumulation_steps}") |
| |
| self.model.train() |
| |
| for epoch in range(self.epochs): |
| total_loss = 0 |
| self.optimizer.zero_grad() |
| |
| progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.epochs}") |
| |
| for step, batch in enumerate(progress_bar): |
| |
| |
| loss = self.train_step(batch) |
| |
| |
| loss = loss / self.accumulation_steps |
| |
| |
| self.scaler.scale(loss).backward() |
| |
| if (step + 1) % self.accumulation_steps == 0: |
| |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| self.optimizer.zero_grad() |
| |
| total_loss += loss.item() * self.accumulation_steps |
| progress_bar.set_postfix({'loss': total_loss / (step + 1)}) |
| |
| |
| self.save_model(epoch+1) |
| |
| def save_model(self, epoch): |
| save_path = os.path.join(self.output_dir, f"checkpoint-{epoch}") |
| os.makedirs(save_path, exist_ok=True) |
| |
| logger.info(f"Saving model to {save_path}...") |
| |
| |
| self.model.encoder.save_pretrained(save_path, safe_serialization=True) |
| self.model.config.save_pretrained(save_path) |
| |
| |
|
|