File size: 3,172 Bytes
fba3401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import BertTokenizer
import torch.optim as optim
from sklearn.metrics import accuracy_score
from tqdm import tqdm

from data.imdb_data import  IMDBDataset
from model.sentiment_model import SentimentAnalysisModel


def evaluate(model, dataloader, device):
    model.eval()
    preds, labels = [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            batch_labels = batch["labels"].to(device)

            outputs = model(input_ids, attention_mask)
            predictions = torch.argmax(outputs, dim=1)

            preds.extend(predictions.cpu().tolist())
            labels.extend(batch_labels.cpu().tolist())

    acc = accuracy_score(labels, preds)
    print(f"Validation Accuracy: {acc:.4f}")
    return acc


def train():
    # ================== 超参数 ==================
    model_name = "bert-base-uncased"
    batch_size = 8
    max_length = 256
    lr = 2e-5
    epochs = 3
    # ===========================================

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    tokenizer = BertTokenizer.from_pretrained(model_name)

    train_dataset = IMDBDataset("train", tokenizer, max_length)
    test_dataset = IMDBDataset("test", tokenizer, max_length)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    model = SentimentAnalysisModel(model_name).to(device)
    print("Model device:", next(model.parameters()).device)

    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # ================== 训练 ==================
    for epoch in range(epochs):
        model.train()
        total_loss = 0

        loop = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")
        for step, batch in enumerate(loop):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            loop.set_postfix(loss=loss.item())

        avg_loss = total_loss / len(train_loader)
        print(f"\nEpoch {epoch + 1} Training Loss: {avg_loss:.4f}")

        evaluate(model, test_loader, device)

    # ================== 保存模型 ==================
    torch.save(model.state_dict(), "bert_imdb_sentiment.pth")
    print("Model saved.")


if __name__ == "__main__":
    train()