Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.optim import Adam | |
| from torch.utils.data import Dataset, DataLoader | |
| import math | |
| import time | |
| import csv | |
| import os | |
| import re | |
| # --- HYPERPARAMETERS --- | |
| D_MODEL = 128 | |
| NUM_HEADS = 8 | |
| NUM_LAYERS = 10 | |
| DROPOUT = 0.2 | |
| MAX_SEQ_LENGTH = 17 | |
| LEARNING_RATE = 0.0002 | |
| NUM_EPOCHS = 20 # Default full training epochs | |
| BATCH_SIZE = 64 | |
| NUM_EPOCHS = 100# Default full training epochs (increased) | |
| INTERACTIVE_EPOCHS = 50 # Epochs for quick retraining (increased) | |
| # --- GENERATION SETTINGS --- | |
| TOP_K = 3 | |
| REPETITION_PENALTY = 3 | |
| TEMPERATURE = 1 | |
| PENALTIES_FILE = 'Penalties.csv' | |
| def load_penalties(): | |
| loaded_penalties = [] | |
| if os.path.exists('Penalties.csv'): | |
| with open('Penalties.csv', 'r', encoding='utf-8') as f: | |
| reader = csv.reader(f) | |
| for row in reader: | |
| if len(row) >= 2: | |
| # Store as a list of [sentence, penalty_value] | |
| loaded_penalties.append([row[0], float(row[1])]) | |
| elif row: | |
| # Fallback for old 1-column rows | |
| loaded_penalties.append([row[0], 3.0]) | |
| return loaded_penalties | |
| def save_single_penalty(penalty_string): | |
| """Appends a new penalty to the CSV immediately.""" | |
| with open(PENALTIES_FILE, 'a', newline='', encoding='utf-8') as f: | |
| writer = csv.writer(f) | |
| writer.writerow([penalty_string]) | |
| SETTINGS_FILE = 'settings.csv' | |
| def save_settings(penalty, temp): | |
| with open('settings.csv', 'a', newline='', encoding='utf-8') as f: | |
| writer = csv.writer(f) | |
| # We save both so we have a history of your "knob" turns | |
| writer.writerow([penalty, temp]) | |
| print(f"[Console] Logged to settings history: Penalty={penalty}, Temp={temp}") | |
| def load_settings(): | |
| if os.path.exists('settings.csv'): | |
| with open('settings.csv', 'r', encoding='utf-8') as f: | |
| reader = csv.reader(f) | |
| last_row = None | |
| for row in reader: | |
| if row: | |
| last_row = row | |
| if last_row: | |
| return float(last_row[0]), float(last_row[1]) # Return penalty and temp | |
| return 3.0, 1.0 # Default if file doesn't exist | |
| # --- At the start of your script --- | |
| penalties = load_penalties() | |
| # --- PERSISTENCE CONFIGURATION --- | |
| DATA_FILE = 'training_data.csv' # File where all training data is stored | |
| # --- INITIAL DATA FALLBACK (The 27 sentences you provided) --- | |
| DEFAULT_TRAINING_DATA = [ | |
| "The quick brown fox jumps over the lazy dog.", | |
| "A glass of water is clear.", | |
| "The sun is shining bright and the sky is clear.", | |
| "The dog and the fox are friends forever.", | |
| "Coding with Pytorch and Transformers is fun and very rewarding.", | |
| "A computer runs very fast and never stops.", | |
| "The windows are big and bright.", | |
| "A green park is a great place to relax.", | |
| "The sky is clear today, with no clouds.", | |
| "The cat jumped over the fence.", | |
| "The plane has many windows.", | |
| "A big bird flew over the house.", | |
| "The plane smoothly landed on the concrete runway.", | |
| "The bird flew above the bustling city.", | |
| "The plane had an engine failure and had to land in the river.", | |
| "The Cessna 172 is a low-wing monoplane.", | |
| "The plane flew by the trees.", | |
| "The plane, almost out of fuel, finally landed at an airport.", | |
| "The angry bird flew away furiously.", | |
| "A plane is a machine that flies.", | |
| "The fast plane landed at the bright airport.", | |
| "The plane quickly landed on the runway.", | |
| "The letter A is part of the alphabet.", | |
| "The plane landed hardly on a grass runway in the forest.", | |
| "The clouds were floating above the ground.", | |
| "The plane was a very bright plane, it's livery glimmered in the night sky.", | |
| "The GPWS sounds on a plane are like Caution Terrain PULL up PULL up." | |
| ] | |
| # --- FILE I/O FUNCTIONS (CRITICAL FOR PERSISTENCE) --- | |
| def load_data_from_csv(filepath): | |
| """Loads all training sentences from the CSV file, or returns the default data.""" | |
| texts = [] | |
| def split_into_sentences(paragraph): | |
| # Split on sentence end punctuation followed by whitespace and a capital or number | |
| # Use a safe regex string and fall back to newline/sentence punctuation splitting on error | |
| try: | |
| pattern = r'(?<=[\.\!?])\s+(?=[A-Z0-9"\'""\u201c])' | |
| parts = re.split(pattern, paragraph) | |
| return [p.strip() for p in parts if p and p.strip()] | |
| except re.error: | |
| # fallback: split on sentence enders and newlines | |
| parts = re.split(r'[\.\!?]\s+|\n+', paragraph) | |
| return [p.strip() for p in parts if p and p.strip()] | |
| # Attempt to read existing data | |
| if os.path.exists(filepath) and os.path.getsize(filepath) > 0: | |
| print(f"[Console] Loading training data from {filepath}...") | |
| try: | |
| with open(filepath, 'r', newline='', encoding='utf-8') as f: | |
| reader = csv.reader(f) | |
| raw_rows = [] | |
| for row in reader: | |
| if row and row[0].strip(): | |
| raw_text = row[0].strip() | |
| # Remove surrounding quotes if present | |
| if (raw_text.startswith('"') and raw_text.endswith('"')) or (raw_text.startswith("'") and raw_text.endswith("'")): | |
| raw_text = raw_text[1:-1].strip() | |
| if raw_text: | |
| raw_rows.append(raw_text) | |
| # Now split rows into sentences, filter and handle adjacent runs | |
| sequence = [] | |
| for raw in raw_rows: | |
| # If the row contains multiple sentences, split them | |
| parts = split_into_sentences(raw) | |
| # If splitting produced only one part but it contains multiple internal newlines, also split on newlines | |
| if len(parts) == 1 and '\n' in parts[0]: | |
| parts = [p.strip() for p in parts[0].splitlines() if p.strip()] | |
| for s in parts: | |
| # Normalize whitespace and strip quotes | |
| s_clean = ' '.join(s.split()).strip(' "\'') | |
| words = s_clean.split() | |
| # Basic length filters to remove garbage/too-short sentences | |
| if len(words) < 3: | |
| continue | |
| if len(words) > 300: | |
| # skip extremely long paragraphs | |
| continue | |
| # Filter out noisy/corrupted lines | |
| # Skip if contains excessive repetition (same word 3+ times in a row) | |
| is_noisy = False | |
| for i in range(len(words) - 2): | |
| if words[i] == words[i+1] == words[i+2]: | |
| is_noisy = True | |
| break | |
| if is_noisy: | |
| continue | |
| # Skip lines that look like training artifacts (high ratio of common junk words) | |
| junk_patterns = ['pull', 'up', 'land', 'river', 'sky', 'clear', 'table'] | |
| junk_count = sum(1 for w in words if w in junk_patterns) | |
| if junk_count > len(words) * 0.9999999: # more than 30% junk | |
| continue | |
| sequence.append(s_clean) | |
| # Collapse consecutive identical sentences (runs) to at most two copies | |
| i = 0 | |
| while i < len(sequence): | |
| j = i + 1 | |
| while j < len(sequence) and sequence[j] == sequence[i]: | |
| j += 1 | |
| run_len = j - i | |
| if run_len == 1: | |
| texts.append(sequence[i]) | |
| else: | |
| # keep first and last occurrence of the run | |
| texts.append(sequence[i]) | |
| texts.append(sequence[i]) | |
| i = j | |
| except Exception as e: | |
| print(f"[Console Bug] Error loading CSV: {e}. Falling back to default data.") | |
| texts = [] # Clear corrupted load | |
| # If no data loaded (file missing, empty, or corrupted), use the default knowledge base | |
| if not texts: | |
| print("[Console] CSV file not found or empty. Using default knowledge base.") | |
| return list(DEFAULT_TRAINING_DATA) | |
| # Debug: report how many sentences were actually loaded and sample content | |
| print(f"[Console] Loaded {len(texts)} sentence(s) from {filepath}.") | |
| sample_head = texts[:10] | |
| sample_tail = texts[-10:] | |
| print("[Console] First loaded sentences:") | |
| for i, s in enumerate(sample_head, 1): | |
| print(f" {i}: {s[:200]}") | |
| if len(texts) > 10: | |
| print("[Console] Last loaded sentences:") | |
| start_index = max(0, len(texts) - 10) | |
| for i, s in enumerate(texts[start_index:], start_index + 1): | |
| print(f" {i}: {s[:200]}") | |
| return texts | |
| def save_data_to_csv(filepath, texts): | |
| """Saves the entire list of training sentences (including new ones) to the CSV.""" | |
| # The 'w' mode ensures the file is overwritten with the complete, updated dataset. | |
| print(f"[Console] Saving {len(texts)} sentences to {filepath} using 'w' mode...") | |
| try: | |
| with open(filepath, 'w', newline='', encoding='utf-8') as f: | |
| writer = csv.writer(f) | |
| # Write each sentence as a single row/column entry | |
| for text in texts: | |
| writer.writerow([text]) | |
| except Exception as e: | |
| print(f"[Console Bug] Error saving to CSV: {e}") | |
| # --- TOKENIZER --- | |
| class SimpleTokenizer: | |
| def __init__(self, texts): | |
| self.word_to_idx = {"<PAD>": 0, "<UNK>": 1} | |
| self.idx_to_word = {0: "<PAD>", 1: "<UNK>"} | |
| self.build_vocab(texts) | |
| def build_vocab(self, texts): | |
| for text in texts: | |
| for word in text.lower().split(): | |
| word = word.strip("/<>") | |
| if word not in self.word_to_idx: | |
| idx = len(self.word_to_idx) | |
| self.word_to_idx[word] = idx | |
| self.idx_to_word[idx] = word | |
| def encode(self, text, max_len): | |
| words = [word.strip(".,!?") for word in text.lower().split()] | |
| indices = [self.word_to_idx.get(word, self.word_to_idx["<UNK>"]) for word in words] | |
| # Padding and Truncation | |
| if len(indices) < max_len: | |
| indices.extend([self.word_to_idx["<PAD>"]] * (max_len - len(indices))) | |
| elif len(indices) > max_len: | |
| indices = indices[:max_len] | |
| return torch.tensor(indices, dtype=torch.long) | |
| def decode(self, indices): | |
| return " ".join([self.idx_to_word.get(idx.item(), "<UNK>") for idx in indices if idx.item() != self.word_to_idx["<PAD>"]]) | |
| def vocab_size(self): | |
| return len(self.word_to_idx) | |
| # --- DATASET --- | |
| class TextDataset(Dataset): | |
| def __init__(self, texts, tokenizer, max_len): | |
| self.data = [] | |
| for text in texts: | |
| encoded = tokenizer.encode(text, max_len) | |
| self.data.append(encoded) | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| return self.data[idx] | |
| # --- TRANSFORMER MODEL COMPONENTS (UNMODIFIED) --- | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, max_len=5000): | |
| super(PositionalEncoding, self).__init__() | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0).transpose(0, 1) | |
| self.register_buffer('pe', pe) | |
| def forward(self, x): | |
| return x + self.pe[:x.size(1), :].transpose(0, 1) | |
| class TransformerLanguageModel(nn.Module): | |
| def __init__(self, vocab_size, d_model, nhead, num_layers, dropout, max_len): | |
| super(TransformerLanguageModel, self).__init__() | |
| self.model_type = 'Transformer' | |
| self.d_model = d_model | |
| self.vocab_size = vocab_size | |
| self.embedding = nn.Embedding(vocab_size, d_model) | |
| self.pos_encoder = PositionalEncoding(d_model, max_len) | |
| # Use decoder layers for proper causal masking in text generation | |
| decoder_layer = nn.TransformerDecoderLayer( | |
| d_model=d_model, | |
| nhead=nhead, | |
| dim_feedforward=d_model*4, | |
| dropout=dropout, | |
| batch_first=True, | |
| activation='gelu' | |
| ) | |
| self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers) | |
| self.fc_out = nn.Linear(d_model, vocab_size) | |
| self.init_weights() | |
| def init_weights(self): | |
| initrange = 0.1 | |
| self.embedding.weight.data.uniform_(-initrange, initrange) | |
| self.fc_out.bias.data.zero_() | |
| self.fc_out.weight.data.uniform_(-initrange, initrange) | |
| def forward(self, src): | |
| src = self.embedding(src) * math.sqrt(self.d_model) | |
| src = self.pos_encoder(src) | |
| # Create causal mask to prevent attending to future tokens | |
| seq_len = src.size(1) | |
| causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=src.device) * float('-inf'), diagonal=1) | |
| # Decoder expects (tgt, memory) but we use same for both (causal language modeling) | |
| output = self.transformer_decoder(src, src, tgt_mask=causal_mask) | |
| return self.fc_out(output) | |
| # --- TRAINING FUNCTIONS --- | |
| def train_model(model, data_loader, optimizer, criterion, device, epochs): | |
| model.train() | |
| for epoch in range(1, epochs + 1): | |
| total_loss = 0.0 | |
| for batch in data_loader: | |
| batch = batch.to(device) | |
| # Unpacking the cargo correctly | |
| src = batch[:, :-1] | |
| tgt = batch[:, 1:] | |
| optimizer.zero_grad() | |
| # 1. THE BRAIN THINKS | |
| output = model(src) | |
| # 2. THE STANDARD LOSS (Lore Check) | |
| # We reshape for CrossEntropy | |
| current_loss = criterion(output.reshape(-1, output.size(-1)), tgt.reshape(-1)) | |
| # >>> THE BLACK BOX PENALTY (Mechanical Scandal Check) <<< | |
| # If loss is at the "2.0 Purgatory", we slap it with a 0.5 penalty | |
| if current_loss.item() > 1.9: | |
| penalty = 1 | |
| current_loss = current_loss + penalty | |
| # 3. THE REACTION | |
| current_loss.backward() | |
| optimizer.step() | |
| total_loss += current_loss.item() | |
| avg_loss = total_loss / len(data_loader) | |
| if epochs > 10 and epoch % (epochs // 10) == 0: | |
| print(f"Epoch {epoch}/{epochs}, Average Loss: {avg_loss:.4f}") | |
| elif epochs > 0 and epochs <= 50 and epoch % 10 == 0: | |
| print(f"Epoch {epoch}/{epochs}, Average Loss: {avg_loss:.4f}") | |
| print(f"[Console] The {NUM_LAYERS} Layers have been updated.") | |
| penalties = [] #Penalty list to track penalties applied during generation | |
| # --- GENERATION FUNCTION (UNMODIFIED) --- | |
| def generate_text(model, tokenizer, prompt, max_len, device, top_k=40, penalty=1.8, temperature=1.0): | |
| global penalties, last_generated_text | |
| global REPETITION_PENALTY, TEMPERATURE # <--- LINK TO THE SAVED SETTINGS | |
| # If the function wasn't given a specific number, use the global loaded one | |
| if penalty is None: | |
| penalty = REPETITION_PENALTY | |
| if temperature is None: | |
| temperature = TEMPERATURE | |
| model.eval() | |
| encoded_prompt = tokenizer.encode(prompt, max_len=max_len).to(device) | |
| # Count non-PAD tokens in encoded prompt to get true prompt length | |
| pad_idx = tokenizer.word_to_idx["<PAD>"] | |
| prompt_len = (encoded_prompt != pad_idx).sum().item() | |
| generated_indices = encoded_prompt[:prompt_len].tolist() | |
| input_ids = encoded_prompt.unsqueeze(0) | |
| if prompt in penalties: | |
| word = last_generated_text.lower().split() | |
| banned_ids = [tokenizer.word_to_idx.get(w, 1) for w in word] | |
| for i in range(prompt_len, max_len): | |
| src_input = input_ids[:, :i] | |
| with torch.no_grad(): | |
| output = model(src_input) | |
| logits = output[0, i-1, :] | |
| # Use this one! | |
| # In your generate_text loop: | |
| for record in penalties: | |
| bad_sentence = record[0] # The string | |
| saved_penalty = record[1] # The specific value for THIS mistake | |
| if bad_sentence.startswith(prompt): | |
| bad_words = bad_sentence[len(prompt):].strip().lower().split() | |
| banned_ids = [tokenizer.word_to_idx.get(w, 1) for w in bad_words] | |
| for bid in banned_ids: | |
| # We use saved_penalty here, NOT the global one! | |
| logits[bid] -= saved_penalty | |
| # ... (The rest of the code continues with: logits = logits / TEMPERATURE, TOP_K filtering, etc.) | |
| # Apply Repetition Penalty | |
| history = generated_indices | |
| for idx in set(history): | |
| if logits[idx] > 0: | |
| logits[idx] /= penalty | |
| else: | |
| logits[idx] *= penalty | |
| # Apply temperature scaling before top-k | |
| logits = logits / temperature | |
| # Apply Top-K Sampling | |
| top_k_values, top_k_indices = torch.topk(logits, min(top_k, len(logits))) | |
| probabilities = torch.softmax(top_k_values, dim=0) | |
| try: | |
| next_token_idx = torch.multinomial(probabilities, num_samples=1).item() | |
| except RuntimeError: | |
| predicted_token = top_k_indices[0].item() | |
| if predicted_token == tokenizer.word_to_idx["<PAD>"]: | |
| break | |
| else: | |
| predicted_token = top_k_indices[next_token_idx].item() | |
| generated_indices.append(predicted_token) | |
| input_ids[0, i] = predicted_token | |
| # --- START USER-REQUESTED WERE PLURALIZATION RULE --- | |
| # Decode only the continuation text | |
| decoded_text = tokenizer.decode(torch.tensor(generated_indices, dtype=torch.long)) | |
| prompt_words = [word.strip(".,!?") for word in prompt.lower().split()] | |
| decoded_words = decoded_text.split() | |
| start_index = len(prompt_words) | |
| continuation_text = " ".join(decoded_words[start_index:]) | |
| return continuation_text.replace(" <pad>", "").strip() | |
| # --- MAIN EXECUTION --- | |
| # Global variables for model/tokenizer instances | |
| last_generated_text = None | |
| last_user_prompt = None | |
| current_tokenizer = None | |
| current_model = None | |
| device = torch.device("cpu") | |
| live_data_updates = [] # Temporary queue for new sentences added during the current session | |
| initial_training_texts = [] # Stores all data loaded from CSV | |
| def initialize_or_retrain(initial_train=True, use_live_data=False, epochs=NUM_EPOCHS): | |
| global current_tokenizer, current_model, live_data_updates, initial_training_texts | |
| # 1. Load Data (Permanent) | |
| if initial_train: | |
| initial_training_texts = load_data_from_csv(DATA_FILE) | |
| training_data = list(initial_training_texts) | |
| # 2. Add Live Data | |
| if use_live_data: | |
| print(f"[Console] Retraining on {len(initial_training_texts)} base examples plus {len(live_data_updates)} new examples.") | |
| training_data.extend(live_data_updates) | |
| # 3. Tokenizer Initialization and Model Rebuild if necessary | |
| old_vocab_size = current_tokenizer.vocab_size if current_tokenizer else 0 | |
| current_tokenizer = SimpleTokenizer(training_data) | |
| new_vocab_size = current_tokenizer.vocab_size | |
| if new_vocab_size != old_vocab_size or initial_train: | |
| if initial_train: | |
| print(f"Tokenizer Vocabulary Size: {new_vocab_size}") | |
| print(f"\nModel D_MODEL={D_MODEL}, NUM_HEADS={NUM_HEADS}, NUM_LAYERS={NUM_LAYERS}") | |
| current_model = TransformerLanguageModel( | |
| vocab_size=new_vocab_size, | |
| d_model=D_MODEL, | |
| nhead=NUM_HEADS, | |
| num_layers=NUM_LAYERS, | |
| dropout=DROPOUT, | |
| max_len=MAX_SEQ_LENGTH | |
| ).to(device) | |
| if os.path.exists("aoban_weights.pth"): | |
| checkpoint = torch.load("aoban_weights.pth", map_location=device) | |
| # Get the sizes from the saved file vs the current model | |
| saved_vocab_size = checkpoint['embedding.weight'].shape[0] | |
| current_vocab_size = current_model.embedding.weight.shape[0] | |
| if saved_vocab_size != current_vocab_size: | |
| print(f"[Console] Expanding Aoban's brain from {saved_vocab_size} to {current_vocab_size} words...") | |
| # 1. Create a copy of the model's current (empty/new) weights | |
| new_state_dict = current_model.state_dict() | |
| # 2. Loop through the saved memories and inject them into the new state | |
| for key, value in checkpoint.items(): | |
| if key in new_state_dict: | |
| if value.shape == new_state_dict[key].shape: | |
| # Normal layers (Attention/Layers) fit perfectly | |
| new_state_dict[key] = value | |
| else: | |
| # Entry/Exit layers (Embedding/FC) need surgical pasting | |
| print(f"[Surgery] Patching {key}...") | |
| # We copy the old 77 words into the first 77 slots of the 78 slots | |
| new_state_dict[key][:saved_vocab_size] = value[:saved_vocab_size] | |
| # 3. Load the expanded brain into the model | |
| current_model.load_state_dict(new_state_dict) | |
| else: | |
| # If sizes match, just load normally | |
| current_model.load_state_dict(checkpoint) | |
| # 4. Training Setup and Execution | |
| dataset = TextDataset(training_data, current_tokenizer, MAX_SEQ_LENGTH) | |
| data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) | |
| optimizer = Adam(current_model.parameters(), lr=LEARNING_RATE) | |
| criterion = nn.CrossEntropyLoss(ignore_index=current_tokenizer.word_to_idx["<PAD>"]) | |
| print(f"\n[Console] Starting {epochs} epochs with {len(dataset)} examples...") | |
| train_model(current_model, data_loader, optimizer, criterion, device, epochs) | |
| # 5. Persistence Update (Saves data if it was a retraining session) | |
| if use_live_data: | |
| # 5a. Update the base list to include the new data | |
| initial_training_texts = training_data | |
| # 5b. Save the combined data permanently to the CSV | |
| save_data_to_csv(DATA_FILE, initial_training_texts) | |
| # 5c. Clear the temporary queue | |
| live_data_updates = [] | |
| print("[Console] Retraining complete. New knowledge acquired and **permanently saved**.") | |
| def interactive_mode(): | |
| global live_data_updates, last_generated_text, last_user_prompt, penalties | |
| global REPETITION_PENALTY, TEMPERATURE # Add this line! | |
| # 0. Load saved settings from settings.csv | |
| REPETITION_PENALTY, TEMPERATURE = load_settings() | |
| # Check if the file exists before initial training | |
| file_existed_before_run = os.path.exists(DATA_FILE) | |
| global device | |
| device = torch.device("cuda") | |
| if not torch.cuda.is_available(): | |
| print("Required GPU not located, running on CPU instead.") | |
| device = torch.device("cpu") | |
| # Run the initial, long training session | |
| print(f"[Console] Using device: {device}") | |
| initialize_or_retrain(initial_train=True, use_live_data=False, epochs=NUM_EPOCHS) | |
| # IMPORTANT: If the file did not exist before this run (meaning default data was used), | |
| # we force a save right now to write the 27 default sentences to the CSV file immediately. | |
| if not file_existed_before_run: | |
| print("\n[SYSTEM] CSV file was empty/missing. Forcing initial save of default knowledge...") | |
| save_data_to_csv(DATA_FILE, initial_training_texts) | |
| print("[SYSTEM] Default 27 sentences are now permanently written to training_data.csv.") | |
| print("\n" + "=" * 60) | |
| print("🤖 Console Information🤖") | |
| print("1. Type a phrase to generate text (max 10 words).") | |
| print("2. Use '!add [sentence]' to queue new training data.") | |
| print("3. Use '!accept' to add the model's last **full** sentence to the training queue.") | |
| print(f"4. Use '!retrain' to re-train the model on new data (runs for {INTERACTIVE_EPOCHS} epochs) **and save it**.") | |
| print(f"5. Use '!refine' to re-train on existing data (runs for {INTERACTIVE_EPOCHS} epochs) **without saving.**") | |
| print("6. Use '!penalty <value>' to regenerate with a different repetition penalty (higher = less repetition).") | |
| print("7. Type 'quit' or 'exit' to stop.") | |
| print("8. Type '!help' to see this message again.") | |
| print("9. Use '!instead [corrected text]' to replace the last output with a corrected version.") | |
| print("=" * 60) | |
| while True: | |
| try: | |
| user_input = input("You: ") | |
| if user_input.lower() in ['quit', 'exit']: | |
| break | |
| if user_input.lower().startswith('!add '): | |
| sentence = user_input[5:].strip() | |
| if sentence: | |
| live_data_updates.append(sentence) | |
| print(f"[Console] Added sentence to update queue: '{sentence}'") | |
| print(f"[Console] Current update queue size: {len(live_data_updates)}. Type '!retrain' to apply and save changes.") | |
| last_generated_text = None # Clear accepted text | |
| last_user_prompt = None | |
| continue | |
| # --- !ACCEPT COMMAND --- | |
| if user_input.lower().strip() == '!accept': | |
| if last_generated_text and last_user_prompt: | |
| # CRITICAL: Reconstruct the full sentence by joining prompt and output | |
| full_sentence_parts = [last_user_prompt.strip(), last_generated_text.strip()] | |
| sentence_to_add = " ".join(full_sentence_parts) | |
| # Basic cleaning: ensure there aren't double spaces | |
| sentence_to_add = " ".join(sentence_to_add.split()) | |
| if sentence_to_add and len(sentence_to_add.split()) > 4: | |
| live_data_updates.append(sentence_to_add) | |
| print(f"[Console] ACCEPTED: The full sentence '{sentence_to_add}' added to update queue.") | |
| print(f"[Console] Current update queue size: {len(live_data_updates)}. Type '!retrain' to apply and save changes.") | |
| last_generated_text = None # Clear after acceptance | |
| last_user_prompt = None | |
| else: | |
| print("[Console] Cannot accept: The reconstructed sentence was too short or incomplete. Please use '!add [full sentence]' instead.") | |
| else: | |
| print("[Console] No text generated or prompt found. Generate text first.") | |
| continue | |
| # --- END !ACCEPT COMMAND --- | |
| if user_input.lower() == '!help': | |
| print("\n🤖 Console Information🤖") | |
| print("1. Type a phrase to generate text (max 10 words).") | |
| print("2. Use '!add [sentence]' to queue new training data.") | |
| print("3. Use '!accept' to add the model's last **full** sentence to the training queue.") | |
| print(f"4. Use '!retrain' to re-train the model on new data (runs for {INTERACTIVE_EPOCHS} epochs) **and save it**.") | |
| print(f"5. Use '!refine' to re-train on existing data (runs for {INTERACTIVE_EPOCHS} epochs) **without saving.**") | |
| print("6. Use '!penalty <value>' to regenerate with a different repetition penalty (higher = less repetition).") | |
| print("7. Type 'quit' or 'exit' to stop.") | |
| print("8. Type '!help' to see this message again.") | |
| print("9. Use '!instead [corrected text]' to replace the last output with a corrected version.") | |
| print("=" * 60 + "\n") | |
| continue | |
| if user_input.lower() == '!retrain': | |
| if not live_data_updates: | |
| print("[Console] No new data to train on. Use '!add [sentence]' first.") | |
| continue | |
| print(f"\n[Console] RETRAINING MODEL ON NEW DATA ({INTERACTIVE_EPOCHS} EPOCHS)...") | |
| initialize_or_retrain(initial_train=False, use_live_data=True, epochs=INTERACTIVE_EPOCHS) | |
| last_generated_text = None # Clear the accepted text cache | |
| last_user_prompt = None | |
| # To Load (at the start of your script) | |
| torch.save(current_model.state_dict(), "aoban_weights.pth") | |
| print("[Console] Model weights permanently saved to aoban_weights.pth") | |
| last_generated_text = None | |
| last_user_prompt = None | |
| continue | |
| # --- !ENDORSE COMMAND --- | |
| if user_input.lower().startswith('!endorse'): | |
| if last_generated_text and last_user_prompt: | |
| try: | |
| # Usage: !endorse 10 (or just !endorse for default 5) | |
| parts = user_input.split() | |
| multiplier = int(parts[1]) if len(parts) > 1 else 5 | |
| # We only want to endorse the GOOD part (e.g., "hello how can i help you today") | |
| # You can manually edit the last_generated_text before endorsing if you want | |
| full_sentence = f"{last_user_prompt.strip()} {last_generated_text.strip()}" | |
| for _ in range(multiplier): | |
| live_data_updates.append(full_sentence) | |
| print(f"[SYSTEM] ENDORSED: Lore added {multiplier}x to queue.") | |
| print(f"[SYSTEM] Target: {full_sentence}") | |
| except ValueError: | |
| print("[SYSTEM] Usage: !endorse <number>") | |
| else: | |
| print("[SYSTEM] Nothing to endorse.") | |
| continue | |
| if user_input.lower() == '!refine': | |
| print(f"\n[Console] REFINING MODEL ON EXISTING DATA ({INTERACTIVE_EPOCHS} EPOCHS)...") | |
| initialize_or_retrain(initial_train=False, use_live_data=False, epochs=INTERACTIVE_EPOCHS) | |
| print("[Console] Refinement complete. Knowledge deepened on existing data.") | |
| continue | |
| if user_input.lower().startswith('!instead '): | |
| if last_user_prompt and last_generated_text: | |
| # The user provides the "Correct" version of the response | |
| corrected_output = user_input[9:].strip() | |
| # 1. LOG THE BAD ONE AS A PENALTY | |
| # We pair the prompt with the bad output so the model learns to avoid it | |
| penalty_record = f"{last_user_prompt} {last_generated_text}" | |
| penalties.append(penalty_record) | |
| with open('Penalties.csv', 'a', newline='', encoding='utf-8') as f: | |
| csv.writer(f).writerow([penalty_record, REPETITION_PENALTY]) | |
| # 2. ADD THE CORRECT ONE TO THE QUEUE (Endorse it 5x) | |
| full_correct_sentence = f"{last_user_prompt} {corrected_output}" | |
| for _ in range(5): | |
| live_data_updates.append(full_correct_sentence) | |
| print(f"[SYSTEM] Fixed! '{last_generated_text}' is now penalized.") | |
| print(f"[SYSTEM] Added correction: '{full_correct_sentence}' to training queue.") | |
| # Optional: Set the corrected text as the 'last_generated_text' | |
| # so you can !accept or !endorse it further | |
| last_generated_text = corrected_output | |
| torch.save(current_model.state_dict(), "aoban_weights.pth") | |
| print("[Console] Model weights permanently saved to aoban_weights.pth") | |
| else: | |
| print("[SYSTEM] Nothing to replace. Generate text first.") | |
| continue | |
| # --- !PENALTY COMMAND --- | |
| if user_input.lower().startswith('!penalty '): | |
| try: | |
| new_val = float(user_input[9:].strip()) | |
| REPETITION_PENALTY = new_val | |
| save_settings(REPETITION_PENALTY, TEMPERATURE) # Store just the value | |
| if last_user_prompt and last_generated_text: | |
| penalty_record = f"{last_user_prompt} {last_generated_text}" | |
| penalties.append(penalty_record) | |
| # Save the sentence PAIRED with the penalty value used | |
| with open('Penalties.csv', 'a', newline='', encoding='utf-8') as f: | |
| writer = csv.writer(f) | |
| # The index is automatically created by the row position in the CSV | |
| writer.writerow([penalty_record, REPETITION_PENALTY]) | |
| # 3. Regenerate using the new permanent penalty | |
| print(f"[Console] Regenerating with new saved penalty={REPETITION_PENALTY}...") | |
| generated_text = generate_text( | |
| current_model, | |
| current_tokenizer, | |
| last_user_prompt, | |
| MAX_SEQ_LENGTH, | |
| device, | |
| TOP_K, | |
| REPETITION_PENALTY, # Using the updated variable | |
| TEMPERATURE | |
| ) | |
| print(f"Model: {generated_text}") | |
| last_generated_text = generated_text | |
| print("\n[Console] If this full sentence is perfect, type '!accept'.") | |
| torch.save(current_model.state_dict(), "aoban_weights.pth") | |
| print("[Console] Model weights permanently saved to aoban_weights.pth") | |
| except ValueError: | |
| print(f"[Console] Invalid value. Usage: !penalty <number>.") | |
| continue | |
| if user_input.strip() and not user_input.lower().startswith(('!',)): | |
| # Text generation logic | |
| prompt = user_input.strip() | |
| if len(prompt.split()) > MAX_SEQ_LENGTH - 1: | |
| print(f"[Console] Prompt too long. Max {MAX_SEQ_LENGTH - 1} words supported.") | |
| last_generated_text = None | |
| last_user_prompt = None | |
| continue | |
| # 1. Store the prompt BEFORE generation | |
| last_user_prompt = prompt | |
| generated_text = generate_text( | |
| current_model, | |
| current_tokenizer, | |
| prompt, | |
| MAX_SEQ_LENGTH, | |
| device, | |
| TOP_K, | |
| REPETITION_PENALTY, | |
| TEMPERATURE | |
| ) | |
| print(f"Model: {generated_text}") | |
| # 2. Store the continuation AFTER generation | |
| last_generated_text = generated_text | |
| print("\n[Console] If this full sentence is perfect, type '!accept' to add it to the training queue.") | |
| except KeyboardInterrupt: | |
| print("\nExiting interactive mode.") | |
| break | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| break | |
| if __name__ == "__main__": | |
| # Load settings globally before anything starts | |
| REPETITION_PENALTY, TEMPERATURE = load_settings() | |
| print(f"[Console] Global Settings Initialized: Penalty={REPETITION_PENALTY}") | |
| interactive_mode() | |