transformer / app.py
HeavensHackDev's picture
Upload 4 files
f521886 verified
import torch
import torch.nn as nn
import gradio as gr
from torch.nn import TransformerDecoder, TransformerDecoderLayer
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import math
# Параметры модели
VOCAB_SIZE = 10000 # Размер словаря
EMBED_SIZE = 256 # Размер эмбеддингов
NUM_HEADS = 8 # Количество голов в трансформере
NUM_LAYERS = 6 # Количество слоев
FFN_DIM = 512 # Размер скрытого слоя в FFN
DROPOUT = 0.1
# Определение модели
class TransformerModel(nn.Module):
def __init__(self, vocab_size, embed_size, num_heads, num_layers, ffn_dim, dropout):
super(TransformerModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
self.pos_encoder = PositionalEncoding(embed_size, dropout)
decoder_layer = TransformerDecoderLayer(embed_size, num_heads, ffn_dim, dropout)
self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers)
self.fc_out = nn.Linear(embed_size, vocab_size)
self.embed_size = embed_size
def forward(self, src, src_mask=None):
src = self.embedding(src) * math.sqrt(self.embed_size)
src = self.pos_encoder(src)
output = self.transformer_decoder(src, memory=None, tgt_mask=src_mask)
output = self.fc_out(output)
return output
class PositionalEncoding(nn.Module):
def __init__(self, embed_size, dropout, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, embed_size)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
# Подсчет параметров
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# Токенизатор и словарь
tokenizer = get_tokenizer('basic_english')
def yield_tokens(data_iter):
for text in data_iter:
yield tokenizer(text)
# Пример данных (замените на свой датасет)
sample_data = ["Hello world", "This is a test", "Build a neural network"] * 1000
vocab = build_vocab_from_iterator(yield_tokens(sample_data), specials=['<unk>', '<pad>'])
vocab.set_default_index(vocab['<unk>'])
# Инициализация модели
model = TransformerModel(
vocab_size=VOCAB_SIZE,
embed_size=EMBED_SIZE,
num_heads=NUM_HEADS,
num_layers=NUM_LAYERS,
ffn_dim=FFN_DIM,
dropout=DROPOUT
)
print(f"Количество параметров модели: {count_parameters(model)}")
# Загрузка обученных весов (если есть)
try:
model.load_state_dict(torch.load("model.pt"))
except FileNotFoundError:
print("Веса модели не найдены. Запустите train.py для обучения.")
# Функция генерации текста
def generate_text(prompt, max_length=50):
model.eval()
tokens = tokenizer(prompt)
indices = [vocab[token] for token in tokens]
src = torch.tensor(indices, dtype=torch.long).unsqueeze(0)
for _ in range(max_length):
with torch.no_grad():
output = model(src)
next_token = output[:, -1, :].argmax(-1).item()
src = torch.cat([src, torch.tensor([[next_token]], dtype=torch.long)], dim=-1)
if next_token == vocab['<pad>']:
break
generated = [vocab.get_itos()[idx] for idx in src.squeeze().tolist()]
return ' '.join(generated)
# Интерфейс Gradio
iface = gr.Interface(
fn=generate_text,
inputs=gr.Textbox(lines=2, placeholder="Введите начало текста..."),
outputs="text",
title="Моя нейросеть (~10M параметров)",
description="Введите текст, и модель продолжит его."
)
# Запуск интерфейса
if __name__ == "__main__":
iface.launch()