Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| import os | |
| # --- КОНФИГУРАЦИЯ --- | |
| BLOCK_SIZE = 64 | |
| EMBED_SIZE = 64 | |
| HEADS = 4 | |
| MODEL_PATH = 'minigpt_checkpoint.pt' | |
| FILE_NAME = 'book.txt' # Используется для загрузки словаря | |
| # --- АРХИТЕКТУРА МОДЕЛИ --- | |
| class MiniGPT(nn.Module): | |
| def __init__(self, vocab_size, embed_size, num_heads, block_size): | |
| super().__init__() | |
| self.block_size = block_size | |
| self.embedding = nn.Embedding(vocab_size, embed_size) | |
| self.pos_embedding = nn.Embedding(block_size, embed_size) | |
| encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads, batch_first=True) | |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2) | |
| self.fc_out = nn.Linear(embed_size, vocab_size) | |
| def forward(self, x): | |
| B, T = x.shape | |
| pos = torch.arange(T, device=x.device).unsqueeze(0) | |
| out = self.embedding(x) + self.pos_embedding(pos) | |
| out = self.transformer(out) | |
| return self.fc_out(out) | |
| # --- ПОДГОТОВКА ДАННЫХ И ТОКЕНИЗАЦИЯ --- | |
| if os.path.exists(FILE_NAME): | |
| with open(FILE_NAME, 'r', encoding='utf-8') as f: text = f.read() | |
| else: | |
| # Fallback текст, если book.txt не найден (должен содержать все токены) | |
| text = "<|user|>привет<|model|>нормально" * 100 | |
| chars = sorted(list(set(text))) | |
| vocab_size = len(chars) | |
| stoi = { ch:i for i,ch in enumerate(chars) } | |
| itos = { i:ch for i,ch in enumerate(chars) } | |
| encode = lambda s: [stoi.get(c, 0) for c in s] | |
| decode = lambda l: ''.join([itos[i] for i in l]) | |
| # --- ЗАГРУЗКА МОДЕЛИ --- | |
| model = MiniGPT(vocab_size, EMBED_SIZE, HEADS, BLOCK_SIZE) | |
| if os.path.exists(MODEL_PATH): | |
| # Загружаем модель на CPU, что важно для HF Spaces с базовым тарифом | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) | |
| model.eval() | |
| # --- ЛОГИКА ГЕНЕРАЦИИ С ТОКЕНАМИ И НАСТРОЙКАМИ --- | |
| def predict(prompt, max_length, temperature): | |
| if not prompt: return "Введите текст" | |
| # Принудительно добавляем токен Модели к запросу пользователя | |
| full_prompt = prompt.strip() + "<|model|>" | |
| context_tokens = encode(full_prompt)[-BLOCK_SIZE:] | |
| context = torch.tensor(context_tokens, dtype=torch.long).unsqueeze(0) | |
| generated_tokens = [] | |
| for _ in range(max_length): | |
| cond = context[:, -BLOCK_SIZE:] | |
| with torch.no_grad(): | |
| logits = model(cond)[:, -1, :] | |
| if temperature == 0: | |
| probs = F.softmax(logits, dim=-1) | |
| next_token = torch.argmax(probs, dim=-1).unsqueeze(0) | |
| else: | |
| probs = F.softmax(logits / temperature, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| # Остановка генерации, если модель сгенерировала начало токена '<' | |
| if decode([next_token.item()]) == '<': | |
| break | |
| context = torch.cat((context, next_token), dim=1) | |
| generated_tokens.append(next_token.item()) | |
| return decode(generated_tokens) | |
| # --- ИНТЕРФЕЙС GRADIO --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🤖 MiniGPT Chat с настройками") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Подсказка пользователю начинать с токена для лучшей работы | |
| input_text = gr.Textbox(label="Ваш запрос (начинайте с <|user|>)", placeholder="Напишите начало фразы...", lines=3) | |
| max_len_slider = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Максимальная длина ответа") | |
| temp_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.05, label="Температура (0=детерминированный)") | |
| btn = gr.Button("Сгенерировать") | |
| output_text = gr.Textbox(label="Ответ модели", lines=10) | |
| btn.click(fn=predict, inputs=[input_text, max_len_slider, temp_slider], outputs=[output_text]) | |
| if __name__ == "__main__": | |
| demo.launch() | |