File size: 11,802 Bytes
bb8f662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
"""
Conversation Manager for Multi-turn VQA
Manages conversation state, context, and pronoun resolution
"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
import uuid
import re
@dataclass
class ConversationTurn:
    """Represents a single turn in a conversation"""
    question: str
    answer: str
    objects_detected: List[str]
    timestamp: datetime
    reasoning_chain: Optional[List[str]] = None
    model_used: Optional[str] = None
@dataclass
class ConversationSession:
    """Represents a complete conversation session"""
    session_id: str
    image_path: str
    history: List[ConversationTurn] = field(default_factory=list)
    current_objects: List[str] = field(default_factory=list)
    created_at: datetime = field(default_factory=datetime.now)
    last_activity: datetime = field(default_factory=datetime.now)
    def add_turn(
        self,
        question: str,
        answer: str,
        objects_detected: List[str],
        reasoning_chain: Optional[List[str]] = None,
        model_used: Optional[str] = None
    ):
        """Add a new turn to the conversation"""
        turn = ConversationTurn(
            question=question,
            answer=answer,
            objects_detected=objects_detected,
            timestamp=datetime.now(),
            reasoning_chain=reasoning_chain,
            model_used=model_used
        )
        self.history.append(turn)
        if objects_detected:
            self.current_objects = objects_detected
        self.last_activity = datetime.now()
    def get_context_summary(self) -> str:
        """Get a summary of the conversation context"""
        if not self.history:
            return "No previous conversation"
        summary_parts = []
        for i, turn in enumerate(self.history[-3:], 1):
            summary_parts.append(f"Turn {i}: Q: {turn.question} A: {turn.answer}")
        return " | ".join(summary_parts)
    def is_expired(self, timeout_minutes: int = 30) -> bool:
        """Check if session has expired"""
        expiry_time = self.last_activity + timedelta(minutes=timeout_minutes)
        return datetime.now() > expiry_time
class ConversationManager:
    """
    Manages multi-turn conversation sessions for VQA.
    Handles context retention, pronoun resolution, and session lifecycle.
    """
    PRONOUNS = ['it', 'this', 'that', 'these', 'those', 'they', 'them']
    def __init__(self, session_timeout_minutes: int = 30):
        """
        Initialize conversation manager
        Args:
            session_timeout_minutes: Minutes before a session expires
        """
        self.sessions: Dict[str, ConversationSession] = {}
        self.session_timeout = session_timeout_minutes
        print(f"βœ… Conversation Manager initialized (timeout: {session_timeout_minutes}min)")
    def create_session(self, image_path: str, session_id: Optional[str] = None) -> str:
        """
        Create a new conversation session
        Args:
            image_path: Path to the image for this conversation
            session_id: Optional custom session ID (generates UUID if not provided)
        Returns:
            Session ID
        """
        if session_id is None:
            session_id = str(uuid.uuid4())
        session = ConversationSession(
            session_id=session_id,
            image_path=image_path
        )
        self.sessions[session_id] = session
        return session_id
    def get_session(self, session_id: str) -> Optional[ConversationSession]:
        """
        Get an existing session
        Args:
            session_id: Session ID to retrieve
        Returns:
            ConversationSession or None if not found/expired
        """
        session = self.sessions.get(session_id)
        if session is None:
            return None
        if session.is_expired(self.session_timeout):
            self.delete_session(session_id)
            return None
        return session
    def get_or_create_session(
        self,
        session_id: Optional[str],
        image_path: str
    ) -> ConversationSession:
        """
        Get existing session or create new one
        Args:
            session_id: Optional session ID
            image_path: Image path for new session
        Returns:
            ConversationSession
        """
        if session_id:
            session = self.get_session(session_id)
            if session:
                return session
        new_id = self.create_session(image_path, session_id)
        return self.sessions[new_id]
    def add_turn(
        self,
        session_id: str,
        question: str,
        answer: str,
        objects_detected: List[str],
        reasoning_chain: Optional[List[str]] = None,
        model_used: Optional[str] = None
    ) -> bool:
        """
        Add a turn to a conversation session
        Args:
            session_id: Session ID
            question: User's question
            answer: VQA answer
            objects_detected: List of detected objects
            reasoning_chain: Optional reasoning steps
            model_used: Optional model identifier
        Returns:
            True if successful, False if session not found
        """
        session = self.get_session(session_id)
        if session is None:
            return False
        session.add_turn(
            question=question,
            answer=answer,
            objects_detected=objects_detected,
            reasoning_chain=reasoning_chain,
            model_used=model_used
        )
        return True
    def resolve_references(
        self,
        question: str,
        session: ConversationSession
    ) -> str:
        """
        Resolve pronouns and references in a question using conversation context.
        Args:
            question: User's question (may contain pronouns)
            session: Conversation session with context
        Returns:
            Question with pronouns resolved
        Example:
            Input: "Is it healthy?"
            Context: Previous object was "apple"
            Output: "Is apple healthy?"
        """
        if not session.history:
            return question
        q_lower = question.lower()
        has_pronoun = any(pronoun in q_lower.split() for pronoun in self.PRONOUNS)
        if not has_pronoun:
            return question
        recent_objects = session.current_objects
        if not recent_objects:
            return question
        resolved = question
        if any(pronoun in q_lower.split() for pronoun in ['it', 'this', 'that']):
            primary_object = recent_objects[0]
            resolved = re.sub(r'\bit\b', primary_object, resolved, flags=re.IGNORECASE)
            resolved = re.sub(r'\bthis\b', primary_object, resolved, flags=re.IGNORECASE)
            resolved = re.sub(r'\bthat\b', primary_object, resolved, flags=re.IGNORECASE)
        if any(pronoun in q_lower.split() for pronoun in ['these', 'those', 'they', 'them']):
            objects_phrase = ', '.join(recent_objects)
            resolved = re.sub(r'\bthese\b', objects_phrase, resolved, flags=re.IGNORECASE)
            resolved = re.sub(r'\bthose\b', objects_phrase, resolved, flags=re.IGNORECASE)
            resolved = re.sub(r'\bthey\b', objects_phrase, resolved, flags=re.IGNORECASE)
            resolved = re.sub(r'\bthem\b', objects_phrase, resolved, flags=re.IGNORECASE)
        return resolved
    def get_context_for_question(
        self,
        session_id: str,
        question: str
    ) -> Dict[str, Any]:
        """
        Get relevant context for answering a question
        Args:
            session_id: Session ID
            question: Current question
        Returns:
            Dict with context information
        """
        session = self.get_session(session_id)
        if session is None:
            return {
                'has_context': False,
                'turn_number': 0,
                'previous_objects': [],
                'previous_questions': []
            }
        return {
            'has_context': len(session.history) > 0,
            'turn_number': len(session.history) + 1,
            'previous_objects': session.current_objects,
            'previous_questions': [turn.question for turn in session.history[-3:]],
            'previous_answers': [turn.answer for turn in session.history[-3:]],
            'context_summary': session.get_context_summary()
        }
    def get_history(self, session_id: str) -> Optional[List[Dict[str, Any]]]:
        """
        Get conversation history for a session
        Args:
            session_id: Session ID
        Returns:
            List of turn dictionaries or None if session not found
        """
        session = self.get_session(session_id)
        if session is None:
            return None
        history = []
        for turn in session.history:
            history.append({
                'question': turn.question,
                'answer': turn.answer,
                'objects_detected': turn.objects_detected,
                'timestamp': turn.timestamp.isoformat(),
                'reasoning_chain': turn.reasoning_chain,
                'model_used': turn.model_used
            })
        return history
    def delete_session(self, session_id: str) -> bool:
        """
        Delete a conversation session
        Args:
            session_id: Session ID to delete
        Returns:
            True if deleted, False if not found
        """
        if session_id in self.sessions:
            del self.sessions[session_id]
            return True
        return False
    def cleanup_expired_sessions(self):
        """Remove all expired sessions"""
        expired_ids = [
            sid for sid, session in self.sessions.items()
            if session.is_expired(self.session_timeout)
        ]
        for sid in expired_ids:
            self.delete_session(sid)
        return len(expired_ids)
    def get_active_sessions_count(self) -> int:
        """Get count of active (non-expired) sessions"""
        self.cleanup_expired_sessions()
        return len(self.sessions)
if __name__ == "__main__":
    print("=" * 80)
    print("πŸ§ͺ Testing Conversation Manager")
    print("=" * 80)
    manager = ConversationManager(session_timeout_minutes=30)
    print("\nπŸ“ Test 1: Multi-turn conversation")
    session_id = manager.create_session("test_image.jpg")
    print(f"Created session: {session_id}")
    manager.add_turn(
        session_id=session_id,
        question="What is this?",
        answer="apple",
        objects_detected=["apple"]
    )
    print("Turn 1: 'What is this?' β†’ 'apple'")
    session = manager.get_session(session_id)
    question_2 = "Is it healthy?"
    resolved_2 = manager.resolve_references(question_2, session)
    print(f"Turn 2: '{question_2}' β†’ Resolved: '{resolved_2}'")
    manager.add_turn(
        session_id=session_id,
        question=question_2,
        answer="Yes, apples are healthy",
        objects_detected=["apple"]
    )
    question_3 = "What color is it?"
    resolved_3 = manager.resolve_references(question_3, session)
    print(f"Turn 3: '{question_3}' β†’ Resolved: '{resolved_3}'")
    print("\nπŸ“ Test 2: Context retrieval")
    context = manager.get_context_for_question(session_id, "Another question")
    print(f"Turn number: {context['turn_number']}")
    print(f"Previous objects: {context['previous_objects']}")
    print(f"Context summary: {context['context_summary']}")
    print("\nπŸ“ Test 3: Conversation history")
    history = manager.get_history(session_id)
    for i, turn in enumerate(history, 1):
        print(f"  Turn {i}: Q: {turn['question']} | A: {turn['answer']}")
    print("\n" + "=" * 80)
    print("βœ… Tests completed!")