File size: 6,834 Bytes
461adca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | """
User session state management.
"""
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from llama_index.core.schema import NodeWithScore
@dataclass
class UserSessionState:
"""
Isolated state for each user session.
This class encapsulates all the state that needs to be maintained
separately for each user in a multi-user environment.
"""
session_id: str
legal_position_json: Optional[Dict[str, Any]] = None
search_nodes: Optional[List[NodeWithScore]] = None
custom_prompts: Dict[str, str] = field(default_factory=dict)
created_at: datetime = field(default_factory=datetime.now)
last_activity: datetime = field(default_factory=datetime.now)
def update_activity(self) -> None:
"""Update the last activity timestamp."""
self.last_activity = datetime.now()
def is_expired(self, timeout_minutes: int) -> bool:
"""
Check if the session has expired.
Args:
timeout_minutes: Session timeout in minutes
Returns:
True if session has expired
"""
timeout = timedelta(minutes=timeout_minutes)
return datetime.now() - self.last_activity > timeout
def get_age_minutes(self) -> float:
"""
Get the age of the session in minutes.
Returns:
Age in minutes since creation
"""
return (datetime.now() - self.created_at).total_seconds() / 60
def get_idle_minutes(self) -> float:
"""
Get the idle time of the session in minutes.
Returns:
Idle time in minutes since last activity
"""
return (datetime.now() - self.last_activity).total_seconds() / 60
def clear_data(self) -> None:
"""Clear all user data but keep session metadata."""
self.legal_position_json = None
self.search_nodes = None
self.custom_prompts = {}
self.update_activity()
def has_legal_position(self) -> bool:
"""Check if user has generated a legal position."""
return self.legal_position_json is not None
def has_search_results(self) -> bool:
"""Check if user has search results."""
return self.search_nodes is not None and len(self.search_nodes) > 0
def get_prompt(self, prompt_type: str, default_prompt: str) -> str:
"""
Get custom prompt or default if not set.
Args:
prompt_type: Type of prompt ('system', 'legal_position', 'analysis')
default_prompt: Default prompt value
Returns:
Custom prompt if set, otherwise default
"""
return self.custom_prompts.get(prompt_type, default_prompt)
def set_prompt(self, prompt_type: str, prompt_value: str) -> None:
"""
Set custom prompt.
Args:
prompt_type: Type of prompt ('system', 'legal_position', 'analysis')
prompt_value: Prompt text
"""
self.custom_prompts[prompt_type] = prompt_value
self.update_activity()
def reset_prompts(self) -> None:
"""Reset all custom prompts to defaults."""
self.custom_prompts = {}
self.update_activity()
def to_dict(self) -> Dict[str, Any]:
"""
Convert session to dictionary for storage.
Returns:
Dictionary representation
"""
# Convert NodeWithScore objects to serializable format
search_nodes_data = None
if self.search_nodes:
search_nodes_data = [
{
"node": {
"id": node.node.id_,
"text": node.node.text,
"metadata": node.node.metadata,
},
"score": node.score,
}
for node in self.search_nodes
]
return {
"session_id": self.session_id,
"legal_position_json": self.legal_position_json,
"search_nodes": search_nodes_data,
"custom_prompts": self.custom_prompts,
"created_at": self.created_at.isoformat(),
"last_activity": self.last_activity.isoformat(),
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'UserSessionState':
"""
Create session from dictionary.
Args:
data: Dictionary representation
Returns:
UserSessionState instance
"""
# Convert back to NodeWithScore objects
search_nodes = None
if data.get("search_nodes"):
from llama_index.core.schema import Document, NodeWithScore
search_nodes = []
for item in data["search_nodes"]:
node_data = item["node"]
document = Document(
id_=node_data["id"],
text=node_data["text"],
metadata=node_data["metadata"],
)
node_with_score = NodeWithScore(
node=document,
score=item["score"]
)
search_nodes.append(node_with_score)
return cls(
session_id=data["session_id"],
legal_position_json=data["legal_position_json"],
search_nodes=search_nodes,
custom_prompts=data.get("custom_prompts", {}),
created_at=datetime.fromisoformat(data["created_at"]),
last_activity=datetime.fromisoformat(data["last_activity"]),
)
def __str__(self) -> str:
"""String representation."""
return (
f"UserSessionState("
f"session_id={self.session_id[:8]}..., "
f"age={self.get_age_minutes():.1f}min, "
f"idle={self.get_idle_minutes():.1f}min, "
f"has_position={self.has_legal_position()}, "
f"has_search={self.has_search_results()}"
f")"
)
def __repr__(self) -> str:
"""Detailed string representation."""
return self.__str__()
def generate_session_id() -> str:
"""
Generate a unique session ID.
Returns:
Unique session identifier
"""
return str(uuid.uuid4())
def create_empty_session(session_id: Optional[str] = None) -> UserSessionState:
"""
Create an empty session.
Args:
session_id: Optional session ID, generated if not provided
Returns:
New empty UserSessionState
"""
if session_id is None:
session_id = generate_session_id()
return UserSessionState(session_id=session_id)
|