Spaces:
Sleeping
Sleeping
File size: 4,501 Bytes
7044743 207efb8 7044743 c751690 7044743 207efb8 7044743 207efb8 7044743 207efb8 7044743 207efb8 7044743 207efb8 7044743 207efb8 4d89c7c 7044743 207efb8 7044743 207efb8 7044743 207efb8 7044743 207efb8 c751690 207efb8 7044743 207efb8 1f7a666 207efb8 7044743 207efb8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
import os
# --- КОНФИГУРАЦИЯ ---
BLOCK_SIZE = 64
EMBED_SIZE = 64
HEADS = 4
MODEL_PATH = 'minigpt_checkpoint.pt'
FILE_NAME = 'book.txt' # Используется для загрузки словаря
# --- АРХИТЕКТУРА МОДЕЛИ ---
class MiniGPT(nn.Module):
def __init__(self, vocab_size, embed_size, num_heads, block_size):
super().__init__()
self.block_size = block_size
self.embedding = nn.Embedding(vocab_size, embed_size)
self.pos_embedding = nn.Embedding(block_size, embed_size)
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
self.fc_out = nn.Linear(embed_size, vocab_size)
def forward(self, x):
B, T = x.shape
pos = torch.arange(T, device=x.device).unsqueeze(0)
out = self.embedding(x) + self.pos_embedding(pos)
out = self.transformer(out)
return self.fc_out(out)
# --- ПОДГОТОВКА ДАННЫХ И ТОКЕНИЗАЦИЯ ---
if os.path.exists(FILE_NAME):
with open(FILE_NAME, 'r', encoding='utf-8') as f: text = f.read()
else:
# Fallback текст, если book.txt не найден (должен содержать все токены)
text = "<|user|>привет<|model|>нормально" * 100
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi.get(c, 0) for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
# --- ЗАГРУЗКА МОДЕЛИ ---
model = MiniGPT(vocab_size, EMBED_SIZE, HEADS, BLOCK_SIZE)
if os.path.exists(MODEL_PATH):
# Загружаем модель на CPU, что важно для HF Spaces с базовым тарифом
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
model.eval()
# --- ЛОГИКА ГЕНЕРАЦИИ С ТОКЕНАМИ И НАСТРОЙКАМИ ---
def predict(prompt, max_length, temperature):
if not prompt: return "Введите текст"
# Принудительно добавляем токен Модели к запросу пользователя
full_prompt = prompt.strip() + "<|model|>"
context_tokens = encode(full_prompt)[-BLOCK_SIZE:]
context = torch.tensor(context_tokens, dtype=torch.long).unsqueeze(0)
generated_tokens = []
for _ in range(max_length):
cond = context[:, -BLOCK_SIZE:]
with torch.no_grad():
logits = model(cond)[:, -1, :]
if temperature == 0:
probs = F.softmax(logits, dim=-1)
next_token = torch.argmax(probs, dim=-1).unsqueeze(0)
else:
probs = F.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Остановка генерации, если модель сгенерировала начало токена '<'
if decode([next_token.item()]) == '<':
break
context = torch.cat((context, next_token), dim=1)
generated_tokens.append(next_token.item())
return decode(generated_tokens)
# --- ИНТЕРФЕЙС GRADIO ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🤖 MiniGPT Chat с настройками")
with gr.Row():
with gr.Column():
# Подсказка пользователю начинать с токена для лучшей работы
input_text = gr.Textbox(label="Ваш запрос (начинайте с <|user|>)", placeholder="Напишите начало фразы...", lines=3)
max_len_slider = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Максимальная длина ответа")
temp_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.05, label="Температура (0=детерминированный)")
btn = gr.Button("Сгенерировать")
output_text = gr.Textbox(label="Ответ модели", lines=10)
btn.click(fn=predict, inputs=[input_text, max_len_slider, temp_slider], outputs=[output_text])
if __name__ == "__main__":
demo.launch()
|