File size: 4,434 Bytes
f521886
 
0aee636
f521886
 
 
 
0aee636
f521886
 
 
 
 
 
 
0aee636
f521886
 
 
 
 
 
 
 
 
 
0aee636
f521886
 
 
 
 
 
0aee636
f521886
 
 
 
 
 
 
 
 
 
 
0aee636
f521886
 
 
0aee636
f521886
 
 
0aee636
f521886
 
 
 
 
0aee636
f521886
 
 
 
0aee636
f521886
 
 
 
 
 
 
 
0aee636
f521886
 
 
 
 
 
 
0aee636
f521886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aee636
f521886
0aee636
f521886
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torch.nn as nn
import gradio as gr
from torch.nn import TransformerDecoder, TransformerDecoderLayer
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import math

# Параметры модели
VOCAB_SIZE = 10000  # Размер словаря
EMBED_SIZE = 256    # Размер эмбеддингов
NUM_HEADS = 8       # Количество голов в трансформере
NUM_LAYERS = 6      # Количество слоев
FFN_DIM = 512       # Размер скрытого слоя в FFN
DROPOUT = 0.1

# Определение модели
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_size, num_heads, num_layers, ffn_dim, dropout):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.pos_encoder = PositionalEncoding(embed_size, dropout)
        decoder_layer = TransformerDecoderLayer(embed_size, num_heads, ffn_dim, dropout)
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers)
        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.embed_size = embed_size

    def forward(self, src, src_mask=None):
        src = self.embedding(src) * math.sqrt(self.embed_size)
        src = self.pos_encoder(src)
        output = self.transformer_decoder(src, memory=None, tgt_mask=src_mask)
        output = self.fc_out(output)
        return output

class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

# Подсчет параметров
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Токенизатор и словарь
tokenizer = get_tokenizer('basic_english')
def yield_tokens(data_iter):
    for text in data_iter:
        yield tokenizer(text)

# Пример данных (замените на свой датасет)
sample_data = ["Hello world", "This is a test", "Build a neural network"] * 1000
vocab = build_vocab_from_iterator(yield_tokens(sample_data), specials=['<unk>', '<pad>'])
vocab.set_default_index(vocab['<unk>'])

# Инициализация модели
model = TransformerModel(
    vocab_size=VOCAB_SIZE,
    embed_size=EMBED_SIZE,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    ffn_dim=FFN_DIM,
    dropout=DROPOUT
)
print(f"Количество параметров модели: {count_parameters(model)}")

# Загрузка обученных весов (если есть)
try:
    model.load_state_dict(torch.load("model.pt"))
except FileNotFoundError:
    print("Веса модели не найдены. Запустите train.py для обучения.")

# Функция генерации текста
def generate_text(prompt, max_length=50):
    model.eval()
    tokens = tokenizer(prompt)
    indices = [vocab[token] for token in tokens]
    src = torch.tensor(indices, dtype=torch.long).unsqueeze(0)
    for _ in range(max_length):
        with torch.no_grad():
            output = model(src)
            next_token = output[:, -1, :].argmax(-1).item()
            src = torch.cat([src, torch.tensor([[next_token]], dtype=torch.long)], dim=-1)
            if next_token == vocab['<pad>']:
                break
    generated = [vocab.get_itos()[idx] for idx in src.squeeze().tolist()]
    return ' '.join(generated)

# Интерфейс Gradio
iface = gr.Interface(
    fn=generate_text,
    inputs=gr.Textbox(lines=2, placeholder="Введите начало текста..."),
    outputs="text",
    title="Моя нейросеть (~10M параметров)",
    description="Введите текст, и модель продолжит его."
)

# Запуск интерфейса
if __name__ == "__main__":
    iface.launch()