aithing / ai_thingy.py
Aobangaming's picture
Rename ai thingy.py to ai_thingy.py
46d8c0c verified
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
import math
import time
import csv
import os
import re
# --- HYPERPARAMETERS ---
D_MODEL = 128
NUM_HEADS = 8
NUM_LAYERS = 10
DROPOUT = 0.2
MAX_SEQ_LENGTH = 17
LEARNING_RATE = 0.0002
NUM_EPOCHS = 20 # Default full training epochs
BATCH_SIZE = 64
NUM_EPOCHS = 100# Default full training epochs (increased)
INTERACTIVE_EPOCHS = 50 # Epochs for quick retraining (increased)
# --- GENERATION SETTINGS ---
TOP_K = 3
REPETITION_PENALTY = 3
TEMPERATURE = 1
PENALTIES_FILE = 'Penalties.csv'
def load_penalties():
loaded_penalties = []
if os.path.exists('Penalties.csv'):
with open('Penalties.csv', 'r', encoding='utf-8') as f:
reader = csv.reader(f)
for row in reader:
if len(row) >= 2:
# Store as a list of [sentence, penalty_value]
loaded_penalties.append([row[0], float(row[1])])
elif row:
# Fallback for old 1-column rows
loaded_penalties.append([row[0], 3.0])
return loaded_penalties
def save_single_penalty(penalty_string):
"""Appends a new penalty to the CSV immediately."""
with open(PENALTIES_FILE, 'a', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow([penalty_string])
SETTINGS_FILE = 'settings.csv'
def save_settings(penalty, temp):
with open('settings.csv', 'a', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
# We save both so we have a history of your "knob" turns
writer.writerow([penalty, temp])
print(f"[Console] Logged to settings history: Penalty={penalty}, Temp={temp}")
def load_settings():
if os.path.exists('settings.csv'):
with open('settings.csv', 'r', encoding='utf-8') as f:
reader = csv.reader(f)
last_row = None
for row in reader:
if row:
last_row = row
if last_row:
return float(last_row[0]), float(last_row[1]) # Return penalty and temp
return 3.0, 1.0 # Default if file doesn't exist
# --- At the start of your script ---
penalties = load_penalties()
# --- PERSISTENCE CONFIGURATION ---
DATA_FILE = 'training_data.csv' # File where all training data is stored
# --- INITIAL DATA FALLBACK (The 27 sentences you provided) ---
DEFAULT_TRAINING_DATA = [
"The quick brown fox jumps over the lazy dog.",
"A glass of water is clear.",
"The sun is shining bright and the sky is clear.",
"The dog and the fox are friends forever.",
"Coding with Pytorch and Transformers is fun and very rewarding.",
"A computer runs very fast and never stops.",
"The windows are big and bright.",
"A green park is a great place to relax.",
"The sky is clear today, with no clouds.",
"The cat jumped over the fence.",
"The plane has many windows.",
"A big bird flew over the house.",
"The plane smoothly landed on the concrete runway.",
"The bird flew above the bustling city.",
"The plane had an engine failure and had to land in the river.",
"The Cessna 172 is a low-wing monoplane.",
"The plane flew by the trees.",
"The plane, almost out of fuel, finally landed at an airport.",
"The angry bird flew away furiously.",
"A plane is a machine that flies.",
"The fast plane landed at the bright airport.",
"The plane quickly landed on the runway.",
"The letter A is part of the alphabet.",
"The plane landed hardly on a grass runway in the forest.",
"The clouds were floating above the ground.",
"The plane was a very bright plane, it's livery glimmered in the night sky.",
"The GPWS sounds on a plane are like Caution Terrain PULL up PULL up."
]
# --- FILE I/O FUNCTIONS (CRITICAL FOR PERSISTENCE) ---
def load_data_from_csv(filepath):
"""Loads all training sentences from the CSV file, or returns the default data."""
texts = []
def split_into_sentences(paragraph):
# Split on sentence end punctuation followed by whitespace and a capital or number
# Use a safe regex string and fall back to newline/sentence punctuation splitting on error
try:
pattern = r'(?<=[\.\!?])\s+(?=[A-Z0-9"\'""\u201c])'
parts = re.split(pattern, paragraph)
return [p.strip() for p in parts if p and p.strip()]
except re.error:
# fallback: split on sentence enders and newlines
parts = re.split(r'[\.\!?]\s+|\n+', paragraph)
return [p.strip() for p in parts if p and p.strip()]
# Attempt to read existing data
if os.path.exists(filepath) and os.path.getsize(filepath) > 0:
print(f"[Console] Loading training data from {filepath}...")
try:
with open(filepath, 'r', newline='', encoding='utf-8') as f:
reader = csv.reader(f)
raw_rows = []
for row in reader:
if row and row[0].strip():
raw_text = row[0].strip()
# Remove surrounding quotes if present
if (raw_text.startswith('"') and raw_text.endswith('"')) or (raw_text.startswith("'") and raw_text.endswith("'")):
raw_text = raw_text[1:-1].strip()
if raw_text:
raw_rows.append(raw_text)
# Now split rows into sentences, filter and handle adjacent runs
sequence = []
for raw in raw_rows:
# If the row contains multiple sentences, split them
parts = split_into_sentences(raw)
# If splitting produced only one part but it contains multiple internal newlines, also split on newlines
if len(parts) == 1 and '\n' in parts[0]:
parts = [p.strip() for p in parts[0].splitlines() if p.strip()]
for s in parts:
# Normalize whitespace and strip quotes
s_clean = ' '.join(s.split()).strip(' "\'')
words = s_clean.split()
# Basic length filters to remove garbage/too-short sentences
if len(words) < 3:
continue
if len(words) > 300:
# skip extremely long paragraphs
continue
# Filter out noisy/corrupted lines
# Skip if contains excessive repetition (same word 3+ times in a row)
is_noisy = False
for i in range(len(words) - 2):
if words[i] == words[i+1] == words[i+2]:
is_noisy = True
break
if is_noisy:
continue
# Skip lines that look like training artifacts (high ratio of common junk words)
junk_patterns = ['pull', 'up', 'land', 'river', 'sky', 'clear', 'table']
junk_count = sum(1 for w in words if w in junk_patterns)
if junk_count > len(words) * 0.9999999: # more than 30% junk
continue
sequence.append(s_clean)
# Collapse consecutive identical sentences (runs) to at most two copies
i = 0
while i < len(sequence):
j = i + 1
while j < len(sequence) and sequence[j] == sequence[i]:
j += 1
run_len = j - i
if run_len == 1:
texts.append(sequence[i])
else:
# keep first and last occurrence of the run
texts.append(sequence[i])
texts.append(sequence[i])
i = j
except Exception as e:
print(f"[Console Bug] Error loading CSV: {e}. Falling back to default data.")
texts = [] # Clear corrupted load
# If no data loaded (file missing, empty, or corrupted), use the default knowledge base
if not texts:
print("[Console] CSV file not found or empty. Using default knowledge base.")
return list(DEFAULT_TRAINING_DATA)
# Debug: report how many sentences were actually loaded and sample content
print(f"[Console] Loaded {len(texts)} sentence(s) from {filepath}.")
sample_head = texts[:10]
sample_tail = texts[-10:]
print("[Console] First loaded sentences:")
for i, s in enumerate(sample_head, 1):
print(f" {i}: {s[:200]}")
if len(texts) > 10:
print("[Console] Last loaded sentences:")
start_index = max(0, len(texts) - 10)
for i, s in enumerate(texts[start_index:], start_index + 1):
print(f" {i}: {s[:200]}")
return texts
def save_data_to_csv(filepath, texts):
"""Saves the entire list of training sentences (including new ones) to the CSV."""
# The 'w' mode ensures the file is overwritten with the complete, updated dataset.
print(f"[Console] Saving {len(texts)} sentences to {filepath} using 'w' mode...")
try:
with open(filepath, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
# Write each sentence as a single row/column entry
for text in texts:
writer.writerow([text])
except Exception as e:
print(f"[Console Bug] Error saving to CSV: {e}")
# --- TOKENIZER ---
class SimpleTokenizer:
def __init__(self, texts):
self.word_to_idx = {"<PAD>": 0, "<UNK>": 1}
self.idx_to_word = {0: "<PAD>", 1: "<UNK>"}
self.build_vocab(texts)
def build_vocab(self, texts):
for text in texts:
for word in text.lower().split():
word = word.strip("/<>")
if word not in self.word_to_idx:
idx = len(self.word_to_idx)
self.word_to_idx[word] = idx
self.idx_to_word[idx] = word
def encode(self, text, max_len):
words = [word.strip(".,!?") for word in text.lower().split()]
indices = [self.word_to_idx.get(word, self.word_to_idx["<UNK>"]) for word in words]
# Padding and Truncation
if len(indices) < max_len:
indices.extend([self.word_to_idx["<PAD>"]] * (max_len - len(indices)))
elif len(indices) > max_len:
indices = indices[:max_len]
return torch.tensor(indices, dtype=torch.long)
def decode(self, indices):
return " ".join([self.idx_to_word.get(idx.item(), "<UNK>") for idx in indices if idx.item() != self.word_to_idx["<PAD>"]])
@property
def vocab_size(self):
return len(self.word_to_idx)
# --- DATASET ---
class TextDataset(Dataset):
def __init__(self, texts, tokenizer, max_len):
self.data = []
for text in texts:
encoded = tokenizer.encode(text, max_len)
self.data.append(encoded)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# --- TRANSFORMER MODEL COMPONENTS (UNMODIFIED) ---
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(1), :].transpose(0, 1)
class TransformerLanguageModel(nn.Module):
def __init__(self, vocab_size, d_model, nhead, num_layers, dropout, max_len):
super(TransformerLanguageModel, self).__init__()
self.model_type = 'Transformer'
self.d_model = d_model
self.vocab_size = vocab_size
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, max_len)
# Use decoder layers for proper causal masking in text generation
decoder_layer = nn.TransformerDecoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=d_model*4,
dropout=dropout,
batch_first=True,
activation='gelu'
)
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
self.fc_out = nn.Linear(d_model, vocab_size)
self.init_weights()
def init_weights(self):
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc_out.bias.data.zero_()
self.fc_out.weight.data.uniform_(-initrange, initrange)
def forward(self, src):
src = self.embedding(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
# Create causal mask to prevent attending to future tokens
seq_len = src.size(1)
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=src.device) * float('-inf'), diagonal=1)
# Decoder expects (tgt, memory) but we use same for both (causal language modeling)
output = self.transformer_decoder(src, src, tgt_mask=causal_mask)
return self.fc_out(output)
# --- TRAINING FUNCTIONS ---
def train_model(model, data_loader, optimizer, criterion, device, epochs):
model.train()
for epoch in range(1, epochs + 1):
total_loss = 0.0
for batch in data_loader:
batch = batch.to(device)
# Unpacking the cargo correctly
src = batch[:, :-1]
tgt = batch[:, 1:]
optimizer.zero_grad()
# 1. THE BRAIN THINKS
output = model(src)
# 2. THE STANDARD LOSS (Lore Check)
# We reshape for CrossEntropy
current_loss = criterion(output.reshape(-1, output.size(-1)), tgt.reshape(-1))
# >>> THE BLACK BOX PENALTY (Mechanical Scandal Check) <<<
# If loss is at the "2.0 Purgatory", we slap it with a 0.5 penalty
if current_loss.item() > 1.9:
penalty = 1
current_loss = current_loss + penalty
# 3. THE REACTION
current_loss.backward()
optimizer.step()
total_loss += current_loss.item()
avg_loss = total_loss / len(data_loader)
if epochs > 10 and epoch % (epochs // 10) == 0:
print(f"Epoch {epoch}/{epochs}, Average Loss: {avg_loss:.4f}")
elif epochs > 0 and epochs <= 50 and epoch % 10 == 0:
print(f"Epoch {epoch}/{epochs}, Average Loss: {avg_loss:.4f}")
print(f"[Console] The {NUM_LAYERS} Layers have been updated.")
penalties = [] #Penalty list to track penalties applied during generation
# --- GENERATION FUNCTION (UNMODIFIED) ---
def generate_text(model, tokenizer, prompt, max_len, device, top_k=40, penalty=1.8, temperature=1.0):
global penalties, last_generated_text
global REPETITION_PENALTY, TEMPERATURE # <--- LINK TO THE SAVED SETTINGS
# If the function wasn't given a specific number, use the global loaded one
if penalty is None:
penalty = REPETITION_PENALTY
if temperature is None:
temperature = TEMPERATURE
model.eval()
encoded_prompt = tokenizer.encode(prompt, max_len=max_len).to(device)
# Count non-PAD tokens in encoded prompt to get true prompt length
pad_idx = tokenizer.word_to_idx["<PAD>"]
prompt_len = (encoded_prompt != pad_idx).sum().item()
generated_indices = encoded_prompt[:prompt_len].tolist()
input_ids = encoded_prompt.unsqueeze(0)
if prompt in penalties:
word = last_generated_text.lower().split()
banned_ids = [tokenizer.word_to_idx.get(w, 1) for w in word]
for i in range(prompt_len, max_len):
src_input = input_ids[:, :i]
with torch.no_grad():
output = model(src_input)
logits = output[0, i-1, :]
# Use this one!
# In your generate_text loop:
for record in penalties:
bad_sentence = record[0] # The string
saved_penalty = record[1] # The specific value for THIS mistake
if bad_sentence.startswith(prompt):
bad_words = bad_sentence[len(prompt):].strip().lower().split()
banned_ids = [tokenizer.word_to_idx.get(w, 1) for w in bad_words]
for bid in banned_ids:
# We use saved_penalty here, NOT the global one!
logits[bid] -= saved_penalty
# ... (The rest of the code continues with: logits = logits / TEMPERATURE, TOP_K filtering, etc.)
# Apply Repetition Penalty
history = generated_indices
for idx in set(history):
if logits[idx] > 0:
logits[idx] /= penalty
else:
logits[idx] *= penalty
# Apply temperature scaling before top-k
logits = logits / temperature
# Apply Top-K Sampling
top_k_values, top_k_indices = torch.topk(logits, min(top_k, len(logits)))
probabilities = torch.softmax(top_k_values, dim=0)
try:
next_token_idx = torch.multinomial(probabilities, num_samples=1).item()
except RuntimeError:
predicted_token = top_k_indices[0].item()
if predicted_token == tokenizer.word_to_idx["<PAD>"]:
break
else:
predicted_token = top_k_indices[next_token_idx].item()
generated_indices.append(predicted_token)
input_ids[0, i] = predicted_token
# --- START USER-REQUESTED WERE PLURALIZATION RULE ---
# Decode only the continuation text
decoded_text = tokenizer.decode(torch.tensor(generated_indices, dtype=torch.long))
prompt_words = [word.strip(".,!?") for word in prompt.lower().split()]
decoded_words = decoded_text.split()
start_index = len(prompt_words)
continuation_text = " ".join(decoded_words[start_index:])
return continuation_text.replace(" <pad>", "").strip()
# --- MAIN EXECUTION ---
# Global variables for model/tokenizer instances
last_generated_text = None
last_user_prompt = None
current_tokenizer = None
current_model = None
device = torch.device("cpu")
live_data_updates = [] # Temporary queue for new sentences added during the current session
initial_training_texts = [] # Stores all data loaded from CSV
def initialize_or_retrain(initial_train=True, use_live_data=False, epochs=NUM_EPOCHS):
global current_tokenizer, current_model, live_data_updates, initial_training_texts
# 1. Load Data (Permanent)
if initial_train:
initial_training_texts = load_data_from_csv(DATA_FILE)
training_data = list(initial_training_texts)
# 2. Add Live Data
if use_live_data:
print(f"[Console] Retraining on {len(initial_training_texts)} base examples plus {len(live_data_updates)} new examples.")
training_data.extend(live_data_updates)
# 3. Tokenizer Initialization and Model Rebuild if necessary
old_vocab_size = current_tokenizer.vocab_size if current_tokenizer else 0
current_tokenizer = SimpleTokenizer(training_data)
new_vocab_size = current_tokenizer.vocab_size
if new_vocab_size != old_vocab_size or initial_train:
if initial_train:
print(f"Tokenizer Vocabulary Size: {new_vocab_size}")
print(f"\nModel D_MODEL={D_MODEL}, NUM_HEADS={NUM_HEADS}, NUM_LAYERS={NUM_LAYERS}")
current_model = TransformerLanguageModel(
vocab_size=new_vocab_size,
d_model=D_MODEL,
nhead=NUM_HEADS,
num_layers=NUM_LAYERS,
dropout=DROPOUT,
max_len=MAX_SEQ_LENGTH
).to(device)
if os.path.exists("aoban_weights.pth"):
checkpoint = torch.load("aoban_weights.pth", map_location=device)
# Get the sizes from the saved file vs the current model
saved_vocab_size = checkpoint['embedding.weight'].shape[0]
current_vocab_size = current_model.embedding.weight.shape[0]
if saved_vocab_size != current_vocab_size:
print(f"[Console] Expanding Aoban's brain from {saved_vocab_size} to {current_vocab_size} words...")
# 1. Create a copy of the model's current (empty/new) weights
new_state_dict = current_model.state_dict()
# 2. Loop through the saved memories and inject them into the new state
for key, value in checkpoint.items():
if key in new_state_dict:
if value.shape == new_state_dict[key].shape:
# Normal layers (Attention/Layers) fit perfectly
new_state_dict[key] = value
else:
# Entry/Exit layers (Embedding/FC) need surgical pasting
print(f"[Surgery] Patching {key}...")
# We copy the old 77 words into the first 77 slots of the 78 slots
new_state_dict[key][:saved_vocab_size] = value[:saved_vocab_size]
# 3. Load the expanded brain into the model
current_model.load_state_dict(new_state_dict)
else:
# If sizes match, just load normally
current_model.load_state_dict(checkpoint)
# 4. Training Setup and Execution
dataset = TextDataset(training_data, current_tokenizer, MAX_SEQ_LENGTH)
data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
optimizer = Adam(current_model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=current_tokenizer.word_to_idx["<PAD>"])
print(f"\n[Console] Starting {epochs} epochs with {len(dataset)} examples...")
train_model(current_model, data_loader, optimizer, criterion, device, epochs)
# 5. Persistence Update (Saves data if it was a retraining session)
if use_live_data:
# 5a. Update the base list to include the new data
initial_training_texts = training_data
# 5b. Save the combined data permanently to the CSV
save_data_to_csv(DATA_FILE, initial_training_texts)
# 5c. Clear the temporary queue
live_data_updates = []
print("[Console] Retraining complete. New knowledge acquired and **permanently saved**.")
def interactive_mode():
global live_data_updates, last_generated_text, last_user_prompt, penalties
global REPETITION_PENALTY, TEMPERATURE # Add this line!
# 0. Load saved settings from settings.csv
REPETITION_PENALTY, TEMPERATURE = load_settings()
# Check if the file exists before initial training
file_existed_before_run = os.path.exists(DATA_FILE)
global device
device = torch.device("cuda")
if not torch.cuda.is_available():
print("Required GPU not located, running on CPU instead.")
device = torch.device("cpu")
# Run the initial, long training session
print(f"[Console] Using device: {device}")
initialize_or_retrain(initial_train=True, use_live_data=False, epochs=NUM_EPOCHS)
# IMPORTANT: If the file did not exist before this run (meaning default data was used),
# we force a save right now to write the 27 default sentences to the CSV file immediately.
if not file_existed_before_run:
print("\n[SYSTEM] CSV file was empty/missing. Forcing initial save of default knowledge...")
save_data_to_csv(DATA_FILE, initial_training_texts)
print("[SYSTEM] Default 27 sentences are now permanently written to training_data.csv.")
print("\n" + "=" * 60)
print("🤖 Console Information🤖")
print("1. Type a phrase to generate text (max 10 words).")
print("2. Use '!add [sentence]' to queue new training data.")
print("3. Use '!accept' to add the model's last **full** sentence to the training queue.")
print(f"4. Use '!retrain' to re-train the model on new data (runs for {INTERACTIVE_EPOCHS} epochs) **and save it**.")
print(f"5. Use '!refine' to re-train on existing data (runs for {INTERACTIVE_EPOCHS} epochs) **without saving.**")
print("6. Use '!penalty <value>' to regenerate with a different repetition penalty (higher = less repetition).")
print("7. Type 'quit' or 'exit' to stop.")
print("8. Type '!help' to see this message again.")
print("9. Use '!instead [corrected text]' to replace the last output with a corrected version.")
print("=" * 60)
while True:
try:
user_input = input("You: ")
if user_input.lower() in ['quit', 'exit']:
break
if user_input.lower().startswith('!add '):
sentence = user_input[5:].strip()
if sentence:
live_data_updates.append(sentence)
print(f"[Console] Added sentence to update queue: '{sentence}'")
print(f"[Console] Current update queue size: {len(live_data_updates)}. Type '!retrain' to apply and save changes.")
last_generated_text = None # Clear accepted text
last_user_prompt = None
continue
# --- !ACCEPT COMMAND ---
if user_input.lower().strip() == '!accept':
if last_generated_text and last_user_prompt:
# CRITICAL: Reconstruct the full sentence by joining prompt and output
full_sentence_parts = [last_user_prompt.strip(), last_generated_text.strip()]
sentence_to_add = " ".join(full_sentence_parts)
# Basic cleaning: ensure there aren't double spaces
sentence_to_add = " ".join(sentence_to_add.split())
if sentence_to_add and len(sentence_to_add.split()) > 4:
live_data_updates.append(sentence_to_add)
print(f"[Console] ACCEPTED: The full sentence '{sentence_to_add}' added to update queue.")
print(f"[Console] Current update queue size: {len(live_data_updates)}. Type '!retrain' to apply and save changes.")
last_generated_text = None # Clear after acceptance
last_user_prompt = None
else:
print("[Console] Cannot accept: The reconstructed sentence was too short or incomplete. Please use '!add [full sentence]' instead.")
else:
print("[Console] No text generated or prompt found. Generate text first.")
continue
# --- END !ACCEPT COMMAND ---
if user_input.lower() == '!help':
print("\n🤖 Console Information🤖")
print("1. Type a phrase to generate text (max 10 words).")
print("2. Use '!add [sentence]' to queue new training data.")
print("3. Use '!accept' to add the model's last **full** sentence to the training queue.")
print(f"4. Use '!retrain' to re-train the model on new data (runs for {INTERACTIVE_EPOCHS} epochs) **and save it**.")
print(f"5. Use '!refine' to re-train on existing data (runs for {INTERACTIVE_EPOCHS} epochs) **without saving.**")
print("6. Use '!penalty <value>' to regenerate with a different repetition penalty (higher = less repetition).")
print("7. Type 'quit' or 'exit' to stop.")
print("8. Type '!help' to see this message again.")
print("9. Use '!instead [corrected text]' to replace the last output with a corrected version.")
print("=" * 60 + "\n")
continue
if user_input.lower() == '!retrain':
if not live_data_updates:
print("[Console] No new data to train on. Use '!add [sentence]' first.")
continue
print(f"\n[Console] RETRAINING MODEL ON NEW DATA ({INTERACTIVE_EPOCHS} EPOCHS)...")
initialize_or_retrain(initial_train=False, use_live_data=True, epochs=INTERACTIVE_EPOCHS)
last_generated_text = None # Clear the accepted text cache
last_user_prompt = None
# To Load (at the start of your script)
torch.save(current_model.state_dict(), "aoban_weights.pth")
print("[Console] Model weights permanently saved to aoban_weights.pth")
last_generated_text = None
last_user_prompt = None
continue
# --- !ENDORSE COMMAND ---
if user_input.lower().startswith('!endorse'):
if last_generated_text and last_user_prompt:
try:
# Usage: !endorse 10 (or just !endorse for default 5)
parts = user_input.split()
multiplier = int(parts[1]) if len(parts) > 1 else 5
# We only want to endorse the GOOD part (e.g., "hello how can i help you today")
# You can manually edit the last_generated_text before endorsing if you want
full_sentence = f"{last_user_prompt.strip()} {last_generated_text.strip()}"
for _ in range(multiplier):
live_data_updates.append(full_sentence)
print(f"[SYSTEM] ENDORSED: Lore added {multiplier}x to queue.")
print(f"[SYSTEM] Target: {full_sentence}")
except ValueError:
print("[SYSTEM] Usage: !endorse <number>")
else:
print("[SYSTEM] Nothing to endorse.")
continue
if user_input.lower() == '!refine':
print(f"\n[Console] REFINING MODEL ON EXISTING DATA ({INTERACTIVE_EPOCHS} EPOCHS)...")
initialize_or_retrain(initial_train=False, use_live_data=False, epochs=INTERACTIVE_EPOCHS)
print("[Console] Refinement complete. Knowledge deepened on existing data.")
continue
if user_input.lower().startswith('!instead '):
if last_user_prompt and last_generated_text:
# The user provides the "Correct" version of the response
corrected_output = user_input[9:].strip()
# 1. LOG THE BAD ONE AS A PENALTY
# We pair the prompt with the bad output so the model learns to avoid it
penalty_record = f"{last_user_prompt} {last_generated_text}"
penalties.append(penalty_record)
with open('Penalties.csv', 'a', newline='', encoding='utf-8') as f:
csv.writer(f).writerow([penalty_record, REPETITION_PENALTY])
# 2. ADD THE CORRECT ONE TO THE QUEUE (Endorse it 5x)
full_correct_sentence = f"{last_user_prompt} {corrected_output}"
for _ in range(5):
live_data_updates.append(full_correct_sentence)
print(f"[SYSTEM] Fixed! '{last_generated_text}' is now penalized.")
print(f"[SYSTEM] Added correction: '{full_correct_sentence}' to training queue.")
# Optional: Set the corrected text as the 'last_generated_text'
# so you can !accept or !endorse it further
last_generated_text = corrected_output
torch.save(current_model.state_dict(), "aoban_weights.pth")
print("[Console] Model weights permanently saved to aoban_weights.pth")
else:
print("[SYSTEM] Nothing to replace. Generate text first.")
continue
# --- !PENALTY COMMAND ---
if user_input.lower().startswith('!penalty '):
try:
new_val = float(user_input[9:].strip())
REPETITION_PENALTY = new_val
save_settings(REPETITION_PENALTY, TEMPERATURE) # Store just the value
if last_user_prompt and last_generated_text:
penalty_record = f"{last_user_prompt} {last_generated_text}"
penalties.append(penalty_record)
# Save the sentence PAIRED with the penalty value used
with open('Penalties.csv', 'a', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
# The index is automatically created by the row position in the CSV
writer.writerow([penalty_record, REPETITION_PENALTY])
# 3. Regenerate using the new permanent penalty
print(f"[Console] Regenerating with new saved penalty={REPETITION_PENALTY}...")
generated_text = generate_text(
current_model,
current_tokenizer,
last_user_prompt,
MAX_SEQ_LENGTH,
device,
TOP_K,
REPETITION_PENALTY, # Using the updated variable
TEMPERATURE
)
print(f"Model: {generated_text}")
last_generated_text = generated_text
print("\n[Console] If this full sentence is perfect, type '!accept'.")
torch.save(current_model.state_dict(), "aoban_weights.pth")
print("[Console] Model weights permanently saved to aoban_weights.pth")
except ValueError:
print(f"[Console] Invalid value. Usage: !penalty <number>.")
continue
if user_input.strip() and not user_input.lower().startswith(('!',)):
# Text generation logic
prompt = user_input.strip()
if len(prompt.split()) > MAX_SEQ_LENGTH - 1:
print(f"[Console] Prompt too long. Max {MAX_SEQ_LENGTH - 1} words supported.")
last_generated_text = None
last_user_prompt = None
continue
# 1. Store the prompt BEFORE generation
last_user_prompt = prompt
generated_text = generate_text(
current_model,
current_tokenizer,
prompt,
MAX_SEQ_LENGTH,
device,
TOP_K,
REPETITION_PENALTY,
TEMPERATURE
)
print(f"Model: {generated_text}")
# 2. Store the continuation AFTER generation
last_generated_text = generated_text
print("\n[Console] If this full sentence is perfect, type '!accept' to add it to the training queue.")
except KeyboardInterrupt:
print("\nExiting interactive mode.")
break
except Exception as e:
print(f"An error occurred: {e}")
break
if __name__ == "__main__":
# Load settings globally before anything starts
REPETITION_PENALTY, TEMPERATURE = load_settings()
print(f"[Console] Global Settings Initialized: Penalty={REPETITION_PENALTY}")
interactive_mode()