from __future__ import annotations from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession from ..config import settings from ..models import ChatMemory, PendingAction from .inventory import now_iso SEMANTIC_HINTS = ("prefiero", "quiero", "no quiero", "recuerda", "me gusta", "no uses", "sin comandos") async def remember_message(session: AsyncSession, chat_id: str, role: str, content: str, kind: str = "short") -> None: content = (content or "").strip() if not content: return session.add( ChatMemory( chat_id=str(chat_id), kind=kind, role=role, content=content[:4000], created_at=now_iso(), ) ) await session.flush() if kind == "short": result = await session.execute( select(ChatMemory) .where(ChatMemory.chat_id == str(chat_id), ChatMemory.kind == "short") .order_by(ChatMemory.id.desc()) ) messages = list(result.scalars().all()) for stale in messages[settings.memory_short_limit * 2 :]: await session.delete(stale) async def capture_semantic_memory(session: AsyncSession, chat_id: str, user_text: str) -> None: lowered = (user_text or "").lower() if not any(hint in lowered for hint in SEMANTIC_HINTS): return normalized = user_text.strip()[:400] result = await session.execute( select(ChatMemory) .where(ChatMemory.chat_id == str(chat_id), ChatMemory.kind == "semantic") .order_by(ChatMemory.id.desc()) .limit(5) ) if normalized in {item.content for item in result.scalars().all()}: return await remember_message(session, chat_id, "user", normalized, kind="semantic") async def remember_episode(session: AsyncSession, chat_id: str, content: str) -> None: await remember_message(session, chat_id, "system", content, kind="episodic") async def build_memory_context(session: AsyncSession, chat_id: str) -> str: fragments: list[str] = [] semantic_result = await session.execute( select(ChatMemory) .where(ChatMemory.chat_id == str(chat_id), ChatMemory.kind == "semantic") .order_by(ChatMemory.id.desc()) .limit(6) ) semantic = list(reversed(list(semantic_result.scalars().all()))) if semantic: fragments.append("Preferencias y hechos utiles:") fragments.extend(f"- {item.content}" for item in semantic) short_result = await session.execute( select(ChatMemory) .where(ChatMemory.chat_id == str(chat_id), ChatMemory.kind == "short") .order_by(ChatMemory.id.desc()) .limit(settings.memory_short_limit) ) short_messages = list(reversed(list(short_result.scalars().all()))) if short_messages: fragments.append("Contexto reciente:") fragments.extend(f"- {item.role}: {item.content}" for item in short_messages) episode_result = await session.execute( select(ChatMemory) .where(ChatMemory.chat_id == str(chat_id), ChatMemory.kind == "episodic") .order_by(ChatMemory.id.desc()) .limit(4) ) episodes = list(reversed(list(episode_result.scalars().all()))) if episodes: fragments.append("Eventos relevantes:") fragments.extend(f"- {item.content}" for item in episodes) return "\n".join(fragments).strip() async def recall_recent_chat(session: AsyncSession, chat_id: str) -> str: short_result = await session.execute( select(ChatMemory) .where(ChatMemory.chat_id == str(chat_id), ChatMemory.kind == "short") .order_by(ChatMemory.id.desc()) .limit(settings.memory_short_limit) ) messages = list(reversed(list(short_result.scalars().all()))) if not messages: return "No tengo contexto reciente guardado en este chat." return "Esto es lo ultimo relevante del chat:\n" + "\n".join( f"- {item.role}: {item.content}" for item in messages ) async def get_pending_action(session: AsyncSession, chat_id: str) -> PendingAction | None: return await session.get(PendingAction, str(chat_id)) async def set_pending_action(session: AsyncSession, chat_id: str, action_type: str, raw_text: str, question: str) -> None: item = await session.get(PendingAction, str(chat_id)) timestamp = now_iso() if item is None: session.add( PendingAction( chat_id=str(chat_id), action_type=action_type, raw_text=raw_text[:4000], question=question, created_at=timestamp, updated_at=timestamp, ) ) return item.action_type = action_type item.raw_text = raw_text[:4000] item.question = question item.updated_at = timestamp async def clear_pending_action(session: AsyncSession, chat_id: str) -> None: item = await session.get(PendingAction, str(chat_id)) if item is not None: await session.delete(item) async def clear_chat_memory(session: AsyncSession, chat_id: str) -> None: await session.execute(delete(ChatMemory).where(ChatMemory.chat_id == str(chat_id))) await clear_pending_action(session, chat_id)