File size: 7,052 Bytes
8b7cba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36d6eb9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Copyright (c) 2025 CMS Manhattan
# All rights reserved.
# Author: Konstantin Vladimirovich Grabko
# Email: grabko@cmsmanhattan.com
# Phone: +1(516)777-0945
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# Additional terms:
# Any commercial use or distribution of this software or derivative works
# requires explicit written permission from the copyright holder.

import torch
import torch.nn.functional as F
from transformers import GPT2TokenizerFast
from gpt_modern_8b import JiRackPyTorch  # Same import used in fine-tuning
from pathlib import Path

# ============================= GENERATION SETTINGS =============================
# Temperature: Lower = more focused, conservative, and predictable responses
# Start with 0.7. Increase to 0.8–0.9 if the model starts repeating itself
TEMPERATURE = 0.7

# Top-K: Limits sampling to the K most likely next tokens
# Start with 50. Increase if output feels too safe/boring
TOP_K = 50

# Max Length: Maximum number of new tokens to generate per response
MAX_LENGTH = 120

# ============================= PATHS =============================
LAST_TRAINED_PATH = Path("build/fine_tuning_output/epoch2/gpt_finetuned.pt")
FINAL_OUTPUT_DIR = Path("build/fine_tuning_output/epoch2")  # Folder containing the .pt
MODEL_SAVE_NAME = "gpt_finetuned.pt"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ============================= CHATBOT CLASS =============================
class Chatbot:
    def __init__(self, model_path: Path):
        # 1. Load tokenizer (offline-safe recommended — see note below)
        print("Loading standard GPT-2 tokenizer...")
        # For full offline use, replace "gpt2" with "./tokenizers/gpt2" after first download
        self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # 2. Initialize model architecture
        print("Initializing JiRackPyTorch model...")
        self.model = JiRackPyTorch().to(device)
        self.model.eval()

        # 3. Load latest trained weights
        load_path = None
        candidate1 = FINAL_OUTPUT_DIR / MODEL_SAVE_NAME
        candidate2 = model_path if model_path.is_file() else None

        if candidate1.exists():
            load_path = candidate1
            print(f"Found weights in final folder: {load_path}")
        elif candidate2 and candidate2.exists():
            load_path = candidate2
            print(f"Loading weights from: {load_path}")
        else:
            print("Warning: No trained weights found. Running with randomly initialized model.")

        if load_path:
            print(f"Loading state dict from {load_path}...")
            self.model.load_state_dict(torch.load(load_path, map_location=device))
            print("Weights loaded successfully!")

        print(f"Model is now running on {device} — ready for chat!\n")

    def generate_response(self, prompt: str, max_length: int = MAX_LENGTH,
                          temperature: float = TEMPERATURE, top_k: int = TOP_K) -> str:
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(device)

        with torch.no_grad():
            for _ in range(max_length):
                # Forward pass
                logits, _ = self.model(input_ids)  # JiRackPyTorch returns (logits, past_kv)

                # Get logits for the last generated token
                next_token_logits = logits[:, -1, :]

                # Apply temperature
                if temperature != 1.0:
                    next_token_logits = next_token_logits / temperature

                # Apply Top-K sampling
                if top_k > 0:
                    values, indices = torch.topk(next_token_logits, top_k)
                    next_token_logits = torch.full_like(next_token_logits, float('-inf'))
                    next_token_logits.scatter_(1, indices, values)

                # Sample next token
                probabilities = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probabilities, num_samples=1)

                # Append to sequence
                input_ids = torch.cat([input_ids, next_token], dim=-1)

                # Early stop on EOS or custom end-of-utterance token
                token_str = self.tokenizer.decode(next_token.item())
                if "__eou__" in token_str or next_token.item() == self.tokenizer.eos_token_id:
                    break

        # Decode full output and strip prompt
        full_output = self.tokenizer.decode(input_ids[0], skip_special_tokens=False)
        response = full_output[len(prompt):].strip()

        # Clean up any leftover markers
        response = response.replace("__eou__", "").strip()

        return response


# ============================= MAIN CHAT LOOP =============================
def main():
    global TEMPERATURE, TOP_K

    print("Starting JiRack Chatbot...")
    chatbot = Chatbot(LAST_TRAINED_PATH)

    print("\n" + "=" * 70)
    print(f"JIRACK CHATBOT ONLINE")
    print(f"Temperature: {TEMPERATURE} | Top-K: {TOP_K} | Max Length: {MAX_LENGTH}")
    print("Type 'quit' or 'exit' to exit")
    print("Change settings: set temp=0.8   or   set k=80")
    print("=" * 70 + "\n")

    while True:
        try:
            user_input = input("You: ").strip()

            if user_input.lower() in {"quit", "exit", "bye"}:
                print("Goodbye!")
                break

            # Live parameter tuning
            if user_input.lower().startswith("set temp="):
                try:
                    TEMPERATURE = float(user_input.split("=")[1])
                    print(f"Temperature → {TEMPERATURE}")
                except:
                    print("Invalid format. Use: set temp=0.7")
                continue

            if user_input.lower().startswith("set k="):
                try:
                    TOP_K = int(user_input.split("=")[1])
                    print(f"Top-K → {TOP_K}")
                except:
                    print("Invalid format. Use: set k=50")
                continue

            if not user_input:
                continue

            print("Generating...", end="\r")
            response = chatbot.generate_response(user_input)
            print(f"JiRack: {response}\n")

        except KeyboardInterrupt:
            print("\n\nShutting down...")
            break
        except Exception as e:
            print(f"Error: {e}")

if __name__ == "__main__":
    main()