ddoc's picture
Upload 51 files
6242a59
from enum import Enum
from datetime import datetime
from typing import Optional, Union
from sqlalchemy import (
Column,
String,
Text,
Integer,
DateTime,
LargeBinary,
Boolean,
text,
func,
)
from sqlalchemy.orm import Session
from .base import BaseTableManager, Base
from ..models import TaskModel
class TaskStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
DONE = "done"
FAILED = "failed"
INTERRUPTED = "interrupted"
class Task(TaskModel):
script_params: bytes = None
params: str
def __init__(self, priority=int(datetime.utcnow().timestamp() * 1000), **kwargs):
super().__init__(priority=priority, **kwargs)
class Config(TaskModel.__config__):
exclude = ["script_params"]
@staticmethod
def from_table(table: "TaskTable"):
return Task(
id=table.id,
api_task_id=table.api_task_id,
name=table.name,
type=table.type,
params=table.params,
priority=table.priority,
status=table.status,
result=table.result,
bookmarked=table.bookmarked,
created_at=table.created_at,
updated_at=table.updated_at,
)
def to_table(self):
return TaskTable(
id=self.id,
api_task_id=self.api_task_id,
name=self.name,
type=self.type,
params=self.params,
script_params=b"",
priority=self.priority,
status=self.status,
result=self.result,
bookmarked=self.bookmarked,
)
class TaskTable(Base):
__tablename__ = "task"
id = Column(String(64), primary_key=True)
api_task_id = Column(String(64), nullable=True)
name = Column(String(255), nullable=True)
type = Column(String(20), nullable=False) # txt2img or img2txt
params = Column(Text, nullable=False) # task args
script_params = Column(LargeBinary, nullable=False) # script args
priority = Column(Integer, nullable=False)
status = Column(
String(20), nullable=False, default="pending"
) # pending, running, done, failed
result = Column(Text) # task result
bookmarked = Column(Boolean, nullable=True, default=False)
created_at = Column(
DateTime,
nullable=False,
server_default=text("(datetime('now'))"),
)
updated_at = Column(
DateTime,
nullable=False,
server_default=text("(datetime('now'))"),
onupdate=text("(datetime('now'))"),
)
def __repr__(self):
return f"Task(id={self.id!r}, type={self.type!r}, params={self.params!r}, status={self.status!r}, created_at={self.created_at!r})"
class TaskManager(BaseTableManager):
def get_task(self, id: str) -> Union[TaskTable, None]:
session = Session(self.engine)
try:
task = session.get(TaskTable, id)
return Task.from_table(task) if task else None
except Exception as e:
print(f"Exception getting task from database: {e}")
raise e
finally:
session.close()
def get_tasks(
self,
type: str = None,
status: Union[str, list[str]] = None,
bookmarked: bool = None,
api_task_id: str = None,
limit: int = None,
offset: int = None,
order: str = "asc",
) -> list[TaskTable]:
session = Session(self.engine)
try:
query = session.query(TaskTable)
if type:
query = query.filter(TaskTable.type == type)
if status is not None:
if isinstance(status, list):
query = query.filter(TaskTable.status.in_(status))
else:
query = query.filter(TaskTable.status == status)
if api_task_id:
query = query.filter(TaskTable.api_task_id == api_task_id)
if bookmarked == True:
query = query.filter(TaskTable.bookmarked == bookmarked)
else:
query = query.order_by(TaskTable.bookmarked.asc())
query = query.order_by(
TaskTable.priority.asc()
if order == "asc"
else TaskTable.priority.desc()
)
if limit:
query = query.limit(limit)
if offset:
query = query.offset(offset)
all = query.all()
return [Task.from_table(t) for t in all]
except Exception as e:
print(f"Exception getting tasks from database: {e}")
raise e
finally:
session.close()
def count_tasks(
self,
type: str = None,
status: Union[str, list[str]] = None,
api_task_id: str = None,
) -> int:
session = Session(self.engine)
try:
query = session.query(TaskTable)
if type:
query = query.filter(TaskTable.type == type)
if status is not None:
if isinstance(status, list):
query = query.filter(TaskTable.status.in_(status))
else:
query = query.filter(TaskTable.status == status)
if api_task_id:
query = query.filter(TaskTable.api_task_id == api_task_id)
return query.count()
except Exception as e:
print(f"Exception counting tasks from database: {e}")
raise e
finally:
session.close()
def add_task(self, task: Task) -> TaskTable:
session = Session(self.engine)
try:
result = task.to_table()
session.add(result)
session.commit()
return result
except Exception as e:
print(f"Exception adding task to database: {e}")
raise e
finally:
session.close()
def update_task(
self,
id: str,
name: str = None,
status: str = None,
result: str = None,
bookmarked: bool = None,
) -> TaskTable:
session = Session(self.engine)
try:
task = session.get(TaskTable, id)
if task is None:
raise Exception(f"Task with id {id} not found")
if name is not None:
task.name = name
if status is not None:
task.status = status
if result is not None:
task.result = result
if bookmarked is not None:
task.bookmarked = bookmarked
session.commit()
return task
except Exception as e:
print(f"Exception updating task in database: {e}")
raise e
finally:
session.close()
def prioritize_task(self, id: str, priority: int) -> TaskTable:
"""0 means move to top, -1 means move to bottom, otherwise set the exact priority"""
session = Session(self.engine)
try:
result = session.get(TaskTable, id)
if result:
if priority == 0:
result.priority = self.__get_min_priority(status=TaskStatus.PENDING) - 1
elif priority == -1:
result.priority = int(datetime.utcnow().timestamp() * 1000)
else:
self.__move_tasks_down(priority)
session.execute(text("SELECT 1"))
result.priority = priority
session.commit()
return result
else:
raise Exception(f"Task with id {id} not found")
except Exception as e:
print(f"Exception updating task in database: {e}")
raise e
finally:
session.close()
def delete_task(self, id: str):
session = Session(self.engine)
try:
result = session.get(TaskTable, id)
if result:
session.delete(result)
session.commit()
else:
raise Exception(f"Task with id {id} not found")
except Exception as e:
print(f"Exception deleting task from database: {e}")
raise e
finally:
session.close()
def delete_tasks_before(self, before: datetime, all: bool = False):
session = Session(self.engine)
try:
query = session.query(TaskTable).filter(TaskTable.created_at < before)
if not all:
query = query.filter(
TaskTable.status.in_(
[TaskStatus.DONE, TaskStatus.FAILED, TaskStatus.INTERRUPTED]
)
).filter(TaskTable.bookmarked == False)
deleted_rows = query.delete()
session.commit()
return deleted_rows
except Exception as e:
print(f"Exception deleting tasks from database: {e}")
raise e
finally:
session.close()
def __get_min_priority(self, status: str = None) -> int:
session = Session(self.engine)
try:
query = session.query(func.min(TaskTable.priority))
if status is not None:
query = query.filter(TaskTable.status == status)
min_priority = query.scalar()
return min_priority if min_priority else 0
except Exception as e:
print(f"Exception getting min priority from database: {e}")
raise e
finally:
session.close()
def __move_tasks_down(self, priority: int):
session = Session(self.engine)
try:
session.query(TaskTable).filter(TaskTable.priority >= priority).update(
{TaskTable.priority: TaskTable.priority + 1}
)
session.commit()
except Exception as e:
print(f"Exception moving tasks down in database: {e}")
raise e
finally:
session.close()