| import argparse |
| import os |
| import torch |
| from torch.utils.data import DataLoader, Dataset |
| from transformers import AutoTokenizer |
|
|
| from scripts.core.training.model import CodeEmbedder |
| from scripts.core.training.trainer import CodeTrainer |
|
|
| import json |
|
|
| |
| class RealCodeDataset(Dataset): |
| def __init__(self, jsonl_path, tokenizer, max_length=512): |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.data = [] |
| |
| print(f"Loading data from {jsonl_path}...") |
| with open(jsonl_path, 'r', encoding='utf-8') as f: |
| for line in f: |
| if line.strip(): |
| self.data.append(json.loads(line)) |
| print(f"Loaded {len(self.data)} triplets.") |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| item = self.data[idx] |
| |
| |
| def tokenize_text(text): |
| return self.tokenizer( |
| text, |
| return_tensors='pt', |
| padding='max_length', |
| truncation=True, |
| max_length=self.max_length |
| ) |
| |
| |
| anchor = tokenize_text(item['anchor']) |
| positive = tokenize_text(item['positive']) |
| negative = tokenize_text(item['negative']) |
| |
| |
| return { |
| 'anchor_input_ids': anchor['input_ids'].squeeze(0), |
| 'anchor_attention_mask': anchor['attention_mask'].squeeze(0), |
| 'positive_input_ids': positive['input_ids'].squeeze(0), |
| 'positive_attention_mask': positive['attention_mask'].squeeze(0), |
| 'negative_input_ids': negative['input_ids'].squeeze(0), |
| 'negative_attention_mask': negative['attention_mask'].squeeze(0) |
| } |
|
|
| |
| class DummyCodeDataset(Dataset): |
| def __init__(self, tokenizer, size=100): |
| self.tokenizer = tokenizer |
| self.size = size |
| |
| self.data = [{"anchor": "def hello(): return 'world'", "positive": "def hi(): return 'earth'", "negative": "class Foo: pass"}] * size |
|
|
| def __len__(self): |
| return self.size |
|
|
| def __getitem__(self, idx): |
| item = self.data[idx] |
| |
| |
| def tokenize_text(text): |
| return self.tokenizer( |
| text, |
| return_tensors='pt', |
| padding='max_length', |
| truncation=True, |
| max_length=128 |
| ) |
| |
| anchor = tokenize_text(item['anchor']) |
| positive = tokenize_text(item['positive']) |
| negative = tokenize_text(item['negative']) |
|
|
| return { |
| 'anchor_input_ids': anchor['input_ids'].squeeze(0), |
| 'anchor_attention_mask': anchor['attention_mask'].squeeze(0), |
| 'positive_input_ids': positive['input_ids'].squeeze(0), |
| 'positive_attention_mask': positive['attention_mask'].squeeze(0), |
| 'negative_input_ids': negative['input_ids'].squeeze(0), |
| 'negative_attention_mask': negative['attention_mask'].squeeze(0) |
| } |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Train CodeMode Embeddings") |
| |
| parser.add_argument("--model_name", type=str, default="microsoft/codebert-base", help="Hub model name") |
| parser.add_argument("--data_path", type=str, required=False, help="Path to parsed chunks.jsonl") |
| parser.add_argument("--output_dir", type=str, default="./output", help="Where to save checkpoints") |
| parser.add_argument("--epochs", type=int, default=3) |
| parser.add_argument("--batch_size", type=int, default=8) |
| parser.add_argument("--accumulation_steps", type=int, default=4, help="Gradient Accumulation Steps") |
| parser.add_argument("--lr", type=float, default=2e-5) |
| parser.add_argument("--dry_run", action="store_true", help="Run with dummy data for 1 epoch") |
|
|
| args = parser.parse_args() |
| |
| print(f"Initializing Training Pipeline...") |
| print(f" Model: {args.model_name}") |
| print(f" Output: {args.output_dir}") |
| print(f" Device: {'cuda' if torch.cuda.is_available() else 'cpu'}") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
|
|
| |
| if args.data_path and os.path.exists(args.data_path): |
| train_dataset = RealCodeDataset(args.data_path, tokenizer) |
| else: |
| print("No data path provided or file missing. Using DUMMY data for verification.") |
| train_dataset = DummyCodeDataset(tokenizer, size=100) |
|
|
| train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) |
|
|
| |
| model = CodeEmbedder(model_name_or_path=args.model_name) |
|
|
| |
| trainer = CodeTrainer( |
| model=model, |
| train_loader=train_loader, |
| epochs=args.epochs, |
| learning_rate=args.lr, |
| accumulation_steps=args.accumulation_steps, |
| mixed_precision=True, |
| output_dir=args.output_dir |
| ) |
|
|
| |
| trainer.train() |
| |
| print("Training Complete.") |
|
|
| if __name__ == "__main__": |
| main() |
|
|