kgrabko commited on
Commit
36d6eb9
·
verified ·
1 Parent(s): 1db6855

Upload chatbot_1b.py

Browse files
Files changed (1) hide show
  1. chatbot_1b.py +158 -0
chatbot_1b.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 CMS Manhattan
2
+ # All rights reserved.
3
+ #
4
+ # This file is part of a project authored by CMS Manhattan. You may use, distribute, and modify
5
+ # this code under the terms of the APACHE 2.0 license.
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import GPT2TokenizerFast
10
+ from gpt_modern_8b import JiRackPyTorch # Используем тот же импорт, что и в fine_tune.py
11
+ import os
12
+ from pathlib import Path
13
+
14
+ # ============================= НАСТРОЙКИ ГЕНЕРАЦИИ =============================
15
+ # Temperature: Чем ниже, тем более консервативны и предсказуемы ответы.
16
+ # Начните с 0.7. Если модель повторяется, повысьте до 0.8.
17
+ TEMPERATURE = 0.7
18
+
19
+ # Top-K: Ограничивает выборку K наиболее вероятными токенами.
20
+ # Начните с 50. Увеличивайте, если ответы слишком скучные.
21
+ TOP_K = 50
22
+
23
+ # Max Length: Максимальное количество генерируемых токенов за раз
24
+ MAX_LENGTH = 120
25
+
26
+ # ============================= ПУТИ =============================
27
+ #LAST_TRAINED_PATH = Path("models/gpt_last_trained.pt")
28
+ LAST_TRAINED_PATH = Path("build/fine_tuning_output/epoch2/gpt_finetuned.pt")
29
+ #FINAL_OUTPUT_DIR = Path("build/fine_tuning_output/final")
30
+ FINAL_OUTPUT_DIR = Path("build/fine_tuning_output/epoch2/gpt_finetuned.pt")
31
+ MODEL_SAVE_NAME = "gpt_finetuned.pt"
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+
34
+ # ============================= КЛАСС Chatbot =============================
35
+ class Chatbot:
36
+ def __init__(self, model_path):
37
+ # 1. Токенизатор
38
+ print("Loading standard tokenizer (gpt2)...")
39
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
40
+ self.tokenizer.pad_token = self.tokenizer.eos_token
41
+
42
+ # 2. Модель
43
+ print("Initializing model...")
44
+ self.model = JiRackPyTorch().to(device)
45
+ self.model.eval()
46
+
47
+ # Поиск последних весов: сначала финальная папка, потом last_trained
48
+ load_path = None
49
+ if (FINAL_OUTPUT_DIR / MODEL_SAVE_NAME).exists():
50
+ load_path = FINAL_OUTPUT_DIR / MODEL_SAVE_NAME
51
+ print(f"Weights for Epoch 50 found. Loading and moving to {device}...")
52
+ elif model_path.exists():
53
+ load_path = model_path
54
+ print(f"Loading weights from {load_path} and moving to {device}...")
55
+
56
+ if load_path:
57
+ self.model.load_state_dict(torch.load(load_path, map_location=device))
58
+ else:
59
+ print("Warning: No trained weights found. Using initialized model.")
60
+
61
+ print(f"Model successfully loaded on {device} and ready for chat!")
62
+
63
+ def generate_response(self, prompt, max_length=MAX_LENGTH, temperature=TEMPERATURE, top_k=TOP_K):
64
+ # Токенизируем ввод
65
+ input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(device)
66
+
67
+ # Запускаем генерацию
68
+ with torch.no_grad():
69
+ for _ in range(max_length):
70
+ # Пропускаем через модель
71
+ logits, _ = self.model(input_ids)
72
+
73
+ # Берем только логиты для последнего токена
74
+ next_token_logits = logits[:, -1, :]
75
+
76
+ # Применяем температуру
77
+ next_token_logits = next_token_logits / temperature
78
+
79
+ # Применяем Top-K сэмплирование
80
+ if top_k > 0:
81
+ # Отсекаем все токены, кроме TOP_K самых вероятных
82
+ values, indices = torch.topk(next_token_logits, top_k)
83
+ # Создаем маску для исключения остальных токенов
84
+ next_token_logits = torch.full_like(next_token_logits, -float('inf'))
85
+ next_token_logits.scatter_(1, indices, values)
86
+
87
+ # Преобразуем логиты в вероятности и сэмплируем следующий токен
88
+ probabilities = F.softmax(next_token_logits, dim=-1)
89
+ next_token = torch.multinomial(probabilities, num_samples=1)
90
+
91
+ # Добавляем сгенерированный токен к входным данным
92
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
93
+
94
+ # Проверяем, если сгенерирован токен конца диалога (__eou__) или конца текста (EOS)
95
+ generated_token = self.tokenizer.decode(next_token.squeeze().item())
96
+ if "__eou__" in generated_token or next_token.squeeze().item() == self.tokenizer.eos_token_id:
97
+ break
98
+
99
+ # Декодируем всю последовательность, обрезая исходный запрос
100
+ output = self.tokenizer.decode(input_ids.squeeze().tolist())
101
+
102
+ # Убираем исходный промт
103
+ response = output[len(prompt):].strip()
104
+
105
+ # Убираем токен конца диалога, если он остался в конце
106
+ response = response.replace("__eou__", "").strip()
107
+
108
+ return response
109
+
110
+ def main():
111
+ # === КОРРЕКТИРОВКА ОШИБКИ: Объявляем глобальные переменные в начале функции ===
112
+ global TEMPERATURE, TOP_K
113
+
114
+ chatbot = Chatbot(LAST_TRAINED_PATH)
115
+
116
+ print("\n" + "="*60)
117
+ print(f"🤖 CHATBOT ACTIVATED (PPL 2.6 / Temperature {TEMPERATURE} / Top-K {TOP_K})")
118
+ print("Enter 'exit' or 'quit' to quit. Use 'set temp=0.x' or 'set k=N' to change settings.")
119
+ print("="*60 + "\n")
120
+
121
+ while True:
122
+ try:
123
+ user_input = input(">>> You: ")
124
+ if user_input.lower() in ['quit', 'exit']:
125
+ break
126
+
127
+ # Команды управления параметрами (опционально)
128
+ if user_input.lower().startswith('set temp='):
129
+ try:
130
+ # Теперь мы можем присваивать значение напрямую, так как они объявлены глобальными выше
131
+ TEMPERATURE = float(user_input.split('=')[1].strip())
132
+ print(f"🤖 Temperature set to {TEMPERATURE}")
133
+ continue
134
+ except ValueError:
135
+ print("🤖 Invalid temperature value. Use 'set temp=0.x'.")
136
+ continue
137
+
138
+ if user_input.lower().startswith('set k='):
139
+ try:
140
+ # Теперь мы можем присваивать значение напрямую, так как они объявлены глобальными выше
141
+ TOP_K = int(user_input.split('=')[1].strip())
142
+ print(f"🤖 Top-K set to {TOP_K}")
143
+ continue
144
+ except ValueError:
145
+ print("🤖 Invalid K value. Use 'set k=N' (e.g., set k=50).")
146
+ continue
147
+
148
+ print("...Generating...")
149
+ response = chatbot.generate_response(user_input)
150
+ print(f"🤖 Model: {response}\n")
151
+
152
+ except Exception as e:
153
+ print(f"An error occurred: {e}")
154
+ break
155
+
156
+ if __name__ == "__main__":
157
+ from pathlib import Path
158
+ main()