Spaces:
Build error
Build error
| # implementations/async_memory.py | |
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession | |
| from sqlalchemy.orm import sessionmaker | |
| from app.settings import DatabaseSettings, MemorySettings | |
| from app.memory.memory import ConversationMemoryInterface | |
| from app.utils.token_counter import SimpleTokenCounter, TikTokenCounter | |
| from app.memory.models.base import Base | |
| from app.memory.models.message import Message | |
| from app.memory.models.user import User | |
| from typing import List, Dict, Optional | |
| from datetime import datetime | |
| from zoneinfo import ZoneInfo | |
| from sqlalchemy.future import select | |
| class AsyncPostgresConversationMemory(ConversationMemoryInterface): | |
| def __init__(self, db_settings: DatabaseSettings, memory_settings: MemorySettings): | |
| self.engine = create_async_engine( | |
| db_settings.url, | |
| pool_size=db_settings.pool_size, | |
| max_overflow=db_settings.max_overflow, | |
| pool_timeout=db_settings.pool_timeout | |
| ) | |
| self.async_session = sessionmaker( | |
| self.engine, class_=AsyncSession, expire_on_commit=False | |
| ) | |
| self.token_limit = memory_settings.token_limit | |
| if memory_settings.token_counter == "tiktoken": | |
| self.token_counter = TikTokenCounter(memory_settings.model_name) | |
| else: | |
| self.token_counter = SimpleTokenCounter() | |
| async def initialize(self): | |
| """Initialize the database by creating all tables.""" | |
| async with self.engine.begin() as conn: | |
| await conn.run_sync(Base.metadata.create_all) | |
| # In your async_memory.py | |
| async def add_message(self, username: str, role: str, message: str, timestamp: Optional[datetime] = None) -> None: | |
| from app.memory.models.user import User # Import here to avoid circular dependencies | |
| async with self.async_session() as session: | |
| # Look up the user by username | |
| result = await session.execute(select(User).filter_by(username=username)) | |
| user = result.scalars().first() | |
| if user is None: | |
| raise ValueError(f"User with username '{username}' not found") | |
| if timestamp is None: | |
| timestamp = datetime.now(ZoneInfo("Asia/Jakarta")) | |
| # Create the message using the found user's id | |
| msg = Message(user_id=user.id, role=role, message=message, timestamp=timestamp) | |
| session.add(msg) | |
| await session.commit() | |
| await self.trim_memory_if_needed(session) | |
| async def get_all_history(self) -> List[Dict]: | |
| async with self.async_session() as session: | |
| result = await session.execute( | |
| select(Message).order_by(Message.timestamp) | |
| ) | |
| messages = result.scalars().all() | |
| return [{"role": msg.role, "content": msg.message} for msg in messages] | |
| async def get_history( | |
| self, | |
| username: Optional[str] = None, | |
| token_limit: Optional[int] = None, | |
| last_n: Optional[int] = None | |
| ) -> List[Dict]: | |
| async with self.async_session() as session: | |
| # Build the base query | |
| query = select(Message).order_by(Message.timestamp) | |
| if username is not None: | |
| # Join with User table and filter by username | |
| query = query.join(User).filter(User.username == username) | |
| result = await session.execute(query) | |
| messages = result.scalars().all() | |
| # Accumulate messages in reverse (latest first) | |
| selected = [] | |
| total_tokens = 0 | |
| for msg in reversed(messages): | |
| tokens = self.token_counter.count_tokens(msg.message) | |
| # If token_limit is specified and no message has been added yet, | |
| # force-add the last message even if it exceeds token_limit. | |
| if token_limit is not None and len(selected) == 0 and tokens > token_limit: | |
| selected.append(msg) | |
| total_tokens = tokens | |
| continue | |
| # Otherwise, check if adding this message would exceed the token limit. | |
| if token_limit is not None and total_tokens + tokens > token_limit: | |
| break | |
| selected.append(msg) | |
| total_tokens += tokens | |
| # Stop if we've reached the maximum number of messages. | |
| if last_n is not None and len(selected) >= last_n: | |
| break | |
| # Reverse to return in chronological order | |
| selected.reverse() | |
| return [{"role": msg.role, "parts": msg.message} for msg in selected] | |
| async def clear_memory(self) -> None: | |
| async with self.async_session() as session: | |
| await session.execute(select(Message).delete()) | |
| await session.commit() | |
| async def get_total_tokens(self) -> int: | |
| async with self.async_session() as session: | |
| result = await session.execute(select(Message)) | |
| messages = result.scalars().all() | |
| return sum(self.token_counter.count_tokens(msg.message) for msg in messages) | |
| async def trim_memory_if_needed(self, session: AsyncSession) -> None: | |
| result = await session.execute(select(Message).order_by(Message.timestamp)) | |
| messages = result.scalars().all() | |
| total_tokens = sum(self.token_counter.count_tokens(msg.message) for msg in messages) | |
| while total_tokens > self.token_limit and messages: | |
| oldest = messages.pop(0) | |
| total_tokens -= self.token_counter.count_tokens(oldest.message) | |
| await session.delete(oldest) | |
| await session.commit() |