Spaces:
Sleeping
Sleeping
| """ | |
| Database models and session management. | |
| """ | |
| import os | |
| from datetime import datetime | |
| from typing import Optional, List | |
| from sqlalchemy import create_engine, Column, String, Text, DateTime, ForeignKey | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from sqlalchemy.orm import sessionmaker, relationship | |
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession | |
| from sqlalchemy.orm import sessionmaker as async_sessionmaker | |
| import uuid | |
| DATABASE_URL = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///./algebra_chat.db") | |
| Base = declarative_base() | |
| class Conversation(Base): | |
| """Conversation/Session model.""" | |
| __tablename__ = "conversations" | |
| id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) | |
| title = Column(String(255), nullable=True) | |
| created_at = Column(DateTime, default=datetime.utcnow) | |
| updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) | |
| messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan") | |
| class Message(Base): | |
| """Message model for chat history.""" | |
| __tablename__ = "messages" | |
| id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) | |
| conversation_id = Column(String(36), ForeignKey("conversations.id"), nullable=False) | |
| role = Column(String(20), nullable=False) # 'user' or 'assistant' | |
| content = Column(Text, nullable=False) | |
| image_data = Column(Text, nullable=True) # Base64 encoded image | |
| created_at = Column(DateTime, default=datetime.utcnow) | |
| conversation = relationship("Conversation", back_populates="messages") | |
| # Async engine and session | |
| engine = create_async_engine(DATABASE_URL, echo=False) | |
| AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) | |
| async def init_db(): | |
| """Initialize database tables.""" | |
| # Ensure database directory exists | |
| if "sqlite" in DATABASE_URL: | |
| db_path = DATABASE_URL.replace("sqlite+aiosqlite:///", "").replace("sqlite:///", "") | |
| db_dir = os.path.dirname(db_path) | |
| if db_dir and not os.path.exists(db_dir): | |
| try: | |
| os.makedirs(db_dir, exist_ok=True) | |
| print(f"Created database directory: {db_dir}") | |
| except Exception as e: | |
| print(f"Error creating database directory {db_dir}: {e}") | |
| async with engine.begin() as conn: | |
| await conn.run_sync(Base.metadata.create_all) | |
| async def get_db(): | |
| """Get database session.""" | |
| async with AsyncSessionLocal() as session: | |
| yield session | |