File size: 10,291 Bytes
de16696
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
API服务端Bug修复程序
修复 "'str' object has no attribute 'get'" 错误
"""

import os
import shutil
from datetime import datetime

def backup_file(file_path):
    """备份原文件"""
    if os.path.exists(file_path):
        backup_path = f"{file_path}.backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        shutil.copy2(file_path, backup_path)
        print(f"✅ 已备份文件: {backup_path}")
        return backup_path
    return None

def fix_stock_analyzer():
    """修复stock_analyzer.py中的bug"""
    file_path = "stock_analyzer.py"
    
    # 备份原文件
    backup_path = backup_file(file_path)
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # 查找需要修复的代码段
        old_code = '''    def get_stock_info(self, stock_code, market_type='A'):
        """获取股票基本信息,使用新的数据访问层"""
        try:
            # 使用新的数据访问层获取数据
            info = data_service.get_stock_basic_info(stock_code, market_type)

            if info:
                # 转换为原有格式以保持兼容性
                result = {
                    '股票名称': info['stock_name'],
                    '行业': info['industry'],
                    '地区': info.get('sector', '未知'),
                    '总市值': info['market_cap'],
                    '市盈率': info['pe_ratio'],
                    '市净率': info['pb_ratio'],
                    '总股本': info['total_share'],
                    '流通股': info['float_share'],
                    '上市时间': info['list_date']
                }'''
        
        # 修复后的代码
        new_code = '''    def get_stock_info(self, stock_code, market_type='A'):
        """获取股票基本信息,使用新的数据访问层"""
        try:
            # 使用新的数据访问层获取数据
            info = data_service.get_stock_basic_info(stock_code, market_type)

            if info:
                # 检查info是否为字典类型,防止'str' object has no attribute 'get'错误
                if not isinstance(info, dict):
                    self.logger.error(f"获取到的股票信息不是字典格式: {type(info)}, 内容: {info}")
                    raise Exception(f"股票信息格式错误: 期望字典,实际为{type(info)}")
                
                # 转换为原有格式以保持兼容性
                result = {
                    '股票名称': info.get('stock_name', '未知'),
                    '行业': info.get('industry', '未知'),
                    '地区': info.get('sector', '未知'),
                    '总市值': info.get('market_cap', 0),
                    '市盈率': info.get('pe_ratio', 0),
                    '市净率': info.get('pb_ratio', 0),
                    '总股本': info.get('total_share', 0),
                    '流通股': info.get('float_share', 0),
                    '上市时间': info.get('list_date', '')
                }'''
        
        if old_code in content:
            content = content.replace(old_code, new_code)
            
            with open(file_path, 'w', encoding='utf-8') as f:
                f.write(content)
            
            print(f"✅ 已修复 {file_path}")
            return True
        else:
            print(f"⚠️ 在 {file_path} 中未找到需要修复的代码段")
            return False
            
    except Exception as e:
        print(f"❌ 修复 {file_path} 失败: {e}")
        # 恢复备份
        if backup_path and os.path.exists(backup_path):
            shutil.copy2(backup_path, file_path)
            print(f"✅ 已恢复备份文件")
        return False

def fix_data_service():
    """修复data_service.py中的潜在问题"""
    file_path = "data_service.py"
    
    # 备份原文件
    backup_path = backup_file(file_path)
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # 查找并修复get_stock_basic_info方法
        old_code = '''                # 获取股票名称
                try:
                    stock_name_df = ak.stock_info_a_code_name()
                    stock_name = stock_name_df[stock_name_df['code'] == stock_code]['name'].iloc[0]
                except:
                    stock_name = info_dict.get('股票简称', '')'''
        
        new_code = '''                # 获取股票名称
                try:
                    stock_name_df = ak.stock_info_a_code_name()
                    stock_name = stock_name_df[stock_name_df['code'] == stock_code]['name'].iloc[0]
                except Exception as e:
                    self.logger.warning(f"获取股票名称失败: {e}")
                    stock_name = info_dict.get('股票简称', '未知')'''
        
        if old_code in content:
            content = content.replace(old_code, new_code)
        
        # 添加数据验证
        validation_code = '''            # 验证返回数据的完整性
            if not isinstance(data, dict):
                self.logger.error(f"数据格式错误: 期望字典,实际为{type(data)}")
                return None
            
            # 确保必要字段存在
            required_fields = ['stock_code', 'stock_name', 'market_type']
            for field in required_fields:
                if field not in data:
                    self.logger.warning(f"缺少必要字段: {field}")
                    data[field] = '' if field != 'stock_code' else stock_code
            
            return data'''
        
        # 在return data之前插入验证代码
        if 'return data' in content and validation_code not in content:
            content = content.replace('            return data', validation_code)
        
        with open(file_path, 'w', encoding='utf-8') as f:
            f.write(content)
        
        print(f"✅ 已修复 {file_path}")
        return True
        
    except Exception as e:
        print(f"❌ 修复 {file_path} 失败: {e}")
        # 恢复备份
        if backup_path and os.path.exists(backup_path):
            shutil.copy2(backup_path, file_path)
            print(f"✅ 已恢复备份文件")
        return False

def fix_api_endpoints():
    """修复api_endpoints.py中的错误处理"""
    file_path = "api_endpoints.py"
    
    # 备份原文件
    backup_path = backup_file(file_path)
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # 在analyze_stock函数中添加更好的错误处理
        old_error_handling = '''        except Exception as e:
            logger.error(f"分析股票 {normalized_code} 时出错: {str(e)}")
            return APIResponse.error(
                code=ErrorCodes.ANALYSIS_FAILED,
                message=f'股票 {normalized_code} 分析失败',
                details={'error_message': str(e)},
                status_code=500
            )'''
        
        new_error_handling = '''        except Exception as e:
            logger.error(f"分析股票 {normalized_code} 时出错: {str(e)}")
            logger.error(f"错误类型: {type(e).__name__}")
            import traceback
            logger.error(f"错误堆栈: {traceback.format_exc()}")
            
            return APIResponse.error(
                code=ErrorCodes.ANALYSIS_FAILED,
                message=f'股票 {normalized_code} 分析失败',
                details={
                    'error_message': str(e),
                    'error_type': type(e).__name__,
                    'stock_code': normalized_code
                },
                status_code=500
            )'''
        
        if old_error_handling in content:
            content = content.replace(old_error_handling, new_error_handling)
            
            with open(file_path, 'w', encoding='utf-8') as f:
                f.write(content)
            
            print(f"✅ 已修复 {file_path}")
            return True
        else:
            print(f"⚠️ 在 {file_path} 中未找到需要修复的代码段")
            return False
            
    except Exception as e:
        print(f"❌ 修复 {file_path} 失败: {e}")
        # 恢复备份
        if backup_path and os.path.exists(backup_path):
            shutil.copy2(backup_path, file_path)
            print(f"✅ 已恢复备份文件")
        return False

def main():
    """主函数"""
    print("=== API服务端Bug修复程序 ===")
    print(f"开始时间: {datetime.now()}")
    
    print("\n🔧 修复目标:")
    print("- 修复 'str' object has no attribute 'get' 错误")
    print("- 增强数据类型验证")
    print("- 改进错误处理和日志记录")
    
    success_count = 0
    total_fixes = 3
    
    print("\n📝 开始修复...")
    
    # 1. 修复stock_analyzer.py
    print("\n1. 修复 stock_analyzer.py...")
    if fix_stock_analyzer():
        success_count += 1
    
    # 2. 修复data_service.py
    print("\n2. 修复 data_service.py...")
    if fix_data_service():
        success_count += 1
    
    # 3. 修复api_endpoints.py
    print("\n3. 修复 api_endpoints.py...")
    if fix_api_endpoints():
        success_count += 1
    
    print(f"\n=== 修复完成 ===")
    print(f"成功修复: {success_count}/{total_fixes} 个文件")
    print(f"完成时间: {datetime.now()}")
    
    if success_count == total_fixes:
        print("\n🎉 所有修复都已完成!")
        print("\n📋 修复内容:")
        print("✅ 添加了数据类型检查,防止字符串被当作字典使用")
        print("✅ 使用.get()方法安全访问字典键,避免KeyError")
        print("✅ 增强了错误处理和日志记录")
        print("✅ 添加了数据验证和降级处理")
        
        print("\n🚀 下一步:")
        print("1. 重启API服务")
        print("2. 运行批量分析程序测试")
        print("3. 检查日志确认问题已解决")
    else:
        print(f"\n⚠️ 部分修复失败,请检查错误信息")

if __name__ == "__main__":
    main()