""" Conversation Manager for Multi-turn VQA Manages conversation state, context, and pronoun resolution """ from dataclasses import dataclass, field from typing import Dict, List, Optional, Any from datetime import datetime, timedelta import uuid import re @dataclass class ConversationTurn: """Represents a single turn in a conversation""" question: str answer: str objects_detected: List[str] timestamp: datetime reasoning_chain: Optional[List[str]] = None model_used: Optional[str] = None @dataclass class ConversationSession: """Represents a complete conversation session""" session_id: str image_path: str history: List[ConversationTurn] = field(default_factory=list) current_objects: List[str] = field(default_factory=list) created_at: datetime = field(default_factory=datetime.now) last_activity: datetime = field(default_factory=datetime.now) def add_turn( self, question: str, answer: str, objects_detected: List[str], reasoning_chain: Optional[List[str]] = None, model_used: Optional[str] = None ): """Add a new turn to the conversation""" turn = ConversationTurn( question=question, answer=answer, objects_detected=objects_detected, timestamp=datetime.now(), reasoning_chain=reasoning_chain, model_used=model_used ) self.history.append(turn) if objects_detected: self.current_objects = objects_detected self.last_activity = datetime.now() def get_context_summary(self) -> str: """Get a summary of the conversation context""" if not self.history: return "No previous conversation" summary_parts = [] for i, turn in enumerate(self.history[-3:], 1): summary_parts.append(f"Turn {i}: Q: {turn.question} A: {turn.answer}") return " | ".join(summary_parts) def is_expired(self, timeout_minutes: int = 30) -> bool: """Check if session has expired""" expiry_time = self.last_activity + timedelta(minutes=timeout_minutes) return datetime.now() > expiry_time class ConversationManager: """ Manages multi-turn conversation sessions for VQA. Handles context retention, pronoun resolution, and session lifecycle. """ PRONOUNS = ['it', 'this', 'that', 'these', 'those', 'they', 'them'] def __init__(self, session_timeout_minutes: int = 30): """ Initialize conversation manager Args: session_timeout_minutes: Minutes before a session expires """ self.sessions: Dict[str, ConversationSession] = {} self.session_timeout = session_timeout_minutes print(f"โœ… Conversation Manager initialized (timeout: {session_timeout_minutes}min)") def create_session(self, image_path: str, session_id: Optional[str] = None) -> str: """ Create a new conversation session Args: image_path: Path to the image for this conversation session_id: Optional custom session ID (generates UUID if not provided) Returns: Session ID """ if session_id is None: session_id = str(uuid.uuid4()) session = ConversationSession( session_id=session_id, image_path=image_path ) self.sessions[session_id] = session return session_id def get_session(self, session_id: str) -> Optional[ConversationSession]: """ Get an existing session Args: session_id: Session ID to retrieve Returns: ConversationSession or None if not found/expired """ session = self.sessions.get(session_id) if session is None: return None if session.is_expired(self.session_timeout): self.delete_session(session_id) return None return session def get_or_create_session( self, session_id: Optional[str], image_path: str ) -> ConversationSession: """ Get existing session or create new one Args: session_id: Optional session ID image_path: Image path for new session Returns: ConversationSession """ if session_id: session = self.get_session(session_id) if session: return session new_id = self.create_session(image_path, session_id) return self.sessions[new_id] def add_turn( self, session_id: str, question: str, answer: str, objects_detected: List[str], reasoning_chain: Optional[List[str]] = None, model_used: Optional[str] = None ) -> bool: """ Add a turn to a conversation session Args: session_id: Session ID question: User's question answer: VQA answer objects_detected: List of detected objects reasoning_chain: Optional reasoning steps model_used: Optional model identifier Returns: True if successful, False if session not found """ session = self.get_session(session_id) if session is None: return False session.add_turn( question=question, answer=answer, objects_detected=objects_detected, reasoning_chain=reasoning_chain, model_used=model_used ) return True def resolve_references( self, question: str, session: ConversationSession ) -> str: """ Resolve pronouns and references in a question using conversation context. Args: question: User's question (may contain pronouns) session: Conversation session with context Returns: Question with pronouns resolved Example: Input: "Is it healthy?" Context: Previous object was "apple" Output: "Is apple healthy?" """ if not session.history: return question q_lower = question.lower() has_pronoun = any(pronoun in q_lower.split() for pronoun in self.PRONOUNS) if not has_pronoun: return question recent_objects = session.current_objects if not recent_objects: return question resolved = question if any(pronoun in q_lower.split() for pronoun in ['it', 'this', 'that']): primary_object = recent_objects[0] resolved = re.sub(r'\bit\b', primary_object, resolved, flags=re.IGNORECASE) resolved = re.sub(r'\bthis\b', primary_object, resolved, flags=re.IGNORECASE) resolved = re.sub(r'\bthat\b', primary_object, resolved, flags=re.IGNORECASE) if any(pronoun in q_lower.split() for pronoun in ['these', 'those', 'they', 'them']): objects_phrase = ', '.join(recent_objects) resolved = re.sub(r'\bthese\b', objects_phrase, resolved, flags=re.IGNORECASE) resolved = re.sub(r'\bthose\b', objects_phrase, resolved, flags=re.IGNORECASE) resolved = re.sub(r'\bthey\b', objects_phrase, resolved, flags=re.IGNORECASE) resolved = re.sub(r'\bthem\b', objects_phrase, resolved, flags=re.IGNORECASE) return resolved def get_context_for_question( self, session_id: str, question: str ) -> Dict[str, Any]: """ Get relevant context for answering a question Args: session_id: Session ID question: Current question Returns: Dict with context information """ session = self.get_session(session_id) if session is None: return { 'has_context': False, 'turn_number': 0, 'previous_objects': [], 'previous_questions': [] } return { 'has_context': len(session.history) > 0, 'turn_number': len(session.history) + 1, 'previous_objects': session.current_objects, 'previous_questions': [turn.question for turn in session.history[-3:]], 'previous_answers': [turn.answer for turn in session.history[-3:]], 'context_summary': session.get_context_summary() } def get_history(self, session_id: str) -> Optional[List[Dict[str, Any]]]: """ Get conversation history for a session Args: session_id: Session ID Returns: List of turn dictionaries or None if session not found """ session = self.get_session(session_id) if session is None: return None history = [] for turn in session.history: history.append({ 'question': turn.question, 'answer': turn.answer, 'objects_detected': turn.objects_detected, 'timestamp': turn.timestamp.isoformat(), 'reasoning_chain': turn.reasoning_chain, 'model_used': turn.model_used }) return history def delete_session(self, session_id: str) -> bool: """ Delete a conversation session Args: session_id: Session ID to delete Returns: True if deleted, False if not found """ if session_id in self.sessions: del self.sessions[session_id] return True return False def cleanup_expired_sessions(self): """Remove all expired sessions""" expired_ids = [ sid for sid, session in self.sessions.items() if session.is_expired(self.session_timeout) ] for sid in expired_ids: self.delete_session(sid) return len(expired_ids) def get_active_sessions_count(self) -> int: """Get count of active (non-expired) sessions""" self.cleanup_expired_sessions() return len(self.sessions) if __name__ == "__main__": print("=" * 80) print("๐Ÿงช Testing Conversation Manager") print("=" * 80) manager = ConversationManager(session_timeout_minutes=30) print("\n๐Ÿ“ Test 1: Multi-turn conversation") session_id = manager.create_session("test_image.jpg") print(f"Created session: {session_id}") manager.add_turn( session_id=session_id, question="What is this?", answer="apple", objects_detected=["apple"] ) print("Turn 1: 'What is this?' โ†’ 'apple'") session = manager.get_session(session_id) question_2 = "Is it healthy?" resolved_2 = manager.resolve_references(question_2, session) print(f"Turn 2: '{question_2}' โ†’ Resolved: '{resolved_2}'") manager.add_turn( session_id=session_id, question=question_2, answer="Yes, apples are healthy", objects_detected=["apple"] ) question_3 = "What color is it?" resolved_3 = manager.resolve_references(question_3, session) print(f"Turn 3: '{question_3}' โ†’ Resolved: '{resolved_3}'") print("\n๐Ÿ“ Test 2: Context retrieval") context = manager.get_context_for_question(session_id, "Another question") print(f"Turn number: {context['turn_number']}") print(f"Previous objects: {context['previous_objects']}") print(f"Context summary: {context['context_summary']}") print("\n๐Ÿ“ Test 3: Conversation history") history = manager.get_history(session_id) for i, turn in enumerate(history, 1): print(f" Turn {i}: Q: {turn['question']} | A: {turn['answer']}") print("\n" + "=" * 80) print("โœ… Tests completed!")