NonameSsSs commited on
Commit
207efb8
·
verified ·
1 Parent(s): d1ec634

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -51
app.py CHANGED
@@ -1,23 +1,20 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
- import time
5
- import os # Добавлено для проверки наличия файла
6
 
7
  # --- КОНФИГУРАЦИЯ ---
8
- FILE_NAME = 'book.txt'
9
- MODEL_PATH = 'minigpt_checkpoint.pt'
10
  BLOCK_SIZE = 64
11
- BATCH_SIZE = 16
12
  EMBED_SIZE = 64
13
  HEADS = 4
14
- LR = 0.001
15
- EPOCHS = 300 # Увеличено для лучшего обучения на новых данных
16
 
17
- # --- 1. АРХИТЕКТУРА МОДЕЛИ ---
18
  class MiniGPT(nn.Module):
19
  def __init__(self, vocab_size, embed_size, num_heads, block_size):
20
- super(MiniGPT, self).__init__()
21
  self.block_size = block_size
22
  self.embedding = nn.Embedding(vocab_size, embed_size)
23
  self.pos_embedding = nn.Embedding(block_size, embed_size)
@@ -28,61 +25,77 @@ class MiniGPT(nn.Module):
28
  def forward(self, x):
29
  B, T = x.shape
30
  pos = torch.arange(T, device=x.device).unsqueeze(0)
31
- tok_emb = self.embedding(x)
32
- pos_emb = self.pos_embedding(pos)
33
- out = tok_emb + pos_emb
34
  out = self.transformer(out)
35
- logits = self.fc_out(out)
36
- return logits
37
 
38
- # --- 2. ПОДГОТОВКА ДАННЫХ И ТОКЕНИЗАЦИЯ ---
39
- try:
40
- with open(FILE_NAME, 'r', encoding='utf-8') as f:
41
- text = f.read()
42
- print(f"Успешно прочитан файл: {FILE_NAME}, размер текста: {len(text)} символов.")
43
- except FileNotFoundError:
44
- print(f"Ошибка: файл '{FILE_NAME}' не найден. Использую fallback текст.")
45
- # Fallback текст должен содержать символы '<', '|', '>', 'u', 's', 'e', 'r', 'm', 'o', 'd', 'l'
46
  text = "<|user|>привет<|model|>нормально" * 100
47
 
48
  chars = sorted(list(set(text)))
49
  vocab_size = len(chars)
50
  stoi = { ch:i for i,ch in enumerate(chars) }
51
  itos = { i:ch for i,ch in enumerate(chars) }
52
- encode = lambda s: [stoi[c] for c in s]
53
  decode = lambda l: ''.join([itos[i] for i in l])
54
 
55
- data = torch.tensor(encode(text), dtype=torch.long)
56
- print(f"Данные закодированы в тензор размером: {data.shape}")
57
-
58
- # --- 3. НАСТРОЙКИ ОБУЧЕНИЯ И ИНИЦИАЛИЗАЦИЯ ---
59
  model = MiniGPT(vocab_size, EMBED_SIZE, HEADS, BLOCK_SIZE)
60
- optimizer = torch.optim.Adam(model.parameters(), lr=LR)
61
- criterion = nn.CrossEntropyLoss()
62
-
63
- # --- 4. ЦИКЛ ОБУЧЕНИЯ ---
64
- print("Начинаю обучение...")
65
- model.train()
66
-
67
- for epoch in range(EPOCHS):
68
- # Генерация случайных батчей из данных
69
- ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,))
70
- xb = torch.stack([data[i:i+BLOCK_SIZE] for i in ix])
71
- yb = torch.stack([data[i+1:i+BLOCK_SIZE+1] for i in ix])
72
 
73
- logits = model(xb)
74
- B, T, C = logits.shape
75
- loss = criterion(logits.view(B*T, C), yb.view(B*T))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- optimizer.zero_grad()
78
- loss.backward()
79
- optimizer.step()
 
80
 
81
- if epoch % 50 == 0:
82
- print(f"Эпоха {epoch}, Ошибка: {loss.item():.4f}")
 
 
 
 
 
 
 
 
83
 
84
- print("Обучение завершено.")
 
 
85
 
86
- # --- 5. СОХРАНЕНИЕ МОДЕЛИ ---
87
- torch.save(model.state_dict(), MODEL_PATH)
88
- print(f"Модель сохранена в файл {MODEL_PATH}")
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ import gradio as gr
5
+ import os
6
 
7
  # --- КОНФИГУРАЦИЯ ---
 
 
8
  BLOCK_SIZE = 64
 
9
  EMBED_SIZE = 64
