Spaces:
Sleeping
Sleeping
File size: 8,079 Bytes
e272f4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
import os
import pickle
import logging
from typing import Dict, List, Any
from app.config import Config
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
class SessionStorage:
"""
Manages session persistence using a hybrid storage approach:
- Stores session metadata in local pickle files
- Stores vector data in Qdrant collections
- Maintains connection between the two
"""
def __init__(self):
"""
Initialize the session storage system.
Sets up Qdrant client connection and ensures storage directory exists.
"""
try:
Config.create_storage_dir()
self.logger = logging.getLogger(__name__)
# Initialize Qdrant client with configuration from Config
self.qdrant_client = QdrantClient(
host=Config.QDRANT_HOST,
port=Config.QDRANT_PORT,
prefer_grpc=True # Use gRPC for better performance
)
self.logger.info("Qdrant client initialized")
except Exception as e:
self.logger.error(f"Storage initialization error: {str(e)}")
raise RuntimeError("Storage initialization failed") from e
def get_session_path(self, session_id: str) -> str:
"""
Get the filesystem path for a session's pickle file.
Args:
session_id: Unique session identifier
Returns:
str: Full path to session file
"""
return os.path.join(Config.STORAGE_DIR, f"{session_id}.pkl")
def save_session(self, session_id: str, data: Dict):
"""
Persist session data to disk (excluding Qdrant references).
Args:
session_id: Session identifier
data: Session data dictionary
"""
session_path = self.get_session_path(session_id)
# Remove Qdrant collection reference before saving to avoid serialization issues
data = data.copy()
if 'qdrant_collection' in data:
del data['qdrant_collection']
with open(session_path, 'wb') as f:
pickle.dump(data, f)
def load_session(self, session_id: str) -> Dict:
"""
Load session data from disk and reconnect to Qdrant collection.
Args:
session_id: Session identifier
Returns:
Dict: Session data with restored Qdrant collection reference
"""
session_path = self.get_session_path(session_id)
if not os.path.exists(session_path):
return None
with open(session_path, 'rb') as f:
data = pickle.load(f)
# Restore Qdrant collection reference
collection_name = f"session_{session_id}"
data['qdrant_collection'] = collection_name
# Ensure collection exists in Qdrant (create if missing)
if not self.qdrant_client.collection_exists(collection_name):
self.logger.warning(f"Qdrant collection {collection_name} missing, creating new")
self.qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=Config.EMBEDDING_SIZE,
distance=Distance.COSINE
)
)
return data
def delete_session(self, session_id: str):
"""
Completely remove a session (both disk and Qdrant storage).
Args:
session_id: Session identifier to delete
"""
session_path = self.get_session_path(session_id)
# Delete Qdrant collection first
collection_name = f"session_{session_id}"
try:
self.qdrant_client.delete_collection(collection_name)
self.logger.info(f"Deleted Qdrant collection: {collection_name}")
except Exception as e:
self.logger.error(f"Error deleting Qdrant collection: {str(e)}")
# Delete session file
if os.path.exists(session_path):
os.remove(session_path)
class QdrantStorage:
"""
Manages vector storage operations using Qdrant.
Handles collection management and vector operations.
"""
def __init__(self, collection_name: str, vector_size: int,
host: str = Config.QDRANT_HOST, port: int = Config.QDRANT_PORT):
"""
Initialize Qdrant storage for a specific collection.
Args:
collection_name: Name of the Qdrant collection
vector_size: Dimensionality of vectors to store
host: Qdrant server host (default from Config)
port: Qdrant server port (default from Config)
"""
self.logger = logging.getLogger(__name__)
self.collection_name = collection_name
self.vector_size = vector_size
# Initialize Qdrant client with gRPC preference
self.qdrant = QdrantClient(host=host, port=port, prefer_grpc=True)
self._ensure_collection()
def _ensure_collection(self):
"""
Ensure the collection exists in Qdrant.
Creates it if missing, otherwise verifies configuration.
"""
try:
collection_info = self.qdrant.get_collection(self.collection_name)
if collection_info.vectors_count > 0:
self.logger.info(f"Using existing Qdrant collection: {self.collection_name}")
except Exception:
self.logger.info(f"Creating Qdrant collection: {self.collection_name}")
self.qdrant.recreate_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=self.vector_size,
distance=Distance.COSINE # Using cosine similarity
)
)
def add_vectors(self, vectors: List[List[float]], payloads: List[Dict[str, Any]], offset: int = 0):
"""
Add vectors and associated metadata to the collection.
Args:
vectors: List of vector embeddings
payloads: List of metadata dictionaries
offset: Starting ID for new points (default 0)
"""
points = [
PointStruct(
id=offset + idx, # Sequential IDs with optional offset
vector=vector,
payload=payload
)
for idx, (vector, payload) in enumerate(zip(vectors, payloads))
]
self.qdrant.upsert(
collection_name=self.collection_name,
points=points,
wait=True # Ensure immediate persistence
)
self.logger.info(f"Added {len(points)} vectors to Qdrant collection '{self.collection_name}'")
def search(self, query_vector: List[float], session_id: str, limit: int = 5):
"""
Search the collection for similar vectors, filtered by session.
Args:
query_vector: The vector to compare against
session_id: Session identifier to filter results
limit: Maximum number of results to return
Returns:
List[Dict]: Search results with scores and metadata
"""
# Add session filter to ensure only current session results
results = self.qdrant.search(
collection_name=self.collection_name,
query_vector=query_vector,
query_filter=Filter(
must=[
FieldCondition(
key="session_id",
match=MatchValue(value=session_id)
)
]
),
limit=limit
)
return [
{
"id": hit.id,
"score": hit.score,
"payload": hit.payload
}
for hit in results
] |