kgrabko commited on
Commit
8b7cba2
·
verified ·
1 Parent(s): 108847b

Update chatbot_1b.py

Browse files
Files changed (1) hide show
  1. chatbot_1b.py +180 -164
chatbot_1b.py CHANGED
@@ -1,165 +1,181 @@
1
- # Copyright (c) 2025 CMS Manhattan
2
- # All rights reserved.
3
- #
4
- # This file is part of a project authored by CMS Manhattan.
5
- # You may use, distribute, and modify this code under the terms of the Apache 2.0 license.
6
-
7
- import torch
8
- import torch.nn.functional as F
9
- from transformers import GPT2TokenizerFast
10
- from gpt_modern_8b import JiRackPyTorch # Same import used in fine-tuning
11
- from pathlib import Path
12
-
13
- # ============================= GENERATION SETTINGS =============================
14
- # Temperature: Lower = more focused, conservative, and predictable responses
15
- # Start with 0.7. Increase to 0.8–0.9 if the model starts repeating itself
16
- TEMPERATURE = 0.7
17
-
18
- # Top-K: Limits sampling to the K most likely next tokens
19
- # Start with 50. Increase if output feels too safe/boring
20
- TOP_K = 50
21
-
22
- # Max Length: Maximum number of new tokens to generate per response
23
- MAX_LENGTH = 120
24
-
25
- # ============================= PATHS =============================
26
- LAST_TRAINED_PATH = Path("build/fine_tuning_output/epoch2/gpt_finetuned.pt")
27
- FINAL_OUTPUT_DIR = Path("build/fine_tuning_output/epoch2") # Folder containing the .pt
28
- MODEL_SAVE_NAME = "gpt_finetuned.pt"
29
-
30
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
- print(f"Using device: {device}")
32
-
33
- # ============================= CHATBOT CLASS =============================
34
- class Chatbot:
35
- def __init__(self, model_path: Path):
36
- # 1. Load tokenizer (offline-safe recommended — see note below)
37
- print("Loading standard GPT-2 tokenizer...")
38
- # For full offline use, replace "gpt2" with "./tokenizers/gpt2" after first download
39
- self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
40
- self.tokenizer.pad_token = self.tokenizer.eos_token
41
-
42
- # 2. Initialize model architecture
43
- print("Initializing JiRackPyTorch model...")
44
- self.model = JiRackPyTorch().to(device)
45
- self.model.eval()
46
-
47
- # 3. Load latest trained weights
48
- load_path = None
49
- candidate1 = FINAL_OUTPUT_DIR / MODEL_SAVE_NAME
50
- candidate2 = model_path if model_path.is_file() else None
51
-
52
- if candidate1.exists():
53
- load_path = candidate1
54
- print(f"Found weights in final folder: {load_path}")
55
- elif candidate2 and candidate2.exists():
56
- load_path = candidate2
57
- print(f"Loading weights from: {load_path}")
58
- else:
59
- print("Warning: No trained weights found. Running with randomly initialized model.")
60
-
61
- if load_path:
62
- print(f"Loading state dict from {load_path}...")
63
- self.model.load_state_dict(torch.load(load_path, map_location=device))
64
- print("Weights loaded successfully!")
65
-
66
- print(f"Model is now running on {device} — ready for chat!\n")
67
-
68
- def generate_response(self, prompt: str, max_length: int = MAX_LENGTH,
69
- temperature: float = TEMPERATURE, top_k: int = TOP_K) -> str:
70
- input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(device)
71
-
72
- with torch.no_grad():
73
- for _ in range(max_length):
74
- # Forward pass
75
- logits, _ = self.model(input_ids) # JiRackPyTorch returns (logits, past_kv)
76
-
77
- # Get logits for the last generated token
78
- next_token_logits = logits[:, -1, :]
79
-
80
- # Apply temperature
81
- if temperature != 1.0:
82
- next_token_logits = next_token_logits / temperature
83
-
84
- # Apply Top-K sampling
85
- if top_k > 0:
86
- values, indices = torch.topk(next_token_logits, top_k)
87
- next_token_logits = torch.full_like(next_token_logits, float('-inf'))
88
- next_token_logits.scatter_(1, indices, values)
89
-
90
- # Sample next token
91
- probabilities = F.softmax(next_token_logits, dim=-1)
92
- next_token = torch.multinomial(probabilities, num_samples=1)
93
-
94
- # Append to sequence
95
- input_ids = torch.cat([input_ids, next_token], dim=-1)
96
-
97
- # Early stop on EOS or custom end-of-utterance token
98
- token_str = self.tokenizer.decode(next_token.item())
99
- if "__eou__" in token_str or next_token.item() == self.tokenizer.eos_token_id:
100
- break
101
-
102
- # Decode full output and strip prompt
103
- full_output = self.tokenizer.decode(input_ids[0], skip_special_tokens=False)
104
- response = full_output[len(prompt):].strip()
105
-
106
- # Clean up any leftover markers
107
- response = response.replace("__eou__", "").strip()
108
-
109
- return response
110
-
111
-
112
- # ============================= MAIN CHAT LOOP =============================
113
- def main():
114
- global TEMPERATURE, TOP_K
115
-
116
- print("Starting JiRack Chatbot...")
117
- chatbot = Chatbot(LAST_TRAINED_PATH)
118
-
119
- print("\n" + "=" * 70)
120
- print(f"JIRACK CHATBOT ONLINE")
121
- print(f"Temperature: {TEMPERATURE} | Top-K: {TOP_K} | Max Length: {MAX_LENGTH}")
122
- print("Type 'quit' or 'exit' to exit")
123
- print("Change settings: set temp=0.8 or set k=80")
124
- print("=" * 70 + "\n")
125
-
126
- while True:
127
- try:
128
- user_input = input("You: ").strip()
129
-
130
- if user_input.lower() in {"quit", "exit", "bye"}:
131
- print("Goodbye!")
132
- break
133
-
134
- # Live parameter tuning
135
- if user_input.lower().startswith("set temp="):
136
- try:
137
- TEMPERATURE = float(user_input.split("=")[1])
138
- print(f"Temperature {TEMPERATURE}")
139
- except:
140
- print("Invalid format. Use: set temp=0.7")
141
- continue
142
-
143
- if user_input.lower().startswith("set k="):
144
- try:
145
- TOP_K = int(user_input.split("=")[1])
146
- print(f"Top-K {TOP_K}")
147
- except:
148
- print("Invalid format. Use: set k=50")
149
- continue
150
-
151
- if not user_input:
152
- continue
153
-
154
- print("Generating...", end="\r")
155
- response = chatbot.generate_response(user_input)
156
- print(f"JiRack: {response}\n")
157
-
158
- except KeyboardInterrupt:
159
- print("\n\nShutting down...")
160
- break
161
- except Exception as e:
162
- print(f"Error: {e}")
163
-
164
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  main()
 
