File size: 3,847 Bytes
f521886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

# Параметры модели (должны совпадать с app.py)
VOCAB_SIZE = 10000
EMBED_SIZE = 256
NUM_HEADS = 8
NUM_LAYERS = 6
FFN_DIM = 512
DROPOUT = 0.1

# Определение модели (копия из app.py для независимости)
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_size, num_heads, num_layers, ffn_dim, dropout):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.pos_encoder = PositionalEncoding(embed_size, dropout)
        decoder_layer = TransformerDecoderLayer(embed_size, num_heads, ffn_dim, dropout)
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers)
        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.embed_size = embed_size

    def forward(self, src, src_mask=None):
        src = self.embedding(src) * math.sqrt(self.embed_size)
        src = self.pos_encoder(src)
        output = self.transformer_decoder(src, memory=None, tgt_mask=src_mask)
        output = self.fc_out(output)
        return output

class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

# Токенизатор и словарь
tokenizer = get_tokenizer('basic_english')
def yield_tokens(data_iter):
    for text in data_iter:
        yield tokenizer(text)

# Пример данных (замените на свой датасет)
sample_data = ["Hello world", "This is a test", "Build a neural network"] * 1000
vocab = build_vocab_from_iterator(yield_tokens(sample_data), specials=['<unk>', '<pad>'])
vocab.set_default_index(vocab['<unk>'])

# Инициализация модели
model = TransformerModel(
    vocab_size=VOCAB_SIZE,
    embed_size=EMBED_SIZE,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    ffn_dim=FFN_DIM,
    dropout=DROPOUT
)

# Функция обучения
def train_model(model, data, epochs=5, device='cpu'):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for text in data:
            tokens = tokenizer(text)
            indices = [vocab[token] for token in tokens][:50]  # Ограничение длины
            if len(indices) < 2:
                continue
            src = torch.tensor(indices[:-1], dtype=torch.long).unsqueeze(0).to(device)
            tgt = torch.tensor(indices[1:], dtype=torch.long).unsqueeze(0).to(device)
            optimizer.zero_grad()
            output = model(src)
            loss = criterion(output.view(-1, VOCAB_SIZE), tgt.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss / len(data)}")
    torch.save(model.state_dict(), "model.pt")

# Запуск обучения
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_model(model, sample_data, epochs=5, device=device)