""" interactive_chat.py - Fully Customizable MAP-NEO Mini Interactive Chat Interface Features: Real-time parameter tuning, conversation memory, context management, multiple responses """ import torch from transformers import AutoTokenizer from model_neo import NeoMini, NeoMiniConfig import os import json import time from pathlib import Path from datetime import datetime import gc class InteractiveChat: def __init__(self, checkpoint_path="checkpoints/extended_context_model.pt"): self.model = None self.tokenizer = None self.device = "cuda" if torch.cuda.is_available() else "cpu" self.conversation_history = [] self.max_context_length = 16384 # Generation parameters (fully customizable) self.params = { 'temperature': 0.7, 'top_k': 50, 'top_p': 0.9, 'repetition_penalty': 1.1, 'max_length': 150, 'do_sample': True, 'num_responses': 1 } print("๐Ÿš€ MAP-NEO Mini Interactive Chat Interface") print("=" * 60) self.load_model(checkpoint_path) def clear_gpu_cache(self): """Clear GPU memory cache""" if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() gc.collect() def get_memory_usage(self): """Get current GPU memory usage""" if not torch.cuda.is_available(): return "CPU only" allocated = torch.cuda.memory_allocated(0) / 1024**3 cached = torch.cuda.memory_reserved(0) / 1024**3 total = torch.cuda.get_device_properties(0).total_memory / 1024**3 return f"{allocated:.2f}GB/{total:.2f}GB (cached: {cached:.2f}GB)" def load_model(self, checkpoint_path): """Load model and tokenizer""" print(f"๐Ÿ“‚ Loading model from {checkpoint_path}...") if not os.path.exists(checkpoint_path): print(f"โŒ Checkpoint not found: {checkpoint_path}") return False try: checkpoint = torch.load(checkpoint_path, map_location=self.device) # Get context length from config if 'config' in checkpoint: self.max_context_length = checkpoint['config'].get('max_seq_len', 16384) # Load model config = NeoMiniConfig() config.max_seq_len = self.max_context_length self.model = NeoMini(config) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.eval() self.model = self.model.to(self.device) # Load tokenizer tokenizer_path = "data/tokenizer" if Path(tokenizer_path).exists(): self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) else: print("Using GPT-2 tokenizer as fallback...") self.tokenizer = AutoTokenizer.from_pretrained("gpt2") if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token print(f"โœ… Model loaded successfully!") print(f"๐Ÿง  Parameters: {self.model.get_num_params():,}") print(f"๐Ÿ“ Context window: {self.max_context_length:,} tokens") print(f"๐Ÿ’พ Memory: {self.get_memory_usage()}") return True except Exception as e: print(f"โŒ Error loading model: {e}") return False def format_conversation_context(self): """Format conversation history for model input""" if not self.conversation_history: return "" context = "The following is a conversation between a human and an AI assistant. The AI assistant is helpful, harmless, and honest.\n\n" for exchange in self.conversation_history: context += f"Human: {exchange['human']}\n" context += f"AI: {exchange['ai']}\n\n" return context def generate_response(self, user_input, num_responses=None): """Generate AI response(s) to user input""" if num_responses is None: num_responses = self.params['num_responses'] # Build full context context = self.format_conversation_context() full_prompt = context + f"Human: {user_input}\nAI: " # Check context length input_ids = self.tokenizer.encode(full_prompt, return_tensors="pt").to(self.device) prompt_length = input_ids.size(1) print(f"๐Ÿ“ Context: {prompt_length:,}/{self.max_context_length:,} tokens") if prompt_length >= self.max_context_length: print("โš ๏ธ Context too long, trimming conversation history...") self.trim_conversation_history() context = self.format_conversation_context() full_prompt = context + f"Human: {user_input}\nAI: " input_ids = self.tokenizer.encode(full_prompt, return_tensors="pt").to(self.device) prompt_length = input_ids.size(1) # Generate response(s) responses = [] for i in range(num_responses): print(f"๐Ÿค– Generating response {i+1}/{num_responses}...") with torch.no_grad(): generated = input_ids.clone() max_new_tokens = min(self.params['max_length'], self.max_context_length - prompt_length) for step in range(max_new_tokens): logits = self.model(generated) next_token_logits = logits[0, -1, :] / self.params['temperature'] # Apply repetition penalty if self.params['repetition_penalty'] != 1.0: for token_id in set(generated[0].tolist()): if next_token_logits[token_id] < 0: next_token_logits[token_id] *= self.params['repetition_penalty'] else: next_token_logits[token_id] /= self.params['repetition_penalty'] # Top-k filtering if self.params['top_k'] > 0: top_k_logits, _ = torch.topk(next_token_logits, self.params['top_k']) min_top_k = top_k_logits[-1] next_token_logits[next_token_logits < min_top_k] = float("-inf") # Top-p filtering if self.params['top_p'] < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > self.params['top_p'] sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] next_token_logits[indices_to_remove] = float("-inf") # Sample next token if self.params['do_sample']: probs = torch.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1) # Check stopping conditions if next_token.item() == self.tokenizer.eos_token_id: break # Check for natural stopping points if step > 10: # Only check after minimum generation decoded = self.tokenizer.decode(generated[0][prompt_length:], skip_special_tokens=True) if decoded.strip().endswith(('.', '!', '?', '\n\n')): break # Extract just the AI response full_response = self.tokenizer.decode(generated[0], skip_special_tokens=True) ai_response = full_response[len(full_prompt):].strip() # Clean up response if '\nHuman:' in ai_response: ai_response = ai_response.split('\nHuman:')[0].strip() responses.append(ai_response) return responses def trim_conversation_history(self): """Remove oldest conversation turns to fit context""" while len(self.conversation_history) > 1: self.conversation_history.pop(0) context = self.format_conversation_context() if len(self.tokenizer.encode(context)) < self.max_context_length // 2: break print(f"๐Ÿงน Trimmed conversation history to {len(self.conversation_history)} turns") def update_parameters(self): """Interactive parameter adjustment""" print("\n๐ŸŽ›๏ธ Current Generation Parameters:") for key, value in self.params.items(): print(f" {key}: {value}") print("\nEnter new values (press Enter to keep current):") # Temperature temp_input = input(f"Temperature (0.1-2.0, current: {self.params['temperature']}): ").strip() if temp_input: try: self.params['temperature'] = max(0.1, min(2.0, float(temp_input))) except ValueError: print("โŒ Invalid temperature, keeping current value") # Top-k topk_input = input(f"Top-k (0-100, current: {self.params['top_k']}): ").strip() if topk_input: try: self.params['top_k'] = max(0, min(100, int(topk_input))) except ValueError: print("โŒ Invalid top-k, keeping current value") # Top-p topp_input = input(f"Top-p (0.1-1.0, current: {self.params['top_p']}): ").strip() if topp_input: try: self.params['top_p'] = max(0.1, min(1.0, float(topp_input))) except ValueError: print("โŒ Invalid top-p, keeping current value") # Max length maxlen_input = input(f"Max length (10-500, current: {self.params['max_length']}): ").strip() if maxlen_input: try: self.params['max_length'] = max(10, min(500, int(maxlen_input))) except ValueError: print("โŒ Invalid max length, keeping current value") # Number of responses num_resp = input(f"Number of responses (1-3, current: {self.params['num_responses']}): ").strip() if num_resp: try: self.params['num_responses'] = max(1, min(3, int(num_resp))) except ValueError: print("โŒ Invalid number, keeping current value") print("โœ… Parameters updated!") def save_conversation(self): """Save conversation to file""" if not self.conversation_history: print("โŒ No conversation to save") return timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"conversation_{timestamp}.json" conversation_data = { 'timestamp': timestamp, 'model_info': { 'max_context': self.max_context_length, 'parameters': self.model.get_num_params() }, 'generation_params': self.params, 'conversation': self.conversation_history } with open(filename, 'w', encoding='utf-8') as f: json.dump(conversation_data, f, indent=2, ensure_ascii=False) print(f"๐Ÿ’พ Conversation saved to {filename}") def load_conversation(self, filename): """Load conversation from file""" try: with open(filename, 'r', encoding='utf-8') as f: conversation_data = json.load(f) self.conversation_history = conversation_data['conversation'] print(f"๐Ÿ“‚ Loaded conversation with {len(self.conversation_history)} turns") except Exception as e: print(f"โŒ Error loading conversation: {e}") def show_help(self): """Show available commands""" print("\n๐Ÿ”ง Available Commands:") print(" /help - Show this help message") print(" /params - Adjust generation parameters") print(" /clear - Clear conversation history") print(" /save - Save current conversation") print(" /load - Load conversation from file") print(" /memory - Show GPU memory usage") print(" /context - Show current context usage") print(" /multi - Generate n responses to next input") print(" /exit - Exit the chat") print(" /quit - Exit the chat") def run(self): """Main chat loop""" if not self.model or not self.tokenizer: print("โŒ Model not loaded. Exiting.") return print(f"\n๐Ÿ’ฌ Chat started! Context window: {self.max_context_length:,} tokens") print("Type /help for commands, /exit to quit") print("-" * 60) while True: try: # Get user input user_input = input("\n๐Ÿ‘ค You: ").strip() if not user_input: continue # Handle commands if user_input.startswith('/'): command = user_input.lower() if command in ['/exit', '/quit']: print("๐Ÿ‘‹ Goodbye!") break elif command == '/help': self.show_help() elif command == '/params': self.update_parameters() elif command == '/clear': self.conversation_history = [] self.clear_gpu_cache() print("๐Ÿงน Conversation history cleared") elif command == '/save': self.save_conversation() elif command.startswith('/load '): filename = command[6:].strip() self.load_conversation(filename) elif command == '/memory': print(f"๐Ÿ’พ GPU Memory: {self.get_memory_usage()}") elif command == '/context': context_length = len(self.tokenizer.encode(self.format_conversation_context())) print(f"๐Ÿ“ Current context: {context_length:,}/{self.max_context_length:,} tokens") elif command.startswith('/multi '): try: num = int(command[7:].strip()) self.params['num_responses'] = max(1, min(3, num)) print(f"๐ŸŽฏ Next response will generate {self.params['num_responses']} options") except ValueError: print("โŒ Invalid number format") else: print("โŒ Unknown command. Type /help for available commands.") continue # Generate response(s) start_time = time.time() responses = self.generate_response(user_input) generation_time = time.time() - start_time # Display response(s) if len(responses) == 1: print(f"\n๐Ÿค– AI: {responses[0]}") chosen_response = responses[0] else: print(f"\n๐Ÿค– AI generated {len(responses)} responses:") for i, response in enumerate(responses, 1): print(f"\n[{i}] {response}") while True: choice = input(f"\nChoose response (1-{len(responses)}, Enter for 1): ").strip() if not choice: choice = "1" try: choice_idx = int(choice) - 1 if 0 <= choice_idx < len(responses): chosen_response = responses[choice_idx] break else: print(f"โŒ Invalid choice. Enter 1-{len(responses)}") except ValueError: print("โŒ Invalid input. Enter a number.") # Add to conversation history self.conversation_history.append({ 'human': user_input, 'ai': chosen_response, 'timestamp': datetime.now().isoformat(), 'generation_time': round(generation_time, 2) }) # Reset num_responses if it was changed if self.params['num_responses'] != 1: self.params['num_responses'] = 1 print(f"โฑ๏ธ Generated in {generation_time:.2f}s | ๐Ÿ’พ {self.get_memory_usage()}") except KeyboardInterrupt: print("\n\n๐Ÿ‘‹ Chat interrupted. Goodbye!") break except Exception as e: print(f"\nโŒ Error: {e}") self.clear_gpu_cache() def main(): # Allow custom checkpoint path import sys checkpoint_path = "checkpoints/extended_context_model.pt" if len(sys.argv) > 1: checkpoint_path = sys.argv[1] chat = InteractiveChat(checkpoint_path) chat.run() if __name__ == "__main__": main()