Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Main application for the Hugging Face Memory Project. | |
| Handles conversation interface and memory management. | |
| """ | |
| import os | |
| from dotenv import load_dotenv | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| from typing import List, Dict, Optional | |
| # Import our advanced conversation model | |
| try: | |
| from src.conversation_model import ConversationModel | |
| except ImportError: | |
| # Fallback import for direct execution | |
| from conversation_model import ConversationModel | |
| # Load environment variables (with fallback if dotenv not available) | |
| try: | |
| load_dotenv() | |
| except ImportError: | |
| print("β οΈ python-dotenv not available, using default values") | |
| class MemoryAI: | |
| def __init__(self, use_advanced_model: bool = True): | |
| """Initialize the AI model and memory system.""" | |
| self.model_name = os.getenv("MODEL_NAME", "gpt2") | |
| self.max_memory = int(os.getenv("MAX_MEMORY_ENTRIES", 100)) | |
| self.data_dir = os.getenv("DATA_DIR", "data") | |
| self.models_dir = os.getenv("MODELS_DIR", "models") | |
| # Load generation parameters from environment | |
| self.temperature = float(os.getenv("TEMPERATURE", 0.7)) | |
| self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", 80)) | |
| self.top_p = float(os.getenv("TOP_P", 0.9)) | |
| self.repetition_penalty = float(os.getenv("REPETITION_PENALTY", 1.2)) | |
| # Initialize memory storage | |
| self.memories = [] | |
| # Initialize conversation model | |
| self.use_advanced_model = use_advanced_model | |
| self.conversation_model = None | |
| if use_advanced_model: | |
| try: | |
| print("Loading advanced conversation model...") | |
| self.conversation_model = ConversationModel() | |
| print("β Advanced conversation model loaded!") | |
| except Exception as e: | |
| print(f"β Error loading advanced model: {e}") | |
| print("Falling back to basic model...") | |
| self.use_advanced_model = False | |
| # Load basic model as fallback | |
| if not self.use_advanced_model: | |
| print(f"Loading basic {self.model_name} model...") | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = AutoModelForCausalLM.from_pretrained(self.model_name) | |
| # Move model to GPU if available | |
| if torch.cuda.is_available(): | |
| self.model = self.model.to('cuda') | |
| print("Using CUDA (GPU acceleration)") | |
| else: | |
| print("Using CPU") | |
| print(f"Memory capacity: {self.max_memory} entries") | |
| print(f"Generation params - Temp: {self.temperature}, Max tokens: {self.max_new_tokens}") | |
| def add_memory(self, memory_text): | |
| """Add a memory entry to the system.""" | |
| if len(self.memories) >= self.max_memory: | |
| self.memories.pop(0) # Remove oldest memory | |
| self.memories.append(memory_text) | |
| print(f"Memory added. Total memories: {len(self.memories)}") | |
| def generate_response(self, prompt, max_new_tokens=80, conversation_history=None): | |
| """Generate a response using the AI model with improved quality.""" | |
| # Use advanced conversation model if available | |
| if self.use_advanced_model and self.conversation_model: | |
| try: | |
| # Convert memory to conversation history format | |
| conv_history = [] | |
| if conversation_history: | |
| for entry in conversation_history: | |
| conv_history.append({"role": entry.get("role", "user"), | |
| "content": entry.get("content", entry.get("text", ""))}) | |
| # Generate response using advanced model | |
| response = self.conversation_model.generate_response(prompt, conv_history) | |
| # Add to memories | |
| self.add_memory(f"User: {prompt}") | |
| self.add_memory(f"AI: {response}") | |
| return response | |
| except Exception as e: | |
| print(f"Advanced model error: {e}") | |
| # Fallback to basic model | |
| pass | |
| # Fallback to basic model | |
| # Improved prompt engineering for conversational AI | |
| if "microsoft/DialoGPT" in self.model_name: | |
| # DialoGPT uses a different format | |
| improved_prompt = prompt | |
| else: | |
| # For other models, use better prompt engineering | |
| improved_prompt = f"{prompt}\n\nAssistant:" | |
| inputs = self.tokenizer(improved_prompt, return_tensors="pt") | |
| # Move inputs to same device as model | |
| if hasattr(self, 'model') and next(self.model.parameters()).is_cuda: | |
| inputs = {k: v.to('cuda') for k, v in inputs.items()} | |
| # Generate with better parameters | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=self.max_new_tokens, | |
| temperature=self.temperature, | |
| top_p=self.top_p, | |
| do_sample=True, # Enable sampling | |
| repetition_penalty=self.repetition_penalty, | |
| no_repeat_ngram_size=2, # Prevent exact repeats | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id | |
| ) | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the new part of the response | |
| response = response[len(improved_prompt):].strip() | |
| # Clean up response | |
| response = self._clean_response(response) | |
| # Fallback for poor responses | |
| if not response or len(response.split()) < 2 or response.startswith("I'm"): | |
| response = self._generate_fallback_response(prompt) | |
| return response | |
| def _generate_fallback_response(self, prompt): | |
| """Generate a fallback response when the model produces poor output.""" | |
| # Simple rule-based responses for common questions | |
| prompt_lower = prompt.lower() | |
| if "hello" in prompt_lower or "hi" in prompt_lower or "hey" in prompt_lower: | |
| return "Hello! I'm MemoryAI, your conversational assistant. How can I help you today?" | |
| elif "how are you" in prompt_lower: | |
| return "I'm doing well, thank you for asking! As an AI, I'm always ready to chat. How about you?" | |
| elif "your name" in prompt_lower: | |
| return "I'm MemoryAI! I'm designed to remember our conversations and provide helpful responses." | |
| elif "memory" in prompt_lower and ("work" in prompt_lower or "how" in prompt_lower): | |
| return "I remember our past conversations and use that context to provide better, more relevant responses. It's like having a conversation with someone who remembers what you've talked about before!" | |
| elif "thank" in prompt_lower or "thanks" in prompt_lower: | |
| return "You're welcome! I'm happy to help. Is there anything else you'd like to talk about?" | |
| elif "joke" in prompt_lower: | |
| return "Why don't scientists trust atoms? Because they make up everything!" | |
| elif "weather" in prompt_lower: | |
| return "I can't check real-time weather, but I hope it's nice where you are! What city are you in?" | |
| else: | |
| # Generic fallback | |
| return "That's an interesting question! As an AI with memory, I can tell you that we've talked about various topics. What would you like to discuss?" | |
| def _clean_response(self, response): | |
| """Clean up the AI response for better quality.""" | |
| # Remove incomplete sentences at the end | |
| if response.endswith(('...', '..', '.')) and len(response.split()) < 3: | |
| # If it's a very short response ending with dots, keep it | |
| pass | |
| else: | |
| # Remove trailing incomplete words | |
| response = response.rstrip('.,;:!?') | |
| # Remove excessive repetition | |
| words = response.split() | |
| if len(words) > 1: | |
| # Check for repeated phrases | |
| for i in range(min(3, len(words) // 2)): | |
| phrase = ' '.join(words[-i-1:-1]) | |
| if response.endswith(f" {phrase} {phrase}"): | |
| response = response[:-len(f" {phrase}")].rstrip() | |
| break | |
| # Capitalize first letter if it's a complete sentence | |
| if len(response) > 0 and response[0].islower(): | |
| response = response[0].upper() + response[1:] | |
| # Add punctuation if missing | |
| if len(response) > 0 and response[-1] not in ('.', '!', '?'): | |
| response += '.' | |
| return response | |
| def converse(self): | |
| """Start a conversation loop with the AI.""" | |
| print("π€ MemoryAI - Advanced Conversation Mode") | |
| print("Type 'quit' to exit.") | |
| print("Type '!memories' to see recent memories, '!clear' to clear memories") | |
| print("Type '!summary' for conversation summary, '!reset' to reset conversation") | |
| print("=" * 60) | |
| # Initialize conversation history for advanced model | |
| conversation_history = [] | |
| while True: | |
| user_input = input("π€ You: ") | |
| if user_input.lower() == 'quit': | |
| print("π€ AI: Goodbye! Have a great day!") | |
| break | |
| # Handle special commands | |
| if user_input.lower() == '!memories': | |
| recent_memories = self.get_recent_memories() | |
| print("π Recent memories:") | |
| for i, memory in enumerate(recent_memories, 1): | |
| print(f" {i}. {memory}") | |
| continue | |
| if user_input.lower() == '!clear': | |
| self.clear_memories() | |
| print("ποΈ Memories cleared!") | |
| continue | |
| if user_input.lower() == '!summary' and self.use_advanced_model: | |
| summary = self.get_conversation_summary() | |
| print(f"π {summary}") | |
| continue | |
| if user_input.lower() == '!reset': | |
| self.reset_conversation() | |
| conversation_history = [] | |
| continue | |
| if user_input.strip(): | |
| # Generate response with conversation history | |
| response = self.generate_response(user_input, conversation_history=conversation_history) | |
| print(f"π€ AI: {response}") | |
| # Update conversation history | |
| conversation_history.append({"role": "user", "content": user_input}) | |
| conversation_history.append({"role": "assistant", "content": response}) | |
| # Show conversation stats if using advanced model | |
| if self.use_advanced_model and self.conversation_model: | |
| stats = self.conversation_model.get_conversation_stats() | |
| print(f"π Topic: {stats['current_topic']} | Emotion: {stats['user_emotion']}") | |
| def get_available_models(self): | |
| """Get a list of commonly available models.""" | |
| models = [ | |
| "gpt2", | |
| "distilgpt2", | |
| "gpt2-medium", | |
| "gpt2-large", | |
| "EleutherAI/gpt-neo-125M", | |
| "facebook/opt-125m", | |
| "microsoft/DialoGPT-small", | |
| "microsoft/DialoGPT-medium" | |
| ] | |
| # Add advanced conversation models | |
| if self.use_advanced_model: | |
| models.extend([ | |
| "facebook/blenderbot-400M-distill", | |
| "facebook/blenderbot-1B-distill", | |
| "microsoft/DialoGPT-large" | |
| ]) | |
| return models | |
| def get_conversation_summary(self) -> str: | |
| """Get a summary of the current conversation.""" | |
| if not self.use_advanced_model or not self.conversation_model: | |
| return "Conversation summary available only with advanced model." | |
| return self.conversation_model.get_conversation_summary() | |
| def find_similar_memories(self, query: str, top_k: int = 3) -> list: | |
| """Find memories similar to the query using semantic search.""" | |
| if not self.use_advanced_model or not self.conversation_model: | |
| return [] | |
| return self.conversation_model.find_similar_conversations(query, top_k) | |
| def reset_conversation(self): | |
| """Reset the conversation state.""" | |
| if self.use_advanced_model and self.conversation_model: | |
| self.conversation_model.reset_conversation() | |
| print("Conversation reset successfully!") | |
| def save_memories(self): | |
| """Save memories to a file.""" | |
| memory_file = os.path.join(self.data_dir, "memories.txt") | |
| with open(memory_file, 'w') as f: | |
| for memory in self.memories: | |
| f.write(memory + "\n") | |
| print(f"Memories saved to {memory_file}") | |
| def load_memories(self): | |
| """Load memories from a file.""" | |
| memory_file = os.path.join(self.data_dir, "memories.txt") | |
| if os.path.exists(memory_file): | |
| with open(memory_file, 'r') as f: | |
| self.memories = [line.strip() for line in f.readlines() if line.strip()] | |
| print(f"Loaded {len(self.memories)} memories from {memory_file}") | |
| else: | |
| print("No existing memories found.") | |
| def get_recent_memories(self, count=5): | |
| """Get the most recent memories.""" | |
| return self.memories[-count:] if self.memories else [] | |
| def clear_memories(self): | |
| """Clear all memories.""" | |
| self.memories = [] | |
| print("All memories cleared.") | |
| if __name__ == "__main__": | |
| ai = MemoryAI() | |
| ai.load_memories() # Load existing memories | |
| ai.converse() | |
| ai.save_memories() |