Spaces:
Running
Running
| # training/train_programmer.py | |
| import os | |
| import torch | |
| from torch import nn, optim | |
| from torch.utils.data import Dataset, DataLoader | |
| from datasets import load_dataset | |
| from core.device import DEVICE | |
| from language.tokenizer import SimpleTokenizer | |
| from language.embeddings import EmbeddingLayer | |
| from language.encoder import SentenceEncoder | |
| # ================================ | |
| # CONFIG | |
| # ================================ | |
| ARTIFACTS_DIR = "artifacts" | |
| BATCH_SIZE = 16 | |
| EPOCHS = 5 | |
| LEARNING_RATE = 3e-4 | |
| MAX_SEQ_LEN = 128 | |
| os.makedirs(ARTIFACTS_DIR, exist_ok=True) | |
| # ================================ | |
| # LOAD HF CODE DATASET | |
| # ================================ | |
| print("[INFO] Loading CodeXGLUE dataset...") | |
| dataset = load_dataset("google/code_x_glue_tc_nl_code_search_adv") | |
| texts = [] | |
| labels = [] | |
| for item in dataset["train"]: | |
| texts.append(item["docstring"]) # Natural language | |
| labels.append(1) # Programming label | |
| # Add some non-programming noise examples | |
| noise_examples = [ | |
| "Hello how are you", | |
| "Tell me a story", | |
| "What is the weather today", | |
| "Who are you" | |
| ] | |
| for text in noise_examples: | |
| texts.append(text) | |
| labels.append(0) | |
| print(f"[INFO] Loaded {len(texts)} samples") | |
| # ================================ | |
| # TOKENIZER | |
| # ================================ | |
| tokenizer = SimpleTokenizer() | |
| tokenizer.build_vocab(texts) | |
| tokenizer.freeze_vocab() | |
| # ================================ | |
| # DATASET CLASS | |
| # ================================ | |
| class ProgrammingDataset(Dataset): | |
| def __init__(self, texts, labels): | |
| self.texts = texts | |
| self.labels = labels | |
| def __len__(self): | |
| return len(self.texts) | |
| def __getitem__(self, idx): | |
| token_ids = tokenizer.encode(self.texts[idx])[:MAX_SEQ_LEN] | |
| token_ids = torch.tensor(token_ids, dtype=torch.long) | |
| label = torch.tensor(self.labels[idx], dtype=torch.long) | |
| return token_ids, label | |
| def collate_fn(batch): | |
| token_ids, labels = zip(*batch) | |
| max_len = max(len(t) for t in token_ids) | |
| padded = [] | |
| for t in token_ids: | |
| pad_len = max_len - len(t) | |
| padded.append( | |
| torch.cat([ | |
| t, | |
| torch.full( | |
| (pad_len,), | |
| tokenizer.vocab[tokenizer.PAD_TOKEN], | |
| dtype=torch.long | |
| ) | |
| ]) | |
| ) | |
| return torch.stack(padded), torch.tensor(labels) | |
| dataset = ProgrammingDataset(texts, labels) | |
| loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn) | |
| # ================================ | |
| # MODEL | |
| # ================================ | |
| embedder = EmbeddingLayer(len(tokenizer.vocab), | |
| pad_index=tokenizer.vocab[tokenizer.PAD_TOKEN]) | |
| encoder = SentenceEncoder() | |
| classifier = nn.Linear(encoder.projection.out_features, 2) | |
| embedder, encoder, classifier = embedder.to(DEVICE), encoder.to(DEVICE), classifier.to(DEVICE) | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam( | |
| list(embedder.parameters()) + | |
| list(encoder.parameters()) + | |
| list(classifier.parameters()), | |
| lr=LEARNING_RATE | |
| ) | |
| # ================================ | |
| # TRAIN | |
| # ================================ | |
| def train(): | |
| best_loss = float("inf") | |
| for epoch in range(EPOCHS): | |
| total_loss = 0 | |
| for token_ids, labels_batch in loader: | |
| token_ids = token_ids.to(DEVICE) | |
| labels_batch = labels_batch.to(DEVICE) | |
| embeddings = embedder(token_ids) | |
| attention_mask = (token_ids != tokenizer.vocab[tokenizer.PAD_TOKEN]).long() | |
| sentence_vec = encoder(embeddings, attention_mask=attention_mask) | |
| logits = classifier(sentence_vec) | |
| loss = criterion(logits, labels_batch) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_( | |
| list(embedder.parameters()) + | |
| list(encoder.parameters()) + | |
| list(classifier.parameters()), | |
| 1.0 | |
| ) | |
| optimizer.step() | |
| total_loss += loss.item() | |
| avg_loss = total_loss / len(loader) | |
| print(f"[Epoch {epoch+1}/{EPOCHS}] Loss: {avg_loss:.6f}") | |
| if avg_loss < best_loss: | |
| best_loss = avg_loss | |
| save_models() | |
| print("[SUCCESS] Programming model training complete!") | |
| # ================================ | |
| # SAVE | |
| # ================================ | |
| def save_models(): | |
| torch.save(encoder.state_dict(), | |
| os.path.join(ARTIFACTS_DIR, "programming_encoder.pt")) | |
| torch.save(classifier.state_dict(), | |
| os.path.join(ARTIFACTS_DIR, "programming_classifier.pt")) | |
| torch.save(embedder.state_dict(), | |
| os.path.join(ARTIFACTS_DIR, "programming_embedding.pt")) | |
| print("[INFO] Programming models saved") | |
| # ================================ | |
| # ENTRY | |
| # ================================ | |
| if __name__ == "__main__": | |
| train() | |