mvi-ai-engine / training /train_programmer.py
Musombi's picture
Create train_programmer.py
f7fd6df verified
# training/train_programmer.py
import os
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from core.device import DEVICE
from language.tokenizer import SimpleTokenizer
from language.embeddings import EmbeddingLayer
from language.encoder import SentenceEncoder
# ================================
# CONFIG
# ================================
ARTIFACTS_DIR = "artifacts"
BATCH_SIZE = 16
EPOCHS = 5
LEARNING_RATE = 3e-4
MAX_SEQ_LEN = 128
os.makedirs(ARTIFACTS_DIR, exist_ok=True)
# ================================
# LOAD HF CODE DATASET
# ================================
print("[INFO] Loading CodeXGLUE dataset...")
dataset = load_dataset("google/code_x_glue_tc_nl_code_search_adv")
texts = []
labels = []
for item in dataset["train"]:
texts.append(item["docstring"]) # Natural language
labels.append(1) # Programming label
# Add some non-programming noise examples
noise_examples = [
"Hello how are you",
"Tell me a story",
"What is the weather today",
"Who are you"
]
for text in noise_examples:
texts.append(text)
labels.append(0)
print(f"[INFO] Loaded {len(texts)} samples")
# ================================
# TOKENIZER
# ================================
tokenizer = SimpleTokenizer()
tokenizer.build_vocab(texts)
tokenizer.freeze_vocab()
# ================================
# DATASET CLASS
# ================================
class ProgrammingDataset(Dataset):
def __init__(self, texts, labels):
self.texts = texts
self.labels = labels
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
token_ids = tokenizer.encode(self.texts[idx])[:MAX_SEQ_LEN]
token_ids = torch.tensor(token_ids, dtype=torch.long)
label = torch.tensor(self.labels[idx], dtype=torch.long)
return token_ids, label
def collate_fn(batch):
token_ids, labels = zip(*batch)
max_len = max(len(t) for t in token_ids)
padded = []
for t in token_ids:
pad_len = max_len - len(t)
padded.append(
torch.cat([
t,
torch.full(
(pad_len,),
tokenizer.vocab[tokenizer.PAD_TOKEN],
dtype=torch.long
)
])
)
return torch.stack(padded), torch.tensor(labels)
dataset = ProgrammingDataset(texts, labels)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
# ================================
# MODEL
# ================================
embedder = EmbeddingLayer(len(tokenizer.vocab),
pad_index=tokenizer.vocab[tokenizer.PAD_TOKEN])
encoder = SentenceEncoder()
classifier = nn.Linear(encoder.projection.out_features, 2)
embedder, encoder, classifier = embedder.to(DEVICE), encoder.to(DEVICE), classifier.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
list(embedder.parameters()) +
list(encoder.parameters()) +
list(classifier.parameters()),
lr=LEARNING_RATE
)
# ================================
# TRAIN
# ================================
def train():
best_loss = float("inf")
for epoch in range(EPOCHS):
total_loss = 0
for token_ids, labels_batch in loader:
token_ids = token_ids.to(DEVICE)
labels_batch = labels_batch.to(DEVICE)
embeddings = embedder(token_ids)
attention_mask = (token_ids != tokenizer.vocab[tokenizer.PAD_TOKEN]).long()
sentence_vec = encoder(embeddings, attention_mask=attention_mask)
logits = classifier(sentence_vec)
loss = criterion(logits, labels_batch)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(
list(embedder.parameters()) +
list(encoder.parameters()) +
list(classifier.parameters()),
1.0
)
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(loader)
print(f"[Epoch {epoch+1}/{EPOCHS}] Loss: {avg_loss:.6f}")
if avg_loss < best_loss:
best_loss = avg_loss
save_models()
print("[SUCCESS] Programming model training complete!")
# ================================
# SAVE
# ================================
def save_models():
torch.save(encoder.state_dict(),
os.path.join(ARTIFACTS_DIR, "programming_encoder.pt"))
torch.save(classifier.state_dict(),
os.path.join(ARTIFACTS_DIR, "programming_classifier.pt"))
torch.save(embedder.state_dict(),
os.path.join(ARTIFACTS_DIR, "programming_embedding.pt"))
print("[INFO] Programming models saved")
# ================================
# ENTRY
# ================================
if __name__ == "__main__":
train()