File size: 6,966 Bytes
0feea22 03744b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
# Copyright (c) 2025 CMS Manhattan
# All rights reserved.
# Author: Konstantin Vladimirovich Grabko
# Email: grabko@cmsmanhattan.com
# Phone: +1(516)777-0945
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
# Additional terms:
# Any commercial use or distribution of this software or derivative works
# requires explicit written permission from the copyright holder.
import torch
import torch.nn.functional as F
from transformers import GPT2TokenizerFast
from gpt_pytorch import GPTPyTorch # Using the same import as in fine_tune.py
import os
from pathlib import Path
# ============================= GENERATION SETTINGS =============================
# Temperature: Lower = more conservative and predictable answers.
# Start with 0.7. Increase to 0.8 if the model starts repeating itself.
TEMPERATURE = 0.7
# Top-K: Limits sampling to the K most likely tokens.
# Start with 50. Increase if responses feel too boring/repetitive.
TOP_K = 50
# Max Length: Maximum number of tokens to generate in one go
MAX_LENGTH = 120
# ============================= PATHS =============================
# LAST_TRAINED_PATH = Path("models/gpt_last_trained.pt")
LAST_TRAINED_PATH = Path("build/fine_tuning_output/epoch49/gpt_finetuned.pt")
# FINAL_OUTPUT_DIR = Path("build/fine_tuning_output/final")
FINAL_OUTPUT_DIR = Path("build/fine_tuning_output/epoch49/gpt_finetuned.pt")
MODEL_SAVE_NAME = "gpt_finetuned.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ============================= Chatbot CLASS =============================
class Chatbot:
def __init__(self, model_path):
# 1. Tokenizer
print("Loading standard tokenizer (gpt2)...")
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
self.tokenizer.pad_token = self.tokenizer.eos_token
#2. Model
print("Initializing model...")
self.model = GPTPyTorch().to(device)
self.model.eval()
# Look for the latest weights: first check final folder, then last_trained
load_path = None
if (FINAL_OUTPUT_DIR / MODEL_SAVE_NAME).exists():
load_path = FINAL_OUTPUT_DIR / MODEL_SAVE_NAME
print(f"Weights from Epoch 50 found. Loading and moving to {device}...")
elif model_path.exists():
load_path = model_path
print(f"Loading weights from {load_path} and moving to {device}...")
if load_path:
self.model.load_state_dict(torch.load(load_path, map_location=device))
else:
print("Warning: No trained weights found. Using randomly initialized model.")
print(f"Model successfully loaded on {device} and ready for chat!")
def generate_response(self, prompt, max_length=MAX_LENGTH, temperature=TEMPERATURE, top_k=TOP_K):
# Tokenize input
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(device)
# Generation loop
with torch.no_grad():
for _ in range(max_length):
# Forward pass through the model
logits, _ = self.model(input_ids)
# Take logits only for the last token
next_token_logits = logits[:, -1, :]
# Apply temperature
next_token_logits = next_token_logits / temperature
# Apply Top-K sampling
if top_k > 0:
# Keep only the top-k most likely tokens
values, indices = torch.topk(next_token_logits, top_k)
# Zero out everything else (set to -inf)
next_token_logits = torch.full_like(next_token_logits, float('-inf'))
next_token_logits.scatter_(1, indices, values)
# Convert to probabilities and sample the next token
probabilities = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probabilities, num_samples=1)
# Append generated token to the sequence
input_ids = torch.cat([input_ids, next_token], dim=-1)
# Stop if end-of-utterance (__eou__) or EOS token is generated
generated_token = self.tokenizer.decode(next_token.squeeze().item())
if "__eou__" in generated_token or next_token.squeeze().item() == self.tokenizer.eos_token_id:
break
# Decode the full generated sequence
output = self.tokenizer.decode(input_ids.squeeze().tolist())
# Remove the original prompt from the output
response = output[len(prompt):].strip()
# Clean up any leftover end-of-utterance tokens
response = response.replace("__eou__", "").strip()
return response
def main():
# Fix for modifying globals inside the function
global TEMPERATURE, TOP_K
chatbot = Chatbot(LAST_TRAINED_PATH)
print("\n" + "="*60)
print(f"CHATBOT ACTIVATED (PPL ~2.6 / Temperature {TEMPERATURE} / Top-K {TOP_K})")
print("Type 'exit' or 'quit' to quit. Use 'set temp=0.x' or 'set k=N' to change settings.")
print("="*60 + "\n")
while True:
try:
user_input = input(">>> You: ")
if user_input.lower() in ['quit', 'exit']:
print("Goodbye!")
break
# Settings commands
if user_input.lower().startswith('set temp='):
try:
TEMPERATURE = float(user_input.split('=')[1].strip())
print(f"Temperature updated to {TEMPERATURE}")
continue
except ValueError:
print("Invalid temperature. Use format: set temp=0.7")
continue
if user_input.lower().startswith('set k='):
try:
TOP_K = int(user_input.split('=')[1].strip())
print(f"Top-K updated to {TOP_K}")
continue
except ValueError:
print("Invalid value. Use format: set k=50")
continue
print("...Generating...")
response = chatbot.generate_response(user_input)
print(f"Model: {response}\n")
except KeyboardInterrupt:
print("\nGoodbye!")
break
except Exception as e:
print(f"An error occurred: {e}")
break
if __name__ == "__main__":
main() |