# training/train_programmer.py import os import torch from torch import nn, optim from torch.utils.data import Dataset, DataLoader from datasets import load_dataset from core.device import DEVICE from language.tokenizer import SimpleTokenizer from language.embeddings import EmbeddingLayer from language.encoder import SentenceEncoder # ================================ # CONFIG # ================================ ARTIFACTS_DIR = "artifacts" BATCH_SIZE = 16 EPOCHS = 5 LEARNING_RATE = 3e-4 MAX_SEQ_LEN = 128 os.makedirs(ARTIFACTS_DIR, exist_ok=True) # ================================ # LOAD HF CODE DATASET # ================================ print("[INFO] Loading CodeXGLUE dataset...") dataset = load_dataset("google/code_x_glue_tc_nl_code_search_adv") texts = [] labels = [] for item in dataset["train"]: texts.append(item["docstring"]) # Natural language labels.append(1) # Programming label # Add some non-programming noise examples noise_examples = [ "Hello how are you", "Tell me a story", "What is the weather today", "Who are you" ] for text in noise_examples: texts.append(text) labels.append(0) print(f"[INFO] Loaded {len(texts)} samples") # ================================ # TOKENIZER # ================================ tokenizer = SimpleTokenizer() tokenizer.build_vocab(texts) tokenizer.freeze_vocab() # ================================ # DATASET CLASS # ================================ class ProgrammingDataset(Dataset): def __init__(self, texts, labels): self.texts = texts self.labels = labels def __len__(self): return len(self.texts) def __getitem__(self, idx): token_ids = tokenizer.encode(self.texts[idx])[:MAX_SEQ_LEN] token_ids = torch.tensor(token_ids, dtype=torch.long) label = torch.tensor(self.labels[idx], dtype=torch.long) return token_ids, label def collate_fn(batch): token_ids, labels = zip(*batch) max_len = max(len(t) for t in token_ids) padded = [] for t in token_ids: pad_len = max_len - len(t) padded.append( torch.cat([ t, torch.full( (pad_len,), tokenizer.vocab[tokenizer.PAD_TOKEN], dtype=torch.long ) ]) ) return torch.stack(padded), torch.tensor(labels) dataset = ProgrammingDataset(texts, labels) loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn) # ================================ # MODEL # ================================ embedder = EmbeddingLayer(len(tokenizer.vocab), pad_index=tokenizer.vocab[tokenizer.PAD_TOKEN]) encoder = SentenceEncoder() classifier = nn.Linear(encoder.projection.out_features, 2) embedder, encoder, classifier = embedder.to(DEVICE), encoder.to(DEVICE), classifier.to(DEVICE) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam( list(embedder.parameters()) + list(encoder.parameters()) + list(classifier.parameters()), lr=LEARNING_RATE ) # ================================ # TRAIN # ================================ def train(): best_loss = float("inf") for epoch in range(EPOCHS): total_loss = 0 for token_ids, labels_batch in loader: token_ids = token_ids.to(DEVICE) labels_batch = labels_batch.to(DEVICE) embeddings = embedder(token_ids) attention_mask = (token_ids != tokenizer.vocab[tokenizer.PAD_TOKEN]).long() sentence_vec = encoder(embeddings, attention_mask=attention_mask) logits = classifier(sentence_vec) loss = criterion(logits, labels_batch) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_( list(embedder.parameters()) + list(encoder.parameters()) + list(classifier.parameters()), 1.0 ) optimizer.step() total_loss += loss.item() avg_loss = total_loss / len(loader) print(f"[Epoch {epoch+1}/{EPOCHS}] Loss: {avg_loss:.6f}") if avg_loss < best_loss: best_loss = avg_loss save_models() print("[SUCCESS] Programming model training complete!") # ================================ # SAVE # ================================ def save_models(): torch.save(encoder.state_dict(), os.path.join(ARTIFACTS_DIR, "programming_encoder.pt")) torch.save(classifier.state_dict(), os.path.join(ARTIFACTS_DIR, "programming_classifier.pt")) torch.save(embedder.state_dict(), os.path.join(ARTIFACTS_DIR, "programming_embedding.pt")) print("[INFO] Programming models saved") # ================================ # ENTRY # ================================ if __name__ == "__main__": train()