import joblib import torch import numpy as np from collections import Counter from sklearn.preprocessing import LabelEncoder from sklearn.model_selection import train_test_split from torch.utils.data import Dataset, DataLoader from src.data_processing import load_and_clean_data from src.model_def import EmotionTransformer # Hyperparameters MAX_LEN = 32 BATCH_SIZE = 16 EPOCHS = 5 LR = 1e-3 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Dataset wrapper class EmotionDataset(Dataset): def __init__(self, X, y): self.X = torch.tensor(X, dtype=torch.long) self.y = torch.tensor(y, dtype=torch.long) def __len__(self): return len(self.X) def __getitem__(self, idx): return self.X[idx], self.y[idx] def train(): df = load_and_clean_data() toks = df['clean'].str.split() # Build vocab ctr = Counter(tok for sent in toks for tok in sent) vocab = {w:i+2 for i,(w,_) in enumerate(ctr.most_common())} vocab[''], vocab[''] = 0, 1 joblib.dump(vocab, 'vocab.pkl') # Encode + pad X = [ ([vocab.get(tok,1) for tok in sent] + [0]*(MAX_LEN-len(sent)))[:MAX_LEN] for sent in toks ] # Encode labels le = LabelEncoder() y = le.fit_transform(df['label']) joblib.dump(le, 'label_encoder.pkl') # Split & loader X_tr, X_va, y_tr, y_va = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42) tr_loader = DataLoader(EmotionDataset(X_tr, y_tr), batch_size=BATCH_SIZE, shuffle=True) va_loader = DataLoader(EmotionDataset(X_va, y_va), batch_size=BATCH_SIZE) # Model, optimizer, loss model = EmotionTransformer(len(vocab), num_classes=len(le.classes_)).to(DEVICE) opt = torch.optim.Adam(model.parameters(), lr=LR) crit = torch.nn.CrossEntropyLoss() # Training loop for epoch in range(EPOCHS): model.train(); total_loss = 0 for xb, yb in tr_loader: xb, yb = xb.to(DEVICE), yb.to(DEVICE) opt.zero_grad() loss = crit(model(xb), yb) loss.backward(); opt.step() total_loss += loss.item() print(f"Epoch {epoch+1}/{EPOCHS} Loss: {total_loss/len(tr_loader):.4f}") # Save weights torch.save(model.state_dict(), 'emotion_transformer_model.pth') if __name__=='__main__': train()