# -*- coding: utf-8 -*- """ API缓存集成模块 确保API接口与现有MySQL缓存系统良好集成,优化数据获取性能 """ import logging import time import hashlib import json from typing import Dict, Any, Optional, List from functools import wraps from flask import request, g # 导入现有的缓存和数据库模块 try: from database import get_session, USE_DATABASE, StockData, AnalysisCache DATABASE_AVAILABLE = True except ImportError: DATABASE_AVAILABLE = False USE_DATABASE = False logger = logging.getLogger(__name__) class APICacheManager: """API缓存管理器""" def __init__(self): self.cache_ttl = { 'stock_analysis': 900, # 个股分析:15分钟 'portfolio_analysis': 300, # 组合分析:5分钟 'batch_score': 600, # 批量评分:10分钟 'market_data': 300, # 市场数据:5分钟 'fundamental_data': 86400, # 基本面数据:1天 'technical_indicators': 900 # 技术指标:15分钟 } def generate_cache_key(self, cache_type: str, params: Dict) -> str: """生成缓存键""" # 创建参数的哈希值 params_str = json.dumps(params, sort_keys=True, ensure_ascii=False) params_hash = hashlib.md5(params_str.encode()).hexdigest()[:16] # 添加用户等级信息(不同等级可能有不同的分析深度) user_tier = getattr(g, 'user_tier', 'free') return f"api_cache:{cache_type}:{user_tier}:{params_hash}" def get_cache(self, cache_key: str) -> Optional[Dict]: """从缓存获取数据""" if not DATABASE_AVAILABLE or not USE_DATABASE: return None try: session = get_session() cache_record = session.query(AnalysisCache).filter( AnalysisCache.cache_key == cache_key, AnalysisCache.expires_at > time.time() ).first() if cache_record: session.close() return { 'data': json.loads(cache_record.cache_data), 'created_at': cache_record.created_at, 'cache_hit': True } session.close() return None except Exception as e: logger.error(f"获取缓存数据出错: {e}") return None def set_cache(self, cache_key: str, data: Dict, ttl: int = None) -> bool: """设置缓存数据""" if not DATABASE_AVAILABLE or not USE_DATABASE: return False try: session = get_session() # 删除旧的缓存记录 session.query(AnalysisCache).filter( AnalysisCache.cache_key == cache_key ).delete() # 创建新的缓存记录 cache_record = AnalysisCache( cache_key=cache_key, cache_data=json.dumps(data, ensure_ascii=False), created_at=time.time(), expires_at=time.time() + (ttl or 900) ) session.add(cache_record) session.commit() session.close() return True except Exception as e: logger.error(f"设置缓存数据出错: {e}") return False def invalidate_cache(self, pattern: str = None) -> int: """清除缓存""" if not DATABASE_AVAILABLE or not USE_DATABASE: return 0 try: session = get_session() if pattern: # 清除匹配模式的缓存 count = session.query(AnalysisCache).filter( AnalysisCache.cache_key.like(f"%{pattern}%") ).delete(synchronize_session=False) else: # 清除过期缓存 count = session.query(AnalysisCache).filter( AnalysisCache.expires_at < time.time() ).delete(synchronize_session=False) session.commit() session.close() return count except Exception as e: logger.error(f"清除缓存出错: {e}") return 0 # 全局缓存管理器 api_cache_manager = APICacheManager() def api_cache(cache_type: str, ttl: int = None, key_params: List[str] = None): """API缓存装饰器""" def decorator(f): @wraps(f) def decorated_function(*args, **kwargs): # 生成缓存键 if key_params: # 使用指定的参数生成缓存键 cache_params = {} request_data = request.get_json() or {} for param in key_params: if param in request_data: cache_params[param] = request_data[param] else: # 使用所有请求参数 cache_params = request.get_json() or {} cache_key = api_cache_manager.generate_cache_key(cache_type, cache_params) # 尝试从缓存获取数据 cached_result = api_cache_manager.get_cache(cache_key) if cached_result: logger.info(f"API缓存命中: {cache_key}") return cached_result['data'] # 执行原函数 start_time = time.time() result = f(*args, **kwargs) processing_time = time.time() - start_time # 将结果存入缓存 cache_ttl = ttl or api_cache_manager.cache_ttl.get(cache_type, 900) # 只缓存成功的结果 if hasattr(result, 'status_code') and result.status_code == 200: try: result_data = result.get_json() if result_data.get('success'): # 添加缓存元数据 if 'meta' not in result_data: result_data['meta'] = {} result_data['meta']['cache_hit'] = False result_data['meta']['processing_time_ms'] = int(processing_time * 1000) api_cache_manager.set_cache(cache_key, result_data, cache_ttl) logger.info(f"API结果已缓存: {cache_key}, TTL: {cache_ttl}秒") except Exception as e: logger.error(f"缓存API结果出错: {e}") return result return decorated_function return decorator def smart_cache_invalidation(stock_codes: List[str] = None, cache_types: List[str] = None): """智能缓存失效""" try: if stock_codes: # 清除特定股票相关的缓存 for stock_code in stock_codes: pattern = f"stock_code:{stock_code}" count = api_cache_manager.invalidate_cache(pattern) logger.info(f"清除股票 {stock_code} 相关缓存: {count} 条") if cache_types: # 清除特定类型的缓存 for cache_type in cache_types: pattern = f"api_cache:{cache_type}" count = api_cache_manager.invalidate_cache(pattern) logger.info(f"清除 {cache_type} 类型缓存: {count} 条") # 清除过期缓存 expired_count = api_cache_manager.invalidate_cache() if expired_count > 0: logger.info(f"清除过期缓存: {expired_count} 条") except Exception as e: logger.error(f"智能缓存失效出错: {e}") def get_cache_statistics() -> Dict: """获取缓存统计信息""" if not DATABASE_AVAILABLE or not USE_DATABASE: return {'error': '数据库不可用'} try: session = get_session() # 总缓存数量 total_count = session.query(AnalysisCache).count() # 有效缓存数量 valid_count = session.query(AnalysisCache).filter( AnalysisCache.expires_at > time.time() ).count() # 过期缓存数量 expired_count = total_count - valid_count # 按类型统计 type_stats = {} cache_records = session.query(AnalysisCache).all() for record in cache_records: cache_type = record.cache_key.split(':')[1] if ':' in record.cache_key else 'unknown' if cache_type not in type_stats: type_stats[cache_type] = {'total': 0, 'valid': 0, 'expired': 0} type_stats[cache_type]['total'] += 1 if record.expires_at > time.time(): type_stats[cache_type]['valid'] += 1 else: type_stats[cache_type]['expired'] += 1 session.close() return { 'total_count': total_count, 'valid_count': valid_count, 'expired_count': expired_count, 'hit_rate': 0.0, # 需要额外统计 'type_statistics': type_stats } except Exception as e: logger.error(f"获取缓存统计出错: {e}") return {'error': str(e)} def preload_cache_for_popular_stocks(): """为热门股票预加载缓存""" popular_stocks = [ '000001.SZ', '000002.SZ', '600000.SH', '600036.SH', '000858.SZ', '600519.SH', '000725.SZ', '002415.SZ', '600276.SH', '000568.SZ' ] logger.info("开始为热门股票预加载缓存") try: from stock_analyzer import StockAnalyzer analyzer = StockAnalyzer() for stock_code in popular_stocks: try: # 预加载个股分析数据 result = analyzer.quick_analyze_stock(stock_code, 'A') # 生成缓存键并存储 cache_params = {'stock_code': stock_code, 'market_type': 'A'} cache_key = api_cache_manager.generate_cache_key('stock_analysis', cache_params) api_cache_manager.set_cache(cache_key, { 'success': True, 'data': result, 'meta': {'preloaded': True} }, 1800) # 30分钟TTL logger.info(f"预加载股票 {stock_code} 缓存完成") except Exception as e: logger.error(f"预加载股票 {stock_code} 缓存失败: {e}") logger.info("热门股票缓存预加载完成") except Exception as e: logger.error(f"预加载缓存出错: {e}") def schedule_cache_cleanup(): """定期清理缓存""" import threading import time def cleanup_worker(): while True: try: # 每小时清理一次过期缓存 expired_count = api_cache_manager.invalidate_cache() if expired_count > 0: logger.info(f"定期清理过期缓存: {expired_count} 条") time.sleep(3600) # 1小时 except Exception as e: logger.error(f"定期缓存清理出错: {e}") time.sleep(300) # 出错时等待5分钟 cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True) cleanup_thread.start() logger.info("缓存清理调度器已启动") # 启动缓存清理调度器 if DATABASE_AVAILABLE and USE_DATABASE: schedule_cache_cleanup()