Spaces:
Build error
Build error
File size: 11,816 Bytes
a105470 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 | # -*- coding: utf-8 -*-
"""
API缓存集成模块
确保API接口与现有MySQL缓存系统良好集成,优化数据获取性能
"""
import logging
import time
import hashlib
import json
from typing import Dict, Any, Optional, List
from functools import wraps
from flask import request, g
# 导入现有的缓存和数据库模块
try:
from database import get_session, USE_DATABASE, StockData, AnalysisCache
DATABASE_AVAILABLE = True
except ImportError:
DATABASE_AVAILABLE = False
USE_DATABASE = False
logger = logging.getLogger(__name__)
class APICacheManager:
"""API缓存管理器"""
def __init__(self):
self.cache_ttl = {
'stock_analysis': 900, # 个股分析:15分钟
'portfolio_analysis': 300, # 组合分析:5分钟
'batch_score': 600, # 批量评分:10分钟
'market_data': 300, # 市场数据:5分钟
'fundamental_data': 86400, # 基本面数据:1天
'technical_indicators': 900 # 技术指标:15分钟
}
def generate_cache_key(self, cache_type: str, params: Dict) -> str:
"""生成缓存键"""
# 创建参数的哈希值
params_str = json.dumps(params, sort_keys=True, ensure_ascii=False)
params_hash = hashlib.md5(params_str.encode()).hexdigest()[:16]
# 添加用户等级信息(不同等级可能有不同的分析深度)
user_tier = getattr(g, 'user_tier', 'free')
return f"api_cache:{cache_type}:{user_tier}:{params_hash}"
def get_cache(self, cache_key: str) -> Optional[Dict]:
"""从缓存获取数据"""
if not DATABASE_AVAILABLE or not USE_DATABASE:
return None
try:
session = get_session()
cache_record = session.query(AnalysisCache).filter(
AnalysisCache.cache_key == cache_key,
AnalysisCache.expires_at > time.time()
).first()
if cache_record:
session.close()
return {
'data': json.loads(cache_record.cache_data),
'created_at': cache_record.created_at,
'cache_hit': True
}
session.close()
return None
except Exception as e:
logger.error(f"获取缓存数据出错: {e}")
return None
def set_cache(self, cache_key: str, data: Dict, ttl: int = None) -> bool:
"""设置缓存数据"""
if not DATABASE_AVAILABLE or not USE_DATABASE:
return False
try:
session = get_session()
# 删除旧的缓存记录
session.query(AnalysisCache).filter(
AnalysisCache.cache_key == cache_key
).delete()
# 创建新的缓存记录
cache_record = AnalysisCache(
cache_key=cache_key,
cache_data=json.dumps(data, ensure_ascii=False),
created_at=time.time(),
expires_at=time.time() + (ttl or 900)
)
session.add(cache_record)
session.commit()
session.close()
return True
except Exception as e:
logger.error(f"设置缓存数据出错: {e}")
return False
def invalidate_cache(self, pattern: str = None) -> int:
"""清除缓存"""
if not DATABASE_AVAILABLE or not USE_DATABASE:
return 0
try:
session = get_session()
if pattern:
# 清除匹配模式的缓存
count = session.query(AnalysisCache).filter(
AnalysisCache.cache_key.like(f"%{pattern}%")
).delete(synchronize_session=False)
else:
# 清除过期缓存
count = session.query(AnalysisCache).filter(
AnalysisCache.expires_at < time.time()
).delete(synchronize_session=False)
session.commit()
session.close()
return count
except Exception as e:
logger.error(f"清除缓存出错: {e}")
return 0
# 全局缓存管理器
api_cache_manager = APICacheManager()
def api_cache(cache_type: str, ttl: int = None, key_params: List[str] = None):
"""API缓存装饰器"""
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
# 生成缓存键
if key_params:
# 使用指定的参数生成缓存键
cache_params = {}
request_data = request.get_json() or {}
for param in key_params:
if param in request_data:
cache_params[param] = request_data[param]
else:
# 使用所有请求参数
cache_params = request.get_json() or {}
cache_key = api_cache_manager.generate_cache_key(cache_type, cache_params)
# 尝试从缓存获取数据
cached_result = api_cache_manager.get_cache(cache_key)
if cached_result:
logger.info(f"API缓存命中: {cache_key}")
return cached_result['data']
# 执行原函数
start_time = time.time()
result = f(*args, **kwargs)
processing_time = time.time() - start_time
# 将结果存入缓存
cache_ttl = ttl or api_cache_manager.cache_ttl.get(cache_type, 900)
# 只缓存成功的结果
if hasattr(result, 'status_code') and result.status_code == 200:
try:
result_data = result.get_json()
if result_data.get('success'):
# 添加缓存元数据
if 'meta' not in result_data:
result_data['meta'] = {}
result_data['meta']['cache_hit'] = False
result_data['meta']['processing_time_ms'] = int(processing_time * 1000)
api_cache_manager.set_cache(cache_key, result_data, cache_ttl)
logger.info(f"API结果已缓存: {cache_key}, TTL: {cache_ttl}秒")
except Exception as e:
logger.error(f"缓存API结果出错: {e}")
return result
return decorated_function
return decorator
def smart_cache_invalidation(stock_codes: List[str] = None, cache_types: List[str] = None):
"""智能缓存失效"""
try:
if stock_codes:
# 清除特定股票相关的缓存
for stock_code in stock_codes:
pattern = f"stock_code:{stock_code}"
count = api_cache_manager.invalidate_cache(pattern)
logger.info(f"清除股票 {stock_code} 相关缓存: {count} 条")
if cache_types:
# 清除特定类型的缓存
for cache_type in cache_types:
pattern = f"api_cache:{cache_type}"
count = api_cache_manager.invalidate_cache(pattern)
logger.info(f"清除 {cache_type} 类型缓存: {count} 条")
# 清除过期缓存
expired_count = api_cache_manager.invalidate_cache()
if expired_count > 0:
logger.info(f"清除过期缓存: {expired_count} 条")
except Exception as e:
logger.error(f"智能缓存失效出错: {e}")
def get_cache_statistics() -> Dict:
"""获取缓存统计信息"""
if not DATABASE_AVAILABLE or not USE_DATABASE:
return {'error': '数据库不可用'}
try:
session = get_session()
# 总缓存数量
total_count = session.query(AnalysisCache).count()
# 有效缓存数量
valid_count = session.query(AnalysisCache).filter(
AnalysisCache.expires_at > time.time()
).count()
# 过期缓存数量
expired_count = total_count - valid_count
# 按类型统计
type_stats = {}
cache_records = session.query(AnalysisCache).all()
for record in cache_records:
cache_type = record.cache_key.split(':')[1] if ':' in record.cache_key else 'unknown'
if cache_type not in type_stats:
type_stats[cache_type] = {'total': 0, 'valid': 0, 'expired': 0}
type_stats[cache_type]['total'] += 1
if record.expires_at > time.time():
type_stats[cache_type]['valid'] += 1
else:
type_stats[cache_type]['expired'] += 1
session.close()
return {
'total_count': total_count,
'valid_count': valid_count,
'expired_count': expired_count,
'hit_rate': 0.0, # 需要额外统计
'type_statistics': type_stats
}
except Exception as e:
logger.error(f"获取缓存统计出错: {e}")
return {'error': str(e)}
def preload_cache_for_popular_stocks():
"""为热门股票预加载缓存"""
popular_stocks = [
'000001.SZ', '000002.SZ', '600000.SH', '600036.SH', '000858.SZ',
'600519.SH', '000725.SZ', '002415.SZ', '600276.SH', '000568.SZ'
]
logger.info("开始为热门股票预加载缓存")
try:
from stock_analyzer import StockAnalyzer
analyzer = StockAnalyzer()
for stock_code in popular_stocks:
try:
# 预加载个股分析数据
result = analyzer.quick_analyze_stock(stock_code, 'A')
# 生成缓存键并存储
cache_params = {'stock_code': stock_code, 'market_type': 'A'}
cache_key = api_cache_manager.generate_cache_key('stock_analysis', cache_params)
api_cache_manager.set_cache(cache_key, {
'success': True,
'data': result,
'meta': {'preloaded': True}
}, 1800) # 30分钟TTL
logger.info(f"预加载股票 {stock_code} 缓存完成")
except Exception as e:
logger.error(f"预加载股票 {stock_code} 缓存失败: {e}")
logger.info("热门股票缓存预加载完成")
except Exception as e:
logger.error(f"预加载缓存出错: {e}")
def schedule_cache_cleanup():
"""定期清理缓存"""
import threading
import time
def cleanup_worker():
while True:
try:
# 每小时清理一次过期缓存
expired_count = api_cache_manager.invalidate_cache()
if expired_count > 0:
logger.info(f"定期清理过期缓存: {expired_count} 条")
time.sleep(3600) # 1小时
except Exception as e:
logger.error(f"定期缓存清理出错: {e}")
time.sleep(300) # 出错时等待5分钟
cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
cleanup_thread.start()
logger.info("缓存清理调度器已启动")
# 启动缓存清理调度器
if DATABASE_AVAILABLE and USE_DATABASE:
schedule_cache_cleanup()
|