memoryotherone / src /main.py
artecnosomatic's picture
Fix python-dotenv compatibility issue and add advanced conversation model
c5b2741
#!/usr/bin/env python3
"""
Main application for the Hugging Face Memory Project.
Handles conversation interface and memory management.
"""
import os
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import List, Dict, Optional
# Import our advanced conversation model
try:
from src.conversation_model import ConversationModel
except ImportError:
# Fallback import for direct execution
from conversation_model import ConversationModel
# Load environment variables (with fallback if dotenv not available)
try:
load_dotenv()
except ImportError:
print("⚠️ python-dotenv not available, using default values")
class MemoryAI:
def __init__(self, use_advanced_model: bool = True):
"""Initialize the AI model and memory system."""
self.model_name = os.getenv("MODEL_NAME", "gpt2")
self.max_memory = int(os.getenv("MAX_MEMORY_ENTRIES", 100))
self.data_dir = os.getenv("DATA_DIR", "data")
self.models_dir = os.getenv("MODELS_DIR", "models")
# Load generation parameters from environment
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", 80))
self.top_p = float(os.getenv("TOP_P", 0.9))
self.repetition_penalty = float(os.getenv("REPETITION_PENALTY", 1.2))
# Initialize memory storage
self.memories = []
# Initialize conversation model
self.use_advanced_model = use_advanced_model
self.conversation_model = None
if use_advanced_model:
try:
print("Loading advanced conversation model...")
self.conversation_model = ConversationModel()
print("βœ… Advanced conversation model loaded!")
except Exception as e:
print(f"❌ Error loading advanced model: {e}")
print("Falling back to basic model...")
self.use_advanced_model = False
# Load basic model as fallback
if not self.use_advanced_model:
print(f"Loading basic {self.model_name} model...")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
# Move model to GPU if available
if torch.cuda.is_available():
self.model = self.model.to('cuda')
print("Using CUDA (GPU acceleration)")
else:
print("Using CPU")
print(f"Memory capacity: {self.max_memory} entries")
print(f"Generation params - Temp: {self.temperature}, Max tokens: {self.max_new_tokens}")
def add_memory(self, memory_text):
"""Add a memory entry to the system."""
if len(self.memories) >= self.max_memory:
self.memories.pop(0) # Remove oldest memory
self.memories.append(memory_text)
print(f"Memory added. Total memories: {len(self.memories)}")
def generate_response(self, prompt, max_new_tokens=80, conversation_history=None):
"""Generate a response using the AI model with improved quality."""
# Use advanced conversation model if available
if self.use_advanced_model and self.conversation_model:
try:
# Convert memory to conversation history format
conv_history = []
if conversation_history:
for entry in conversation_history:
conv_history.append({"role": entry.get("role", "user"),
"content": entry.get("content", entry.get("text", ""))})
# Generate response using advanced model
response = self.conversation_model.generate_response(prompt, conv_history)
# Add to memories
self.add_memory(f"User: {prompt}")
self.add_memory(f"AI: {response}")
return response
except Exception as e:
print(f"Advanced model error: {e}")
# Fallback to basic model
pass
# Fallback to basic model
# Improved prompt engineering for conversational AI
if "microsoft/DialoGPT" in self.model_name:
# DialoGPT uses a different format
improved_prompt = prompt
else:
# For other models, use better prompt engineering
improved_prompt = f"{prompt}\n\nAssistant:"
inputs = self.tokenizer(improved_prompt, return_tensors="pt")
# Move inputs to same device as model
if hasattr(self, 'model') and next(self.model.parameters()).is_cuda:
inputs = {k: v.to('cuda') for k, v in inputs.items()}
# Generate with better parameters
outputs = self.model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
temperature=self.temperature,
top_p=self.top_p,
do_sample=True, # Enable sampling
repetition_penalty=self.repetition_penalty,
no_repeat_ngram_size=2, # Prevent exact repeats
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the new part of the response
response = response[len(improved_prompt):].strip()
# Clean up response
response = self._clean_response(response)
# Fallback for poor responses
if not response or len(response.split()) < 2 or response.startswith("I'm"):
response = self._generate_fallback_response(prompt)
return response
def _generate_fallback_response(self, prompt):
"""Generate a fallback response when the model produces poor output."""
# Simple rule-based responses for common questions
prompt_lower = prompt.lower()
if "hello" in prompt_lower or "hi" in prompt_lower or "hey" in prompt_lower:
return "Hello! I'm MemoryAI, your conversational assistant. How can I help you today?"
elif "how are you" in prompt_lower:
return "I'm doing well, thank you for asking! As an AI, I'm always ready to chat. How about you?"
elif "your name" in prompt_lower:
return "I'm MemoryAI! I'm designed to remember our conversations and provide helpful responses."
elif "memory" in prompt_lower and ("work" in prompt_lower or "how" in prompt_lower):
return "I remember our past conversations and use that context to provide better, more relevant responses. It's like having a conversation with someone who remembers what you've talked about before!"
elif "thank" in prompt_lower or "thanks" in prompt_lower:
return "You're welcome! I'm happy to help. Is there anything else you'd like to talk about?"
elif "joke" in prompt_lower:
return "Why don't scientists trust atoms? Because they make up everything!"
elif "weather" in prompt_lower:
return "I can't check real-time weather, but I hope it's nice where you are! What city are you in?"
else:
# Generic fallback
return "That's an interesting question! As an AI with memory, I can tell you that we've talked about various topics. What would you like to discuss?"
def _clean_response(self, response):
"""Clean up the AI response for better quality."""
# Remove incomplete sentences at the end
if response.endswith(('...', '..', '.')) and len(response.split()) < 3:
# If it's a very short response ending with dots, keep it
pass
else:
# Remove trailing incomplete words
response = response.rstrip('.,;:!?')
# Remove excessive repetition
words = response.split()
if len(words) > 1:
# Check for repeated phrases
for i in range(min(3, len(words) // 2)):
phrase = ' '.join(words[-i-1:-1])
if response.endswith(f" {phrase} {phrase}"):
response = response[:-len(f" {phrase}")].rstrip()
break
# Capitalize first letter if it's a complete sentence
if len(response) > 0 and response[0].islower():
response = response[0].upper() + response[1:]
# Add punctuation if missing
if len(response) > 0 and response[-1] not in ('.', '!', '?'):
response += '.'
return response
def converse(self):
"""Start a conversation loop with the AI."""
print("πŸ€– MemoryAI - Advanced Conversation Mode")
print("Type 'quit' to exit.")
print("Type '!memories' to see recent memories, '!clear' to clear memories")
print("Type '!summary' for conversation summary, '!reset' to reset conversation")
print("=" * 60)
# Initialize conversation history for advanced model
conversation_history = []
while True:
user_input = input("πŸ‘€ You: ")
if user_input.lower() == 'quit':
print("πŸ€– AI: Goodbye! Have a great day!")
break
# Handle special commands
if user_input.lower() == '!memories':
recent_memories = self.get_recent_memories()
print("πŸ“š Recent memories:")
for i, memory in enumerate(recent_memories, 1):
print(f" {i}. {memory}")
continue
if user_input.lower() == '!clear':
self.clear_memories()
print("πŸ—‘οΈ Memories cleared!")
continue
if user_input.lower() == '!summary' and self.use_advanced_model:
summary = self.get_conversation_summary()
print(f"πŸ“Š {summary}")
continue
if user_input.lower() == '!reset':
self.reset_conversation()
conversation_history = []
continue
if user_input.strip():
# Generate response with conversation history
response = self.generate_response(user_input, conversation_history=conversation_history)
print(f"πŸ€– AI: {response}")
# Update conversation history
conversation_history.append({"role": "user", "content": user_input})
conversation_history.append({"role": "assistant", "content": response})
# Show conversation stats if using advanced model
if self.use_advanced_model and self.conversation_model:
stats = self.conversation_model.get_conversation_stats()
print(f"πŸ“Š Topic: {stats['current_topic']} | Emotion: {stats['user_emotion']}")
def get_available_models(self):
"""Get a list of commonly available models."""
models = [
"gpt2",
"distilgpt2",
"gpt2-medium",
"gpt2-large",
"EleutherAI/gpt-neo-125M",
"facebook/opt-125m",
"microsoft/DialoGPT-small",
"microsoft/DialoGPT-medium"
]
# Add advanced conversation models
if self.use_advanced_model:
models.extend([
"facebook/blenderbot-400M-distill",
"facebook/blenderbot-1B-distill",
"microsoft/DialoGPT-large"
])
return models
def get_conversation_summary(self) -> str:
"""Get a summary of the current conversation."""
if not self.use_advanced_model or not self.conversation_model:
return "Conversation summary available only with advanced model."
return self.conversation_model.get_conversation_summary()
def find_similar_memories(self, query: str, top_k: int = 3) -> list:
"""Find memories similar to the query using semantic search."""
if not self.use_advanced_model or not self.conversation_model:
return []
return self.conversation_model.find_similar_conversations(query, top_k)
def reset_conversation(self):
"""Reset the conversation state."""
if self.use_advanced_model and self.conversation_model:
self.conversation_model.reset_conversation()
print("Conversation reset successfully!")
def save_memories(self):
"""Save memories to a file."""
memory_file = os.path.join(self.data_dir, "memories.txt")
with open(memory_file, 'w') as f:
for memory in self.memories:
f.write(memory + "\n")
print(f"Memories saved to {memory_file}")
def load_memories(self):
"""Load memories from a file."""
memory_file = os.path.join(self.data_dir, "memories.txt")
if os.path.exists(memory_file):
with open(memory_file, 'r') as f:
self.memories = [line.strip() for line in f.readlines() if line.strip()]
print(f"Loaded {len(self.memories)} memories from {memory_file}")
else:
print("No existing memories found.")
def get_recent_memories(self, count=5):
"""Get the most recent memories."""
return self.memories[-count:] if self.memories else []
def clear_memories(self):
"""Clear all memories."""
self.memories = []
print("All memories cleared.")
if __name__ == "__main__":
ai = MemoryAI()
ai.load_memories() # Load existing memories
ai.converse()
ai.save_memories()