Spaces:
Build error
Build error
File size: 6,862 Bytes
87a665c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | # tasks.py
import asyncio
from typing import Dict
from uuid import uuid4
import json
import logging
from redis.asyncio import Redis
from fastapi import Request
from typing import Dict, List, Optional
from open_webui.env import REDIS_KEY_PREFIX
log = logging.getLogger(__name__)
# A dictionary to keep track of active tasks
tasks: Dict[str, asyncio.Task] = {}
item_tasks = {}
REDIS_TASKS_KEY = f'{REDIS_KEY_PREFIX}:tasks'
REDIS_ITEM_TASKS_KEY = f'{REDIS_KEY_PREFIX}:tasks:item'
REDIS_PUBSUB_CHANNEL = f'{REDIS_KEY_PREFIX}:tasks:commands'
async def redis_task_command_listener(app):
redis: Redis = app.state.redis
pubsub = redis.pubsub()
await pubsub.subscribe(REDIS_PUBSUB_CHANNEL)
async for message in pubsub.listen():
if message['type'] != 'message':
continue
try:
command = json.loads(message['data'])
if command.get('action') == 'stop':
task_id = command.get('task_id')
local_task = tasks.get(task_id)
if local_task:
local_task.cancel()
except Exception as e:
log.exception(f'Error handling distributed task command: {e}')
### ------------------------------
### REDIS-ENABLED HANDLERS
### ------------------------------
async def redis_save_task(redis: Redis, task_id: str, item_id: Optional[str]):
pipe = redis.pipeline()
pipe.hset(REDIS_TASKS_KEY, task_id, item_id or '')
if item_id:
pipe.sadd(f'{REDIS_ITEM_TASKS_KEY}:{item_id}', task_id)
await pipe.execute()
async def redis_cleanup_task(redis: Redis, task_id: str, item_id: Optional[str]):
pipe = redis.pipeline()
pipe.hdel(REDIS_TASKS_KEY, task_id)
if item_id:
pipe.srem(f'{REDIS_ITEM_TASKS_KEY}:{item_id}', task_id)
await pipe.execute()
# Remove the set key entirely if no tasks remain for this item
if await redis.scard(f'{REDIS_ITEM_TASKS_KEY}:{item_id}') == 0:
await redis.delete(f'{REDIS_ITEM_TASKS_KEY}:{item_id}')
else:
await pipe.execute()
async def redis_list_tasks(redis: Redis) -> List[str]:
return list(await redis.hkeys(REDIS_TASKS_KEY))
async def redis_list_item_tasks(redis: Redis, item_id: str) -> List[str]:
return list(await redis.smembers(f'{REDIS_ITEM_TASKS_KEY}:{item_id}'))
async def redis_send_command(redis: Redis, command: dict):
command_json = json.dumps(command)
# RedisCluster doesn't expose publish() directly, but the
# PUBLISH command broadcasts across all cluster nodes server-side.
if hasattr(redis, 'nodes_manager'):
await redis.execute_command('PUBLISH', REDIS_PUBSUB_CHANNEL, command_json)
else:
await redis.publish(REDIS_PUBSUB_CHANNEL, command_json)
async def cleanup_task(redis, task_id: str, id=None):
"""
Remove a completed or canceled task from the global `tasks` dictionary.
"""
if redis:
await redis_cleanup_task(redis, task_id, id)
tasks.pop(task_id, None) # Remove the task if it exists
# If an ID is provided, remove the task from the item_tasks dictionary
if id and task_id in item_tasks.get(id, []):
item_tasks[id].remove(task_id)
if not item_tasks[id]: # If no tasks left for this ID, remove the entry
item_tasks.pop(id, None)
async def create_task(redis, coroutine, id=None):
"""
Create a new asyncio task and add it to the global task dictionary.
"""
task_id = str(uuid4()) # Generate a unique ID for the task
task = asyncio.create_task(coroutine) # Create the task
# Add a done callback for cleanup
task.add_done_callback(lambda t: asyncio.create_task(cleanup_task(redis, task_id, id)))
tasks[task_id] = task
# If an ID is provided, associate the task with that ID
if item_tasks.get(id):
item_tasks[id].append(task_id)
else:
item_tasks[id] = [task_id]
if redis:
await redis_save_task(redis, task_id, id)
return task_id, task
async def list_tasks(redis):
"""
List all currently active task IDs.
"""
if redis:
return await redis_list_tasks(redis)
return list(tasks.keys())
async def list_task_ids_by_item_id(redis, id):
"""
List all tasks associated with a specific ID.
"""
if redis:
return await redis_list_item_tasks(redis, id)
return item_tasks.get(id, [])
async def stop_task(redis, task_id: str):
"""
Cancel a running task and remove it from the global task list.
"""
if redis:
# Look up the item_id before cleanup so we can remove the set entry too
item_id = await redis.hget(REDIS_TASKS_KEY, task_id)
# PUBSUB: All instances check if they have this task, and stop if so.
await redis_send_command(
redis,
{
'action': 'stop',
'task_id': task_id,
},
)
# Always clean Redis directly — hdel/srem are idempotent, safe even
# if the done_callback on the owning process also fires cleanup.
await redis_cleanup_task(redis, task_id, item_id or None)
return {'status': True, 'message': f'Task {task_id} stopped.'}
task = tasks.pop(task_id, None)
if not task:
return {'status': False, 'message': f'Task with ID {task_id} not found.'}
task.cancel() # Request task cancellation
try:
await task # Wait for the task to handle the cancellation
except asyncio.CancelledError:
# Task successfully canceled
return {'status': True, 'message': f'Task {task_id} successfully stopped.'}
if task.cancelled() or task.done():
return {'status': True, 'message': f'Task {task_id} successfully cancelled.'}
return {'status': True, 'message': f'Cancellation requested for {task_id}.'}
async def stop_item_tasks(redis: Redis, item_id: str):
"""
Stop all tasks associated with a specific item ID.
"""
task_ids = await list_task_ids_by_item_id(redis, item_id)
if not task_ids:
return {'status': True, 'message': f'No tasks found for item {item_id}.'}
for task_id in task_ids:
result = await stop_task(redis, task_id)
if not result['status']:
return result # Return the first failure
return {'status': True, 'message': f'All tasks for item {item_id} stopped.'}
async def has_active_tasks(redis, chat_id: str) -> bool:
"""Check if a chat has any active tasks."""
task_ids = await list_task_ids_by_item_id(redis, chat_id)
return len(task_ids) > 0
async def get_active_chat_ids(redis, chat_ids: List[str]) -> List[str]:
"""Filter a list of chat_ids to only those with active tasks."""
active = []
for chat_id in chat_ids:
if await has_active_tasks(redis, chat_id):
active.append(chat_id)
return active
|