Spaces:
Sleeping
Sleeping
| """ | |
| Conversation Manager for Multi-turn VQA | |
| Manages conversation state, context, and pronoun resolution | |
| """ | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional, Any | |
| from datetime import datetime, timedelta | |
| import uuid | |
| import re | |
| class ConversationTurn: | |
| """Represents a single turn in a conversation""" | |
| question: str | |
| answer: str | |
| objects_detected: List[str] | |
| timestamp: datetime | |
| reasoning_chain: Optional[List[str]] = None | |
| model_used: Optional[str] = None | |
| class ConversationSession: | |
| """Represents a complete conversation session""" | |
| session_id: str | |
| image_path: str | |
| history: List[ConversationTurn] = field(default_factory=list) | |
| current_objects: List[str] = field(default_factory=list) | |
| created_at: datetime = field(default_factory=datetime.now) | |
| last_activity: datetime = field(default_factory=datetime.now) | |
| def add_turn( | |
| self, | |
| question: str, | |
| answer: str, | |
| objects_detected: List[str], | |
| reasoning_chain: Optional[List[str]] = None, | |
| model_used: Optional[str] = None | |
| ): | |
| """Add a new turn to the conversation""" | |
| turn = ConversationTurn( | |
| question=question, | |
| answer=answer, | |
| objects_detected=objects_detected, | |
| timestamp=datetime.now(), | |
| reasoning_chain=reasoning_chain, | |
| model_used=model_used | |
| ) | |
| self.history.append(turn) | |
| if objects_detected: | |
| self.current_objects = objects_detected | |
| self.last_activity = datetime.now() | |
| def get_context_summary(self) -> str: | |
| """Get a summary of the conversation context""" | |
| if not self.history: | |
| return "No previous conversation" | |
| summary_parts = [] | |
| for i, turn in enumerate(self.history[-3:], 1): | |
| summary_parts.append(f"Turn {i}: Q: {turn.question} A: {turn.answer}") | |
| return " | ".join(summary_parts) | |
| def is_expired(self, timeout_minutes: int = 30) -> bool: | |
| """Check if session has expired""" | |
| expiry_time = self.last_activity + timedelta(minutes=timeout_minutes) | |
| return datetime.now() > expiry_time | |
| class ConversationManager: | |
| """ | |
| Manages multi-turn conversation sessions for VQA. | |
| Handles context retention, pronoun resolution, and session lifecycle. | |
| """ | |
| PRONOUNS = ['it', 'this', 'that', 'these', 'those', 'they', 'them'] | |
| def __init__(self, session_timeout_minutes: int = 30): | |
| """ | |
| Initialize conversation manager | |
| Args: | |
| session_timeout_minutes: Minutes before a session expires | |
| """ | |
| self.sessions: Dict[str, ConversationSession] = {} | |
| self.session_timeout = session_timeout_minutes | |
| print(f"β Conversation Manager initialized (timeout: {session_timeout_minutes}min)") | |
| def create_session(self, image_path: str, session_id: Optional[str] = None) -> str: | |
| """ | |
| Create a new conversation session | |
| Args: | |
| image_path: Path to the image for this conversation | |
| session_id: Optional custom session ID (generates UUID if not provided) | |
| Returns: | |
| Session ID | |
| """ | |
| if session_id is None: | |
| session_id = str(uuid.uuid4()) | |
| session = ConversationSession( | |
| session_id=session_id, | |
| image_path=image_path | |
| ) | |
| self.sessions[session_id] = session | |
| return session_id | |
| def get_session(self, session_id: str) -> Optional[ConversationSession]: | |
| """ | |
| Get an existing session | |
| Args: | |
| session_id: Session ID to retrieve | |
| Returns: | |
| ConversationSession or None if not found/expired | |
| """ | |
| session = self.sessions.get(session_id) | |
| if session is None: | |
| return None | |
| if session.is_expired(self.session_timeout): | |
| self.delete_session(session_id) | |
| return None | |
| return session | |
| def get_or_create_session( | |
| self, | |
| session_id: Optional[str], | |
| image_path: str | |
| ) -> ConversationSession: | |
| """ | |
| Get existing session or create new one | |
| Args: | |
| session_id: Optional session ID | |
| image_path: Image path for new session | |
| Returns: | |
| ConversationSession | |
| """ | |
| if session_id: | |
| session = self.get_session(session_id) | |
| if session: | |
| return session | |
| new_id = self.create_session(image_path, session_id) | |
| return self.sessions[new_id] | |
| def add_turn( | |
| self, | |
| session_id: str, | |
| question: str, | |
| answer: str, | |
| objects_detected: List[str], | |
| reasoning_chain: Optional[List[str]] = None, | |
| model_used: Optional[str] = None | |
| ) -> bool: | |
| """ | |
| Add a turn to a conversation session | |
| Args: | |
| session_id: Session ID | |
| question: User's question | |
| answer: VQA answer | |
| objects_detected: List of detected objects | |
| reasoning_chain: Optional reasoning steps | |
| model_used: Optional model identifier | |
| Returns: | |
| True if successful, False if session not found | |
| """ | |
| session = self.get_session(session_id) | |
| if session is None: | |
| return False | |
| session.add_turn( | |
| question=question, | |
| answer=answer, | |
| objects_detected=objects_detected, | |
| reasoning_chain=reasoning_chain, | |
| model_used=model_used | |
| ) | |
| return True | |
| def resolve_references( | |
| self, | |
| question: str, | |
| session: ConversationSession | |
| ) -> str: | |
| """ | |
| Resolve pronouns and references in a question using conversation context. | |
| Args: | |
| question: User's question (may contain pronouns) | |
| session: Conversation session with context | |
| Returns: | |
| Question with pronouns resolved | |
| Example: | |
| Input: "Is it healthy?" | |
| Context: Previous object was "apple" | |
| Output: "Is apple healthy?" | |
| """ | |
| if not session.history: | |
| return question | |
| q_lower = question.lower() | |
| has_pronoun = any(pronoun in q_lower.split() for pronoun in self.PRONOUNS) | |
| if not has_pronoun: | |
| return question | |
| recent_objects = session.current_objects | |
| if not recent_objects: | |
| return question | |
| resolved = question | |
| if any(pronoun in q_lower.split() for pronoun in ['it', 'this', 'that']): | |
| primary_object = recent_objects[0] | |
| resolved = re.sub(r'\bit\b', primary_object, resolved, flags=re.IGNORECASE) | |
| resolved = re.sub(r'\bthis\b', primary_object, resolved, flags=re.IGNORECASE) | |
| resolved = re.sub(r'\bthat\b', primary_object, resolved, flags=re.IGNORECASE) | |
| if any(pronoun in q_lower.split() for pronoun in ['these', 'those', 'they', 'them']): | |
| objects_phrase = ', '.join(recent_objects) | |
| resolved = re.sub(r'\bthese\b', objects_phrase, resolved, flags=re.IGNORECASE) | |
| resolved = re.sub(r'\bthose\b', objects_phrase, resolved, flags=re.IGNORECASE) | |
| resolved = re.sub(r'\bthey\b', objects_phrase, resolved, flags=re.IGNORECASE) | |
| resolved = re.sub(r'\bthem\b', objects_phrase, resolved, flags=re.IGNORECASE) | |
| return resolved | |
| def get_context_for_question( | |
| self, | |
| session_id: str, | |
| question: str | |
| ) -> Dict[str, Any]: | |
| """ | |
| Get relevant context for answering a question | |
| Args: | |
| session_id: Session ID | |
| question: Current question | |
| Returns: | |
| Dict with context information | |
| """ | |
| session = self.get_session(session_id) | |
| if session is None: | |
| return { | |
| 'has_context': False, | |
| 'turn_number': 0, | |
| 'previous_objects': [], | |
| 'previous_questions': [] | |
| } | |
| return { | |
| 'has_context': len(session.history) > 0, | |
| 'turn_number': len(session.history) + 1, | |
| 'previous_objects': session.current_objects, | |
| 'previous_questions': [turn.question for turn in session.history[-3:]], | |
| 'previous_answers': [turn.answer for turn in session.history[-3:]], | |
| 'context_summary': session.get_context_summary() | |
| } | |
| def get_history(self, session_id: str) -> Optional[List[Dict[str, Any]]]: | |
| """ | |
| Get conversation history for a session | |
| Args: | |
| session_id: Session ID | |
| Returns: | |
| List of turn dictionaries or None if session not found | |
| """ | |
| session = self.get_session(session_id) | |
| if session is None: | |
| return None | |
| history = [] | |
| for turn in session.history: | |
| history.append({ | |
| 'question': turn.question, | |
| 'answer': turn.answer, | |
| 'objects_detected': turn.objects_detected, | |
| 'timestamp': turn.timestamp.isoformat(), | |
| 'reasoning_chain': turn.reasoning_chain, | |
| 'model_used': turn.model_used | |
| }) | |
| return history | |
| def delete_session(self, session_id: str) -> bool: | |
| """ | |
| Delete a conversation session | |
| Args: | |
| session_id: Session ID to delete | |
| Returns: | |
| True if deleted, False if not found | |
| """ | |
| if session_id in self.sessions: | |
| del self.sessions[session_id] | |
| return True | |
| return False | |
| def cleanup_expired_sessions(self): | |
| """Remove all expired sessions""" | |
| expired_ids = [ | |
| sid for sid, session in self.sessions.items() | |
| if session.is_expired(self.session_timeout) | |
| ] | |
| for sid in expired_ids: | |
| self.delete_session(sid) | |
| return len(expired_ids) | |
| def get_active_sessions_count(self) -> int: | |
| """Get count of active (non-expired) sessions""" | |
| self.cleanup_expired_sessions() | |
| return len(self.sessions) | |
| if __name__ == "__main__": | |
| print("=" * 80) | |
| print("π§ͺ Testing Conversation Manager") | |
| print("=" * 80) | |
| manager = ConversationManager(session_timeout_minutes=30) | |
| print("\nπ Test 1: Multi-turn conversation") | |
| session_id = manager.create_session("test_image.jpg") | |
| print(f"Created session: {session_id}") | |
| manager.add_turn( | |
| session_id=session_id, | |
| question="What is this?", | |
| answer="apple", | |
| objects_detected=["apple"] | |
| ) | |
| print("Turn 1: 'What is this?' β 'apple'") | |
| session = manager.get_session(session_id) | |
| question_2 = "Is it healthy?" | |
| resolved_2 = manager.resolve_references(question_2, session) | |
| print(f"Turn 2: '{question_2}' β Resolved: '{resolved_2}'") | |
| manager.add_turn( | |
| session_id=session_id, | |
| question=question_2, | |
| answer="Yes, apples are healthy", | |
| objects_detected=["apple"] | |
| ) | |
| question_3 = "What color is it?" | |
| resolved_3 = manager.resolve_references(question_3, session) | |
| print(f"Turn 3: '{question_3}' β Resolved: '{resolved_3}'") | |
| print("\nπ Test 2: Context retrieval") | |
| context = manager.get_context_for_question(session_id, "Another question") | |
| print(f"Turn number: {context['turn_number']}") | |
| print(f"Previous objects: {context['previous_objects']}") | |
| print(f"Context summary: {context['context_summary']}") | |
| print("\nπ Test 3: Conversation history") | |
| history = manager.get_history(session_id) | |
| for i, turn in enumerate(history, 1): | |
| print(f" Turn {i}: Q: {turn['question']} | A: {turn['answer']}") | |
| print("\n" + "=" * 80) | |
| print("β Tests completed!") |