| from sqlalchemy import create_engine, Column, Integer, String, Text, ForeignKey |
| from sqlalchemy.orm import sessionmaker, relationship, declarative_base |
| from sqlalchemy.exc import SQLAlchemyError |
|
|
| from langchain_community.chat_message_histories import ChatMessageHistory |
| from langchain_core.chat_history import BaseChatMessageHistory |
|
|
| import sqlite3 |
|
|
| DATABASE_URL = "sqlite:///persist/chat_history.db" |
| Base = declarative_base() |
|
|
| class Session(Base): |
| __tablename__ = "sessions" |
| id = Column(Integer, primary_key=True) |
| session_id = Column(String, unique=True, nullable=False) |
| messages = relationship("Message", back_populates="session") |
|
|
| class Message(Base): |
| __tablename__ = "messages" |
| id = Column(Integer, primary_key=True) |
| session_id = Column(Integer, ForeignKey("sessions.id"), nullable=False) |
| role = Column(String, nullable=False) |
| content = Column(Text, nullable=False) |
| session = relationship("Session", back_populates="messages") |
|
|
| |
| engine = create_engine(DATABASE_URL) |
| Base.metadata.create_all(engine) |
| SessionLocal = sessionmaker(bind=engine) |
|
|
| def get_db(): |
| db = SessionLocal() |
| try: |
| yield db |
| finally: |
| db.close() |
|
|
| |
| def save_message(session_id: str, role: str, content: str): |
| db = next(get_db()) |
| try: |
| session = db.query(Session).filter(Session.session_id == session_id).first() |
| if not session: |
| session = Session(session_id=session_id) |
| db.add(session) |
| db.commit() |
| db.refresh(session) |
|
|
| db.add(Message(session_id=session.id, role=role, content=content)) |
| db.commit() |
| except SQLAlchemyError: |
| db.rollback() |
| finally: |
| db.close() |
|
|
| |
| def load_session_history(session_id: str) -> BaseChatMessageHistory: |
| db = next(get_db()) |
| chat_history = ChatMessageHistory() |
| try: |
| session = db.query(Session).filter(Session.session_id == session_id).first() |
| if session: |
| for message in session.messages: |
| chat_history.add_message({"role": message.role, "content": message.content}) |
| except SQLAlchemyError: |
| pass |
| finally: |
| db.close() |
|
|
| return chat_history |
|
|
| |
| def get_session_history(session_id: str) -> BaseChatMessageHistory: |
| if session_id not in store: |
| store[session_id] = load_session_history(session_id) |
| return store[session_id] |
|
|
| |
| def save_all_sessions(): |
| for session_id, chat_history in store.items(): |
| for message in chat_history.messages: |
| save_message(session_id, message["role"], message["content"]) |
| |
| |