File size: 6,966 Bytes
0feea22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03744b3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# Copyright (c) 2025 CMS Manhattan
# All rights reserved.
# Author: Konstantin Vladimirovich Grabko
# Email: grabko@cmsmanhattan.com
# Phone: +1(516)777-0945
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# Additional terms:
# Any commercial use or distribution of this software or derivative works
# requires explicit written permission from the copyright holder.

import torch
import torch.nn.functional as F
from transformers import GPT2TokenizerFast
from gpt_pytorch import GPTPyTorch  # Using the same import as in fine_tune.py
import os
from pathlib import Path

# ============================= GENERATION SETTINGS =============================
# Temperature: Lower = more conservative and predictable answers.
# Start with 0.7. Increase to 0.8 if the model starts repeating itself.
TEMPERATURE = 0.7

# Top-K: Limits sampling to the K most likely tokens.
# Start with 50. Increase if responses feel too boring/repetitive.
TOP_K = 50

# Max Length: Maximum number of tokens to generate in one go
MAX_LENGTH = 120

# ============================= PATHS =============================
# LAST_TRAINED_PATH = Path("models/gpt_last_trained.pt")
LAST_TRAINED_PATH = Path("build/fine_tuning_output/epoch49/gpt_finetuned.pt")
# FINAL_OUTPUT_DIR = Path("build/fine_tuning_output/final")
FINAL_OUTPUT_DIR = Path("build/fine_tuning_output/epoch49/gpt_finetuned.pt")
MODEL_SAVE_NAME = "gpt_finetuned.pt"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ============================= Chatbot CLASS =============================
class Chatbot:
    def __init__(self, model_path):
        # 1. Tokenizer
        print("Loading standard tokenizer (gpt2)...")
        self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
        self.tokenizer.pad_token = self.tokenizer.eos_token

        #2. Model
        print("Initializing model...")
        self.model = GPTPyTorch().to(device)
        self.model.eval()

        # Look for the latest weights: first check final folder, then last_trained
        load_path = None
        if (FINAL_OUTPUT_DIR / MODEL_SAVE_NAME).exists():
            load_path = FINAL_OUTPUT_DIR / MODEL_SAVE_NAME
            print(f"Weights from Epoch 50 found. Loading and moving to {device}...")
        elif model_path.exists():
            load_path = model_path
            print(f"Loading weights from {load_path} and moving to {device}...")

        if load_path:
            self.model.load_state_dict(torch.load(load_path, map_location=device))
        else:
            print("Warning: No trained weights found. Using randomly initialized model.")

        print(f"Model successfully loaded on {device} and ready for chat!")

    def generate_response(self, prompt, max_length=MAX_LENGTH, temperature=TEMPERATURE, top_k=TOP_K):
        # Tokenize input
        input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(device)

        # Generation loop
        with torch.no_grad():
            for _ in range(max_length):
                # Forward pass through the model
                logits, _ = self.model(input_ids)

                # Take logits only for the last token
                next_token_logits = logits[:, -1, :]

                # Apply temperature
                next_token_logits = next_token_logits / temperature

                # Apply Top-K sampling
                if top_k > 0:
                    # Keep only the top-k most likely tokens
                    values, indices = torch.topk(next_token_logits, top_k)
                    # Zero out everything else (set to -inf)
                    next_token_logits = torch.full_like(next_token_logits, float('-inf'))
                    next_token_logits.scatter_(1, indices, values)

                # Convert to probabilities and sample the next token
                probabilities = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probabilities, num_samples=1)

                # Append generated token to the sequence
                input_ids = torch.cat([input_ids, next_token], dim=-1)

                # Stop if end-of-utterance (__eou__) or EOS token is generated
                generated_token = self.tokenizer.decode(next_token.squeeze().item())
                if "__eou__" in generated_token or next_token.squeeze().item() == self.tokenizer.eos_token_id:
                    break

        # Decode the full generated sequence
        output = self.tokenizer.decode(input_ids.squeeze().tolist())

        # Remove the original prompt from the output
        response = output[len(prompt):].strip()

        # Clean up any leftover end-of-utterance tokens
        response = response.replace("__eou__", "").strip()

        return response


def main():
    # Fix for modifying globals inside the function
    global TEMPERATURE, TOP_K

    chatbot = Chatbot(LAST_TRAINED_PATH)

    print("\n" + "="*60)
    print(f"CHATBOT ACTIVATED (PPL ~2.6 / Temperature {TEMPERATURE} / Top-K {TOP_K})")
    print("Type 'exit' or 'quit' to quit. Use 'set temp=0.x' or 'set k=N' to change settings.")
    print("="*60 + "\n")

    while True:
        try:
            user_input = input(">>> You: ")

            if user_input.lower() in ['quit', 'exit']:
                print("Goodbye!")
                break

            # Settings commands
            if user_input.lower().startswith('set temp='):
                try:
                    TEMPERATURE = float(user_input.split('=')[1].strip())
                    print(f"Temperature updated to {TEMPERATURE}")
                    continue
                except ValueError:
                    print("Invalid temperature. Use format: set temp=0.7")
                    continue

            if user_input.lower().startswith('set k='):
                try:
                    TOP_K = int(user_input.split('=')[1].strip())
                    print(f"Top-K updated to {TOP_K}")
                    continue
                except ValueError:
                    print("Invalid value. Use format: set k=50")
                    continue

            print("...Generating...")
            response = chatbot.generate_response(user_input)
            print(f"Model: {response}\n")

        except KeyboardInterrupt:
            print("\nGoodbye!")
            break
        except Exception as e:
            print(f"An error occurred: {e}")
            break


if __name__ == "__main__":
    main()