File size: 11,816 Bytes
a105470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
# -*- coding: utf-8 -*-
"""
API缓存集成模块
确保API接口与现有MySQL缓存系统良好集成,优化数据获取性能
"""

import logging
import time
import hashlib
import json
from typing import Dict, Any, Optional, List
from functools import wraps
from flask import request, g

# 导入现有的缓存和数据库模块
try:
    from database import get_session, USE_DATABASE, StockData, AnalysisCache
    DATABASE_AVAILABLE = True
except ImportError:
    DATABASE_AVAILABLE = False
    USE_DATABASE = False

logger = logging.getLogger(__name__)


class APICacheManager:
    """API缓存管理器"""
    
    def __init__(self):
        self.cache_ttl = {
            'stock_analysis': 900,      # 个股分析:15分钟
            'portfolio_analysis': 300,  # 组合分析:5分钟
            'batch_score': 600,         # 批量评分:10分钟
            'market_data': 300,         # 市场数据:5分钟
            'fundamental_data': 86400,  # 基本面数据:1天
            'technical_indicators': 900 # 技术指标:15分钟
        }
    
    def generate_cache_key(self, cache_type: str, params: Dict) -> str:
        """生成缓存键"""
        # 创建参数的哈希值
        params_str = json.dumps(params, sort_keys=True, ensure_ascii=False)
        params_hash = hashlib.md5(params_str.encode()).hexdigest()[:16]
        
        # 添加用户等级信息(不同等级可能有不同的分析深度)
        user_tier = getattr(g, 'user_tier', 'free')
        
        return f"api_cache:{cache_type}:{user_tier}:{params_hash}"
    
    def get_cache(self, cache_key: str) -> Optional[Dict]:
        """从缓存获取数据"""
        if not DATABASE_AVAILABLE or not USE_DATABASE:
            return None
        
        try:
            session = get_session()
            cache_record = session.query(AnalysisCache).filter(
                AnalysisCache.cache_key == cache_key,
                AnalysisCache.expires_at > time.time()
            ).first()
            
            if cache_record:
                session.close()
                return {
                    'data': json.loads(cache_record.cache_data),
                    'created_at': cache_record.created_at,
                    'cache_hit': True
                }
            
            session.close()
            return None
            
        except Exception as e:
            logger.error(f"获取缓存数据出错: {e}")
            return None
    
    def set_cache(self, cache_key: str, data: Dict, ttl: int = None) -> bool:
        """设置缓存数据"""
        if not DATABASE_AVAILABLE or not USE_DATABASE:
            return False
        
        try:
            session = get_session()
            
            # 删除旧的缓存记录
            session.query(AnalysisCache).filter(
                AnalysisCache.cache_key == cache_key
            ).delete()
            
            # 创建新的缓存记录
            cache_record = AnalysisCache(
                cache_key=cache_key,
                cache_data=json.dumps(data, ensure_ascii=False),
                created_at=time.time(),
                expires_at=time.time() + (ttl or 900)
            )
            
            session.add(cache_record)
            session.commit()
            session.close()
            
            return True
            
        except Exception as e:
            logger.error(f"设置缓存数据出错: {e}")
            return False
    
    def invalidate_cache(self, pattern: str = None) -> int:
        """清除缓存"""
        if not DATABASE_AVAILABLE or not USE_DATABASE:
            return 0
        
        try:
            session = get_session()
            
            if pattern:
                # 清除匹配模式的缓存
                count = session.query(AnalysisCache).filter(
                    AnalysisCache.cache_key.like(f"%{pattern}%")
                ).delete(synchronize_session=False)
            else:
                # 清除过期缓存
                count = session.query(AnalysisCache).filter(
                    AnalysisCache.expires_at < time.time()
                ).delete(synchronize_session=False)
            
            session.commit()
            session.close()
            
            return count
            
        except Exception as e:
            logger.error(f"清除缓存出错: {e}")
            return 0


# 全局缓存管理器
api_cache_manager = APICacheManager()


def api_cache(cache_type: str, ttl: int = None, key_params: List[str] = None):
    """API缓存装饰器"""
    def decorator(f):
        @wraps(f)
        def decorated_function(*args, **kwargs):
            # 生成缓存键
            if key_params:
                # 使用指定的参数生成缓存键
                cache_params = {}
                request_data = request.get_json() or {}
                for param in key_params:
                    if param in request_data:
                        cache_params[param] = request_data[param]
            else:
                # 使用所有请求参数
                cache_params = request.get_json() or {}
            
            cache_key = api_cache_manager.generate_cache_key(cache_type, cache_params)
            
            # 尝试从缓存获取数据
            cached_result = api_cache_manager.get_cache(cache_key)
            if cached_result:
                logger.info(f"API缓存命中: {cache_key}")
                return cached_result['data']
            
            # 执行原函数
            start_time = time.time()
            result = f(*args, **kwargs)
            processing_time = time.time() - start_time
            
            # 将结果存入缓存
            cache_ttl = ttl or api_cache_manager.cache_ttl.get(cache_type, 900)
            
            # 只缓存成功的结果
            if hasattr(result, 'status_code') and result.status_code == 200:
                try:
                    result_data = result.get_json()
                    if result_data.get('success'):
                        # 添加缓存元数据
                        if 'meta' not in result_data:
                            result_data['meta'] = {}
                        result_data['meta']['cache_hit'] = False
                        result_data['meta']['processing_time_ms'] = int(processing_time * 1000)
                        
                        api_cache_manager.set_cache(cache_key, result_data, cache_ttl)
                        logger.info(f"API结果已缓存: {cache_key}, TTL: {cache_ttl}秒")
                except Exception as e:
                    logger.error(f"缓存API结果出错: {e}")
            
            return result
        
        return decorated_function
    return decorator


