Lin / backend /utils /redis_job_store.py
Zelyanoth's picture
add redis for job queuing
48e5de1
"""
Redis-based job store for distributed job tracking across multiple Gunicorn workers.
"""
import json
import uuid
import redis
from flask import current_app
from datetime import datetime, timedelta
import logging
import re
class RedisJobStore:
"""
Redis-based job store for tracking background tasks across multiple Gunicorn workers.
"""
# Valid job statuses
VALID_STATUSES = {'pending', 'processing', 'completed', 'failed', 'cancelled'}
def __init__(self, redis_url=None, default_ttl_hours=24):
"""
Initialize Redis job store.
Args:
redis_url (str): Redis connection URL. If None, uses current_app.config['REDIS_URL']
default_ttl_hours (int): Default time-to-live for jobs in hours
"""
self.default_ttl_hours = default_ttl_hours
if redis_url:
self.redis_client = redis.from_url(redis_url)
elif current_app and hasattr(current_app, 'config') and 'REDIS_URL' in current_app.config:
self.redis_client = redis.from_url(current_app.config['REDIS_URL'])
else:
# Default to localhost Redis
self.redis_client = redis.from_url('redis://localhost:6379/0')
def _validate_job_id(self, job_id):
"""
Validate job ID format to prevent injection attacks.
Args:
job_id (str): Job ID to validate
Returns:
bool: True if valid, False otherwise
"""
if not job_id or not isinstance(job_id, str):
return False
# Allow UUID format or alphanumeric with hyphens/underscores
return bool(re.match(r'^[a-zA-Z0-9_-]{1,64}$', job_id))
def _validate_status(self, status):
"""
Validate job status value.
Args:
status (str): Status to validate
Returns:
bool: True if valid, False otherwise
"""
return status in self.VALID_STATUSES
def create_job(self, job_id=None, initial_status='pending', initial_data=None):
"""
Create a new job in Redis.
Args:
job_id (str): Job ID. If None, generates a new UUID.
initial_status (str): Initial job status.
initial_data (dict): Initial job data.
Returns:
str: Job ID
"""
if initial_status and not self._validate_status(initial_status):
raise ValueError(f"Invalid status: {initial_status}. Valid statuses are: {self.VALID_STATUSES}")
if job_id and not self._validate_job_id(job_id):
raise ValueError(f"Invalid job ID format: {job_id}")
if job_id is None:
job_id = str(uuid.uuid4())
job_data = {
'status': initial_status,
'result': None,
'error': None,
'created_at': datetime.utcnow().isoformat(),
'updated_at': datetime.utcnow().isoformat()
}
if initial_data:
job_data.update(initial_data)
try:
# Store job data as JSON in Redis with expiration (24 hours)
self.redis_client.setex(
f"job:{job_id}",
timedelta(hours=self.default_ttl_hours),
json.dumps(job_data)
)
except redis.ConnectionError:
logging.error(f"Failed to connect to Redis when creating job {job_id}")
raise
except Exception as e:
logging.error(f"Unexpected error when creating job {job_id}: {str(e)}")
raise
return job_id
def get_job(self, job_id):
"""
Get job data from Redis.
Args:
job_id (str): Job ID
Returns:
dict: Job data or None if not found
"""
if not self._validate_job_id(job_id):
raise ValueError(f"Invalid job ID format: {job_id}")
try:
job_data_json = self.redis_client.get(f"job:{job_id}")
if job_data_json:
return json.loads(job_data_json)
return None
except redis.ConnectionError:
logging.error(f"Failed to connect to Redis when getting job {job_id}")
return None
except json.JSONDecodeError:
logging.error(f"Failed to decode JSON for job {job_id}")
return None
except Exception as e:
logging.error(f"Unexpected error when getting job {job_id}: {str(e)}")
return None
def update_job(self, job_id, status=None, result=None, error=None):
"""
Update job status and data in Redis using atomic operations to prevent race conditions.
Args:
job_id (str): Job ID
status (str): New status
result (any): Result data
error (str): Error message
Returns:
bool: True if job was updated, False if not found
"""
if not self._validate_job_id(job_id):
raise ValueError(f"Invalid job ID format: {job_id}")
if status is not None and not self._validate_status(status):
raise ValueError(f"Invalid status: {status}. Valid statuses are: {self.VALID_STATUSES}")
# Use Lua script for atomic read-modify-write operation
lua_script = """
local job_key = KEYS[1]
local job_data = redis.call('GET', job_key)
if not job_data then
return 0
end
local updated_data = cjson.decode(job_data)
if ARGV[1] ~= 'nil' then
updated_data.status = ARGV[1]
end
if ARGV[2] ~= 'nil' then
updated_data.result = cjson.decode(ARGV[2])
end
if ARGV[3] ~= 'nil' then
updated_data.error = ARGV[3]
end
updated_data.updated_at = ARGV[4]
local ttl = redis.call('TTL', job_key)
redis.call('SET', job_key, cjson.encode(updated_data), 'EX', ttl)
return 1
"""
try:
# Prepare arguments for the Lua script
status_arg = status if status is not None else 'nil'
result_arg = json.dumps(result) if result is not None else 'nil'
error_arg = error if error is not None else 'nil'
updated_at_arg = datetime.utcnow().isoformat()
script = self.redis_client.register_script(lua_script)
result = script(keys=[f"job:{job_id}"],
args=[status_arg, result_arg, error_arg, updated_at_arg])
return result == 1
except redis.ConnectionError:
logging.error(f"Failed to connect to Redis when updating job {job_id}")
return False
except Exception as e:
logging.error(f"Unexpected error when updating job {job_id}: {str(e)}")
return False
def delete_job(self, job_id):
"""
Delete a job from Redis.
Args:
job_id (str): Job ID
Returns:
bool: True if job was deleted, False if not found
"""
if not self._validate_job_id(job_id):
raise ValueError(f"Invalid job ID format: {job_id}")
try:
result = self.redis_client.delete(f"job:{job_id}")
return result > 0
except redis.ConnectionError:
logging.error(f"Failed to connect to Redis when deleting job {job_id}")
return False
except Exception as e:
logging.error(f"Unexpected error when deleting job {job_id}: {str(e)}")
return False
def cleanup_expired_jobs(self):
"""
Clean up jobs that have expired based on their creation time.
This is a placeholder method - in a real implementation, you might want to
scan for expired jobs and remove them, but Redis automatically handles TTL.
"""
# Redis handles TTL automatically, so this is mostly for documentation
# In a production system, you might want to implement custom cleanup logic
pass
def get_redis_job_store():
"""
Get the Redis job store instance from the current app context.
Returns:
RedisJobStore: Redis job store instance
"""
if not hasattr(current_app, 'redis_job_store'):
redis_url = current_app.config.get('REDIS_URL', 'redis://localhost:6379/0')
current_app.redis_job_store = RedisJobStore(redis_url)
return current_app.redis_job_store