Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| """ | |
| API响应格式标准化模块 | |
| 提供统一的API响应格式,包括成功响应、错误响应、分页等 | |
| """ | |
| from flask import jsonify, request | |
| from datetime import datetime | |
| import uuid | |
| import time | |
| from typing import Any, Dict, List, Optional, Union | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class APIResponse: | |
| """API响应格式标准化类""" | |
| def success(data: Any = None, message: str = None, meta: Dict = None, status_code: int = 200): | |
| """成功响应格式""" | |
| response_data = { | |
| 'success': True, | |
| 'data': data, | |
| 'meta': { | |
| 'timestamp': datetime.now().isoformat() + 'Z', | |
| 'request_id': str(uuid.uuid4()), | |
| 'processing_time_ms': getattr(request, '_start_time', None) and | |
| int((time.time() - request._start_time) * 1000) or None, | |
| **(meta or {}) | |
| } | |
| } | |
| if message: | |
| response_data['message'] = message | |
| response = jsonify(response_data) | |
| response.status_code = status_code | |
| return response | |
| def error(code: str, message: str, details: Any = None, status_code: int = 400): | |
| """错误响应格式""" | |
| response_data = { | |
| 'success': False, | |
| 'error': { | |
| 'code': code, | |
| 'message': message, | |
| 'details': details | |
| }, | |
| 'meta': { | |
| 'timestamp': datetime.now().isoformat() + 'Z', | |
| 'request_id': str(uuid.uuid4()), | |
| 'endpoint': request.endpoint, | |
| 'method': request.method | |
| } | |
| } | |
| response = jsonify(response_data) | |
| response.status_code = status_code | |
| return response | |
| def paginated(data: List, page: int, per_page: int, total: int, meta: Dict = None): | |
| """分页响应格式""" | |
| total_pages = (total + per_page - 1) // per_page | |
| has_next = page < total_pages | |
| has_prev = page > 1 | |
| pagination_meta = { | |
| 'pagination': { | |
| 'page': page, | |
| 'per_page': per_page, | |
| 'total': total, | |
| 'total_pages': total_pages, | |
| 'has_next': has_next, | |
| 'has_prev': has_prev, | |
| 'next_page': page + 1 if has_next else None, | |
| 'prev_page': page - 1 if has_prev else None | |
| }, | |
| 'timestamp': datetime.now().isoformat() + 'Z', | |
| 'request_id': str(uuid.uuid4()), | |
| **(meta or {}) | |
| } | |
| return APIResponse.success(data=data, meta=pagination_meta) | |
| def task_created(task_id: str, task_type: str, estimated_time: int = None): | |
| """任务创建响应格式""" | |
| data = { | |
| 'task_id': task_id, | |
| 'task_type': task_type, | |
| 'status': 'pending', | |
| 'created_at': datetime.now().isoformat() + 'Z', | |
| 'estimated_completion_time': estimated_time | |
| } | |
| meta = { | |
| 'status_endpoint': f'/api/v1/tasks/{task_id}', | |
| 'result_endpoint': f'/api/v1/tasks/{task_id}/result' | |
| } | |
| return APIResponse.success(data=data, meta=meta, status_code=201) | |
| def task_status(task_id: str, status: str, progress: int = None, | |
| estimated_remaining: int = None, result: Any = None): | |
| """任务状态响应格式""" | |
| data = { | |
| 'task_id': task_id, | |
| 'status': status, | |
| 'progress': progress, | |
| 'estimated_remaining_seconds': estimated_remaining, | |
| 'updated_at': datetime.now().isoformat() + 'Z' | |
| } | |
| if result is not None: | |
| data['result'] = result | |
| return APIResponse.success(data=data) | |
| class ErrorCodes: | |
| """标准错误代码定义""" | |
| # 认证相关错误 | |
| MISSING_API_KEY = 'MISSING_API_KEY' | |
| INVALID_API_KEY = 'INVALID_API_KEY' | |
| EXPIRED_API_KEY = 'EXPIRED_API_KEY' | |
| INSUFFICIENT_PERMISSIONS = 'INSUFFICIENT_PERMISSIONS' | |
| RATE_LIMIT_EXCEEDED = 'RATE_LIMIT_EXCEEDED' | |
| # 请求相关错误 | |
| INVALID_REQUEST_FORMAT = 'INVALID_REQUEST_FORMAT' | |
| MISSING_REQUIRED_FIELD = 'MISSING_REQUIRED_FIELD' | |
| INVALID_PARAMETER_VALUE = 'INVALID_PARAMETER_VALUE' | |
| INVALID_STOCK_CODE = 'INVALID_STOCK_CODE' | |
| # 业务逻辑错误 | |
| STOCK_NOT_FOUND = 'STOCK_NOT_FOUND' | |
| ANALYSIS_FAILED = 'ANALYSIS_FAILED' | |
| INSUFFICIENT_DATA = 'INSUFFICIENT_DATA' | |
| PORTFOLIO_TOO_LARGE = 'PORTFOLIO_TOO_LARGE' | |
| # 系统错误 | |
| INTERNAL_SERVER_ERROR = 'INTERNAL_SERVER_ERROR' | |
| SERVICE_UNAVAILABLE = 'SERVICE_UNAVAILABLE' | |
| TIMEOUT_ERROR = 'TIMEOUT_ERROR' | |
| DATABASE_ERROR = 'DATABASE_ERROR' | |
| # 任务相关错误 | |
| TASK_NOT_FOUND = 'TASK_NOT_FOUND' | |
| TASK_FAILED = 'TASK_FAILED' | |
| TASK_TIMEOUT = 'TASK_TIMEOUT' | |
| def handle_validation_error(error_details: Dict): | |
| """处理参数验证错误""" | |
| return APIResponse.error( | |
| code=ErrorCodes.INVALID_REQUEST_FORMAT, | |
| message='请求参数验证失败', | |
| details=error_details, | |
| status_code=400 | |
| ) | |
| def handle_stock_code_error(invalid_codes: List[str]): | |
| """处理股票代码错误""" | |
| return APIResponse.error( | |
| code=ErrorCodes.INVALID_STOCK_CODE, | |
| message='无效的股票代码', | |
| details={ | |
| 'invalid_codes': invalid_codes, | |
| 'supported_formats': ['000001.SZ', '600000.SH', '300001.SZ'] | |
| }, | |
| status_code=400 | |
| ) | |
| def handle_analysis_error(stock_code: str, error_message: str): | |
| """处理分析错误""" | |
| return APIResponse.error( | |
| code=ErrorCodes.ANALYSIS_FAILED, | |
| message=f'股票 {stock_code} 分析失败', | |
| details={'error_message': error_message}, | |
| status_code=500 | |
| ) | |
| def handle_rate_limit_error(limit: int, reset_time: int): | |
| """处理限流错误""" | |
| return APIResponse.error( | |
| code=ErrorCodes.RATE_LIMIT_EXCEEDED, | |
| message='请求频率超过限制', | |
| details={ | |
| 'limit': limit, | |
| 'reset_time': reset_time, | |
| 'retry_after': reset_time - int(time.time()) | |
| }, | |
| status_code=429 | |
| ) | |
| def handle_task_not_found_error(task_id: str): | |
| """处理任务不存在错误""" | |
| return APIResponse.error( | |
| code=ErrorCodes.TASK_NOT_FOUND, | |
| message='任务不存在', | |
| details={'task_id': task_id}, | |
| status_code=404 | |
| ) | |
| def handle_internal_error(error_message: str = None): | |
| """处理内部服务器错误""" | |
| return APIResponse.error( | |
| code=ErrorCodes.INTERNAL_SERVER_ERROR, | |
| message='内部服务器错误', | |
| details={'error_message': error_message} if error_message else None, | |
| status_code=500 | |
| ) | |
| class ResponseHeaders: | |
| """响应头管理""" | |
| def add_cors_headers(response): | |
| """添加CORS头""" | |
| response.headers['Access-Control-Allow-Origin'] = '*' | |
| response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS' | |
| response.headers['Access-Control-Allow-Headers'] = 'Content-Type, X-API-Key, X-HMAC-Signature, X-Timestamp' | |
| return response | |
| def add_cache_headers(response, max_age: int = 300): | |
| """添加缓存头""" | |
| response.headers['Cache-Control'] = f'public, max-age={max_age}' | |
| response.headers['ETag'] = str(hash(response.get_data())) | |
| return response | |
| def add_rate_limit_headers(response, limit: int, remaining: int, reset_time: int): | |
| """添加限流头""" | |
| response.headers['X-RateLimit-Limit'] = str(limit) | |
| response.headers['X-RateLimit-Remaining'] = str(remaining) | |
| response.headers['X-RateLimit-Reset'] = str(reset_time) | |
| return response | |
| def validate_stock_code(stock_code: str) -> bool: | |
| """验证股票代码格式""" | |
| if not stock_code or not isinstance(stock_code, str): | |
| return False | |
| # A股代码格式验证 | |
| import re | |
| patterns = [ | |
| r'^\d{6}\.(SZ|SH)$', # 标准格式:000001.SZ | |
| r'^\d{6}$' # 简化格式:000001 | |
| ] | |
| return any(re.match(pattern, stock_code.upper()) for pattern in patterns) | |
| def normalize_stock_code(stock_code: str) -> str: | |
| """标准化股票代码格式""" | |
| if not stock_code: | |
| return stock_code | |
| stock_code = stock_code.upper().strip() | |
| # 如果只有6位数字,自动添加后缀 | |
| if stock_code.isdigit() and len(stock_code) == 6: | |
| # 根据代码范围判断市场 | |
| code_num = int(stock_code) | |
| if code_num >= 600000: | |
| return f"{stock_code}.SH" | |
| else: | |
| return f"{stock_code}.SZ" | |
| return stock_code | |
| def validate_request_data(data: Dict, required_fields: List[str]) -> Optional[Dict]: | |
| """验证请求数据""" | |
| if not data: | |
| return {'message': '请求体不能为空'} | |
| missing_fields = [] | |
| for field in required_fields: | |
| if field not in data or data[field] is None: | |
| missing_fields.append(field) | |
| if missing_fields: | |
| return { | |
| 'message': '缺少必需字段', | |
| 'missing_fields': missing_fields | |
| } | |
| return None | |
| def timing_middleware(f): | |
| """请求计时中间件""" | |
| def decorated_function(*args, **kwargs): | |
| request._start_time = time.time() | |
| return f(*args, **kwargs) | |
| return decorated_function | |