def smart_cache_invalidation(stock_codes: List[str] = None, cache_types: List[str] = None):
    """智能缓存失效"""
    try:
        if stock_codes:
            # 清除特定股票相关的缓存
            for stock_code in stock_codes:
                pattern = f"stock_code:{stock_code}"
                count = api_cache_manager.invalidate_cache(pattern)
                logger.info(f"清除股票 {stock_code} 相关缓存: {count} 条")
        
        if cache_types:
            # 清除特定类型的缓存
            for cache_type in cache_types:
                pattern = f"api_cache:{cache_type}"
                count = api_cache_manager.invalidate_cache(pattern)
                logger.info(f"清除 {cache_type} 类型缓存: {count} 条")
        
        # 清除过期缓存
        expired_count = api_cache_manager.invalidate_cache()
        if expired_count > 0:
            logger.info(f"清除过期缓存: {expired_count} 条")
            
    except Exception as e:
        logger.error(f"智能缓存失效出错: {e}")


def get_cache_statistics() -> Dict:
    """获取缓存统计信息"""
    if not DATABASE_AVAILABLE or not USE_DATABASE:
        return {'error': '数据库不可用'}
    
    try:
        session = get_session()
        
        # 总缓存数量
        total_count = session.query(AnalysisCache).count()
        
        # 有效缓存数量
        valid_count = session.query(AnalysisCache).filter(
            AnalysisCache.expires_at > time.time()
        ).count()
        
        # 过期缓存数量
        expired_count = total_count - valid_count
        
        # 按类型统计
        type_stats = {}
        cache_records = session.query(AnalysisCache).all()
        
        for record in cache_records:
            cache_type = record.cache_key.split(':')[1] if ':' in record.cache_key else 'unknown'
            if cache_type not in type_stats:
                type_stats[cache_type] = {'total': 0, 'valid': 0, 'expired': 0}
            
            type_stats[cache_type]['total'] += 1
            if record.expires_at > time.time():
                type_stats[cache_type]['valid'] += 1
            else:
                type_stats[cache_type]['expired'] += 1
        
        session.close()
        
        return {
            'total_count': total_count,
            'valid_count': valid_count,
            'expired_count': expired_count,
            'hit_rate': 0.0,  # 需要额外统计
            'type_statistics': type_stats
        }
        
    except Exception as e:
        logger.error(f"获取缓存统计出错: {e}")
        return {'error': str(e)}


def preload_cache_for_popular_stocks():
    """为热门股票预加载缓存"""
    popular_stocks = [
        '000001.SZ', '000002.SZ', '600000.SH', '600036.SH', '000858.SZ',
        '600519.SH', '000725.SZ', '002415.SZ', '600276.SH', '000568.SZ'
    ]
    
    logger.info("开始为热门股票预加载缓存")
    
    try:
        from stock_analyzer import StockAnalyzer
        analyzer = StockAnalyzer()
        
        for stock_code in popular_stocks:
            try:
                # 预加载个股分析数据
                result = analyzer.quick_analyze_stock(stock_code, 'A')
                
                # 生成缓存键并存储
                cache_params = {'stock_code': stock_code, 'market_type': 'A'}
                cache_key = api_cache_manager.generate_cache_key('stock_analysis', cache_params)
                
                api_cache_manager.set_cache(cache_key, {
                    'success': True,
                    'data': result,
                    'meta': {'preloaded': True}
                }, 1800)  # 30分钟TTL
                
                logger.info(f"预加载股票 {stock_code} 缓存完成")
                
            except Exception as e:
                logger.error(f"预加载股票 {stock_code} 缓存失败: {e}")
        
        logger.info("热门股票缓存预加载完成")
        
    except Exception as e:
        logger.error(f"预加载缓存出错: {e}")


def schedule_cache_cleanup():
    """定期清理缓存"""
    import threading
    import time
    
    def cleanup_worker():
        while True:
            try:
                # 每小时清理一次过期缓存
                expired_count = api_cache_manager.invalidate_cache()
                if expired_count > 0:
                    logger.info(f"定期清理过期缓存: {expired_count} 条")
                
                time.sleep(3600)  # 1小时
                
            except Exception as e:
                logger.error(f"定期缓存清理出错: {e}")
                time.sleep(300)  # 出错时等待5分钟
    
    cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
    cleanup_thread.start()
    logger.info("缓存清理调度器已启动")


# 启动缓存清理调度器
if DATABASE_AVAILABLE and USE_DATABASE:
    schedule_cache_cleanup()