File size: 8,412 Bytes
48e5de1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 | """
Redis-based job store for distributed job tracking across multiple Gunicorn workers.
"""
import json
import uuid
import redis
from flask import current_app
from datetime import datetime, timedelta
import logging
import re
class RedisJobStore:
"""
Redis-based job store for tracking background tasks across multiple Gunicorn workers.
"""
# Valid job statuses
VALID_STATUSES = {'pending', 'processing', 'completed', 'failed', 'cancelled'}
def __init__(self, redis_url=None, default_ttl_hours=24):
"""
Initialize Redis job store.
Args:
redis_url (str): Redis connection URL. If None, uses current_app.config['REDIS_URL']
default_ttl_hours (int): Default time-to-live for jobs in hours
"""
self.default_ttl_hours = default_ttl_hours
if redis_url:
self.redis_client = redis.from_url(redis_url)
elif current_app and hasattr(current_app, 'config') and 'REDIS_URL' in current_app.config:
self.redis_client = redis.from_url(current_app.config['REDIS_URL'])
else:
# Default to localhost Redis
self.redis_client = redis.from_url('redis://localhost:6379/0')
def _validate_job_id(self, job_id):
"""
Validate job ID format to prevent injection attacks.
Args:
job_id (str): Job ID to validate
Returns:
bool: True if valid, False otherwise
"""
if not job_id or not isinstance(job_id, str):
return False
# Allow UUID format or alphanumeric with hyphens/underscores
return bool(re.match(r'^[a-zA-Z0-9_-]{1,64}$', job_id))
def _validate_status(self, status):
"""
Validate job status value.
Args:
status (str): Status to validate
Returns:
bool: True if valid, False otherwise
"""
return status in self.VALID_STATUSES
def create_job(self, job_id=None, initial_status='pending', initial_data=None):
"""
Create a new job in Redis.
Args:
job_id (str): Job ID. If None, generates a new UUID.
initial_status (str): Initial job status.
initial_data (dict): Initial job data.
Returns:
str: Job ID
"""
if initial_status and not self._validate_status(initial_status):
raise ValueError(f"Invalid status: {initial_status}. Valid statuses are: {self.VALID_STATUSES}")
if job_id and not self._validate_job_id(job_id):
raise ValueError(f"Invalid job ID format: {job_id}")
if job_id is None:
job_id = str(uuid.uuid4())
job_data = {
'status': initial_status,
'result': None,
'error': None,
'created_at': datetime.utcnow().isoformat(),
'updated_at': datetime.utcnow().isoformat()
}
if initial_data:
job_data.update(initial_data)
try:
# Store job data as JSON in Redis with expiration (24 hours)
self.redis_client.setex(
f"job:{job_id}",
timedelta(hours=self.default_ttl_hours),
json.dumps(job_data)
)
except redis.ConnectionError:
logging.error(f"Failed to connect to Redis when creating job {job_id}")
raise
except Exception as e:
logging.error(f"Unexpected error when creating job {job_id}: {str(e)}")
raise
return job_id
def get_job(self, job_id):
"""
Get job data from Redis.
Args:
job_id (str): Job ID
Returns:
dict: Job data or None if not found
"""
if not self._validate_job_id(job_id):
raise ValueError(f"Invalid job ID format: {job_id}")
try:
job_data_json = self.redis_client.get(f"job:{job_id}")
if job_data_json:
return json.loads(job_data_json)
return None
except redis.ConnectionError:
logging.error(f"Failed to connect to Redis when getting job {job_id}")
return None
except json.JSONDecodeError:
logging.error(f"Failed to decode JSON for job {job_id}")
return None
except Exception as e:
logging.error(f"Unexpected error when getting job {job_id}: {str(e)}")
return None
def update_job(self, job_id, status=None, result=None, error=None):
"""
Update job status and data in Redis using atomic operations to prevent race conditions.
Args:
job_id (str): Job ID
status (str): New status
result (any): Result data
error (str): Error message
Returns:
bool: True if job was updated, False if not found
"""
if not self._validate_job_id(job_id):
raise ValueError(f"Invalid job ID format: {job_id}")
if status is not None and not self._validate_status(status):
raise ValueError(f"Invalid status: {status}. Valid statuses are: {self.VALID_STATUSES}")
# Use Lua script for atomic read-modify-write operation
lua_script = """
local job_key = KEYS[1]
local job_data = redis.call('GET', job_key)
if not job_data then
return 0
end
local updated_data = cjson.decode(job_data)
if ARGV[1] ~= 'nil' then
updated_data.status = ARGV[1]
end
if ARGV[2] ~= 'nil' then
updated_data.result = cjson.decode(ARGV[2])
end
if ARGV[3] ~= 'nil' then
updated_data.error = ARGV[3]
end
updated_data.updated_at = ARGV[4]
local ttl = redis.call('TTL', job_key)
redis.call('SET', job_key, cjson.encode(updated_data), 'EX', ttl)
return 1
"""
try:
# Prepare arguments for the Lua script
status_arg = status if status is not None else 'nil'
result_arg = json.dumps(result) if result is not None else 'nil'
error_arg = error if error is not None else 'nil'
updated_at_arg = datetime.utcnow().isoformat()
script = self.redis_client.register_script(lua_script)
result = script(keys=[f"job:{job_id}"],
args=[status_arg, result_arg, error_arg, updated_at_arg])
return result == 1
except redis.ConnectionError:
logging.error(f"Failed to connect to Redis when updating job {job_id}")
return False
except Exception as e:
logging.error(f"Unexpected error when updating job {job_id}: {str(e)}")
return False
def delete_job(self, job_id):
"""
Delete a job from Redis.
Args:
job_id (str): Job ID
Returns:
bool: True if job was deleted, False if not found
"""
if not self._validate_job_id(job_id):
raise ValueError(f"Invalid job ID format: {job_id}")
try:
result = self.redis_client.delete(f"job:{job_id}")
return result > 0
except redis.ConnectionError:
logging.error(f"Failed to connect to Redis when deleting job {job_id}")
return False
except Exception as e:
logging.error(f"Unexpected error when deleting job {job_id}: {str(e)}")
return False
def cleanup_expired_jobs(self):
"""
Clean up jobs that have expired based on their creation time.
This is a placeholder method - in a real implementation, you might want to
scan for expired jobs and remove them, but Redis automatically handles TTL.
"""
# Redis handles TTL automatically, so this is mostly for documentation
# In a production system, you might want to implement custom cleanup logic
pass
def get_redis_job_store():
"""
Get the Redis job store instance from the current app context.
Returns:
RedisJobStore: Redis job store instance
"""
if not hasattr(current_app, 'redis_job_store'):
redis_url = current_app.config.get('REDIS_URL', 'redis://localhost:6379/0')
current_app.redis_job_store = RedisJobStore(redis_url)
return current_app.redis_job_store |