File size: 7,052 Bytes
8b7cba2 36d6eb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
# Copyright (c) 2025 CMS Manhattan
# All rights reserved.
# Author: Konstantin Vladimirovich Grabko
# Email: grabko@cmsmanhattan.com
# Phone: +1(516)777-0945
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
# Additional terms:
# Any commercial use or distribution of this software or derivative works
# requires explicit written permission from the copyright holder.
import torch
import torch.nn.functional as F
from transformers import GPT2TokenizerFast
from gpt_modern_8b import JiRackPyTorch # Same import used in fine-tuning
from pathlib import Path
# ============================= GENERATION SETTINGS =============================
# Temperature: Lower = more focused, conservative, and predictable responses
# Start with 0.7. Increase to 0.8–0.9 if the model starts repeating itself
TEMPERATURE = 0.7
# Top-K: Limits sampling to the K most likely next tokens
# Start with 50. Increase if output feels too safe/boring
TOP_K = 50
# Max Length: Maximum number of new tokens to generate per response
MAX_LENGTH = 120
# ============================= PATHS =============================
LAST_TRAINED_PATH = Path("build/fine_tuning_output/epoch2/gpt_finetuned.pt")
FINAL_OUTPUT_DIR = Path("build/fine_tuning_output/epoch2") # Folder containing the .pt
MODEL_SAVE_NAME = "gpt_finetuned.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# ============================= CHATBOT CLASS =============================
class Chatbot:
def __init__(self, model_path: Path):
# 1. Load tokenizer (offline-safe recommended — see note below)
print("Loading standard GPT-2 tokenizer...")
# For full offline use, replace "gpt2" with "./tokenizers/gpt2" after first download
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
self.tokenizer.pad_token = self.tokenizer.eos_token
# 2. Initialize model architecture
print("Initializing JiRackPyTorch model...")
self.model = JiRackPyTorch().to(device)
self.model.eval()
# 3. Load latest trained weights
load_path = None
candidate1 = FINAL_OUTPUT_DIR / MODEL_SAVE_NAME
candidate2 = model_path if model_path.is_file() else None
if candidate1.exists():
load_path = candidate1
print(f"Found weights in final folder: {load_path}")
elif candidate2 and candidate2.exists():
load_path = candidate2
print(f"Loading weights from: {load_path}")
else:
print("Warning: No trained weights found. Running with randomly initialized model.")
if load_path:
print(f"Loading state dict from {load_path}...")
self.model.load_state_dict(torch.load(load_path, map_location=device))
print("Weights loaded successfully!")
print(f"Model is now running on {device} — ready for chat!\n")
def generate_response(self, prompt: str, max_length: int = MAX_LENGTH,
temperature: float = TEMPERATURE, top_k: int = TOP_K) -> str:
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(device)
with torch.no_grad():
for _ in range(max_length):
# Forward pass
logits, _ = self.model(input_ids) # JiRackPyTorch returns (logits, past_kv)
# Get logits for the last generated token
next_token_logits = logits[:, -1, :]
# Apply temperature
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
# Apply Top-K sampling
if top_k > 0:
values, indices = torch.topk(next_token_logits, top_k)
next_token_logits = torch.full_like(next_token_logits, float('-inf'))
next_token_logits.scatter_(1, indices, values)
# Sample next token
probabilities = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probabilities, num_samples=1)
# Append to sequence
input_ids = torch.cat([input_ids, next_token], dim=-1)
# Early stop on EOS or custom end-of-utterance token
token_str = self.tokenizer.decode(next_token.item())
if "__eou__" in token_str or next_token.item() == self.tokenizer.eos_token_id:
break
# Decode full output and strip prompt
full_output = self.tokenizer.decode(input_ids[0], skip_special_tokens=False)
response = full_output[len(prompt):].strip()
# Clean up any leftover markers
response = response.replace("__eou__", "").strip()
return response
# ============================= MAIN CHAT LOOP =============================
def main():
global TEMPERATURE, TOP_K
print("Starting JiRack Chatbot...")
chatbot = Chatbot(LAST_TRAINED_PATH)
print("\n" + "=" * 70)
print(f"JIRACK CHATBOT ONLINE")
print(f"Temperature: {TEMPERATURE} | Top-K: {TOP_K} | Max Length: {MAX_LENGTH}")
print("Type 'quit' or 'exit' to exit")
print("Change settings: set temp=0.8 or set k=80")
print("=" * 70 + "\n")
while True:
try:
user_input = input("You: ").strip()
if user_input.lower() in {"quit", "exit", "bye"}:
print("Goodbye!")
break
# Live parameter tuning
if user_input.lower().startswith("set temp="):
try:
TEMPERATURE = float(user_input.split("=")[1])
print(f"Temperature → {TEMPERATURE}")
except:
print("Invalid format. Use: set temp=0.7")
continue
if user_input.lower().startswith("set k="):
try:
TOP_K = int(user_input.split("=")[1])
print(f"Top-K → {TOP_K}")
except:
print("Invalid format. Use: set k=50")
continue
if not user_input:
continue
print("Generating...", end="\r")
response = chatbot.generate_response(user_input)
print(f"JiRack: {response}\n")
except KeyboardInterrupt:
print("\n\nShutting down...")
break
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
main() |