import torch import torch.nn as nn from torch.optim import Adam from torch.utils.data import Dataset, DataLoader import math import time import csv import os import re # --- HYPERPARAMETERS --- D_MODEL = 128 NUM_HEADS = 8 NUM_LAYERS = 10 DROPOUT = 0.2 MAX_SEQ_LENGTH = 17 LEARNING_RATE = 0.0002 NUM_EPOCHS = 20 # Default full training epochs BATCH_SIZE = 64 NUM_EPOCHS = 100# Default full training epochs (increased) INTERACTIVE_EPOCHS = 50 # Epochs for quick retraining (increased) # --- GENERATION SETTINGS --- TOP_K = 3 REPETITION_PENALTY = 3 TEMPERATURE = 1 PENALTIES_FILE = 'Penalties.csv' def load_penalties(): loaded_penalties = [] if os.path.exists('Penalties.csv'): with open('Penalties.csv', 'r', encoding='utf-8') as f: reader = csv.reader(f) for row in reader: if len(row) >= 2: # Store as a list of [sentence, penalty_value] loaded_penalties.append([row[0], float(row[1])]) elif row: # Fallback for old 1-column rows loaded_penalties.append([row[0], 3.0]) return loaded_penalties def save_single_penalty(penalty_string): """Appends a new penalty to the CSV immediately.""" with open(PENALTIES_FILE, 'a', newline='', encoding='utf-8') as f: writer = csv.writer(f) writer.writerow([penalty_string]) SETTINGS_FILE = 'settings.csv' def save_settings(penalty, temp): with open('settings.csv', 'a', newline='', encoding='utf-8') as f: writer = csv.writer(f) # We save both so we have a history of your "knob" turns writer.writerow([penalty, temp]) print(f"[Console] Logged to settings history: Penalty={penalty}, Temp={temp}") def load_settings(): if os.path.exists('settings.csv'): with open('settings.csv', 'r', encoding='utf-8') as f: reader = csv.reader(f) last_row = None for row in reader: if row: last_row = row if last_row: return float(last_row[0]), float(last_row[1]) # Return penalty and temp return 3.0, 1.0 # Default if file doesn't exist # --- At the start of your script --- penalties = load_penalties() # --- PERSISTENCE CONFIGURATION --- DATA_FILE = 'training_data.csv' # File where all training data is stored # --- INITIAL DATA FALLBACK (The 27 sentences you provided) --- DEFAULT_TRAINING_DATA = [ "The quick brown fox jumps over the lazy dog.", "A glass of water is clear.", "The sun is shining bright and the sky is clear.", "The dog and the fox are friends forever.", "Coding with Pytorch and Transformers is fun and very rewarding.", "A computer runs very fast and never stops.", "The windows are big and bright.", "A green park is a great place to relax.", "The sky is clear today, with no clouds.", "The cat jumped over the fence.", "The plane has many windows.", "A big bird flew over the house.", "The plane smoothly landed on the concrete runway.", "The bird flew above the bustling city.", "The plane had an engine failure and had to land in the river.", "The Cessna 172 is a low-wing monoplane.", "The plane flew by the trees.", "The plane, almost out of fuel, finally landed at an airport.", "The angry bird flew away furiously.", "A plane is a machine that flies.", "The fast plane landed at the bright airport.", "The plane quickly landed on the runway.", "The letter A is part of the alphabet.", "The plane landed hardly on a grass runway in the forest.", "The clouds were floating above the ground.", "The plane was a very bright plane, it's livery glimmered in the night sky.", "The GPWS sounds on a plane are like Caution Terrain PULL up PULL up." ] # --- FILE I/O FUNCTIONS (CRITICAL FOR PERSISTENCE) --- def load_data_from_csv(filepath): """Loads all training sentences from the CSV file, or returns the default data.""" texts = [] def split_into_sentences(paragraph): # Split on sentence end punctuation followed by whitespace and a capital or number # Use a safe regex string and fall back to newline/sentence punctuation splitting on error try: pattern = r'(?<=[\.\!?])\s+(?=[A-Z0-9"\'""\u201c])' parts = re.split(pattern, paragraph) return [p.strip() for p in parts if p and p.strip()] except re.error: # fallback: split on sentence enders and newlines parts = re.split(r'[\.\!?]\s+|\n+', paragraph) return [p.strip() for p in parts if p and p.strip()] # Attempt to read existing data if os.path.exists(filepath) and os.path.getsize(filepath) > 0: print(f"[Console] Loading training data from {filepath}...") try: with open(filepath, 'r', newline='', encoding='utf-8') as f: reader = csv.reader(f) raw_rows = [] for row in reader: if row and row[0].strip(): raw_text = row[0].strip() # Remove surrounding quotes if present if (raw_text.startswith('"') and raw_text.endswith('"')) or (raw_text.startswith("'") and raw_text.endswith("'")): raw_text = raw_text[1:-1].strip() if raw_text: raw_rows.append(raw_text) # Now split rows into sentences, filter and handle adjacent runs sequence = [] for raw in raw_rows: # If the row contains multiple sentences, split them parts = split_into_sentences(raw) # If splitting produced only one part but it contains multiple internal newlines, also split on newlines if len(parts) == 1 and '\n' in parts[0]: parts = [p.strip() for p in parts[0].splitlines() if p.strip()] for s in parts: # Normalize whitespace and strip quotes s_clean = ' '.join(s.split()).strip(' "\'') words = s_clean.split() # Basic length filters to remove garbage/too-short sentences if len(words) < 3: continue if len(words) > 300: # skip extremely long paragraphs continue # Filter out noisy/corrupted lines # Skip if contains excessive repetition (same word 3+ times in a row) is_noisy = False for i in range(len(words) - 2): if words[i] == words[i+1] == words[i+2]: is_noisy = True break if is_noisy: continue # Skip lines that look like training artifacts (high ratio of common junk words) junk_patterns = ['pull', 'up', 'land', 'river', 'sky', 'clear', 'table'] junk_count = sum(1 for w in words if w in junk_patterns) if junk_count > len(words) * 0.9999999: # more than 30% junk continue sequence.append(s_clean) # Collapse consecutive identical sentences (runs) to at most two copies i = 0 while i < len(sequence): j = i + 1 while j < len(sequence) and sequence[j] == sequence[i]: j += 1 run_len = j - i if run_len == 1: texts.append(sequence[i]) else: # keep first and last occurrence of the run texts.append(sequence[i]) texts.append(sequence[i]) i = j except Exception as e: print(f"[Console Bug] Error loading CSV: {e}. Falling back to default data.") texts = [] # Clear corrupted load # If no data loaded (file missing, empty, or corrupted), use the default knowledge base if not texts: print("[Console] CSV file not found or empty. Using default knowledge base.") return list(DEFAULT_TRAINING_DATA) # Debug: report how many sentences were actually loaded and sample content print(f"[Console] Loaded {len(texts)} sentence(s) from {filepath}.") sample_head = texts[:10] sample_tail = texts[-10:] print("[Console] First loaded sentences:") for i, s in enumerate(sample_head, 1): print(f" {i}: {s[:200]}") if len(texts) > 10: print("[Console] Last loaded sentences:") start_index = max(0, len(texts) - 10) for i, s in enumerate(texts[start_index:], start_index + 1): print(f" {i}: {s[:200]}") return texts def save_data_to_csv(filepath, texts): """Saves the entire list of training sentences (including new ones) to the CSV.""" # The 'w' mode ensures the file is overwritten with the complete, updated dataset. print(f"[Console] Saving {len(texts)} sentences to {filepath} using 'w' mode...") try: with open(filepath, 'w', newline='', encoding='utf-8') as f: writer = csv.writer(f) # Write each sentence as a single row/column entry for text in texts: writer.writerow([text]) except Exception as e: print(f"[Console Bug] Error saving to CSV: {e}") # --- TOKENIZER --- class SimpleTokenizer: def __init__(self, texts): self.word_to_idx = {"": 0, "": 1} self.idx_to_word = {0: "", 1: ""} self.build_vocab(texts) def build_vocab(self, texts): for text in texts: for word in text.lower().split(): word = word.strip("/<>") if word not in self.word_to_idx: idx = len(self.word_to_idx) self.word_to_idx[word] = idx self.idx_to_word[idx] = word def encode(self, text, max_len): words = [word.strip(".,!?") for word in text.lower().split()] indices = [self.word_to_idx.get(word, self.word_to_idx[""]) for word in words] # Padding and Truncation if len(indices) < max_len: indices.extend([self.word_to_idx[""]] * (max_len - len(indices))) elif len(indices) > max_len: indices = indices[:max_len] return torch.tensor(indices, dtype=torch.long) def decode(self, indices): return " ".join([self.idx_to_word.get(idx.item(), "") for idx in indices if idx.item() != self.word_to_idx[""]]) @property def vocab_size(self): return len(self.word_to_idx) # --- DATASET --- class TextDataset(Dataset): def __init__(self, texts, tokenizer, max_len): self.data = [] for text in texts: encoded = tokenizer.encode(text, max_len) self.data.append(encoded) def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] # --- TRANSFORMER MODEL COMPONENTS (UNMODIFIED) --- class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super(PositionalEncoding, self).__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): return x + self.pe[:x.size(1), :].transpose(0, 1) class TransformerLanguageModel(nn.Module): def __init__(self, vocab_size, d_model, nhead, num_layers, dropout, max_len): super(TransformerLanguageModel, self).__init__() self.model_type = 'Transformer' self.d_model = d_model self.vocab_size = vocab_size self.embedding = nn.Embedding(vocab_size, d_model) self.pos_encoder = PositionalEncoding(d_model, max_len) # Use decoder layers for proper causal masking in text generation decoder_layer = nn.TransformerDecoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=d_model*4, dropout=dropout, batch_first=True, activation='gelu' ) self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers) self.fc_out = nn.Linear(d_model, vocab_size) self.init_weights() def init_weights(self): initrange = 0.1 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc_out.bias.data.zero_() self.fc_out.weight.data.uniform_(-initrange, initrange) def forward(self, src): src = self.embedding(src) * math.sqrt(self.d_model) src = self.pos_encoder(src) # Create causal mask to prevent attending to future tokens seq_len = src.size(1) causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=src.device) * float('-inf'), diagonal=1) # Decoder expects (tgt, memory) but we use same for both (causal language modeling) output = self.transformer_decoder(src, src, tgt_mask=causal_mask) return self.fc_out(output) # --- TRAINING FUNCTIONS --- def train_model(model, data_loader, optimizer, criterion, device, epochs): model.train() for epoch in range(1, epochs + 1): total_loss = 0.0 for batch in data_loader: batch = batch.to(device) # Unpacking the cargo correctly src = batch[:, :-1] tgt = batch[:, 1:] optimizer.zero_grad() # 1. THE BRAIN THINKS output = model(src) # 2. THE STANDARD LOSS (Lore Check) # We reshape for CrossEntropy current_loss = criterion(output.reshape(-1, output.size(-1)), tgt.reshape(-1)) # >>> THE BLACK BOX PENALTY (Mechanical Scandal Check) <<< # If loss is at the "2.0 Purgatory", we slap it with a 0.5 penalty if current_loss.item() > 1.9: penalty = 1 current_loss = current_loss + penalty # 3. THE REACTION current_loss.backward() optimizer.step() total_loss += current_loss.item() avg_loss = total_loss / len(data_loader) if epochs > 10 and epoch % (epochs // 10) == 0: print(f"Epoch {epoch}/{epochs}, Average Loss: {avg_loss:.4f}") elif epochs > 0 and epochs <= 50 and epoch % 10 == 0: print(f"Epoch {epoch}/{epochs}, Average Loss: {avg_loss:.4f}") print(f"[Console] The {NUM_LAYERS} Layers have been updated.") penalties = [] #Penalty list to track penalties applied during generation # --- GENERATION FUNCTION (UNMODIFIED) --- def generate_text(model, tokenizer, prompt, max_len, device, top_k=40, penalty=1.8, temperature=1.0): global penalties, last_generated_text global REPETITION_PENALTY, TEMPERATURE # <--- LINK TO THE SAVED SETTINGS # If the function wasn't given a specific number, use the global loaded one if penalty is None: penalty = REPETITION_PENALTY if temperature is None: temperature = TEMPERATURE model.eval() encoded_prompt = tokenizer.encode(prompt, max_len=max_len).to(device) # Count non-PAD tokens in encoded prompt to get true prompt length pad_idx = tokenizer.word_to_idx[""] prompt_len = (encoded_prompt != pad_idx).sum().item() generated_indices = encoded_prompt[:prompt_len].tolist() input_ids = encoded_prompt.unsqueeze(0) if prompt in penalties: word = last_generated_text.lower().split() banned_ids = [tokenizer.word_to_idx.get(w, 1) for w in word] for i in range(prompt_len, max_len): src_input = input_ids[:, :i] with torch.no_grad(): output = model(src_input) logits = output[0, i-1, :] # Use this one! # In your generate_text loop: for record in penalties: bad_sentence = record[0] # The string saved_penalty = record[1] # The specific value for THIS mistake if bad_sentence.startswith(prompt): bad_words = bad_sentence[len(prompt):].strip().lower().split() banned_ids = [tokenizer.word_to_idx.get(w, 1) for w in bad_words] for bid in banned_ids: # We use saved_penalty here, NOT the global one! logits[bid] -= saved_penalty # ... (The rest of the code continues with: logits = logits / TEMPERATURE, TOP_K filtering, etc.) # Apply Repetition Penalty history = generated_indices for idx in set(history): if logits[idx] > 0: logits[idx] /= penalty else: logits[idx] *= penalty # Apply temperature scaling before top-k logits = logits / temperature # Apply Top-K Sampling top_k_values, top_k_indices = torch.topk(logits, min(top_k, len(logits))) probabilities = torch.softmax(top_k_values, dim=0) try: next_token_idx = torch.multinomial(probabilities, num_samples=1).item() except RuntimeError: predicted_token = top_k_indices[0].item() if predicted_token == tokenizer.word_to_idx[""]: break else: predicted_token = top_k_indices[next_token_idx].item() generated_indices.append(predicted_token) input_ids[0, i] = predicted_token # --- START USER-REQUESTED WERE PLURALIZATION RULE --- # Decode only the continuation text decoded_text = tokenizer.decode(torch.tensor(generated_indices, dtype=torch.long)) prompt_words = [word.strip(".,!?") for word in prompt.lower().split()] decoded_words = decoded_text.split() start_index = len(prompt_words) continuation_text = " ".join(decoded_words[start_index:]) return continuation_text.replace(" ", "").strip() # --- MAIN EXECUTION --- # Global variables for model/tokenizer instances last_generated_text = None last_user_prompt = None current_tokenizer = None current_model = None device = torch.device("cpu") live_data_updates = [] # Temporary queue for new sentences added during the current session initial_training_texts = [] # Stores all data loaded from CSV def initialize_or_retrain(initial_train=True, use_live_data=False, epochs=NUM_EPOCHS): global current_tokenizer, current_model, live_data_updates, initial_training_texts # 1. Load Data (Permanent) if initial_train: initial_training_texts = load_data_from_csv(DATA_FILE) training_data = list(initial_training_texts) # 2. Add Live Data if use_live_data: print(f"[Console] Retraining on {len(initial_training_texts)} base examples plus {len(live_data_updates)} new examples.") training_data.extend(live_data_updates) # 3. Tokenizer Initialization and Model Rebuild if necessary old_vocab_size = current_tokenizer.vocab_size if current_tokenizer else 0 current_tokenizer = SimpleTokenizer(training_data) new_vocab_size = current_tokenizer.vocab_size if new_vocab_size != old_vocab_size or initial_train: if initial_train: print(f"Tokenizer Vocabulary Size: {new_vocab_size}") print(f"\nModel D_MODEL={D_MODEL}, NUM_HEADS={NUM_HEADS}, NUM_LAYERS={NUM_LAYERS}") current_model = TransformerLanguageModel( vocab_size=new_vocab_size, d_model=D_MODEL, nhead=NUM_HEADS, num_layers=NUM_LAYERS, dropout=DROPOUT, max_len=MAX_SEQ_LENGTH ).to(device) if os.path.exists("aoban_weights.pth"): checkpoint = torch.load("aoban_weights.pth", map_location=device) # Get the sizes from the saved file vs the current model saved_vocab_size = checkpoint['embedding.weight'].shape[0] current_vocab_size = current_model.embedding.weight.shape[0] if saved_vocab_size != current_vocab_size: print(f"[Console] Expanding Aoban's brain from {saved_vocab_size} to {current_vocab_size} words...") # 1. Create a copy of the model's current (empty/new) weights new_state_dict = current_model.state_dict() # 2. Loop through the saved memories and inject them into the new state for key, value in checkpoint.items(): if key in new_state_dict: if value.shape == new_state_dict[key].shape: # Normal layers (Attention/Layers) fit perfectly new_state_dict[key] = value else: # Entry/Exit layers (Embedding/FC) need surgical pasting print(f"[Surgery] Patching {key}...") # We copy the old 77 words into the first 77 slots of the 78 slots new_state_dict[key][:saved_vocab_size] = value[:saved_vocab_size] # 3. Load the expanded brain into the model current_model.load_state_dict(new_state_dict) else: # If sizes match, just load normally current_model.load_state_dict(checkpoint) # 4. Training Setup and Execution dataset = TextDataset(training_data, current_tokenizer, MAX_SEQ_LENGTH) data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) optimizer = Adam(current_model.parameters(), lr=LEARNING_RATE) criterion = nn.CrossEntropyLoss(ignore_index=current_tokenizer.word_to_idx[""]) print(f"\n[Console] Starting {epochs} epochs with {len(dataset)} examples...") train_model(current_model, data_loader, optimizer, criterion, device, epochs) # 5. Persistence Update (Saves data if it was a retraining session) if use_live_data: # 5a. Update the base list to include the new data initial_training_texts = training_data # 5b. Save the combined data permanently to the CSV save_data_to_csv(DATA_FILE, initial_training_texts) # 5c. Clear the temporary queue live_data_updates = [] print("[Console] Retraining complete. New knowledge acquired and **permanently saved**.") def interactive_mode(): global live_data_updates, last_generated_text, last_user_prompt, penalties global REPETITION_PENALTY, TEMPERATURE # Add this line! # 0. Load saved settings from settings.csv REPETITION_PENALTY, TEMPERATURE = load_settings() # Check if the file exists before initial training file_existed_before_run = os.path.exists(DATA_FILE) global device device = torch.device("cuda") if not torch.cuda.is_available(): print("Required GPU not located, running on CPU instead.") device = torch.device("cpu") # Run the initial, long training session print(f"[Console] Using device: {device}") initialize_or_retrain(initial_train=True, use_live_data=False, epochs=NUM_EPOCHS) # IMPORTANT: If the file did not exist before this run (meaning default data was used), # we force a save right now to write the 27 default sentences to the CSV file immediately. if not file_existed_before_run: print("\n[SYSTEM] CSV file was empty/missing. Forcing initial save of default knowledge...") save_data_to_csv(DATA_FILE, initial_training_texts) print("[SYSTEM] Default 27 sentences are now permanently written to training_data.csv.") print("\n" + "=" * 60) print("šŸ¤– Console InformationšŸ¤–") print("1. Type a phrase to generate text (max 10 words).") print("2. Use '!add [sentence]' to queue new training data.") print("3. Use '!accept' to add the model's last **full** sentence to the training queue.") print(f"4. Use '!retrain' to re-train the model on new data (runs for {INTERACTIVE_EPOCHS} epochs) **and save it**.") print(f"5. Use '!refine' to re-train on existing data (runs for {INTERACTIVE_EPOCHS} epochs) **without saving.**") print("6. Use '!penalty ' to regenerate with a different repetition penalty (higher = less repetition).") print("7. Type 'quit' or 'exit' to stop.") print("8. Type '!help' to see this message again.") print("9. Use '!instead [corrected text]' to replace the last output with a corrected version.") print("=" * 60) while True: try: user_input = input("You: ") if user_input.lower() in ['quit', 'exit']: break if user_input.lower().startswith('!add '): sentence = user_input[5:].strip() if sentence: live_data_updates.append(sentence) print(f"[Console] Added sentence to update queue: '{sentence}'") print(f"[Console] Current update queue size: {len(live_data_updates)}. Type '!retrain' to apply and save changes.") last_generated_text = None # Clear accepted text last_user_prompt = None continue # --- !ACCEPT COMMAND --- if user_input.lower().strip() == '!accept': if last_generated_text and last_user_prompt: # CRITICAL: Reconstruct the full sentence by joining prompt and output full_sentence_parts = [last_user_prompt.strip(), last_generated_text.strip()] sentence_to_add = " ".join(full_sentence_parts) # Basic cleaning: ensure there aren't double spaces sentence_to_add = " ".join(sentence_to_add.split()) if sentence_to_add and len(sentence_to_add.split()) > 4: live_data_updates.append(sentence_to_add) print(f"[Console] ACCEPTED: The full sentence '{sentence_to_add}' added to update queue.") print(f"[Console] Current update queue size: {len(live_data_updates)}. Type '!retrain' to apply and save changes.") last_generated_text = None # Clear after acceptance last_user_prompt = None else: print("[Console] Cannot accept: The reconstructed sentence was too short or incomplete. Please use '!add [full sentence]' instead.") else: print("[Console] No text generated or prompt found. Generate text first.") continue # --- END !ACCEPT COMMAND --- if user_input.lower() == '!help': print("\nšŸ¤– Console InformationšŸ¤–") print("1. Type a phrase to generate text (max 10 words).") print("2. Use '!add [sentence]' to queue new training data.") print("3. Use '!accept' to add the model's last **full** sentence to the training queue.") print(f"4. Use '!retrain' to re-train the model on new data (runs for {INTERACTIVE_EPOCHS} epochs) **and save it**.") print(f"5. Use '!refine' to re-train on existing data (runs for {INTERACTIVE_EPOCHS} epochs) **without saving.**") print("6. Use '!penalty ' to regenerate with a different repetition penalty (higher = less repetition).") print("7. Type 'quit' or 'exit' to stop.") print("8. Type '!help' to see this message again.") print("9. Use '!instead [corrected text]' to replace the last output with a corrected version.") print("=" * 60 + "\n") continue if user_input.lower() == '!retrain': if not live_data_updates: print("[Console] No new data to train on. Use '!add [sentence]' first.") continue print(f"\n[Console] RETRAINING MODEL ON NEW DATA ({INTERACTIVE_EPOCHS} EPOCHS)...") initialize_or_retrain(initial_train=False, use_live_data=True, epochs=INTERACTIVE_EPOCHS) last_generated_text = None # Clear the accepted text cache last_user_prompt = None # To Load (at the start of your script) torch.save(current_model.state_dict(), "aoban_weights.pth") print("[Console] Model weights permanently saved to aoban_weights.pth") last_generated_text = None last_user_prompt = None continue # --- !ENDORSE COMMAND --- if user_input.lower().startswith('!endorse'): if last_generated_text and last_user_prompt: try: # Usage: !endorse 10 (or just !endorse for default 5) parts = user_input.split() multiplier = int(parts[1]) if len(parts) > 1 else 5 # We only want to endorse the GOOD part (e.g., "hello how can i help you today") # You can manually edit the last_generated_text before endorsing if you want full_sentence = f"{last_user_prompt.strip()} {last_generated_text.strip()}" for _ in range(multiplier): live_data_updates.append(full_sentence) print(f"[SYSTEM] ENDORSED: Lore added {multiplier}x to queue.") print(f"[SYSTEM] Target: {full_sentence}") except ValueError: print("[SYSTEM] Usage: !endorse ") else: print("[SYSTEM] Nothing to endorse.") continue if user_input.lower() == '!refine': print(f"\n[Console] REFINING MODEL ON EXISTING DATA ({INTERACTIVE_EPOCHS} EPOCHS)...") initialize_or_retrain(initial_train=False, use_live_data=False, epochs=INTERACTIVE_EPOCHS) print("[Console] Refinement complete. Knowledge deepened on existing data.") continue if user_input.lower().startswith('!instead '): if last_user_prompt and last_generated_text: # The user provides the "Correct" version of the response corrected_output = user_input[9:].strip() # 1. LOG THE BAD ONE AS A PENALTY # We pair the prompt with the bad output so the model learns to avoid it penalty_record = f"{last_user_prompt} {last_generated_text}" penalties.append(penalty_record) with open('Penalties.csv', 'a', newline='', encoding='utf-8') as f: csv.writer(f).writerow([penalty_record, REPETITION_PENALTY]) # 2. ADD THE CORRECT ONE TO THE QUEUE (Endorse it 5x) full_correct_sentence = f"{last_user_prompt} {corrected_output}" for _ in range(5): live_data_updates.append(full_correct_sentence) print(f"[SYSTEM] Fixed! '{last_generated_text}' is now penalized.") print(f"[SYSTEM] Added correction: '{full_correct_sentence}' to training queue.") # Optional: Set the corrected text as the 'last_generated_text' # so you can !accept or !endorse it further last_generated_text = corrected_output torch.save(current_model.state_dict(), "aoban_weights.pth") print("[Console] Model weights permanently saved to aoban_weights.pth") else: print("[SYSTEM] Nothing to replace. Generate text first.") continue # --- !PENALTY COMMAND --- if user_input.lower().startswith('!penalty '): try: new_val = float(user_input[9:].strip()) REPETITION_PENALTY = new_val save_settings(REPETITION_PENALTY, TEMPERATURE) # Store just the value if last_user_prompt and last_generated_text: penalty_record = f"{last_user_prompt} {last_generated_text}" penalties.append(penalty_record) # Save the sentence PAIRED with the penalty value used with open('Penalties.csv', 'a', newline='', encoding='utf-8') as f: writer = csv.writer(f) # The index is automatically created by the row position in the CSV writer.writerow([penalty_record, REPETITION_PENALTY]) # 3. Regenerate using the new permanent penalty print(f"[Console] Regenerating with new saved penalty={REPETITION_PENALTY}...") generated_text = generate_text( current_model, current_tokenizer, last_user_prompt, MAX_SEQ_LENGTH, device, TOP_K, REPETITION_PENALTY, # Using the updated variable TEMPERATURE ) print(f"Model: {generated_text}") last_generated_text = generated_text print("\n[Console] If this full sentence is perfect, type '!accept'.") torch.save(current_model.state_dict(), "aoban_weights.pth") print("[Console] Model weights permanently saved to aoban_weights.pth") except ValueError: print(f"[Console] Invalid value. Usage: !penalty .") continue if user_input.strip() and not user_input.lower().startswith(('!',)): # Text generation logic prompt = user_input.strip() if len(prompt.split()) > MAX_SEQ_LENGTH - 1: print(f"[Console] Prompt too long. Max {MAX_SEQ_LENGTH - 1} words supported.") last_generated_text = None last_user_prompt = None continue # 1. Store the prompt BEFORE generation last_user_prompt = prompt generated_text = generate_text( current_model, current_tokenizer, prompt, MAX_SEQ_LENGTH, device, TOP_K, REPETITION_PENALTY, TEMPERATURE ) print(f"Model: {generated_text}") # 2. Store the continuation AFTER generation last_generated_text = generated_text print("\n[Console] If this full sentence is perfect, type '!accept' to add it to the training queue.") except KeyboardInterrupt: print("\nExiting interactive mode.") break except Exception as e: print(f"An error occurred: {e}") break if __name__ == "__main__": # Load settings globally before anything starts REPETITION_PENALTY, TEMPERATURE = load_settings() print(f"[Console] Global Settings Initialized: Penalty={REPETITION_PENALTY}") interactive_mode()