# state_manager.py """ Global state management and logical expression system for Mimir. Components: - GlobalStateManager: Thread-safe state persistence with SQLite + HF dataset backup - PromptStateManager: Per-turn prompt segment activation tracking - LogicalExpressions: Regex-based prompt triggers """ import os import re import sqlite3 import json import logging import threading from datetime import datetime, timedelta from typing import Dict, List, Optional, Any from datasets import load_dataset, Dataset from huggingface_hub import HfApi logger = logging.getLogger(__name__) # ============================================================================ # PROMPT STATE MANAGER # ============================================================================ class PromptStateManager: """ Manages prompt segment activation state for a single turn. Resets to default (all False) at the start of each turn. """ def __init__(self): self._default_state = { "MATH_THINKING": False, "QUESTION_ANSWER_DESIGN": False, "REASONING_THINKING": False, "VAUGE_INPUT": False, "USER_UNDERSTANDING": False, "GENERAL_FORMATTING": False, "LATEX_FORMATTING": False, "GUIDING_TEACHING": False, "STRUCTURE_PRACTICE_QUESTIONS": False, "PRACTICE_QUESTION_FOLLOWUP": False, "TOOL_USE_ENHANCEMENT": False, } self._current_state = self._default_state.copy() logger.info("PromptStateManager initialized") def reset(self): """Reset all prompt states to False for new turn""" self._current_state = self._default_state.copy() logger.debug("Prompt state reset for new turn") def get_state(self) -> Dict[str, bool]: """Get current prompt state dictionary""" return self._current_state.copy() def update(self, prompt_name: str, value: bool): """ Update a specific prompt state. Args: prompt_name: Name of prompt segment (must be in default_state) value: True to activate, False to deactivate """ if prompt_name not in self._default_state: logger.warning(f"Unknown prompt name: {prompt_name}") return self._current_state[prompt_name] = value logger.debug(f"Prompt state updated: {prompt_name} = {value}") def update_multiple(self, updates: Dict[str, bool]): """ Update multiple prompt states at once. Args: updates: Dictionary of {prompt_name: bool} updates """ for prompt_name, value in updates.items(): self.update(prompt_name, value) def is_active(self, prompt_name: str) -> bool: """Check if a prompt segment is active""" return self._current_state.get(prompt_name, False) def get_active_prompts(self) -> List[str]: """Get list of all currently active prompt names""" return [name for name, active in self._current_state.items() if active] def get_active_response_prompts(self) -> List[str]: """ Get list of active response agent prompts only. Excludes thinking agent prompts. """ response_prompts = [ "VAUGE_INPUT", "USER_UNDERSTANDING", "GENERAL_FORMATTING", "LATEX_FORMATTING", "GUIDING_TEACHING", "STRUCTURE_PRACTICE_QUESTIONS", "PRACTICE_QUESTION_FOLLOWUP", "TOOL_USE_ENHANCEMENT" ] return [name for name in response_prompts if self._current_state.get(name, False)] def get_active_thinking_prompts(self) -> List[str]: """ Get list of active thinking agent prompts only. """ thinking_prompts = ["MATH_THINKING", "QUESTION_ANSWER_DESIGN", "REASONING_THINKING"] return [name for name in thinking_prompts if self._current_state.get(name, False)] # ============================================================================ # LOGICAL EXPRESSIONS # ============================================================================ class LogicalExpressions: """ Regex-based logical expressions for prompt trigger detection. Analyzes user input to activate appropriate prompt segments. """ def __init__(self): # Math-related keywords self.math_regex = r'\b(math|calculus|algebra|geometry|equation|formula|solve|calculate|derivative|integral|trigonometry|statistics|probability)\b' # Additional regex patterns can be added here logger.info("LogicalExpressions initialized") def check_math_keywords(self, user_input: str) -> bool: """ Check if user input contains mathematical keywords. Triggers LATEX_FORMATTING. Args: user_input: User's message Returns: True if math keywords detected """ result = bool(re.search(self.math_regex, user_input, re.IGNORECASE)) if result: logger.debug(f"Math keywords detected in: '{user_input[:50]}...'") return result def apply_all_checks(self, user_input: str, prompt_state: PromptStateManager): """ Apply all logical expression checks and update prompt_state. Args: user_input: User's message prompt_state: PromptStateManager instance to update """ # GENERAL_FORMATTING is always applied prompt_state.update("GENERAL_FORMATTING", True) # Check for math keywords if self.check_math_keywords(user_input): prompt_state.update("LATEX_FORMATTING", True) # Additional checks can be added here as needed logger.debug(f"Logical expressions applied. Active prompts: {prompt_state.get_active_prompts()}") # ============================================================================ # GLOBAL STATE MANAGER # ============================================================================ class GlobalStateManager: """ Thread-safe global state manager with SQLite persistence and HF dataset backup. Now includes PromptStateManager for per-turn prompt segment tracking. """ def __init__(self, db_path="mimir_analytics.db", dataset_repo="jdesiree/mimir_analytics"): self._db_path = db_path self.dataset_repo = dataset_repo self.hf_token = os.getenv("HF_TOKEN") # Existing state caches self._states = {} self._analytics_cache = {} self._ml_models_cache = {} self._evaluation_cache = {} # Thread safety self._lock = threading.Lock() # Cleanup settings self._cleanup_interval = 3600 self._max_age = 24 * 3600 self._last_cleanup = datetime.now() self._last_hf_backup = datetime.now() self._hf_backup_interval = 3600 # NEW: Prompt state management self._prompt_state_manager = PromptStateManager() # Initialize existing systems self._init_database() self._load_from_database() self._load_from_hf_dataset() logger.info("GlobalStateManager initialized with PromptStateManager") # ======================================================================== # PROMPT STATE MANAGEMENT # ======================================================================== def get_prompt_state_manager(self) -> PromptStateManager: """Get the prompt state manager for current turn""" return self._prompt_state_manager def reset_prompt_state(self): """Reset prompt state for new turn""" self._prompt_state_manager.reset() logger.debug("Prompt state reset for new turn") def get_prompt_state(self) -> Dict[str, bool]: """Get current prompt state dictionary""" return self._prompt_state_manager.get_state() def update_prompt_state(self, prompt_name: str, value: bool): """Update specific prompt state""" self._prompt_state_manager.update(prompt_name, value) def update_prompt_states(self, updates: Dict[str, bool]): """Update multiple prompt states""" self._prompt_state_manager.update_multiple(updates) # ======================================================================== # EXISTING DATABASE METHODS (unchanged) # ======================================================================== def _init_database(self): """Initialize SQLite database for persistent storage""" conn = sqlite3.connect(self._db_path) cursor = conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS conversations ( session_id TEXT PRIMARY KEY, chat_history TEXT, conversation_state TEXT, last_accessed TEXT, created TEXT ) """) cursor.execute(""" CREATE TABLE IF NOT EXISTS analytics ( session_id TEXT PRIMARY KEY, project_stats TEXT, recent_interactions TEXT, dashboard_html TEXT, last_refresh TEXT, export_history TEXT ) """) cursor.execute(""" CREATE TABLE IF NOT EXISTS evaluations ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id TEXT, timestamp TEXT, metric_type TEXT, metric_data TEXT ) """) cursor.execute(""" CREATE TABLE IF NOT EXISTS classifications ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id TEXT, timestamp TEXT, user_input TEXT, prediction_data TEXT, features TEXT ) """) conn.commit() conn.close() def _load_from_database(self): """Load all data from SQLite on startup""" try: conn = sqlite3.connect(self._db_path) cursor = conn.cursor() cursor.execute("SELECT * FROM conversations") for row in cursor.fetchall(): session_id = row[0] self._states[session_id] = { 'chat_history': json.loads(row[1]), 'conversation_state': json.loads(row[2]), 'last_accessed': datetime.fromisoformat(row[3]), 'created': datetime.fromisoformat(row[4]) } cursor.execute("SELECT * FROM analytics") for row in cursor.fetchall(): session_id = row[0] self._analytics_cache[session_id] = { 'project_stats': json.loads(row[1]), 'recent_interactions': json.loads(row[2]), 'dashboard_html': row[3], 'last_refresh': datetime.fromisoformat(row[4]) if row[4] else None, 'export_history': json.loads(row[5]), 'last_accessed': datetime.now() } conn.close() logger.info(f"Loaded {len(self._states)} conversations and {len(self._analytics_cache)} analytics from database") except Exception as e: logger.error(f"Error loading from database: {e}") def _load_from_hf_dataset(self): """Load data from HF dataset on startup""" try: ds = load_dataset(self.dataset_repo, split="train", token=self.hf_token) for item in ds: if item['data_type'] == 'conversation': session_id = item['session_id'] data = json.loads(item['data']) self._states[session_id] = data elif item['data_type'] == 'analytics': session_id = item['session_id'] data = json.loads(item['data']) self._analytics_cache[session_id] = data logger.info(f"Loaded data from HF dataset {self.dataset_repo}") except Exception as e: logger.warning(f"Could not load from HF dataset: {e}") def _save_to_database_conversations(self, session_id): """Save conversation to SQLite""" if session_id not in self._states: return state = self._states[session_id] conn = sqlite3.connect(self._db_path) cursor = conn.cursor() cursor.execute(""" INSERT OR REPLACE INTO conversations (session_id, chat_history, conversation_state, last_accessed, created) VALUES (?, ?, ?, ?, ?) """, ( session_id, json.dumps(state['chat_history']), json.dumps(state['conversation_state']), state['last_accessed'].isoformat(), state.get('created', datetime.now()).isoformat() )) conn.commit() conn.close() def _save_to_database_analytics(self, session_id): """Save analytics to SQLite""" if session_id not in self._analytics_cache: return analytics = self._analytics_cache[session_id] conn = sqlite3.connect(self._db_path) cursor = conn.cursor() cursor.execute(""" INSERT OR REPLACE INTO analytics (session_id, project_stats, recent_interactions, dashboard_html, last_refresh, export_history) VALUES (?, ?, ?, ?, ?, ?) """, ( session_id, json.dumps(analytics.get('project_stats', {})), json.dumps(analytics.get('recent_interactions', [])), analytics.get('dashboard_html', ''), analytics.get('last_refresh').isoformat() if analytics.get('last_refresh') else None, json.dumps(analytics.get('export_history', [])) )) conn.commit() conn.close() def _backup_to_hf_dataset(self): """Backup all data to HF dataset""" if (datetime.now() - self._last_hf_backup).seconds < self._hf_backup_interval: return try: data_items = [] for session_id, state in self._states.items(): data_items.append({ 'session_id': session_id, 'data_type': 'conversation', 'data': json.dumps(state, default=str), 'timestamp': datetime.now().isoformat() }) for session_id, analytics in self._analytics_cache.items(): data_items.append({ 'session_id': session_id, 'data_type': 'analytics', 'data': json.dumps(analytics, default=str), 'timestamp': datetime.now().isoformat() }) if data_items: ds = Dataset.from_list(data_items) ds.push_to_hub(self.dataset_repo, token=self.hf_token) self._last_hf_backup = datetime.now() logger.info(f"Backed up {len(data_items)} items to HF dataset") except Exception as e: logger.error(f"Error backing up to HF dataset: {e}") def _cleanup_old_states(self): """Remove old unused states to prevent memory leaks""" now = datetime.now() if (now - self._last_cleanup).seconds < self._cleanup_interval: return with self._lock: expired_keys = [] for session_id, state_data in self._states.items(): if (now - state_data.get('last_accessed', now)).seconds > self._max_age: expired_keys.append(session_id) for key in expired_keys: del self._states[key] logger.info(f"Cleaned up expired state: {key}") self._last_cleanup = now # ======================================================================== # CONVERSATION STATE METHODS (unchanged) # ======================================================================== def get_session_id(self, request=None): """Generate or retrieve session ID""" return "default_session" def get_conversation_state(self, session_id=None): """Get conversation state for a session""" if session_id is None: session_id = self.get_session_id() self._cleanup_old_states() with self._lock: if session_id not in self._states: self._states[session_id] = { 'chat_history': [], 'conversation_state': [], 'last_accessed': datetime.now(), 'created': datetime.now() } else: self._states[session_id]['last_accessed'] = datetime.now() return self._states[session_id].copy() def update_conversation_state(self, chat_history, conversation_state, session_id=None): """Update conversation state for a session""" if session_id is None: session_id = self.get_session_id() with self._lock: if session_id not in self._states: self._states[session_id] = {} self._states[session_id].update({ 'chat_history': chat_history.copy() if chat_history else [], 'conversation_state': conversation_state.copy() if conversation_state else [], 'last_accessed': datetime.now() }) # self._save_to_database_conversations(session_id) # self._backup_to_hf_dataset() threading.Thread(target=self._save_to_database_conversations, args=(session_id,), daemon=True).start() if (datetime.now() - self._last_hf_backup).seconds >= self._hf_backup_interval: threading.Thread(target=self._backup_to_hf_dataset, daemon=True).start() def reset_conversation_state(self, session_id=None): """Reset conversation state for a session""" if session_id is None: session_id = self.get_session_id() with self._lock: if session_id in self._states: self._states[session_id].update({ 'chat_history': [], 'conversation_state': [], 'last_accessed': datetime.now() }) self._save_to_database_conversations(session_id) def get_all_sessions(self): """Get all active sessions (for analytics)""" self._cleanup_old_states() with self._lock: return list(self._states.keys()) # ======================================================================== # ANALYTICS STATE METHODS (unchanged) # ======================================================================== def get_analytics_state(self, session_id=None): """Get analytics state for a session""" if session_id is None: session_id = self.get_session_id() self._cleanup_old_states() with self._lock: if session_id not in self._analytics_cache: self._analytics_cache[session_id] = { 'project_stats': { "total_conversations": None, "avg_session_length": None, "success_rate": None, "model_type": "Phi-3-mini (Fine-tuned)", "last_updated": None }, 'recent_interactions': [], 'dashboard_html': None, 'last_refresh': None, 'export_history': [], 'database_status': 'unknown', 'error_state': None, 'last_accessed': datetime.now() } else: self._analytics_cache[session_id]['last_accessed'] = datetime.now() return self._analytics_cache[session_id].copy() def update_analytics_state(self, project_stats=None, recent_interactions=None, dashboard_html=None, error_state=None, session_id=None): """Update analytics state for a session""" if session_id is None: session_id = self.get_session_id() with self._lock: if session_id not in self._analytics_cache: self._analytics_cache[session_id] = {} current_time = datetime.now() if project_stats is not None: self._analytics_cache[session_id]['project_stats'] = project_stats.copy() self._analytics_cache[session_id]['last_refresh'] = current_time if recent_interactions is not None: self._analytics_cache[session_id]['recent_interactions'] = recent_interactions.copy() if dashboard_html is not None: self._analytics_cache[session_id]['dashboard_html'] = dashboard_html if error_state is not None: self._analytics_cache[session_id]['error_state'] = error_state self._analytics_cache[session_id]['last_accessed'] = current_time self._save_to_database_analytics(session_id) self._backup_to_hf_dataset() def add_export_record(self, export_type, filename, success=True, session_id=None): """Add export record to analytics state""" if session_id is None: session_id = self.get_session_id() with self._lock: if session_id not in self._analytics_cache: self.get_analytics_state(session_id) export_record = { 'timestamp': datetime.now().isoformat(), 'type': export_type, 'filename': filename, 'success': success } if 'export_history' not in self._analytics_cache[session_id]: self._analytics_cache[session_id]['export_history'] = [] self._analytics_cache[session_id]['export_history'].append(export_record) if len(self._analytics_cache[session_id]['export_history']) > 20: self._analytics_cache[session_id]['export_history'] = \ self._analytics_cache[session_id]['export_history'][-20:] self._save_to_database_analytics(session_id) # ======================================================================== # ML MODEL CACHE METHODS (unchanged) # ======================================================================== def get_ml_model_cache(self, model_type: str = "prompt_classifier"): """Get cached ML model""" with self._lock: return self._ml_models_cache.get(model_type, None) def cache_ml_model(self, model, model_type: str = "prompt_classifier", metadata: dict = None): """Cache a trained ML model""" with self._lock: self._ml_models_cache[model_type] = { 'model': model, 'cached_at': datetime.now(), 'metadata': metadata or {}, 'access_count': 0 } logger.info(f"ML model '{model_type}' cached successfully") # ======================================================================== # EVALUATION STATE METHODS (unchanged) # ======================================================================== def get_evaluation_state(self, session_id=None): """Get evaluation state for a session""" if session_id is None: session_id = self.get_session_id() with self._lock: if session_id not in self._evaluation_cache: self._evaluation_cache[session_id] = { 'educational_quality_scores': [], 'rag_performance_metrics': [], 'prompt_classification_accuracy': [], 'user_feedback_history': [], 'aggregate_metrics': { 'avg_educational_quality': 0.0, 'avg_rag_relevance': 0.0, 'classifier_accuracy_rate': 0.0, 'user_satisfaction_rate': 0.0 }, 'evaluation_session_count': 0, 'last_updated': datetime.now() } return self._evaluation_cache[session_id].copy() def add_educational_quality_score(self, user_query: str, response: str, metrics: dict, session_id=None): """Add educational quality evaluation result""" if session_id is None: session_id = self.get_session_id() with self._lock: if session_id not in self._evaluation_cache: self.get_evaluation_state(session_id) quality_record = { 'timestamp': datetime.now().isoformat(), 'user_query': user_query[:100], 'response_length': len(response), 'semantic_quality': metrics.get('semantic_quality', 0.0), 'educational_score': metrics.get('educational_score', 0.0), 'response_time': metrics.get('response_time', 0.0), 'overall_score': (metrics.get('semantic_quality', 0.0) + metrics.get('educational_score', 0.0)) / 2 } self._evaluation_cache[session_id]['educational_quality_scores'].append(quality_record) self._update_aggregate_metrics(session_id) def add_prompt_classification_result(self, predicted_mode: str, was_successful: bool, metadata: dict = None, session_id=None): """Add prompt classification accuracy result""" if session_id is None: session_id = self.get_session_id() with self._lock: if session_id not in self._evaluation_cache: self.get_evaluation_state(session_id) classification_record = { 'timestamp': datetime.now().isoformat(), 'predicted_mode': predicted_mode, 'was_successful': was_successful, 'accuracy_score': 1.0 if was_successful else 0.0, 'metadata': metadata or {} } self._evaluation_cache[session_id]['prompt_classification_accuracy'].append(classification_record) self._update_aggregate_metrics(session_id) def add_user_feedback(self, response_id: str, feedback_type: str, conversation_context: dict = None, session_id=None): """Add user feedback result""" if session_id is None: session_id = self.get_session_id() with self._lock: if session_id not in self._evaluation_cache: self.get_evaluation_state(session_id) feedback_record = { 'timestamp': datetime.now().isoformat(), 'response_id': response_id, 'feedback_type': feedback_type, 'satisfaction_score': 1.0 if feedback_type == 'thumbs_up' else 0.0, 'conversation_context': conversation_context or {} } self._evaluation_cache[session_id]['user_feedback_history'].append(feedback_record) self._update_aggregate_metrics(session_id) def _update_aggregate_metrics(self, session_id: str): """Update aggregate metrics for a session""" eval_state = self._evaluation_cache[session_id] if eval_state['educational_quality_scores']: avg_educational = sum(score['overall_score'] for score in eval_state['educational_quality_scores']) / len(eval_state['educational_quality_scores']) eval_state['aggregate_metrics']['avg_educational_quality'] = avg_educational if eval_state['prompt_classification_accuracy']: accuracy_rate = sum(result['accuracy_score'] for result in eval_state['prompt_classification_accuracy']) / len(eval_state['prompt_classification_accuracy']) eval_state['aggregate_metrics']['classifier_accuracy_rate'] = accuracy_rate if eval_state['user_feedback_history']: satisfaction_rate = sum(feedback['satisfaction_score'] for feedback in eval_state['user_feedback_history']) / len(eval_state['user_feedback_history']) eval_state['aggregate_metrics']['user_satisfaction_rate'] = satisfaction_rate eval_state['last_updated'] = datetime.now() eval_state['evaluation_session_count'] += 1 def get_evaluation_summary(self, session_id=None, include_history: bool = False): """Get evaluation summary for analytics""" if session_id is None: session_id = self.get_session_id() eval_state = self.get_evaluation_state(session_id) summary = { 'aggregate_metrics': eval_state['aggregate_metrics'], 'total_evaluations': { 'educational_quality': len(eval_state['educational_quality_scores']), 'classification_accuracy': len(eval_state['prompt_classification_accuracy']), 'user_feedback': len(eval_state['user_feedback_history']) }, 'last_updated': eval_state['last_updated'], 'session_evaluation_count': eval_state['evaluation_session_count'] } if include_history: summary['history'] = { 'recent_educational_scores': eval_state['educational_quality_scores'][-10:], 'recent_classification_results': eval_state['prompt_classification_accuracy'][-10:], 'recent_user_feedback': eval_state['user_feedback_history'][-10:] } return summary # ======================================================================== # UTILITY METHODS # ======================================================================== def get_cache_status(self, session_id=None): """Get cache status for debugging""" if session_id is None: session_id = self.get_session_id() with self._lock: analytics_cached = session_id in self._analytics_cache conversation_cached = session_id in self._states cache_info = { 'session_id': session_id, 'analytics_cached': analytics_cached, 'conversation_cached': conversation_cached, 'total_analytics_sessions': len(self._analytics_cache), 'total_conversation_sessions': len(self._states), 'prompt_state_active_count': len(self._prompt_state_manager.get_active_prompts()) } if analytics_cached: analytics_state = self._analytics_cache[session_id] cache_info['analytics_last_refresh'] = analytics_state.get('last_refresh') cache_info['analytics_has_data'] = bool(analytics_state.get('project_stats', {}).get('total_conversations')) if conversation_cached: conversation_state = self._states[session_id] cache_info['conversation_length'] = len(conversation_state.get('conversation_state', [])) cache_info['chat_history_length'] = len(conversation_state.get('chat_history', [])) return cache_info def reset_analytics_state(self, session_id=None): """Reset analytics state for a session""" if session_id is None: session_id = self.get_session_id() with self._lock: if session_id in self._analytics_cache: del self._analytics_cache[session_id] def clear_all_states(self): """Clear all states - use with caution""" with self._lock: self._states.clear() self._analytics_cache.clear() self._ml_models_cache.clear() self._evaluation_cache.clear() self._prompt_state_manager.reset() logger.info("All global states cleared")