NonameSsSs commited on
Commit
4d89c7c
·
verified ·
1 Parent(s): 9f66cf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -63
app.py CHANGED
@@ -1,20 +1,23 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
- import gradio as gr
5
- import os
6
 
7
  # --- КОНФИГУРАЦИЯ ---
 
 
8
  BLOCK_SIZE = 64
 
9
  EMBED_SIZE = 64
10
  HEADS = 4
11
- MODEL_PATH = 'minigpt_checkpoint.pt'
12
- FILE_NAME = 'book.txt'
13
 
14
- # --- АРХИТЕКТУРА ---
15
  class MiniGPT(nn.Module):
16
  def __init__(self, vocab_size, embed_size, num_heads, block_size):
17
- super().__init__()
18
  self.block_size = block_size
19
  self.embedding = nn.Embedding(vocab_size, embed_size)
20
  self.pos_embedding = nn.Embedding(block_size, embed_size)
@@ -25,75 +28,61 @@ class MiniGPT(nn.Module):
25
  def forward(self, x):
26
  B, T = x.shape
27
  pos = torch.arange(T, device=x.device).unsqueeze(0)
28
- out = self.embedding(x) + self.pos_embedding(pos)
 
 
29
  out = self.transformer(out)
30
- return self.fc_out(out)
 
31
 
32
- # --- ДАННЫЕ И ТОКЕНИЗАЦИЯ ---
33
- if os.path.exists(FILE_NAME):
34
- with open(FILE_NAME, 'r', encoding='utf-8') as f: text = f.read()
35
- else:
36
- text = "привет как дела нормально пока" * 100
 
 
 
 
37
 
38
  chars = sorted(list(set(text)))
39
  vocab_size = len(chars)
40
  stoi = { ch:i for i,ch in enumerate(chars) }
41
  itos = { i:ch for i,ch in enumerate(chars) }
42
- encode = lambda s: [stoi.get(c, 0) for c in s]
43
  decode = lambda l: ''.join([itos[i] for i in l])
44
 
45
- # --- ЗАГРУЗКА МОДЕЛИ ---
 
 
 
46
  model = MiniGPT(vocab_size, EMBED_SIZE, HEADS, BLOCK_SIZE)
47
- if os.path.exists(MODEL_PATH):
48
- # Загружаем модель на CPU, что важно для HF Spaces с базовым тарифом
49
- model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
50
- model.eval()
 
 
 
 
 
 
 
 
51
 
52
- # --- ЛОГИКА ГЕНЕРАЦИИ С ТОКЕНАМИ ---
53
- def predict(prompt, max_length, temperature):
54
- if not prompt: return "Введите текст"
55
-
56
- # Принудительно добавляем токен Модели к запросу пользователя
57
- full_prompt = prompt.strip() + "<|model|>"
58
-
59
- context_tokens = encode(full_prompt)[-BLOCK_SIZE:]
60
- context = torch.tensor(context_tokens, dtype=torch.long).unsqueeze(0)
61
-
62
- generated_tokens = []
63
- for _ in range(max_length):
64
- cond = context[:, -BLOCK_SIZE:]
65
- with torch.no_grad():
66
- logits = model(cond)[:, -1, :]
67
-
68
- if temperature == 0:
69
- probs = F.softmax(logits, dim=-1)
70
- next_token = torch.argmax(probs, dim=-1).unsqueeze(0)
71
- else:
72
- probs = F.softmax(logits / temperature, dim=-1)
73
- next_token = torch.multinomial(probs, num_samples=1)
74
-
75
- # Остановка генерации, если модель сгенерировала начало токена '<'
76
- if decode([next_token.item()]) == '<':
77
- break
78
 
79
- context = torch.cat((context, next_token), dim=1)
80
- generated_tokens.append(next_token.item())
81
-
82
- return decode(generated_tokens)
83
 
84
- # --- ИНТЕРФЕЙС GRADIO ---
85
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
86
- gr.Markdown("# 🤖 MiniGPT Chat с настройками")
87
- with gr.Row():
88
- with gr.Column():
89
- input_text = gr.Textbox(label="Ваш запрос (начинайте с <|user|>)", placeholder="Напишите начало фразы...", lines=3)
90
- max_len_slider = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Максимальная длина ответа")
91
- temp_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.05, label="Температура (0=детерминированный)")
92
- btn = gr.Button("Сгенерировать")
93
 
