kgrabko commited on
Commit
0feea22
·
verified ·
1 Parent(s): ea7894f

Update chatbot_gpt2.py

Browse files
Files changed (1) hide show
  1. chatbot_gpt2.py +178 -162
chatbot_gpt2.py CHANGED
@@ -1,163 +1,179 @@
1
- # Copyright (c) 2025 CMS Manhattan
2
- # All rights reserved.
3
- #
4
- # This file is part of a project authored by CMS Manhattan. You may use, distribute, and modify
5
- # this code under the terms of the APACHE 2.0 license.
6
-
7
- import torch
8
- import torch.nn.functional as F
9
- from transformers import GPT2TokenizerFast
10
- from gpt_pytorch import GPTPyTorch # Using the same import as in fine_tune.py
11
- import os
12
- from pathlib import Path
13
-
14
- # ============================= GENERATION SETTINGS =============================
15
- # Temperature: Lower = more conservative and predictable answers.
16
- # Start with 0.7. Increase to 0.8 if the model starts repeating itself.
17
- TEMPERATURE = 0.7
18
-
19
- # Top-K: Limits sampling to the K most likely tokens.
20
- # Start with 50. Increase if responses feel too boring/repetitive.
21
- TOP_K = 50
22
-
23
- # Max Length: Maximum number of tokens to generate in one go
24
- MAX_LENGTH = 120
25
-
26
- # ============================= PATHS =============================
27
- # LAST_TRAINED_PATH = Path("models/gpt_last_trained.pt")
28
- LAST_TRAINED_PATH = Path("build/fine_tuning_output/epoch49/gpt_finetuned.pt")
29
- # FINAL_OUTPUT_DIR = Path("build/fine_tuning_output/final")
30
- FINAL_OUTPUT_DIR = Path("build/fine_tuning_output/epoch49/gpt_finetuned.pt")
31
- MODEL_SAVE_NAME = "gpt_finetuned.pt"
32
-
33
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
-
35
- # ============================= Chatbot CLASS =============================
36
- class Chatbot:
37
- def __init__(self, model_path):
38
- # 1. Tokenizer
39
- print("Loading standard tokenizer (gpt2)...")
40
- self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
41
- self.tokenizer.pad_token = self.tokenizer.eos_token
42
-
43
- #2. Model
44
- print("Initializing model...")
45
- self.model = GPTPyTorch().to(device)
46
- self.model.eval()
47
-
48
- # Look for the latest weights: first check final folder, then last_trained
49
- load_path = None
50
- if (FINAL_OUTPUT_DIR / MODEL_SAVE_NAME).exists():
51
- load_path = FINAL_OUTPUT_DIR / MODEL_SAVE_NAME
52
- print(f"Weights from Epoch 50 found. Loading and moving to {device}...")
53
- elif model_path.exists():
54
- load_path = model_path
55
- print(f"Loading weights from {load_path} and moving to {device}...")
56
-
57
- if load_path:
58
- self.model.load_state_dict(torch.load(load_path, map_location=device))
59
- else:
60
- print("Warning: No trained weights found. Using randomly initialized model.")
61
-
62
- print(f"Model successfully loaded on {device} and ready for chat!")
63
-
64
- def generate_response(self, prompt, max_length=MAX_LENGTH, temperature=TEMPERATURE, top_k=TOP_K):
65
- # Tokenize input
66
- input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(device)
67
-
68
- # Generation loop
69
- with torch.no_grad():
70
- for _ in range(max_length):
71
- # Forward pass through the model
72
- logits, _ = self.model(input_ids)
73
-
74
- # Take logits only for the last token
75
- next_token_logits = logits[:, -1, :]
76
-
77
- # Apply temperature
78
- next_token_logits = next_token_logits / temperature
79
-
80
- # Apply Top-K sampling
81
- if top_k > 0:
82
- # Keep only the top-k most likely tokens
83
- values, indices = torch.topk(next_token_logits, top_k)
84
- # Zero out everything else (set to -inf)
85
- next_token_logits = torch.full_like(next_token_logits, float('-inf'))
86
- next_token_logits.scatter_(1, indices, values)
87
-
88
- # Convert to probabilities and sample the next token
89
- probabilities = F.softmax(next_token_logits, dim=-1)
90
- next_token = torch.multinomial(probabilities, num_samples=1)
91
-
92
- # Append generated token to the sequence
93
- input_ids = torch.cat([input_ids, next_token], dim=-1)
94
-
95
- # Stop if end-of-utterance (__eou__) or EOS token is generated
96
- generated_token = self.tokenizer.decode(next_token.squeeze().item())
97
- if "__eou__" in generated_token or next_token.squeeze().item() == self.tokenizer.eos_token_id:
98
- break
99
-
100
- # Decode the full generated sequence
101
- output = self.tokenizer.decode(input_ids.squeeze().tolist())
102
-
103
- # Remove the original prompt from the output
104
- response = output[len(prompt):].strip()
105
-
106
- # Clean up any leftover end-of-utterance tokens
107
- response = response.replace("__eou__", "").strip()
108
-
109
- return response
110
-
111
-
112
- def main():
113
- # Fix for modifying globals inside the function
114
- global TEMPERATURE, TOP_K
115
-
116
- chatbot = Chatbot(LAST_TRAINED_PATH)
117
-
118
- print("\n" + "="*60)
119
- print(f"CHATBOT ACTIVATED (PPL ~2.6 / Temperature {TEMPERATURE} / Top-K {TOP_K})")
120
- print("Type 'exit' or 'quit' to quit. Use 'set temp=0.x' or 'set k=N' to change settings.")
121
- print("="*60 + "\n")
122
-
123
- while True:
124
- try:
125
- user_input = input(">>> You: ")
126
-
127
- if user_input.lower() in ['quit', 'exit']:
128
- print("Goodbye!")
129
- break
130
-
131
- # Settings commands
132
- if user_input.lower().startswith('set temp='):
133
- try:
134
- TEMPERATURE = float(user_input.split('=')[1].strip())
135
- print(f"Temperature updated to {TEMPERATURE}")
136
- continue
137
- except ValueError:
138
- print("Invalid temperature. Use format: set temp=0.7")
139
- continue
140
-
141
- if user_input.lower().startswith('set k='):
142
- try:
143
- TOP_K = int(user_input.split('=')[1].strip())
144
- print(f"Top-K updated to {TOP_K}")
145
- continue
146
- except ValueError:
147
- print("Invalid value. Use format: set k=50")
148
- continue
149
-
150
- print("...Generating...")
151
- response = chatbot.generate_response(user_input)
152
- print(f"Model: {response}\n")
153
-
154
- except KeyboardInterrupt:
155
- print("\nGoodbye!")
156
- break
157
- except Exception as e:
158
- print(f"An error occurred: {e}")
159
- break
160
-
161
-
162
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  main()
 
