Spaces:
Configuration error
Configuration error
| import torch | |
| import torch.nn as nn | |
| import gradio as gr | |
| from torch.nn import TransformerDecoder, TransformerDecoderLayer | |
| from torchtext.data.utils import get_tokenizer | |
| from torchtext.vocab import build_vocab_from_iterator | |
| import math | |
| # Параметры модели | |
| VOCAB_SIZE = 10000 # Размер словаря | |
| EMBED_SIZE = 256 # Размер эмбеддингов | |
| NUM_HEADS = 8 # Количество голов в трансформере | |
| NUM_LAYERS = 6 # Количество слоев | |
| FFN_DIM = 512 # Размер скрытого слоя в FFN | |
| DROPOUT = 0.1 | |
| # Определение модели | |
| class TransformerModel(nn.Module): | |
| def __init__(self, vocab_size, embed_size, num_heads, num_layers, ffn_dim, dropout): | |
| super(TransformerModel, self).__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_size) | |
| self.pos_encoder = PositionalEncoding(embed_size, dropout) | |
| decoder_layer = TransformerDecoderLayer(embed_size, num_heads, ffn_dim, dropout) | |
| self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers) | |
| self.fc_out = nn.Linear(embed_size, vocab_size) | |
| self.embed_size = embed_size | |
| def forward(self, src, src_mask=None): | |
| src = self.embedding(src) * math.sqrt(self.embed_size) | |
| src = self.pos_encoder(src) | |
| output = self.transformer_decoder(src, memory=None, tgt_mask=src_mask) | |
| output = self.fc_out(output) | |
| return output | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, embed_size, dropout, max_len=5000): | |
| super(PositionalEncoding, self).__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros(max_len, embed_size) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0) | |
| self.register_buffer('pe', pe) | |
| def forward(self, x): | |
| x = x + self.pe[:, :x.size(1)] | |
| return self.dropout(x) | |
| # Подсчет параметров | |
| def count_parameters(model): | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| # Токенизатор и словарь | |
| tokenizer = get_tokenizer('basic_english') | |
| def yield_tokens(data_iter): | |
| for text in data_iter: | |
| yield tokenizer(text) | |
| # Пример данных (замените на свой датасет) | |
| sample_data = ["Hello world", "This is a test", "Build a neural network"] * 1000 | |
| vocab = build_vocab_from_iterator(yield_tokens(sample_data), specials=['<unk>', '<pad>']) | |
| vocab.set_default_index(vocab['<unk>']) | |
| # Инициализация модели | |
| model = TransformerModel( | |
| vocab_size=VOCAB_SIZE, | |
| embed_size=EMBED_SIZE, | |
| num_heads=NUM_HEADS, | |
| num_layers=NUM_LAYERS, | |
| ffn_dim=FFN_DIM, | |
| dropout=DROPOUT | |
| ) | |
| print(f"Количество параметров модели: {count_parameters(model)}") | |
| # Загрузка обученных весов (если есть) | |
| try: | |
| model.load_state_dict(torch.load("model.pt")) | |
| except FileNotFoundError: | |
| print("Веса модели не найдены. Запустите train.py для обучения.") | |
| # Функция генерации текста | |
| def generate_text(prompt, max_length=50): | |
| model.eval() | |
| tokens = tokenizer(prompt) | |
| indices = [vocab[token] for token in tokens] | |
| src = torch.tensor(indices, dtype=torch.long).unsqueeze(0) | |
| for _ in range(max_length): | |
| with torch.no_grad(): | |
| output = model(src) | |
| next_token = output[:, -1, :].argmax(-1).item() | |
| src = torch.cat([src, torch.tensor([[next_token]], dtype=torch.long)], dim=-1) | |
| if next_token == vocab['<pad>']: | |
| break | |
| generated = [vocab.get_itos()[idx] for idx in src.squeeze().tolist()] | |
| return ' '.join(generated) | |
| # Интерфейс Gradio | |
| iface = gr.Interface( | |
| fn=generate_text, | |
| inputs=gr.Textbox(lines=2, placeholder="Введите начало текста..."), | |
| outputs="text", | |
| title="Моя нейросеть (~10M параметров)", | |
| description="Введите текст, и модель продолжит его." | |
| ) | |
| # Запуск интерфейса | |
| if __name__ == "__main__": | |
| iface.launch() |