File size: 6,834 Bytes
461adca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""
User session state management.
"""
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from llama_index.core.schema import NodeWithScore


@dataclass
class UserSessionState:
    """
    Isolated state for each user session.

    This class encapsulates all the state that needs to be maintained
    separately for each user in a multi-user environment.
    """
    session_id: str
    legal_position_json: Optional[Dict[str, Any]] = None
    search_nodes: Optional[List[NodeWithScore]] = None
    custom_prompts: Dict[str, str] = field(default_factory=dict)
    created_at: datetime = field(default_factory=datetime.now)
    last_activity: datetime = field(default_factory=datetime.now)
    
    def update_activity(self) -> None:
        """Update the last activity timestamp."""
        self.last_activity = datetime.now()
    
    def is_expired(self, timeout_minutes: int) -> bool:
        """
        Check if the session has expired.
        
        Args:
            timeout_minutes: Session timeout in minutes
            
        Returns:
            True if session has expired
        """
        timeout = timedelta(minutes=timeout_minutes)
        return datetime.now() - self.last_activity > timeout
    
    def get_age_minutes(self) -> float:
        """
        Get the age of the session in minutes.
        
        Returns:
            Age in minutes since creation
        """
        return (datetime.now() - self.created_at).total_seconds() / 60
    
    def get_idle_minutes(self) -> float:
        """
        Get the idle time of the session in minutes.
        
        Returns:
            Idle time in minutes since last activity
        """
        return (datetime.now() - self.last_activity).total_seconds() / 60
    
    def clear_data(self) -> None:
        """Clear all user data but keep session metadata."""
        self.legal_position_json = None
        self.search_nodes = None
        self.custom_prompts = {}
        self.update_activity()
    
    def has_legal_position(self) -> bool:
        """Check if user has generated a legal position."""
        return self.legal_position_json is not None
    
    def has_search_results(self) -> bool:
        """Check if user has search results."""
        return self.search_nodes is not None and len(self.search_nodes) > 0

    def get_prompt(self, prompt_type: str, default_prompt: str) -> str:
        """
        Get custom prompt or default if not set.

        Args:
            prompt_type: Type of prompt ('system', 'legal_position', 'analysis')
            default_prompt: Default prompt value

        Returns:
            Custom prompt if set, otherwise default
        """
        return self.custom_prompts.get(prompt_type, default_prompt)

    def set_prompt(self, prompt_type: str, prompt_value: str) -> None:
        """
        Set custom prompt.

        Args:
            prompt_type: Type of prompt ('system', 'legal_position', 'analysis')
            prompt_value: Prompt text
        """
        self.custom_prompts[prompt_type] = prompt_value
        self.update_activity()

    def reset_prompts(self) -> None:
        """Reset all custom prompts to defaults."""
        self.custom_prompts = {}
        self.update_activity()
    
    def to_dict(self) -> Dict[str, Any]:
        """
        Convert session to dictionary for storage.
        
        Returns:
            Dictionary representation
        """
        # Convert NodeWithScore objects to serializable format
        search_nodes_data = None
        if self.search_nodes:
            search_nodes_data = [
                {
                    "node": {
                        "id": node.node.id_,
                        "text": node.node.text,
                        "metadata": node.node.metadata,
                    },
                    "score": node.score,
                }
                for node in self.search_nodes
            ]
        
        return {
            "session_id": self.session_id,
            "legal_position_json": self.legal_position_json,
            "search_nodes": search_nodes_data,
            "custom_prompts": self.custom_prompts,
            "created_at": self.created_at.isoformat(),
            "last_activity": self.last_activity.isoformat(),
        }
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'UserSessionState':
        """
        Create session from dictionary.
        
        Args:
            data: Dictionary representation
            
        Returns:
            UserSessionState instance
        """
        # Convert back to NodeWithScore objects
        search_nodes = None
        if data.get("search_nodes"):
            from llama_index.core.schema import Document, NodeWithScore
            
            search_nodes = []
            for item in data["search_nodes"]:
                node_data = item["node"]
                document = Document(
                    id_=node_data["id"],
                    text=node_data["text"],
                    metadata=node_data["metadata"],
                )
                node_with_score = NodeWithScore(
                    node=document,
                    score=item["score"]
                )
                search_nodes.append(node_with_score)
        
        return cls(
            session_id=data["session_id"],
            legal_position_json=data["legal_position_json"],
            search_nodes=search_nodes,
            custom_prompts=data.get("custom_prompts", {}),
            created_at=datetime.fromisoformat(data["created_at"]),
            last_activity=datetime.fromisoformat(data["last_activity"]),
        )
    
    def __str__(self) -> str:
        """String representation."""
        return (
            f"UserSessionState("
            f"session_id={self.session_id[:8]}..., "
            f"age={self.get_age_minutes():.1f}min, "
            f"idle={self.get_idle_minutes():.1f}min, "
            f"has_position={self.has_legal_position()}, "
            f"has_search={self.has_search_results()}"
            f")"
        )
    
    def __repr__(self) -> str:
        """Detailed string representation."""
        return self.__str__()


def generate_session_id() -> str:
    """
    Generate a unique session ID.
    
    Returns:
        Unique session identifier
    """
    return str(uuid.uuid4())


def create_empty_session(session_id: Optional[str] = None) -> UserSessionState:
    """
    Create an empty session.
    
    Args:
        session_id: Optional session ID, generated if not provided
        
    Returns:
        New empty UserSessionState
    """
    if session_id is None:
        session_id = generate_session_id()
    
    return UserSessionState(session_id=session_id)