1
+ # Copyright (c) 2025 CMS Manhattan
2
+ # All rights reserved.
3
+ # Author: Konstantin Vladimirovich Grabko
4
+ # Email: grabko@cmsmanhattan.com
5
+ # Phone: +1(516)777-0945
6
+ #
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU General Public License as published by
9
+ # the Free Software Foundation, version 3 of the License.
10
+ #
11
+ # This program is distributed in the hope that it will be useful,
12
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
+ # GNU General Public License for more details.
15
+ #
16
+ # You should have received a copy of the GNU General Public License
17
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
18
+ #
19
+ # Additional terms:
20
+ # Any commercial use or distribution of this software or derivative works
21
+ # requires explicit written permission from the copyright holder.
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from transformers import GPT2TokenizerFast
26
+ from gpt_pytorch import GPTPyTorch # Using the same import as in fine_tune.py
27
+ import os
28
+ from pathlib import Path
29
+
30
+ # ============================= GENERATION SETTINGS =============================
31
+ # Temperature: Lower = more conservative and predictable answers.
32
+ # Start with 0.7. Increase to 0.8 if the model starts repeating itself.
33
+ TEMPERATURE = 0.7
34
+
35
+ # Top-K: Limits sampling to the K most likely tokens.
36
+ # Start with 50. Increase if responses feel too boring/repetitive.
37
+ TOP_K = 50
38
+
39
+ # Max Length: Maximum number of tokens to generate in one go
40
+ MAX_LENGTH = 120
41
+
42
+ # ============================= PATHS =============================
43
+ # LAST_TRAINED_PATH = Path("models/gpt_last_trained.pt")
44
+ LAST_TRAINED_PATH = Path("build/fine_tuning_output/epoch49/gpt_finetuned.pt")
45
+ # FINAL_OUTPUT_DIR = Path("build/fine_tuning_output/final")
46
+ FINAL_OUTPUT_DIR = Path("build/fine_tuning_output/epoch49/gpt_finetuned.pt")
47
+ MODEL_SAVE_NAME = "gpt_finetuned.pt"
48
+
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+
51
+ # ============================= Chatbot CLASS =============================
52
+ class Chatbot:
53
+ def __init__(self, model_path):
54
+ # 1. Tokenizer
55
+ print("Loading standard tokenizer (gpt2)...")
56
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
57
+ self.tokenizer.pad_token = self.tokenizer.eos_token
58
+
59
+ #2. Model
60
+ print("Initializing model...")
61
+ self.model = GPTPyTorch().to(device)
62
+ self.model.eval()
63
+
64
+ # Look for the latest weights: first check final folder, then last_trained
65
+ load_path = None
66
+ if (FINAL_OUTPUT_DIR / MODEL_SAVE_NAME).exists():
67
+ load_path = FINAL_OUTPUT_DIR / MODEL_SAVE_NAME
68
+ print(f"Weights from Epoch 50 found. Loading and moving to {device}...")
69
+ elif model_path.exists():
70
+ load_path = model_path
71
+ print(f"Loading weights from {load_path} and moving to {device}...")
72
+
73
+ if load_path:
74
+ self.model.load_state_dict(torch.load(load_path, map_location=device))
75
+ else:
76
+ print("Warning: No trained weights found. Using randomly initialized model.")
77
+
78
+ print(f"Model successfully loaded on {device} and ready for chat!")
79
+
80
+ def generate_response(self, prompt, max_length=MAX_LENGTH, temperature=TEMPERATURE, top_k=TOP_K):
81
+ # Tokenize input
82
+ input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(device)
83
+
84
+ # Generation loop
85
+ with torch.no_grad():
86
+ for _ in range(max_length):
87
+ # Forward pass through the model
88
+ logits, _ = self.model(input_ids)
89
+
90
+ # Take logits only for the last token
91
+ next_token_logits = logits[:, -1, :]
92
+
93
+ # Apply temperature
94
+ next_token_logits = next_token_logits / temperature
95
+
96
+ # Apply Top-K sampling
97
+ if top_k > 0:
98
+ # Keep only the top-k most likely tokens
99
+ values, indices = torch.topk(next_token_logits, top_k)
100
+ # Zero out everything else (set to -inf)
101
+ next_token_logits = torch.full_like(next_token_logits, float('-inf'))
102
+ next_token_logits.scatter_(1, indices, values)
103
+
104
+ # Convert to probabilities and sample the next token
105
+ probabilities = F.softmax(next_token_logits, dim=-1)
106
+ next_token = torch.multinomial(probabilities, num_samples=1)
107
+
108
+ # Append generated token to the sequence
109
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
110
+
111
+ # Stop if end-of-utterance (__eou__) or EOS token is generated
112
+ generated_token = self.tokenizer.decode(next_token.squeeze().item())
113
+ if "__eou__" in generated_token or next_token.squeeze().item() == self.tokenizer.eos_token_id:
114
+ break
115
+
116
+ # Decode the full generated sequence
117
+ output = self.tokenizer.decode(input_ids.squeeze().tolist())
118
+
119
+ # Remove the original prompt from the output
120
+ response = output[len(prompt):].strip()
121
+
122
+ # Clean up any leftover end-of-utterance tokens
123
+ response = response.replace("__eou__", "").strip()
124
+
125
+ return response
126
+
127
+
128
+ def main():
129
+ # Fix for modifying globals inside the function
130
+ global TEMPERATURE, TOP_K
131
+
132
+ chatbot = Chatbot(LAST_TRAINED_PATH)
133
+
134
+ print("\n" + "="*60)
135
+ print(f"CHATBOT ACTIVATED (PPL ~2.6 / Temperature {TEMPERATURE} / Top-K {TOP_K})")
136
+ print("Type 'exit' or 'quit' to quit. Use 'set temp=0.x' or 'set k=N' to change settings.")
137
+ print("="*60 + "\n")
138
+
139
+ while True:
140
+ try:
141
+ user_input = input(">>> You: ")
142
+
143
+ if user_input.lower() in ['quit', 'exit']:
144
+ print("Goodbye!")
145
+ break
146
+
147
+ # Settings commands
148
+ if user_input.lower().startswith('set temp='):
149
+ try:
150
+ TEMPERATURE = float(user_input.split('=')[1].strip())
151
+ print(f"Temperature updated to {TEMPERATURE}")
152
+ continue
153
+ except ValueError:
154
+ print("Invalid temperature. Use format: set temp=0.7")
155
+ continue
156
+
157
+ if user_input.lower().startswith('set k='):
158
+ try:
159
+ TOP_K = int(user_input.split('=')[1].strip())
160
+ print(f"Top-K updated to {TOP_K}")
161
+ continue
162
+ except ValueError:
163
+ print("Invalid value. Use format: set k=50")
164
+ continue
165
+
166
+ print("...Generating...")
167
+ response = chatbot.generate_response(user_input)
168
+ print(f"Model: {response}\n")
169
+
170
+ except KeyboardInterrupt:
171
+ print("\nGoodbye!")
172
+ break
173
+ except Exception as e:
174
+ print(f"An error occurred: {e}")
175
+ break
176
+
177
+
178
+ if __name__ == "__main__":
179
  main()