NonameSsSs commited on
Commit
7044743
·
verified ·
1 Parent(s): 7011363

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import gradio as gr
5
+ import os
6
+
7
+ # --- КОНФИГУРАЦИЯ ---
8
+ BLOCK_SIZE = 64
9
+ EMBED_SIZE = 64
10
+ HEADS = 4
11
+ MODEL_PATH = 'minigpt_checkpoint.pt'
12
+
13
+ # --- АРХИТЕКТУРА ---
14
+ class MiniGPT(nn.Module):
15
+ def __init__(self, vocab_size, embed_size, num_heads, block_size):
16
+ super().__init__()
17
+ self.block_size = block_size
18
+ self.embedding = nn.Embedding(vocab_size, embed_size)
19
+ self.pos_embedding = nn.Embedding(block_size, embed_size)
20
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads, batch_first=True)
21
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
22
+ self.fc_out = nn.Linear(embed_size, vocab_size)
23
+
24
+ def forward(self, x):
25
+ B, T = x.shape
26
+ pos = torch.arange(T, device=x.device).unsqueeze(0)
27
+ out = self.embedding(x) + self.pos_embedding(pos)
28
+ out = self.transformer(out)
29
+ return self.fc_out(out)
30
+
31
+ # --- ДАННЫЕ И ТОКЕНИЗАЦИЯ ---
32
+ # (В продакшене лучше сохранять словарь в JSON, здесь - упрощенно)
33
+ FILE_NAME = 'book.txt'
34
+ if os.path.exists(FILE_NAME):
35
+ with open(FILE_NAME, 'r', encoding='utf-8') as f: text = f.read()
36
+ else:
37
+ text = "привет как дела нормально пока" * 100
38
+
39
+ chars = sorted(list(set(text)))
40
+ vocab_size = len(chars)
41
+ stoi = { ch:i for i,ch in enumerate(chars) }
42
+ itos = { i:ch for i,ch in enumerate(chars) }
43
+ encode = lambda s: [stoi.get(c, 0) for c in s] # 0 как fallback
44
+ decode = lambda l: ''.join([itos[i] for i in l])
45
+
46
+ # --- ЗАГРУЗКА МОДЕЛИ ---
47
+ model = MiniGPT(vocab_size, EMBED_SIZE, HEADS, BLOCK_SIZE)
48
+ if os.path.exists(MODEL_PATH):
49
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
50
+ model.eval()
51
+
52
+ # --- ЛОГИКА ГЕНЕРАЦИИ ---
53
+ def predict(prompt, max_length=50):
54
+ if not prompt: return "Введите текст"
55
+
56
+ # Ограничиваем входной контекст
57
+ context_tokens = encode(prompt)[-BLOCK_SIZE:]
58
+ context = torch.tensor(context_tokens, dtype=torch.long).unsqueeze(0)
59
+
60
+ generated = []
61
+ for _ in range(max_length):
62
+ cond = context[:, -BLOCK_SIZE:]
63
+ with torch.no_grad():
64
+ logits = model(cond)[:, -1, :]
65
+ probs = F.softmax(logits, dim=-1)
66
+ next_token = torch.multinomial(probs, num_samples=1) # Для разнообразия
67
+ context = torch.cat((context, next_token), dim=1)
68
+ generated.append(next_token.item())
69
+
70
+ return decode(generated)
71
+
72
+ # --- ИНТЕРФЕЙС GRADIO ---
73
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
74
+ gr.Markdown("# 🤖 MiniGPT Chat")
75
+ with gr.Row():
76
+ input_text = gr.Textbox(label="Ваш запрос", placeholder="Напишите начало фразы...")
77
+ output_text = gr.Textbox(label="Ответ модели")
78
+
79
+ btn = gr.Button("Сгенерировать")
80
+ btn.click(fn=predict, inputs=[input_text], outputs=[output_text])
81
+
82
+ if __name__ == "__main__":
83
+ demo.launch()