| import ast |
| from abc import ABC, abstractmethod |
|
|
| from app.config import config |
| from app.models import const |
|
|
|
|
| |
| class BaseState(ABC): |
| @abstractmethod |
| def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs): |
| pass |
|
|
| @abstractmethod |
| def get_task(self, task_id: str): |
| pass |
|
|
| @abstractmethod |
| def get_all_tasks(self, page: int, page_size: int): |
| pass |
|
|
|
|
| |
| class MemoryState(BaseState): |
| def __init__(self): |
| self._tasks = {} |
|
|
| def get_all_tasks(self, page: int, page_size: int): |
| start = (page - 1) * page_size |
| end = start + page_size |
| tasks = list(self._tasks.values()) |
| total = len(tasks) |
| return tasks[start:end], total |
|
|
| def update_task( |
| self, |
| task_id: str, |
| state: int = const.TASK_STATE_PROCESSING, |
| progress: int = 0, |
| **kwargs, |
| ): |
| progress = int(progress) |
| if progress > 100: |
| progress = 100 |
|
|
| self._tasks[task_id] = { |
| "task_id": task_id, |
| "state": state, |
| "progress": progress, |
| **kwargs, |
| } |
|
|
| def get_task(self, task_id: str): |
| return self._tasks.get(task_id, None) |
|
|
| def delete_task(self, task_id: str): |
| if task_id in self._tasks: |
| del self._tasks[task_id] |
|
|
|
|
| |
| class RedisState(BaseState): |
| def __init__(self, host="localhost", port=6379, db=0, password=None): |
| import redis |
|
|
| self._redis = redis.StrictRedis(host=host, port=port, db=db, password=password) |
|
|
| def get_all_tasks(self, page: int, page_size: int): |
| start = (page - 1) * page_size |
| end = start + page_size |
| tasks = [] |
| cursor = 0 |
| total = 0 |
| while True: |
| cursor, keys = self._redis.scan(cursor, count=page_size) |
| total += len(keys) |
| if total > start: |
| for key in keys[max(0, start - total):end - total]: |
| task_data = self._redis.hgetall(key) |
| task = { |
| k.decode("utf-8"): self._convert_to_original_type(v) for k, v in task_data.items() |
| } |
| tasks.append(task) |
| if len(tasks) >= page_size: |
| break |
| if cursor == 0 or len(tasks) >= page_size: |
| break |
| return tasks, total |
|
|
| def update_task( |
| self, |
| task_id: str, |
| state: int = const.TASK_STATE_PROCESSING, |
| progress: int = 0, |
| **kwargs, |
| ): |
| progress = int(progress) |
| if progress > 100: |
| progress = 100 |
|
|
| fields = { |
| "task_id": task_id, |
| "state": state, |
| "progress": progress, |
| **kwargs, |
| } |
|
|
| for field, value in fields.items(): |
| self._redis.hset(task_id, field, str(value)) |
|
|
| def get_task(self, task_id: str): |
| task_data = self._redis.hgetall(task_id) |
| if not task_data: |
| return None |
|
|
| task = { |
| key.decode("utf-8"): self._convert_to_original_type(value) |
| for key, value in task_data.items() |
| } |
| return task |
|
|
| def delete_task(self, task_id: str): |
| self._redis.delete(task_id) |
|
|
| @staticmethod |
| def _convert_to_original_type(value): |
| """ |
| Convert the value from byte string to its original data type. |
| You can extend this method to handle other data types as needed. |
| """ |
| value_str = value.decode("utf-8") |
|
|
| try: |
| |
| return ast.literal_eval(value_str) |
| except (ValueError, SyntaxError): |
| pass |
|
|
| if value_str.isdigit(): |
| return int(value_str) |
| |
| return value_str |
|
|
|
|
| |
| _enable_redis = config.app.get("enable_redis", False) |
| _redis_host = config.app.get("redis_host", "localhost") |
| _redis_port = config.app.get("redis_port", 6379) |
| _redis_db = config.app.get("redis_db", 0) |
| _redis_password = config.app.get("redis_password", None) |
|
|
| state = ( |
| RedisState( |
| host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password |
| ) |
| if _enable_redis |
| else MemoryState() |
| ) |
|
|