Spaces:
Runtime error
Runtime error
| import json | |
| try: | |
| from sqlalchemy.orm import declarative_base | |
| except ImportError: | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from langchain.schema import BaseChatMessageHistory | |
| from datetime import datetime | |
| from sqlalchemy import Column, Text, orm, create_engine | |
| from .schemas import create_message_model, create_session_table | |
| def get_sessions(engine, model_class, user_id): | |
| with orm.sessionmaker(engine)() as session: | |
| result = ( | |
| session.query(model_class) | |
| .where( | |
| model_class.session_id == user_id | |
| ) | |
| .order_by(model_class.create_by.desc()) | |
| ) | |
| return json.loads(result) | |
| class SessionManager: | |
| def __init__(self, session_state, host, port, username, password, | |
| db='chat', sess_table='sessions', msg_table='chat_memory') -> None: | |
| conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https' | |
| self.engine = create_engine(conn_str, echo=False) | |
| self.sess_model_class = create_session_table( | |
| sess_table, declarative_base()) | |
| self.sess_model_class.metadata.create_all(self.engine) | |
| self.msg_model_class = create_message_model( | |
| msg_table, declarative_base()) | |
| self.msg_model_class.metadata.create_all(self.engine) | |
| self.Session = orm.sessionmaker(self.engine) | |
| self.session_state = session_state | |
| def list_sessions(self, user_id): | |
| with self.Session() as session: | |
| result = ( | |
| session.query(self.sess_model_class) | |
| .where( | |
| self.sess_model_class.user_id == user_id | |
| ) | |
| .order_by(self.sess_model_class.create_by.desc()) | |
| ) | |
| sessions = [] | |
| for r in result: | |
| sessions.append({ | |
| "session_id": r.session_id.split("?")[-1], | |
| "system_prompt": r.system_prompt, | |
| }) | |
| return sessions | |
| def modify_system_prompt(self, session_id, sys_prompt): | |
| with self.Session() as session: | |
| session.update(self.sess_model_class).where( | |
| self.sess_model_class == session_id).value(system_prompt=sys_prompt) | |
| session.commit() | |
| def add_session(self, user_id, session_id, system_prompt, **kwargs): | |
| with self.Session() as session: | |
| elem = self.sess_model_class( | |
| user_id=user_id, session_id=session_id, system_prompt=system_prompt, | |
| create_by=datetime.now(), additionals=json.dumps(kwargs) | |
| ) | |
| session.add(elem) | |
| session.commit() | |
| def remove_session(self, session_id): | |
| with self.Session() as session: | |
| session.query(self.sess_model_class).where( | |
| self.sess_model_class.session_id == session_id).delete() | |
| # session.query(self.msg_model_class).where(self.msg_model_class.session_id==session_id).delete() | |
| if "agent" in self.session_state: | |
| self.session_state.agent.memory.chat_memory.clear() | |
| if "file_analyzer" in self.session_state: | |
| self.session_state.file_analyzer.clear_files() | |