Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.optim import AdamW | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import os | |
| import logging | |
| from .model import CodeEmbedder | |
| # Setup Logger | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class CodeTrainer: | |
| def __init__( | |
| self, | |
| model: CodeEmbedder, | |
| train_loader: DataLoader, | |
| val_loader: DataLoader = None, | |
| epochs: int = 3, | |
| learning_rate: float = 2e-5, | |
| accumulation_steps: int = 1, | |
| mixed_precision: bool = True, | |
| output_dir: str = "./output", | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
| ): | |
| self.model = model.to(device) | |
| self.train_loader = train_loader | |
| self.val_loader = val_loader | |
| self.epochs = epochs | |
| self.lr = learning_rate | |
| self.accumulation_steps = accumulation_steps | |
| self.mixed_precision = mixed_precision | |
| self.output_dir = output_dir | |
| self.device = device | |
| # Optimizer | |
| self.optimizer = AdamW(self.model.parameters(), lr=self.lr) | |
| # Scheduler (Optional: constant for now, can transform to Linear later) | |
| # self.scheduler = ... | |
| # Mixed Precision Scaler | |
| self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision) | |
| # Loss Function: Triplet Margin Loss (Standard for Sentence Embeddings) | |
| # Tries to maximize distance between Anchor-Negative and minimize Anchor-Positive | |
| self.criterion = nn.TripletMarginLoss(margin=1.0, p=2) | |
| def train_step(self, batch): | |
| """ | |
| Runs one training step. Returns loss. | |
| """ | |
| # Unpack the Triplet Batch | |
| # We assume the Dataset returns keys: 'anchor_input_ids', 'anchor_attention_mask', etc. | |
| # Helper to move dict to device | |
| to_device = lambda x: x.to(self.device) | |
| # Autocast for Mixed Precision | |
| with torch.cuda.amp.autocast(enabled=self.mixed_precision): | |
| # 1. Forward Pass for all 3 components | |
| anchor_emb = self.model(to_device(batch['anchor_input_ids']), to_device(batch['anchor_attention_mask'])) | |
| positive_emb = self.model(to_device(batch['positive_input_ids']), to_device(batch['positive_attention_mask'])) | |
| negative_emb = self.model(to_device(batch['negative_input_ids']), to_device(batch['negative_attention_mask'])) | |
| # 2. Compute Triplet Loss | |
| loss = self.criterion(anchor_emb, positive_emb, negative_emb) | |
| return loss | |
| def train(self): | |
| logger.info(f"Starting training on {self.device}...") | |
| logger.info(f"Batch Size: {self.train_loader.batch_size}, Accumulation Steps: {self.accumulation_steps}") | |
| logger.info(f"Effective Batch Size: {self.train_loader.batch_size * self.accumulation_steps}") | |
| self.model.train() | |
| for epoch in range(self.epochs): | |
| total_loss = 0 | |
| self.optimizer.zero_grad() | |
| progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.epochs}") | |
| for step, batch in enumerate(progress_bar): | |
| # Forward + Loss Calculation | |
| loss = self.train_step(batch) | |
| # Gradient Accumulation: Normalize loss | |
| loss = loss / self.accumulation_steps | |
| # Backward Pass (Scaled) | |
| self.scaler.scale(loss).backward() | |
| if (step + 1) % self.accumulation_steps == 0: | |
| # Update Weights | |
| self.scaler.step(self.optimizer) | |
| self.scaler.update() | |
| self.optimizer.zero_grad() | |
| total_loss += loss.item() * self.accumulation_steps | |
| progress_bar.set_postfix({'loss': total_loss / (step + 1)}) | |
| # Save Checkpoint | |
| self.save_model(epoch+1) | |
| def save_model(self, epoch): | |
| save_path = os.path.join(self.output_dir, f"checkpoint-{epoch}") | |
| os.makedirs(save_path, exist_ok=True) | |
| logger.info(f"Saving model to {save_path}...") | |
| # Save explicitly as safetensors via transformers API | |
| self.model.encoder.save_pretrained(save_path, safe_serialization=True) | |
| self.model.config.save_pretrained(save_path) | |
| # Note: We save the 'encoder' which is the AutoModel, | |
| # so it can be loaded easily by others. | |