Spaces:
Configuration error
Configuration error
File size: 4,434 Bytes
f521886 0aee636 f521886 0aee636 f521886 0aee636 f521886 0aee636 f521886 0aee636 f521886 0aee636 f521886 0aee636 f521886 0aee636 f521886 0aee636 f521886 0aee636 f521886 0aee636 f521886 0aee636 f521886 0aee636 f521886 0aee636 f521886 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import torch
import torch.nn as nn
import gradio as gr
from torch.nn import TransformerDecoder, TransformerDecoderLayer
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import math
# Параметры модели
VOCAB_SIZE = 10000 # Размер словаря
EMBED_SIZE = 256 # Размер эмбеддингов
NUM_HEADS = 8 # Количество голов в трансформере
NUM_LAYERS = 6 # Количество слоев
FFN_DIM = 512 # Размер скрытого слоя в FFN
DROPOUT = 0.1
# Определение модели
class TransformerModel(nn.Module):
def __init__(self, vocab_size, embed_size, num_heads, num_layers, ffn_dim, dropout):
super(TransformerModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
self.pos_encoder = PositionalEncoding(embed_size, dropout)
decoder_layer = TransformerDecoderLayer(embed_size, num_heads, ffn_dim, dropout)
self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers)
self.fc_out = nn.Linear(embed_size, vocab_size)
self.embed_size = embed_size
def forward(self, src, src_mask=None):
src = self.embedding(src) * math.sqrt(self.embed_size)
src = self.pos_encoder(src)
output = self.transformer_decoder(src, memory=None, tgt_mask=src_mask)
output = self.fc_out(output)
return output
class PositionalEncoding(nn.Module):
def __init__(self, embed_size, dropout, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, embed_size)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
# Подсчет параметров
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# Токенизатор и словарь
tokenizer = get_tokenizer('basic_english')
def yield_tokens(data_iter):
for text in data_iter:
yield tokenizer(text)
# Пример данных (замените на свой датасет)
sample_data = ["Hello world", "This is a test", "Build a neural network"] * 1000
vocab = build_vocab_from_iterator(yield_tokens(sample_data), specials=['<unk>', '<pad>'])
vocab.set_default_index(vocab['<unk>'])
# Инициализация модели
model = TransformerModel(
vocab_size=VOCAB_SIZE,
embed_size=EMBED_SIZE,
num_heads=NUM_HEADS,
num_layers=NUM_LAYERS,
ffn_dim=FFN_DIM,
dropout=DROPOUT
)
print(f"Количество параметров модели: {count_parameters(model)}")
# Загрузка обученных весов (если есть)
try:
model.load_state_dict(torch.load("model.pt"))
except FileNotFoundError:
print("Веса модели не найдены. Запустите train.py для обучения.")
# Функция генерации текста
def generate_text(prompt, max_length=50):
model.eval()
tokens = tokenizer(prompt)
indices = [vocab[token] for token in tokens]
src = torch.tensor(indices, dtype=torch.long).unsqueeze(0)
for _ in range(max_length):
with torch.no_grad():
output = model(src)
next_token = output[:, -1, :].argmax(-1).item()
src = torch.cat([src, torch.tensor([[next_token]], dtype=torch.long)], dim=-1)
if next_token == vocab['<pad>']:
break
generated = [vocab.get_itos()[idx] for idx in src.squeeze().tolist()]
return ' '.join(generated)
# Интерфейс Gradio
iface = gr.Interface(
fn=generate_text,
inputs=gr.Textbox(lines=2, placeholder="Введите начало текста..."),
outputs="text",
title="Моя нейросеть (~10M параметров)",
description="Введите текст, и модель продолжит его."
)
# Запуск интерфейса
if __name__ == "__main__":
iface.launch() |