Spaces:
Sleeping
Sleeping
| import argparse | |
| import os | |
| import torch | |
| from torch.utils.data import DataLoader, Dataset | |
| from transformers import AutoTokenizer | |
| from scripts.core.training.model import CodeEmbedder | |
| from scripts.core.training.trainer import CodeTrainer | |
| import json | |
| # Real Dataset class for Triplet Training | |
| class RealCodeDataset(Dataset): | |
| def __init__(self, jsonl_path, tokenizer, max_length=512): | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| self.data = [] | |
| print(f"Loading data from {jsonl_path}...") | |
| with open(jsonl_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| if line.strip(): | |
| self.data.append(json.loads(line)) | |
| print(f"Loaded {len(self.data)} triplets.") | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| item = self.data[idx] | |
| # Helper to tokenize | |
| def tokenize_text(text): | |
| return self.tokenizer( | |
| text, | |
| return_tensors='pt', | |
| padding='max_length', | |
| truncation=True, | |
| max_length=self.max_length | |
| ) | |
| # Tokenize all three parts | |
| anchor = tokenize_text(item['anchor']) | |
| positive = tokenize_text(item['positive']) | |
| negative = tokenize_text(item['negative']) | |
| # Return a flat dict with prefixed keys | |
| return { | |
| 'anchor_input_ids': anchor['input_ids'].squeeze(0), | |
| 'anchor_attention_mask': anchor['attention_mask'].squeeze(0), | |
| 'positive_input_ids': positive['input_ids'].squeeze(0), | |
| 'positive_attention_mask': positive['attention_mask'].squeeze(0), | |
| 'negative_input_ids': negative['input_ids'].squeeze(0), | |
| 'negative_attention_mask': negative['attention_mask'].squeeze(0) | |
| } | |
| # Dummy Dataset class for MVP testing without the robust data pipeline availability | |
| class DummyCodeDataset(Dataset): | |
| def __init__(self, tokenizer, size=100): | |
| self.tokenizer = tokenizer | |
| self.size = size | |
| # Generate dummy triplet structure | |
| self.data = [{"anchor": "def hello(): return 'world'", "positive": "def hi(): return 'earth'", "negative": "class Foo: pass"}] * size | |
| def __len__(self): | |
| return self.size | |
| def __getitem__(self, idx): | |
| item = self.data[idx] | |
| # Helper to tokenize | |
| def tokenize_text(text): | |
| return self.tokenizer( | |
| text, | |
| return_tensors='pt', | |
| padding='max_length', | |
| truncation=True, | |
| max_length=128 | |
| ) | |
| anchor = tokenize_text(item['anchor']) | |
| positive = tokenize_text(item['positive']) | |
| negative = tokenize_text(item['negative']) | |
| return { | |
| 'anchor_input_ids': anchor['input_ids'].squeeze(0), | |
| 'anchor_attention_mask': anchor['attention_mask'].squeeze(0), | |
| 'positive_input_ids': positive['input_ids'].squeeze(0), | |
| 'positive_attention_mask': positive['attention_mask'].squeeze(0), | |
| 'negative_input_ids': negative['input_ids'].squeeze(0), | |
| 'negative_attention_mask': negative['attention_mask'].squeeze(0) | |
| } | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Train CodeMode Embeddings") | |
| parser.add_argument("--model_name", type=str, default="microsoft/codebert-base", help="Hub model name") | |
| parser.add_argument("--data_path", type=str, required=False, help="Path to parsed chunks.jsonl") | |
| parser.add_argument("--output_dir", type=str, default="./output", help="Where to save checkpoints") | |
| parser.add_argument("--epochs", type=int, default=3) | |
| parser.add_argument("--batch_size", type=int, default=8) | |
| parser.add_argument("--accumulation_steps", type=int, default=4, help="Gradient Accumulation Steps") | |
| parser.add_argument("--lr", type=float, default=2e-5) | |
| parser.add_argument("--dry_run", action="store_true", help="Run with dummy data for 1 epoch") | |
| args = parser.parse_args() | |
| print(f"Initializing Training Pipeline...") | |
| print(f" Model: {args.model_name}") | |
| print(f" Output: {args.output_dir}") | |
| print(f" Device: {'cuda' if torch.cuda.is_available() else 'cpu'}") | |
| # 1. Initialize Tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |
| # 2. Load Dataset (Real or Dummy) | |
| if args.data_path and os.path.exists(args.data_path): | |
| train_dataset = RealCodeDataset(args.data_path, tokenizer) | |
| else: | |
| print("No data path provided or file missing. Using DUMMY data for verification.") | |
| train_dataset = DummyCodeDataset(tokenizer, size=100) | |
| train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) | |
| # 3. Initialize Model | |
| model = CodeEmbedder(model_name_or_path=args.model_name) | |
| # 4. Initialize Trainer | |
| trainer = CodeTrainer( | |
| model=model, | |
| train_loader=train_loader, | |
| epochs=args.epochs, | |
| learning_rate=args.lr, | |
| accumulation_steps=args.accumulation_steps, | |
| mixed_precision=True, # Hardcoded True for the "Zero-Cost" philosophy | |
| output_dir=args.output_dir | |
| ) | |
| # 5. Connect and Train | |
| trainer.train() | |
| print("Training Complete.") | |
| if __name__ == "__main__": | |
| main() | |