Spaces:
Runtime error
Runtime error
| """ | |
| Task queue management | |
| This module provides classes and functions for managing the task queue. | |
| Classes: | |
| QueueTask: A class representing a task in the queue. | |
| TaskQueue: A class for managing the task queue. | |
| """ | |
| import uuid | |
| import time | |
| from typing import List, Tuple | |
| import numpy as np | |
| import requests | |
| from fooocusapi.utils.file_utils import delete_output_file, get_file_serve_url | |
| from fooocusapi.utils.img_utils import narray_to_base64img | |
| from fooocusapi.utils.logger import logger | |
| from fooocusapi.models.common.task import ImageGenerationResult, GenerationFinishReason | |
| from fooocusapi.parameters import ImageGenerationParams | |
| from fooocusapi.models.common.task import TaskType | |
| class QueueTask: | |
| """ | |
| A class representing a task in the queue. | |
| Attributes: | |
| job_id (str): The unique identifier for the task, generated by uuid. | |
| task_type (TaskType): The type of task. | |
| is_finished (bool): Indicates whether the task has been completed. | |
| finish_progress (int): The progress of the task completion. | |
| in_queue_mills (int): The time the task was added to the queue, in milliseconds. | |
| start_mills (int): The time the task started, in milliseconds. | |
| finish_mills (int): The time the task finished, in milliseconds. | |
| finish_with_error (bool): Indicates whether the task finished with an error. | |
| task_status (str): The status of the task. | |
| task_step_preview (str): A list of step previews for the task. | |
| task_result (List[ImageGenerationResult]): The result of the task. | |
| error_message (str): The error message, if any. | |
| webhook_url (str): The webhook URL, if any. | |
| """ | |
| job_id: str | |
| task_type: TaskType | |
| req_param: ImageGenerationParams | |
| is_finished: bool = False | |
| finish_progress: int = 0 | |
| in_queue_mills: int | |
| start_mills: int = 0 | |
| finish_mills: int = 0 | |
| finish_with_error: bool = False | |
| task_status: str | None = None | |
| task_step_preview: str | None = None | |
| task_result: List[ImageGenerationResult] = None | |
| error_message: str | None = None | |
| webhook_url: str | None = None # attribute for individual webhook_url | |
| def __init__( | |
| self, | |
| job_id: str, | |
| task_type: TaskType, | |
| req_param: ImageGenerationParams, | |
| webhook_url: str | None = None, | |
| ): | |
| self.job_id = job_id | |
| self.task_type = task_type | |
| self.req_param = req_param | |
| self.in_queue_mills = int(round(time.time() * 1000)) | |
| self.webhook_url = webhook_url | |
| def set_progress(self, progress: int, status: str | None): | |
| """ | |
| Set progress and status | |
| Arguments: | |
| progress {int} -- progress | |
| status {str} -- status | |
| """ | |
| progress = min(progress, 100) | |
| self.finish_progress = progress | |
| self.task_status = status | |
| def set_step_preview(self, task_step_preview: str | None): | |
| """set step preview | |
| Set step preview | |
| Arguments: | |
| task_step_preview {str} -- step preview | |
| """ | |
| self.task_step_preview = task_step_preview | |
| def set_result( | |
| self, | |
| task_result: List[ImageGenerationResult], | |
| finish_with_error: bool, | |
| error_message: str | None = None, | |
| ): | |
| """set result | |
| Set task result | |
| Arguments: | |
| task_result {List[ImageGenerationResult]} -- task result | |
| finish_with_error {bool} -- finish with error | |
| error_message {str} -- error message | |
| """ | |
| if not finish_with_error: | |
| self.finish_progress = 100 | |
| self.task_status = "Finished" | |
| self.task_result = task_result | |
| self.finish_with_error = finish_with_error | |
| self.error_message = error_message | |
| def __str__(self) -> str: | |
| return f"QueueTask(job_id={self.job_id}, task_type={self.task_type},\ | |
| is_finished={self.is_finished}, finished_progress={self.finish_progress}, \ | |
| in_queue_mills={self.in_queue_mills}, start_mills={self.start_mills}, \ | |
| finish_mills={self.finish_mills}, finish_with_error={self.finish_with_error}, \ | |
| error_message={self.error_message}, task_status={self.task_status}, \ | |
| task_step_preview={self.task_step_preview}, webhook_url={self.webhook_url})" | |
| class TaskQueue: | |
| """ | |
| TaskQueue is a queue of tasks that are waiting to be processed. | |
| Attributes: | |
| queue: List[QueueTask] | |
| history: List[QueueTask] | |
| last_job_id: str | |
| webhook_url: str | |
| persistent: bool | |
| """ | |
| queue: List[QueueTask] = [] | |
| history: List[QueueTask] = [] | |
| last_job_id: str = None | |
| webhook_url: str | None = None | |
| persistent: bool = False | |
| def __init__( | |
| self, | |
| queue_size: int, | |
| history_size: int, | |
| webhook_url: str | None = None, | |
| persistent: bool | None = False, | |
| ): | |
| self.queue_size = queue_size | |
| self.history_size = history_size | |
| self.webhook_url = webhook_url | |
| self.persistent = False if persistent is None else persistent | |
| def add_task( | |
| self, | |
| task_type: TaskType, | |
| req_param: ImageGenerationParams, | |
| webhook_url: str | None = None, | |
| ) -> QueueTask | None: | |
| """ | |
| Create and add task to queue | |
| :param task_type: task type | |
| :param req_param: request parameters | |
| :param webhook_url: webhook url | |
| :returns: The created task's job_id, or None if reach the queue size limit | |
| """ | |
| if len(self.queue) >= self.queue_size: | |
| return None | |
| if isinstance(req_param, dict): | |
| req_param = ImageGenerationParams(**req_param) | |
| job_id = str(uuid.uuid4()) | |
| task = QueueTask( | |
| job_id=job_id, | |
| task_type=task_type, | |
| req_param=req_param, | |
| webhook_url=webhook_url, | |
| ) | |
| self.queue.append(task) | |
| self.last_job_id = job_id | |
| return task | |
| def get_task(self, job_id: str, include_history: bool = False) -> QueueTask | None: | |
| """ | |
| Get task by job_id | |
| :param job_id: job id | |
| :param include_history: whether to include history tasks | |
| :returns: The task with the given job_id, or None if not found | |
| """ | |
| for task in self.queue: | |
| if task.job_id == job_id: | |
| return task | |
| if include_history: | |
| for task in self.history: | |
| if task.job_id == job_id: | |
| return task | |
| return None | |
| def is_task_ready_to_start(self, job_id: str) -> bool: | |
| """ | |
| Check if the task is ready to start | |
| :param job_id: job id | |
| :returns: True if the task is ready to start, False otherwise | |
| """ | |
| task = self.get_task(job_id) | |
| if task is None: | |
| return False | |
| return self.queue[0].job_id == job_id | |
| def is_task_finished(self, job_id: str) -> bool: | |
| """ | |
| Check if the task is finished | |
| :param job_id: job id | |
| :returns: True if the task is finished, False otherwise | |
| """ | |
| task = self.get_task(job_id, True) | |
| if task is None: | |
| return False | |
| return task.is_finished | |
| def start_task(self, job_id: str): | |
| """ | |
| Start task by job_id | |
| :param job_id: job id | |
| """ | |
| task = self.get_task(job_id) | |
| if task is not None: | |
| task.start_mills = int(round(time.time() * 1000)) | |
| def finish_task(self, job_id: str): | |
| """ | |
| Finish task by job_id | |
| :param job_id: job id | |
| """ | |
| task = self.get_task(job_id) | |
| if task is not None: | |
| task.is_finished = True | |
| task.finish_mills = int(round(time.time() * 1000)) | |
| # Use the task's webhook_url if available, else use the default | |
| webhook_url = task.webhook_url or self.webhook_url | |
| data = {"job_id": task.job_id, "job_result": []} | |
| if isinstance(task.task_result, List): | |
| for item in task.task_result: | |
| data["job_result"].append( | |
| { | |
| "url": get_file_serve_url(item.im) if item.im else None, | |
| "seed": item.seed if item.seed else "-1", | |
| } | |
| ) | |
| # Send webhook | |
| if task.is_finished and webhook_url: | |
| try: | |
| res = requests.post(webhook_url, json=data, timeout=15) | |
| print(f"Call webhook response status: {res.status_code}") | |
| except Exception as e: | |
| print("Call webhook error:", e) | |
| # Move task to history | |
| self.queue.remove(task) | |
| self.history.append(task) | |
| # save history to database | |
| if self.persistent: | |
| from fooocusapi.sql_client import add_history | |
| add_history( | |
| params=task.req_param.to_dict(), | |
| task_info=dict( | |
| task_type=task.task_type.value, | |
| task_id=task.job_id, | |
| task_in_queue_mills=task.in_queue_mills, | |
| task_start_mills=task.start_mills, | |
| task_finish_mills=task.finish_mills, | |
| ), | |
| result_url=",".join([job["url"] for job in data["job_result"]]), | |
| finish_reason=task.task_result[0].finish_reason.value, | |
| ) | |
| # Clean history | |
| if len(self.history) > self.history_size != 0: | |
| removed_task = self.history.pop(0) | |
| if isinstance(removed_task.task_result, List): | |
| for item in removed_task.task_result: | |
| if ( | |
| isinstance(item, ImageGenerationResult) | |
| and item.finish_reason == GenerationFinishReason.success | |
| and item.im is not None | |
| ): | |
| delete_output_file(item.im) | |
| logger.std_info( | |
| f"[TaskQueue] Clean task history, remove task: {removed_task.job_id}" | |
| ) | |
| class TaskOutputs: | |
| """ | |
| TaskOutputs is a container for task outputs | |
| """ | |
| outputs = [] | |
| def __init__(self, task: QueueTask): | |
| self.task = task | |
| def append(self, args: List[any]): | |
| """ | |
| Append output to task outputs list | |
| :param args: output arguments | |
| """ | |
| self.outputs.append(args) | |
| if len(args) >= 2: | |
| if ( | |
| args[0] == "preview" | |
| and isinstance(args[1], Tuple) | |
| and len(args[1]) >= 2 | |
| ): | |
| number = args[1][0] | |
| text = args[1][1] | |
| self.task.set_progress(number, text) | |
| if len(args[1]) >= 3 and isinstance(args[1][2], np.ndarray): | |
| base64_preview_img = narray_to_base64img(args[1][2]) | |
| self.task.set_step_preview(base64_preview_img) | |