stock-analysis / api_response.py
fromozu's picture
Auto-deploy from GitHub Actions - 6ea75c7
a105470 verified
# -*- coding: utf-8 -*-
"""
API响应格式标准化模块
提供统一的API响应格式,包括成功响应、错误响应、分页等
"""
from flask import jsonify, request
from datetime import datetime
import uuid
import time
from typing import Any, Dict, List, Optional, Union
import logging
logger = logging.getLogger(__name__)
class APIResponse:
"""API响应格式标准化类"""
@staticmethod
def success(data: Any = None, message: str = None, meta: Dict = None, status_code: int = 200):
"""成功响应格式"""
response_data = {
'success': True,
'data': data,
'meta': {
'timestamp': datetime.now().isoformat() + 'Z',
'request_id': str(uuid.uuid4()),
'processing_time_ms': getattr(request, '_start_time', None) and
int((time.time() - request._start_time) * 1000) or None,
**(meta or {})
}
}
if message:
response_data['message'] = message
response = jsonify(response_data)
response.status_code = status_code
return response
@staticmethod
def error(code: str, message: str, details: Any = None, status_code: int = 400):
"""错误响应格式"""
response_data = {
'success': False,
'error': {
'code': code,
'message': message,
'details': details
},
'meta': {
'timestamp': datetime.now().isoformat() + 'Z',
'request_id': str(uuid.uuid4()),
'endpoint': request.endpoint,
'method': request.method
}
}
response = jsonify(response_data)
response.status_code = status_code
return response
@staticmethod
def paginated(data: List, page: int, per_page: int, total: int, meta: Dict = None):
"""分页响应格式"""
total_pages = (total + per_page - 1) // per_page
has_next = page < total_pages
has_prev = page > 1
pagination_meta = {
'pagination': {
'page': page,
'per_page': per_page,
'total': total,
'total_pages': total_pages,
'has_next': has_next,
'has_prev': has_prev,
'next_page': page + 1 if has_next else None,
'prev_page': page - 1 if has_prev else None
},
'timestamp': datetime.now().isoformat() + 'Z',
'request_id': str(uuid.uuid4()),
**(meta or {})
}
return APIResponse.success(data=data, meta=pagination_meta)
@staticmethod
def task_created(task_id: str, task_type: str, estimated_time: int = None):
"""任务创建响应格式"""
data = {
'task_id': task_id,
'task_type': task_type,
'status': 'pending',
'created_at': datetime.now().isoformat() + 'Z',
'estimated_completion_time': estimated_time
}
meta = {
'status_endpoint': f'/api/v1/tasks/{task_id}',
'result_endpoint': f'/api/v1/tasks/{task_id}/result'
}
return APIResponse.success(data=data, meta=meta, status_code=201)
@staticmethod
def task_status(task_id: str, status: str, progress: int = None,
estimated_remaining: int = None, result: Any = None):
"""任务状态响应格式"""
data = {
'task_id': task_id,
'status': status,
'progress': progress,
'estimated_remaining_seconds': estimated_remaining,
'updated_at': datetime.now().isoformat() + 'Z'
}
if result is not None:
data['result'] = result
return APIResponse.success(data=data)
class ErrorCodes:
"""标准错误代码定义"""
# 认证相关错误
MISSING_API_KEY = 'MISSING_API_KEY'
INVALID_API_KEY = 'INVALID_API_KEY'
EXPIRED_API_KEY = 'EXPIRED_API_KEY'
INSUFFICIENT_PERMISSIONS = 'INSUFFICIENT_PERMISSIONS'
RATE_LIMIT_EXCEEDED = 'RATE_LIMIT_EXCEEDED'
# 请求相关错误
INVALID_REQUEST_FORMAT = 'INVALID_REQUEST_FORMAT'
MISSING_REQUIRED_FIELD = 'MISSING_REQUIRED_FIELD'
INVALID_PARAMETER_VALUE = 'INVALID_PARAMETER_VALUE'
INVALID_STOCK_CODE = 'INVALID_STOCK_CODE'
# 业务逻辑错误
STOCK_NOT_FOUND = 'STOCK_NOT_FOUND'
ANALYSIS_FAILED = 'ANALYSIS_FAILED'
INSUFFICIENT_DATA = 'INSUFFICIENT_DATA'
PORTFOLIO_TOO_LARGE = 'PORTFOLIO_TOO_LARGE'
# 系统错误
INTERNAL_SERVER_ERROR = 'INTERNAL_SERVER_ERROR'
SERVICE_UNAVAILABLE = 'SERVICE_UNAVAILABLE'
TIMEOUT_ERROR = 'TIMEOUT_ERROR'
DATABASE_ERROR = 'DATABASE_ERROR'
# 任务相关错误
TASK_NOT_FOUND = 'TASK_NOT_FOUND'
TASK_FAILED = 'TASK_FAILED'
TASK_TIMEOUT = 'TASK_TIMEOUT'
def handle_validation_error(error_details: Dict):
"""处理参数验证错误"""
return APIResponse.error(
code=ErrorCodes.INVALID_REQUEST_FORMAT,
message='请求参数验证失败',
details=error_details,
status_code=400
)
def handle_stock_code_error(invalid_codes: List[str]):
"""处理股票代码错误"""
return APIResponse.error(
code=ErrorCodes.INVALID_STOCK_CODE,
message='无效的股票代码',
details={
'invalid_codes': invalid_codes,
'supported_formats': ['000001.SZ', '600000.SH', '300001.SZ']
},
status_code=400
)
def handle_analysis_error(stock_code: str, error_message: str):
"""处理分析错误"""
return APIResponse.error(
code=ErrorCodes.ANALYSIS_FAILED,
message=f'股票 {stock_code} 分析失败',
details={'error_message': error_message},
status_code=500
)
def handle_rate_limit_error(limit: int, reset_time: int):
"""处理限流错误"""
return APIResponse.error(
code=ErrorCodes.RATE_LIMIT_EXCEEDED,
message='请求频率超过限制',
details={
'limit': limit,
'reset_time': reset_time,
'retry_after': reset_time - int(time.time())
},
status_code=429
)
def handle_task_not_found_error(task_id: str):
"""处理任务不存在错误"""
return APIResponse.error(
code=ErrorCodes.TASK_NOT_FOUND,
message='任务不存在',
details={'task_id': task_id},
status_code=404
)
def handle_internal_error(error_message: str = None):
"""处理内部服务器错误"""
return APIResponse.error(
code=ErrorCodes.INTERNAL_SERVER_ERROR,
message='内部服务器错误',
details={'error_message': error_message} if error_message else None,
status_code=500
)
class ResponseHeaders:
"""响应头管理"""
@staticmethod
def add_cors_headers(response):
"""添加CORS头"""
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, X-API-Key, X-HMAC-Signature, X-Timestamp'
return response
@staticmethod
def add_cache_headers(response, max_age: int = 300):
"""添加缓存头"""
response.headers['Cache-Control'] = f'public, max-age={max_age}'
response.headers['ETag'] = str(hash(response.get_data()))
return response
@staticmethod
def add_rate_limit_headers(response, limit: int, remaining: int, reset_time: int):
"""添加限流头"""
response.headers['X-RateLimit-Limit'] = str(limit)
response.headers['X-RateLimit-Remaining'] = str(remaining)
response.headers['X-RateLimit-Reset'] = str(reset_time)
return response
def validate_stock_code(stock_code: str) -> bool:
"""验证股票代码格式"""
if not stock_code or not isinstance(stock_code, str):
return False
# A股代码格式验证
import re
patterns = [
r'^\d{6}\.(SZ|SH)$', # 标准格式:000001.SZ
r'^\d{6}$' # 简化格式:000001
]
return any(re.match(pattern, stock_code.upper()) for pattern in patterns)
def normalize_stock_code(stock_code: str) -> str:
"""标准化股票代码格式"""
if not stock_code:
return stock_code
stock_code = stock_code.upper().strip()
# 如果只有6位数字,自动添加后缀
if stock_code.isdigit() and len(stock_code) == 6:
# 根据代码范围判断市场
code_num = int(stock_code)
if code_num >= 600000:
return f"{stock_code}.SH"
else:
return f"{stock_code}.SZ"
return stock_code
def validate_request_data(data: Dict, required_fields: List[str]) -> Optional[Dict]:
"""验证请求数据"""
if not data:
return {'message': '请求体不能为空'}
missing_fields = []
for field in required_fields:
if field not in data or data[field] is None:
missing_fields.append(field)
if missing_fields:
return {
'message': '缺少必需字段',
'missing_fields': missing_fields
}
return None
def timing_middleware(f):
"""请求计时中间件"""
def decorated_function(*args, **kwargs):
request._start_time = time.time()
return f(*args, **kwargs)
return decorated_function