Spaces:
Build error
Build error
| from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, Boolean | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from sqlalchemy.orm import relationship, sessionmaker | |
| from datetime import datetime | |
| import uuid | |
| Base = declarative_base() | |
| class ChatSession(Base): | |
| __tablename__ = 'chat_sessions' | |
| id = Column(Integer, primary_key=True) | |
| session_id = Column(String(36), unique=True, default=lambda: str(uuid.uuid4())) | |
| doctor_name = Column(String(100), nullable=False) # Added doctor name | |
| user_identifier = Column(String(150)) # Will store doctor_name + timestamp | |
| started_at = Column(DateTime, default=datetime.utcnow) | |
| last_activity = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) | |
| messages = relationship("ChatMessage", back_populates="session") | |
| def get_sessions_by_doctor(cls, session, doctor_name): | |
| """Get all sessions for a specific doctor""" | |
| return session.query(cls).filter(cls.doctor_name == doctor_name).all() | |
| class ChatMessage(Base): | |
| __tablename__ = 'chat_messages' | |
| id = Column(Integer, primary_key=True) | |
| session_id = Column(String(36), ForeignKey('chat_sessions.session_id')) | |
| timestamp = Column(DateTime, default=datetime.utcnow) | |
| is_user = Column(Boolean, default=True) | |
| message = Column(Text) | |
| sources_used = Column(Text, nullable=True) | |
| session = relationship("ChatSession", back_populates="messages") | |
| class DatabaseManager: | |
| def __init__(self, db_url="sqlite:///chat_history.db"): | |
| self.engine = create_engine(db_url) | |
| Base.metadata.create_all(self.engine) | |
| self.Session = sessionmaker(bind=self.engine) | |
| def create_session(self, doctor_name): | |
| """Create a new chat session with doctor name and timestamp""" | |
| session = self.Session() | |
| try: | |
| # Create unique timestamp for this session | |
| timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S_%f") # Added milliseconds | |
| user_identifier = f"{doctor_name}_{timestamp}" | |
| chat_session = ChatSession( | |
| doctor_name=doctor_name, | |
| user_identifier=user_identifier | |
| ) | |
| session.add(chat_session) | |
| session.commit() | |
| return chat_session.session_id | |
| finally: | |
| session.close() | |
| def get_doctor_sessions(self, doctor_name): | |
| """Get all sessions for a specific doctor""" | |
| session = self.Session() | |
| try: | |
| return session.query(ChatSession)\ | |
| .filter(ChatSession.doctor_name == doctor_name)\ | |
| .order_by(ChatSession.last_activity.desc()).all() | |
| finally: | |
| session.close() | |
| def log_message(self, session_id, message, is_user=True, sources=None): | |
| session = self.Session() | |
| try: | |
| chat_message = ChatMessage( | |
| session_id=session_id, | |
| message=message, | |
| is_user=is_user, | |
| sources_used=sources | |
| ) | |
| session.add(chat_message) | |
| session.commit() | |
| finally: | |
| session.close() | |
| def get_session_history(self, session_id): | |
| session = self.Session() | |
| try: | |
| messages = session.query(ChatMessage)\ | |
| .filter(ChatMessage.session_id == session_id)\ | |
| .order_by(ChatMessage.timestamp).all() | |
| return messages | |
| finally: | |
| session.close() |