1
+ # Copyright (c) 2025 CMS Manhattan
2
+ # All rights reserved.
3
+ # Author: Konstantin Vladimirovich Grabko
4
+ # Email: grabko@cmsmanhattan.com
5
+ # Phone: +1(516)777-0945
6
+ #
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU General Public License as published by
9
+ # the Free Software Foundation, version 3 of the License.
10
+ #
11
+ # This program is distributed in the hope that it will be useful,
12
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
+ # GNU General Public License for more details.
15
+ #
16
+ # You should have received a copy of the GNU General Public License
17
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
18
+ #
19
+ # Additional terms:
20
+ # Any commercial use or distribution of this software or derivative works
21
+ # requires explicit written permission from the copyright holder.
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from transformers import GPT2TokenizerFast
26
+ from gpt_modern_8b import JiRackPyTorch # Same import used in fine-tuning
27
+ from pathlib import Path
28
+
29
+ # ============================= GENERATION SETTINGS =============================
30
+ # Temperature: Lower = more focused, conservative, and predictable responses
31
+ # Start with 0.7. Increase to 0.8–0.9 if the model starts repeating itself
32
+ TEMPERATURE = 0.7
33
+
34
+ # Top-K: Limits sampling to the K most likely next tokens
35
+ # Start with 50. Increase if output feels too safe/boring
36
+ TOP_K = 50
37
+
38
+ # Max Length: Maximum number of new tokens to generate per response
39
+ MAX_LENGTH = 120
40
+
41
+ # ============================= PATHS =============================
42
+ LAST_TRAINED_PATH = Path("build/fine_tuning_output/epoch2/gpt_finetuned.pt")
43
+ FINAL_OUTPUT_DIR = Path("build/fine_tuning_output/epoch2") # Folder containing the .pt
44
+ MODEL_SAVE_NAME = "gpt_finetuned.pt"
45
+
46
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+ print(f"Using device: {device}")
48
+
49
+ # ============================= CHATBOT CLASS =============================
50
+ class Chatbot:
51
+ def __init__(self, model_path: Path):
52
+ # 1. Load tokenizer (offline-safe recommended — see note below)
53
+ print("Loading standard GPT-2 tokenizer...")
54
+ # For full offline use, replace "gpt2" with "./tokenizers/gpt2" after first download
55
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
56
+ self.tokenizer.pad_token = self.tokenizer.eos_token
57
+
58
+ # 2. Initialize model architecture
59
+ print("Initializing JiRackPyTorch model...")
60
+ self.model = JiRackPyTorch().to(device)
61
+ self.model.eval()
62
+
63
+ # 3. Load latest trained weights
64
+ load_path = None
65
+ candidate1 = FINAL_OUTPUT_DIR / MODEL_SAVE_NAME
66
+ candidate2 = model_path if model_path.is_file() else None
67
+
68
+ if candidate1.exists():
69
+ load_path = candidate1
70
+ print(f"Found weights in final folder: {load_path}")
71
+ elif candidate2 and candidate2.exists():
72
+ load_path = candidate2
73
+ print(f"Loading weights from: {load_path}")
74
+ else:
75
+ print("Warning: No trained weights found. Running with randomly initialized model.")
76
+
77
+ if load_path:
78
+ print(f"Loading state dict from {load_path}...")
79
+ self.model.load_state_dict(torch.load(load_path, map_location=device))
80
+ print("Weights loaded successfully!")
81
+
82
+ print(f"Model is now running on {device} — ready for chat!\n")
83
+
84
+ def generate_response(self, prompt: str, max_length: int = MAX_LENGTH,
85
+ temperature: float = TEMPERATURE, top_k: int = TOP_K) -> str:
86
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(device)
87
+
88
+ with torch.no_grad():
89
+ for _ in range(max_length):
90
+ # Forward pass
91
+ logits, _ = self.model(input_ids) # JiRackPyTorch returns (logits, past_kv)
92
+
93
+ # Get logits for the last generated token
94
+ next_token_logits = logits[:, -1, :]
95
+
96
+ # Apply temperature
97
+ if temperature != 1.0:
98
+ next_token_logits = next_token_logits / temperature
99
+
100
+ # Apply Top-K sampling
101
+ if top_k > 0:
102
+ values, indices = torch.topk(next_token_logits, top_k)
103
+ next_token_logits = torch.full_like(next_token_logits, float('-inf'))
104
+ next_token_logits.scatter_(1, indices, values)
105
+
106
+ # Sample next token
107
+ probabilities = F.softmax(next_token_logits, dim=-1)
108
+ next_token = torch.multinomial(probabilities, num_samples=1)
109
+
110
+ # Append to sequence
111
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
112
+
113
+ # Early stop on EOS or custom end-of-utterance token
114
+ token_str = self.tokenizer.decode(next_token.item())
115
+ if "__eou__" in token_str or next_token.item() == self.tokenizer.eos_token_id:
116
+ break
117
+
118
+ # Decode full output and strip prompt
119
+ full_output = self.tokenizer.decode(input_ids[0], skip_special_tokens=False)
120
+ response = full_output[len(prompt):].strip()
121
+
122
+ # Clean up any leftover markers
123
+ response = response.replace("__eou__", "").strip()
124
+
125
+ return response
126
+
127
+
128
+ # ============================= MAIN CHAT LOOP =============================
129
+ def main():
130
+ global TEMPERATURE, TOP_K
131
+
132
+ print("Starting JiRack Chatbot...")
133
+ chatbot = Chatbot(LAST_TRAINED_PATH)
134
+
135
+ print("\n" + "=" * 70)
136
+ print(f"JIRACK CHATBOT ONLINE")
137
+ print(f"Temperature: {TEMPERATURE} | Top-K: {TOP_K} | Max Length: {MAX_LENGTH}")
138
+ print("Type 'quit' or 'exit' to exit")
139
+ print("Change settings: set temp=0.8 or set k=80")
140
+ print("=" * 70 + "\n")
141
+
142
+ while True:
143
+ try:
144
+ user_input = input("You: ").strip()
145
+
146
+ if user_input.lower() in {"quit", "exit", "bye"}:
147
+ print("Goodbye!")
148
+ break
149
+
150
+ # Live parameter tuning
151
+ if user_input.lower().startswith("set temp="):
152
+ try:
153
+ TEMPERATURE = float(user_input.split("=")[1])
154
+ print(f"Temperature → {TEMPERATURE}")
155
+ except:
156
+ print("Invalid format. Use: set temp=0.7")
157
+ continue
158
+
159
+ if user_input.lower().startswith("set k="):
160
+ try:
161
+ TOP_K = int(user_input.split("=")[1])
162
+ print(f"Top-K {TOP_K}")
163
+ except:
164
+ print("Invalid format. Use: set k=50")
165
+ continue
166
+
167
+ if not user_input:
168
+ continue
169
+
170
+ print("Generating...", end="\r")
171
+ response = chatbot.generate_response(user_input)
172
+ print(f"JiRack: {response}\n")
173
+
174
+ except KeyboardInterrupt:
175
+ print("\n\nShutting down...")
176
+ break
177
+ except Exception as e:
178
+ print(f"Error: {e}")
179
+
180
+ if __name__ == "__main__":
181
  main()