import uuid import logging from datetime import datetime from qdrant_client import QdrantClient from qdrant_client.models import ( PointStruct, Distance, VectorParams, Filter, FieldCondition, MatchValue, ) from app.core.config import settings from openai import OpenAI logger = logging.getLogger(__name__) openai_client = OpenAI(api_key=settings.OPENAI_API_KEY) class MemoryClient: client = QdrantClient( url=settings.QDRANT_HOST, api_key=settings.QDRANT_API_KEY, timeout=3.0 ) collection = settings.QDRANT_COLLECTION # 🔥 FIX: SAFE INIT (NO CRASH) @classmethod def ensure_collection(cls): try: existing = [c.name for c in cls.client.get_collections().collections] if cls.collection not in existing: cls.client.create_collection( collection_name=cls.collection, vectors_config=VectorParams( size=1536, distance=Distance.COSINE, ), ) except Exception as e: logger.warning(f"Memory init skipped: {e}") # EMBEDDING @classmethod def embed(cls, text: str): res = openai_client.embeddings.create( model="text-embedding-3-small", input=text[:1500], ) return res.data[0].embedding # SEARCH (FAST + SAFE) @classmethod def search_memory(cls, user_id: str, query: str, limit: int = 2): try: vector = cls.embed(query) res = cls.client.query_points( collection_name=cls.collection, query=vector, limit=limit, query_filter=Filter( must=[ FieldCondition( key="user_id", match=MatchValue(value=user_id), ) ] ), ) points = getattr(res, "points", []) return [ p.payload.get("text", "") for p in points if getattr(p, "payload", None) ] except Exception as e: logger.warning(f"Memory search failed (ignored): {e}") return [] # ADD MEMORY @classmethod def add_memory(cls, user_id: str, text: str): try: vector = cls.embed(text) cls.client.upsert( collection_name=cls.collection, points=[ PointStruct( id=str(uuid.uuid4()), vector=vector, payload={ "user_id": user_id, "text": text, "timestamp": datetime.utcnow().isoformat(), }, ) ], ) except Exception as e: logger.warning(f"Memory insert failed: {e}")