10
  HEADS = 4
11
+ MODEL_PATH = 'minigpt_checkpoint.pt'
12
+ FILE_NAME = 'book.txt' # Используется для загрузки словаря
13
 
14
+ # --- АРХИТЕКТУРА МОДЕЛИ ---
15
  class MiniGPT(nn.Module):
16
  def __init__(self, vocab_size, embed_size, num_heads, block_size):
17
+ super().__init__()
18
  self.block_size = block_size
19
  self.embedding = nn.Embedding(vocab_size, embed_size)
20
  self.pos_embedding = nn.Embedding(block_size, embed_size)
 
25
  def forward(self, x):
26
  B, T = x.shape
27
  pos = torch.arange(T, device=x.device).unsqueeze(0)
28
+ out = self.embedding(x) + self.pos_embedding(pos)
 
 
29
  out = self.transformer(out)
30
+ return self.fc_out(out)
 
31
 
32
+ # --- ПОДГОТОВКА ДАННЫХ И ТОКЕНИЗАЦИЯ ---
33
+ if os.path.exists(FILE_NAME):
34
+ with open(FILE_NAME, 'r', encoding='utf-8') as f: text = f.read()
35
+ else:
36
+ # Fallback текст, если book.txt не найден (должен содержать все токены)
 
 
 
37
  text = "<|user|>привет<|model|>нормально" * 100
38
 
39
  chars = sorted(list(set(text)))
40
  vocab_size = len(chars)
41
  stoi = { ch:i for i,ch in enumerate(chars) }
42
  itos = { i:ch for i,ch in enumerate(chars) }
43
+ encode = lambda s: [stoi.get(c, 0) for c in s]
44
  decode = lambda l: ''.join([itos[i] for i in l])
45
 
46
+ # --- ЗАГРУЗКА МОДЕЛИ ---
 
 
 
47
  model = MiniGPT(vocab_size, EMBED_SIZE, HEADS, BLOCK_SIZE)
48
+ if os.path.exists(MODEL_PATH):
49
+ # Загружаем модель на CPU, что важно для HF Spaces с базовым тарифом
50
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
51
+ model.eval()
 
 
 
 
 
 
 
 
52
 
53
+ # --- ЛОГИКА ГЕНЕРАЦИИ С ТОКЕНАМИ И НАСТРОЙКАМИ ---
54
+ def predict(prompt, max_length, temperature):
55
+ if not prompt: return "Введите текст"
56
+
57
+ # Принудительно добавляем токен Модели к запросу пользователя
58
+ full_prompt = prompt.strip() + "<|model|>"
59
+
60
+ context_tokens = encode(full_prompt)[-BLOCK_SIZE:]
61
+ context = torch.tensor(context_tokens, dtype=torch.long).unsqueeze(0)
62
+
63
+ generated_tokens = []
64
+ for _ in range(max_length):
65
+ cond = context[:, -BLOCK_SIZE:]
66
+ with torch.no_grad():
67
+ logits = model(cond)[:, -1, :]
68
+
69
+ if temperature == 0:
70
+ probs = F.softmax(logits, dim=-1)
71
+ next_token = torch.argmax(probs, dim=-1).unsqueeze(0)
72
+ else:
73
+ probs = F.softmax(logits / temperature, dim=-1)
74
+ next_token = torch.multinomial(probs, num_samples=1)
75
+
76
+ # Остановка генерации, если модель сгенерировала начало токена '<'
77
+ if decode([next_token.item()]) == '<':
78
+ break
79
 
80
+ context = torch.cat((context, next_token), dim=1)
81
+ generated_tokens.append(next_token.item())
82
+
83
+ return decode(generated_tokens)
84
 
85
+ # --- ИНТЕРФЕЙС GRADIO ---
86
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
87
+ gr.Markdown("# 🤖 MiniGPT Chat с настройками")
88
+ with gr.Row():
89
+ with gr.Column():
90
+ # Подсказка пользователю начинать с токена для лучшей работы
91
+ input_text = gr.Textbox(label="Ваш запрос (начинайте с <|user|>)", placeholder="Напишите начало фразы...", lines=3)
92
+ max_len_slider = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Максимальная длина ответа")
93
+ temp_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.05, label="Температура (0=детерминированный)")
94
+ btn = gr.Button("Сгенерировать")
95
 
96
+ output_text = gr.Textbox(label=твет модели", lines=10)
97
+
98
+ btn.click(fn=predict, inputs=[input_text, max_len_slider, temp_slider], outputs=[output_text])
99
 
100
+ if __name__ == "__main__":
101
+ demo.launch()