94
- output_text = gr.Textbox(label=твет модели", lines=10)
95
-
96
- btn.click(fn=predict, inputs=[input_text, max_len_slider, temp_slider], outputs=[output_text])
97
 
98
- if __name__ == "__main__":
99
- demo.launch()
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ import time
5
+ import os # Добавлено для проверки наличия файла
6
 
7
  # --- КОНФИГУРАЦИЯ ---
8
+ FILE_NAME = 'book.txt'
9
+ MODEL_PATH = 'minigpt_checkpoint.pt'
10
  BLOCK_SIZE = 64
11
+ BATCH_SIZE = 16
12
  EMBED_SIZE = 64
13
  HEADS = 4
14
+ LR = 0.001
15
+ EPOCHS = 300 # Увеличено для лучшего обучения на новых данных
16
 
17
+ # --- 1. АРХИТЕКТУРА МОДЕЛИ ---
18
  class MiniGPT(nn.Module):
19
  def __init__(self, vocab_size, embed_size, num_heads, block_size):
20
+ super(MiniGPT, self).__init__()
21
  self.block_size = block_size
22
  self.embedding = nn.Embedding(vocab_size, embed_size)
23
  self.pos_embedding = nn.Embedding(block_size, embed_size)
 
28
  def forward(self, x):
29
  B, T = x.shape
30
  pos = torch.arange(T, device=x.device).unsqueeze(0)
31
+ tok_emb = self.embedding(x)
32
+ pos_emb = self.pos_embedding(pos)
33
+ out = tok_emb + pos_emb
34
  out = self.transformer(out)
35
+ logits = self.fc_out(out)
36
+ return logits
37
 
38
+ # --- 2. ПОДГОТОВКА ДАННЫХ И ТОКЕНИЗАЦИЯ ---
39
+ try:
40
+ with open(FILE_NAME, 'r', encoding='utf-8') as f:
41
+ text = f.read()
42
+ print(f"Успешно прочитан файл: {FILE_NAME}, размер текста: {len(text)} символов.")
43
+ except FileNotFoundError:
44
+ print(f"Ошибка: файл '{FILE_NAME}' не найден. Использую fallback текст.")
45
+ # Fallback текст должен содержать символы '<', '|', '>', 'u', 's', 'e', 'r', 'm', 'o', 'd', 'l'
46
+ text = "<|user|>привет<|model|>нормально" * 100
47
 
48
  chars = sorted(list(set(text)))
49
  vocab_size = len(chars)
50
  stoi = { ch:i for i,ch in enumerate(chars) }
51
  itos = { i:ch for i,ch in enumerate(chars) }
52
+ encode = lambda s: [stoi[c] for c in s]
53
  decode = lambda l: ''.join([itos[i] for i in l])
54
 
55
+ data = torch.tensor(encode(text), dtype=torch.long)
56
+ print(f"Данные закодированы в тензор размером: {data.shape}")
57
+
58
+ # --- 3. НАСТРОЙКИ ОБУЧЕНИЯ И ИНИЦИАЛИЗАЦИЯ ---
59
  model = MiniGPT(vocab_size, EMBED_SIZE, HEADS, BLOCK_SIZE)
60
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR)
61
+ criterion = nn.CrossEntropyLoss()
62
+
63
+ # --- 4. ЦИКЛ ОБУЧЕНИЯ ---
64
+ print("Начинаю обучение...")
65
+ model.train()
66
+
67
+ for epoch in range(EPOCHS):
68
+ # Генерация случайных батчей из данных
69
+ ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,))
70
+ xb = torch.stack([data[i:i+BLOCK_SIZE] for i in ix])
71
+ yb = torch.stack([data[i+1:i+BLOCK_SIZE+1] for i in ix])
72
 
73
+ logits = model(xb)
74
+ B, T, C = logits.shape
75
+ loss = criterion(logits.view(B*T, C), yb.view(B*T))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ optimizer.zero_grad()
78
+ loss.backward()
79
+ optimizer.step()
 
80
 
81
+ if epoch % 50 == 0:
82
+ print(f"Эпоха {epoch}, Ошибка: {loss.item():.4f}")
 
 
 
 
 
 
 
83
 
84
+ print("Обучение завершено.")
 
 
85
 
86
+ # --- 5. СОХРАНЕНИЕ МОДЕЛИ ---
87
+ torch.save(model.state_dict(), MODEL_PATH)
88
+ print(f"Модель сохранена в файл {MODEL_PATH}")