Spaces:
Paused
Paused
| import random | |
| from typing import List, Optional, Union, Tuple | |
| from fastapi import Depends, Query | |
| from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select, and_ | |
| from sqlalchemy import UniqueConstraint | |
| default_schedule = "8,20" | |
| class UserBase(SQLModel): | |
| __table_args__ = (UniqueConstraint("telegram_chat_id"),) | |
| name: str | |
| telegram_chat_id: int | |
| gmt: Optional[int] = 0 | |
| active: Optional[bool] = True | |
| scheduled_hours: Optional[str] = default_schedule | |
| class User(UserBase, table=True): | |
| id: Optional[int] = Field(default=None, primary_key=True) | |
| reminders: List["Reminder"] = Relationship(back_populates="user") | |
| class UserRead(UserBase): | |
| id: int | |
| class UserCreate(UserBase): | |
| pass | |
| class ReminderBase(SQLModel): | |
| reminder: str | |
| user_id: Optional[int] = Field(default=None, foreign_key="user.id") | |
| class Reminder(ReminderBase, table=True): | |
| id: Optional[int] = Field(default=None, primary_key=True) | |
| user: User = Relationship(back_populates="reminders") | |
| class ReminderRead(ReminderBase): | |
| id: int | |
| class ReminderCreate(ReminderBase): | |
| pass | |
| sqlite_file_name = "database.db" | |
| sqlite_url = f"sqlite:///{sqlite_file_name}" | |
| connect_args = {"check_same_thread": False} | |
| engine = create_engine(sqlite_url, echo=False, connect_args=connect_args) | |
| def create_db_and_tables(): | |
| SQLModel.metadata.create_all(engine) | |
| def get_session(): | |
| with Session(engine) as session: | |
| yield session | |
| def join_user( | |
| user: UserCreate, | |
| session: Session = next(get_session()), | |
| ) -> Optional[User]: | |
| """ | |
| Joins the user to the system. | |
| Args: | |
| user: | |
| session: | |
| Returns: User or None | |
| """ | |
| found_user = session.exec( | |
| select(User).where(User.telegram_chat_id == user.telegram_chat_id) | |
| ).first() | |
| if found_user is None: | |
| user = db_create_user(user, session) | |
| return user | |
| else: | |
| if found_user.active: | |
| return None | |
| else: | |
| found_user.active = True | |
| session.add(found_user) | |
| session.commit() | |
| session.refresh(found_user) | |
| return found_user | |
| def leave_user( | |
| user: UserCreate, | |
| session: Session = next(get_session()), | |
| ) -> Optional[User]: | |
| """ | |
| Deactivates the user. | |
| Args: | |
| user: | |
| session: | |
| Returns: User or None | |
| """ | |
| found_user = session.exec( | |
| select(User).where(User.telegram_chat_id == user.telegram_chat_id) | |
| ).first() | |
| if found_user is None: | |
| return None | |
| else: | |
| if not found_user.active: | |
| return None | |
| else: | |
| found_user.active = False | |
| session.add(found_user) | |
| session.commit() | |
| session.refresh(found_user) | |
| return found_user | |
| def get_user_status( | |
| telegram_chat_id: int, | |
| session: Session = next(get_session()), | |
| ) -> Optional[Tuple[int, bool]]: | |
| """ | |
| Get status of the user. | |
| Args: | |
| telegram_chat_id: | |
| session: | |
| Returns: (gmt,active) | |
| """ | |
| found_user = session.exec( | |
| select(User).where(User.telegram_chat_id == telegram_chat_id) | |
| ).first() | |
| if found_user is None: | |
| return None, None | |
| else: | |
| return found_user.gmt, found_user.active | |
| def select_random_memory( | |
| user: UserCreate, | |
| session: Session = next(get_session()), | |
| ) -> Optional[Union[Reminder, bool]]: | |
| """ | |
| Select random memory from user's memory-vault. | |
| Args: | |
| user: | |
| session: | |
| Returns: | |
| """ | |
| found_user = session.exec( | |
| select(User).where(User.telegram_chat_id == user.telegram_chat_id) | |
| ).first() | |
| if found_user is None: | |
| return None | |
| if len(found_user.reminders) > 0: | |
| memory_list = found_user.reminders | |
| random_memory = random.choice(memory_list) | |
| return random_memory | |
| else: | |
| return False | |
| def list_memories( | |
| user: UserCreate, | |
| session: Session = next(get_session()), | |
| ) -> Optional[List[Reminder]]: | |
| """ | |
| Return all the memory-vault. | |
| Args: | |
| user: | |
| session: | |
| Returns: | |
| """ | |
| found_user = session.exec( | |
| select(User).where(User.telegram_chat_id == user.telegram_chat_id) | |
| ).first() | |
| if found_user is None: | |
| return None | |
| return found_user.reminders | |
| def add_memory( | |
| user: UserCreate, | |
| memory: str, | |
| session: Session = next(get_session()), | |
| ) -> Optional[Union[Reminder, bool]]: | |
| """ | |
| Add a memory to user's memory-vault. | |
| Args: | |
| user: | |
| memory: | |
| session: | |
| Returns: Reminder | |
| """ | |
| found_user = session.exec( | |
| select(User).where(User.telegram_chat_id == user.telegram_chat_id) | |
| ).first() | |
| if found_user is None: | |
| return None | |
| else: | |
| found_memory = session.exec( | |
| select(Reminder).where( | |
| and_(Reminder.user == found_user, Reminder.reminder == memory) | |
| ) | |
| ).first() | |
| if found_memory is not None: | |
| return False | |
| else: | |
| reminder = ReminderCreate( | |
| reminder=memory, | |
| user_id=found_user.id, | |
| ) | |
| db_reminder = Reminder.from_orm(reminder) | |
| session.add(db_reminder) | |
| session.commit() | |
| session.refresh(db_reminder) | |
| return db_reminder | |
| def delete_memory( | |
| user: UserCreate, | |
| memory_id: int, | |
| session: Session = next(get_session()), | |
| ) -> Union[bool, str, None]: | |
| """ | |
| Delete a memory from user's memory-vault. | |
| Args: | |
| user: | |
| memory_id: | |
| session: | |
| Returns: bool | |
| """ | |
| found_user = session.exec( | |
| select(User).where(User.telegram_chat_id == user.telegram_chat_id) | |
| ).first() | |
| if found_user is None: | |
| return None | |
| if len(found_user.reminders) > memory_id: | |
| reminder = found_user.reminders[memory_id] | |
| memory = reminder.reminder | |
| session.delete(reminder) | |
| session.commit() | |
| return memory | |
| else: | |
| return False | |
| def delete_last_memory( | |
| user: UserCreate, | |
| session: Session = next(get_session()), | |
| ) -> Union[bool, str, None]: | |
| """ | |
| Delete the last memory from user's memory-vault. | |
| Args: | |
| user: | |
| session: | |
| Returns: bool | |
| """ | |
| found_user = session.exec( | |
| select(User).where(User.telegram_chat_id == user.telegram_chat_id) | |
| ).first() | |
| if found_user is None: | |
| return None | |
| if found_user.reminders: | |
| reminder = found_user.reminders.pop() | |
| memory = reminder.reminder | |
| session.delete(reminder) | |
| session.commit() | |
| return memory | |
| else: | |
| return False | |
| def update_gmt( | |
| user: UserCreate, | |
| gmt: int, | |
| session: Session = next(get_session()), | |
| ) -> Optional[User]: | |
| """ | |
| Update GMT of the user. | |
| Returns: User or None | |
| """ | |
| found_user = session.exec( | |
| select(User).where(User.telegram_chat_id == user.telegram_chat_id) | |
| ).first() | |
| if found_user is None: | |
| return None | |
| else: | |
| found_user.gmt = gmt | |
| session.add(found_user) | |
| session.commit() | |
| session.refresh(found_user) | |
| return found_user | |
| def get_schedule( | |
| user: UserCreate, | |
| session: Session = next(get_session()), | |
| ) -> Optional[str]: | |
| """ | |
| Get schedule of the user. | |
| Args: | |
| user: | |
| session: | |
| Returns: str or None | |
| """ | |
| found_user = session.exec( | |
| select(User).where(User.telegram_chat_id == user.telegram_chat_id) | |
| ).first() | |
| if found_user is None: | |
| return None | |
| else: | |
| return found_user.scheduled_hours | |
| def reset_schedule( | |
| user: UserCreate, | |
| session: Session = next(get_session()), | |
| ) -> Optional[str]: | |
| """ | |
| Reset schedule of the user. | |
| Args: | |
| user: | |
| session: | |
| Returns: str or None | |
| """ | |
| found_user = session.exec( | |
| select(User).where(User.telegram_chat_id == user.telegram_chat_id) | |
| ).first() | |
| if found_user is None: | |
| return None | |
| else: | |
| found_user.scheduled_hours = default_schedule | |
| session.add(found_user) | |
| session.commit() | |
| session.refresh(found_user) | |
| return found_user.scheduled_hours | |
| def remove_hour_from_schedule( | |
| user: UserCreate, | |
| hour: int, | |
| session: Session = next(get_session()), | |
| ) -> Optional[str]: | |
| """ | |
| Remove all appearances of hour from schedule of the user. | |
| Args: | |
| user: | |
| hour: | |
| session: | |
| Returns: str or None | |
| """ | |
| found_user = session.exec( | |
| select(User).where(User.telegram_chat_id == user.telegram_chat_id) | |
| ).first() | |
| if found_user is None: | |
| return None | |
| else: | |
| old_schedule = create_schedule_array(found_user.scheduled_hours) | |
| new_schedule = [] | |
| for number in old_schedule: | |
| if not hour == number: | |
| new_schedule.append(number) | |
| str_schedule_list = [str(number) for number in new_schedule] | |
| str_schedule = ",".join(str_schedule_list) | |
| found_user.scheduled_hours = str_schedule | |
| session.add(found_user) | |
| session.commit() | |
| session.refresh(found_user) | |
| return found_user.scheduled_hours | |
| def add_hours_to_the_schedule( | |
| user: UserCreate, | |
| schedule_list: List[int], | |
| session: Session = next(get_session()), | |
| ) -> Optional[str]: | |
| """ | |
| Add hours to the schedule of the user. | |
| Args: | |
| user: | |
| schedule_list: | |
| session: | |
| Returns: str or None | |
| """ | |
| found_user = session.exec( | |
| select(User).where(User.telegram_chat_id == user.telegram_chat_id) | |
| ).first() | |
| if found_user is None: | |
| return None | |
| else: | |
| old_schedule = create_schedule_array(found_user.scheduled_hours) | |
| new_schedule = [] | |
| for number in old_schedule: | |
| new_schedule.append(number) | |
| for number in schedule_list: | |
| new_schedule.append(number) | |
| sorted_schedule = sorted(new_schedule) | |
| str_sorted_schedule = [str(number) for number in sorted_schedule] | |
| str_schedule = ",".join(str_sorted_schedule) | |
| found_user.scheduled_hours = str_schedule | |
| session.add(found_user) | |
| session.commit() | |
| session.refresh(found_user) | |
| return found_user.scheduled_hours | |
| def db_create_user( | |
| user: UserCreate, | |
| session: Session = next(get_session()), | |
| ) -> Optional[User]: | |
| try: | |
| user = User.from_orm(user) | |
| session.add(user) | |
| session.commit() | |
| session.refresh(user) | |
| return user | |
| except Exception as ex: | |
| return None | |
| def db_read_users( | |
| *, | |
| only_active_users: bool = True, | |
| session: Session = next(get_session()), | |
| offset: int = 0, | |
| limit: int = 100, | |
| ) -> List[User]: | |
| if only_active_users: | |
| users = session.exec( | |
| select(User).where(User.active).offset(offset).limit(limit) | |
| ).all() | |
| else: | |
| users = session.exec(select(User).offset(offset).limit(limit)).all() | |
| return users | |
| def create_schedule_array(schedule_str: str) -> List[int]: | |
| """ | |
| Create schedule array from schedule string splitted by comas(,). | |
| """ | |
| if schedule_str == "": | |
| return [] | |
| else: | |
| schedule_list = [] | |
| str_list = schedule_str.split(",") | |
| for str_number in str_list: | |
| schedule_list.append(int(str_number)) | |
| return schedule_list | |