yzzzzzzz commited on
Commit
3081294
·
1 Parent(s): ff88a84

Upload 20 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo1.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 zwq2018
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
SW2021_industry_L1.csv ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "index_code","industry_name","level","industry_code","is_pub","parent_code"
2
+ "801010.SI","农林牧渔","L1","110000","1","0"
3
+ "801030.SI","基础化工","L1","220000","1","0"
4
+ "801040.SI","钢铁","L1","230000","1","0"
5
+ "801050.SI","有色金属","L1","240000","1","0"
6
+ "801080.SI","电子","L1","270000","1","0"
7
+ "801880.SI","汽车","L1","280000","1","0"
8
+ "801110.SI","家用电器","L1","330000","1","0"
9
+ "801120.SI","食品饮料","L1","340000","1","0"
10
+ "801130.SI","纺织服饰","L1","350000","1","0"
11
+ "801140.SI","轻工制造","L1","360000","1","0"
12
+ "801150.SI","医药生物","L1","370000","1","0"
13
+ "801160.SI","公用事业","L1","410000","1","0"
14
+ "801170.SI","交通运输","L1","420000","1","0"
15
+ "801180.SI","房地产","L1","430000","1","0"
16
+ "801200.SI","商贸零售","L1","450000","1","0"
17
+ "801210.SI","社会服务","L1","460000","1","0"
18
+ "801780.SI","银行","L1","480000","1","0"
19
+ "801790.SI","非银金融","L1","490000","1","0"
20
+ "801230.SI","综合","L1","510000","1","0"
21
+ "801710.SI","建筑材料","L1","610000","1","0"
22
+ "801720.SI","建筑装饰","L1","620000","1","0"
23
+ "801730.SI","电力设备","L1","630000","1","0"
24
+ "801890.SI","机械设备","L1","640000","1","0"
25
+ "801740.SI","国防军工","L1","650000","1","0"
26
+ "801750.SI","计算机","L1","710000","1","0"
27
+ "801760.SI","传媒","L1","720000","1","0"
28
+ "801770.SI","通信","L1","730000","1","0"
29
+ "801950.SI","煤炭","L1","740000","1","0"
30
+ "801960.SI","石油石化","L1","750000","1","0"
31
+ "801970.SI","环保","L1","760000","1","0"
32
+ "801980.SI","美容护理","L1","770000","1","0"
SW2021_industry_L2.csv ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "index_code","industry_name","level","industry_code","is_pub","parent_code"
2
+ "801016.SI","种植业","L2","110100","1","110000"
3
+ "801015.SI","渔业","L2","110200","1","110000"
4
+ "801011.SI","林业Ⅱ","L2","110300","0","110000"
5
+ "801014.SI","饲料","L2","110400","1","110000"
6
+ "801012.SI","农产品加工","L2","110500","1","110000"
7
+ "801017.SI","养殖业","L2","110700","1","110000"
8
+ "801018.SI","动物保健Ⅱ","L2","110800","1","110000"
9
+ "801019.SI","农业综合Ⅱ","L2","110900","0","110000"
10
+ "801033.SI","化学原料","L2","220200","1","220000"
11
+ "801034.SI","化学制品","L2","220300","1","220000"
12
+ "801032.SI","化学纤维","L2","220400","1","220000"
13
+ "801036.SI","塑料","L2","220500","1","220000"
14
+ "801037.SI","橡胶","L2","220600","1","220000"
15
+ "801038.SI","农化制品","L2","220800","1","220000"
16
+ "801039.SI","非金属材料Ⅱ","L2","220900","1","220000"
17
+ "801043.SI","冶钢原料","L2","230300","1","230000"
18
+ "801044.SI","普钢","L2","230400","1","230000"
19
+ "801045.SI","特钢Ⅱ","L2","230500","1","230000"
20
+ "801051.SI","金属新材料","L2","240200","1","240000"
21
+ "801055.SI","工业金属","L2","240300","1","240000"
22
+ "801053.SI","贵金属","L2","240400","1","240000"
23
+ "801054.SI","小金属","L2","240500","1","240000"
24
+ "801056.SI","能源金属","L2","240600","1","240000"
25
+ "801081.SI","半导体","L2","270100","1","270000"
26
+ "801083.SI","元件","L2","270200","1","270000"
27
+ "801084.SI","光学光电子","L2","270300","1","270000"
28
+ "801082.SI","其他电子Ⅱ","L2","270400","1","270000"
29
+ "801085.SI","消费电子","L2","270500","1","270000"
30
+ "801086.SI","电子化学品Ⅱ","L2","270600","1","270000"
31
+ "801093.SI","汽车零部件","L2","280200","1","280000"
32
+ "801092.SI","汽车服务","L2","280300","1","280000"
33
+ "801881.SI","摩托车及其他","L2","280400","1","280000"
34
+ "801095.SI","乘用车","L2","280500","1","280000"
35
+ "801096.SI","商用车","L2","280600","1","280000"
36
+ "801111.SI","白色家电","L2","330100","1","330000"
37
+ "801112.SI","黑色家电","L2","330200","1","330000"
38
+ "801113.SI","小家电","L2","330300","1","330000"
39
+ "801114.SI","厨卫电器","L2","330400","1","330000"
40
+ "801115.SI","照明设备Ⅱ","L2","330500","1","330000"
41
+ "801116.SI","家电零部件Ⅱ","L2","330600","1","330000"
42
+ "801117.SI","其他家电Ⅱ","L2","330700","0","330000"
43
+ "801124.SI","食品加工","L2","340400","1","340000"
44
+ "801125.SI","白酒Ⅱ","L2","340500","1","340000"
45
+ "801126.SI","非白酒","L2","340600","1","340000"
46
+ "801127.SI","饮料乳品","L2","340700","1","340000"
47
+ "801128.SI","休闲食品","L2","340800","1","340000"
48
+ "801129.SI","调味发酵品Ⅱ","L2","340900","1","340000"
49
+ "801131.SI","纺织制造","L2","350100","1","350000"
50
+ "801132.SI","服装家纺","L2","350200","1","350000"
51
+ "801133.SI","饰品","L2","350300","1","350000"
52
+ "801143.SI","造纸","L2","360100","1","360000"
53
+ "801141.SI","包装印刷","L2","360200","1","360000"
54
+ "801142.SI","家居用品","L2","360300","1","360000"
55
+ "801145.SI","文娱用品","L2","360500","1","360000"
56
+ "801151.SI","化学制药","L2","370100","1","370000"
57
+ "801155.SI","中药Ⅱ","L2","370200","1","370000"
58
+ "801152.SI","生物制品","L2","370300","1","370000"
59
+ "801154.SI","医药商业","L2","370400","1","370000"
60
+ "801153.SI","医疗器械","L2","370500","1","370000"
61
+ "801156.SI","医疗服务","L2","370600","1","370000"
62
+ "801161.SI","电力","L2","410100","1","410000"
63
+ "801163.SI","燃气Ⅱ","L2","410300","1","410000"
64
+ "801178.SI","物流","L2","420800","1","420000"
65
+ "801179.SI","铁路公路","L2","420900","1","420000"
66
+ "801991.SI","航空机场","L2","421000","1","420000"
67
+ "801992.SI","航运港口","L2","421100","1","420000"
68
+ "801181.SI","房地产开发","L2","430100","1","430000"
69
+ "801183.SI","房地产服务","L2","430300","1","430000"
70
+ "801202.SI","贸易Ⅱ","L2","450200","1","450000"
71
+ "801203.SI","一般零售","L2","450300","1","450000"
72
+ "801204.SI","专业连锁Ⅱ","L2","450400","1","450000"
73
+ "801206.SI","互联网电商","L2","450600","1","450000"
74
+ "801207.SI","旅游零售Ⅱ","L2","450700","0","450000"
75
+ "801216.SI","体育Ⅱ","L2","460600","0","460000"
76
+ "801217.SI","本地生活服务Ⅱ","L2","460700","0","460000"
77
+ "801218.SI","专业服务","L2","460800","1","460000"
78
+ "801219.SI","酒店餐饮","L2","460900","1","460000"
79
+ "801993.SI","旅游及景区","L2","461000","1","460000"
80
+ "801994.SI","教育","L2","461100","1","460000"
81
+ "801782.SI","国有大型银行Ⅱ","L2","480200","1","480000"
82
+ "801783.SI","股份制银行Ⅱ","L2","480300","1","480000"
83
+ "801784.SI","城商行Ⅱ","L2","480400","1","480000"
84
+ "801785.SI","农商行Ⅱ","L2","480500","1","480000"
85
+ "801786.SI","其他银行Ⅱ","L2","480600","0","480000"
86
+ "801193.SI","证券Ⅱ","L2","490100","1","490000"
87
+ "801194.SI","保险Ⅱ","L2","490200","1","490000"
88
+ "801191.SI","多元金融","L2","490300","1","490000"
89
+ "801231.SI","综合Ⅱ","L2","510100","1","510000"
90
+ "801711.SI","水泥","L2","610100","1","610000"
91
+ "801712.SI","玻璃玻纤","L2","610200","1","610000"
92
+ "801713.SI","装修建材","L2","610300","1","610000"
93
+ "801721.SI","房屋建设Ⅱ","L2","620100","1","620000"
94
+ "801722.SI","装修装饰Ⅱ","L2","620200","1","620000"
95
+ "801723.SI","基础建设","L2","620300","1","620000"
96
+ "801724.SI","专业工程","L2","620400","1","620000"
97
+ "801726.SI","工程咨询服务Ⅱ","L2","620600","1","620000"
98
+ "801731.SI","电机Ⅱ","L2","630100","1","630000"
99
+ "801733.SI","其他电源设备Ⅱ","L2","630300","1","630000"
100
+ "801735.SI","光伏设备","L2","630500","1","630000"
101
+ "801736.SI","风电设备","L2","630600","1","630000"
102
+ "801737.SI","电池","L2","630700","1","630000"
103
+ "801738.SI","电网设备","L2","630800","1","630000"
104
+ "801072.SI","通用设备","L2","640100","1","640000"
105
+ "801074.SI","专用设备","L2","640200","1","640000"
106
+ "801076.SI","轨交设备Ⅱ","L2","640500","1","640000"
107
+ "801077.SI","工程机械","L2","640600","1","640000"
108
+ "801078.SI","自动化设备","L2","640700","1","640000"
109
+ "801741.SI","航天装备Ⅱ","L2","650100","1","650000"
110
+ "801742.SI","航空装备Ⅱ","L2","650200","1","650000"
111
+ "801743.SI","地面兵装Ⅱ","L2","650300","1","650000"
112
+ "801744.SI","航海装备Ⅱ","L2","650400","1","650000"
113
+ "801745.SI","军工电子Ⅱ","L2","650500","1","650000"
114
+ "801101.SI","计算机设备","L2","710100","1","710000"
115
+ "801103.SI","IT服务Ⅱ","L2","710300","1","710000"
116
+ "801104.SI","软件开发","L2","710400","1","710000"
117
+ "801764.SI","游戏Ⅱ","L2","720400","1","720000"
118
+ "801765.SI","广告营销","L2","720500","1","720000"
119
+ "801766.SI","影视院线","L2","720600","1","720000"
120
+ "801767.SI","数字媒体","L2","720700","1","720000"
121
+ "801768.SI","社交Ⅱ","L2","720800","0","720000"
122
+ "801769.SI","出版","L2","720900","1","720000"
123
+ "801995.SI","电视广播Ⅱ","L2","721000","1","720000"
124
+ "801223.SI","通信服务","L2","730100","1","730000"
125
+ "801102.SI","通信设备","L2","730200","1","730000"
126
+ "801951.SI","煤炭开采","L2","740100","1","740000"
127
+ "801952.SI","焦炭Ⅱ","L2","740200","1","740000"
128
+ "801961.SI","油气开采Ⅱ","L2","750100","0","750000"
129
+ "801962.SI","油服工程","L2","750200","1","750000"
130
+ "801963.SI","炼化及贸易","L2","750300","1","750000"
131
+ "801971.SI","环境治理","L2","760100","1","760000"
132
+ "801972.SI","环保设备Ⅱ","L2","760200","1","760000"
133
+ "801981.SI","个护用品","L2","770100","1","770000"
134
+ "801982.SI","化妆品","L2","770200","1","770000"
135
+ "801983.SI","医疗美容","L2","770300","0","770000"
SW2021_industry_L3.csv ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "index_code","industry_name","level","industry_code","is_pub","parent_code"
2
+ "850111.SI","种子","L3","110101","1","110100"
3
+ "850112.SI","粮食种植","L3","110102","0","110100"
4
+ "850113.SI","其他种植业","L3","110103","1","110100"
5
+ "850114.SI","食用菌","L3","110104","0","110100"
6
+ "850121.SI","海洋捕捞","L3","110201","0","110200"
7
+ "850122.SI","水产养殖","L3","110202","1","110200"
8
+ "850131.SI","林业Ⅲ","L3","110301","0","110300"
9
+ "850142.SI","畜禽饲料","L3","110402","1","110400"
10
+ "850143.SI","水产饲料","L3","110403","0","110400"
11
+ "850144.SI","宠物食品","L3","110404","0","110400"
12
+ "850151.SI","果蔬加工","L3","110501","1","110500"
13
+ "850152.SI","粮油加工","L3","110502","1","110500"
14
+ "850154.SI","其他农产品加工","L3","110504","1","110500"
15
+ "850172.SI","生猪养殖","L3","110702","1","110700"
16
+ "850173.SI","肉鸡养殖","L3","110703","1","110700"
17
+ "850174.SI","其他养殖","L3","110704","0","110700"
18
+ "850181.SI","动物保健Ⅲ","L3","110801","1","110800"
19
+ "850191.SI","农业综合Ⅲ","L3","110901","0","110900"
20
+ "850321.SI","纯碱","L3","220201","0","220200"
21
+ "850322.SI","氯碱","L3","220202","1","220200"
22
+ "850323.SI","无机盐","L3","220203","1","220200"
23
+ "850324.SI","其他化学原料","L3","220204","1","220200"
24
+ "850325.SI","煤化工","L3","220205","1","220200"
25
+ "850326.SI","钛白粉","L3","220206","1","220200"
26
+ "850335.SI","涂料油墨","L3","220305","1","220300"
27
+ "850337.SI","民爆制品","L3","220307","1","220300"
28
+ "850338.SI","纺织化学制品","L3","220308","1","220300"
29
+ "850339.SI","其他化学制品","L3","220309","1","220300"
30
+ "850382.SI","氟化工","L3","220311","1","220300"
31
+ "850372.SI","聚氨酯","L3","220313","1","220300"
32
+ "850135.SI","食品及饲料添加剂","L3","220315","1","220300"
33
+ "850136.SI","有机硅","L3","220316","1","220300"
34
+ "850137.SI","胶黏剂及胶带","L3","220317","0","220300"
35
+ "850341.SI","涤纶","L3","220401","1","220400"
36
+ "850343.SI","粘胶","L3","220403","1","220400"
37
+ "850344.SI","其他化学纤维","L3","220404","0","220400"
38
+ "850345.SI","氨纶","L3","220405","0","220400"
39
+ "850346.SI","锦纶","L3","220406","0","220400"
40
+ "850351.SI","其他塑料制品","L3","220501","1","220500"
41
+ "850353.SI","改性塑料","L3","220503","1","220500"
42
+ "850354.SI","合成树脂","L3","220504","1","220500"
43
+ "850355.SI","膜材料","L3","220505","1","220500"
44
+ "850362.SI","其他橡胶制品","L3","220602","1","220600"
45
+ "850363.SI","炭黑","L3","220603","1","220600"
46
+ "850364.SI","橡胶助剂","L3","220604","0","220600"
47
+ "850331.SI","氮肥","L3","220801","1","220800"
48
+ "850332.SI","磷肥及磷化工","L3","220802","1","220800"
49
+ "850333.SI","农药","L3","220803","1","220800"
50
+ "850336.SI","钾肥","L3","220804","0","220800"
51
+ "850381.SI","复合肥","L3","220805","1","220800"
52
+ "850523.SI","非金属材料Ⅲ","L3","220901","1","220900"
53
+ "850431.SI","铁矿石","L3","230301","0","230300"
54
+ "850432.SI","冶钢辅料","L3","230302","0","230300"
55
+ "850441.SI","长材","L3","230401","0","230400"
56
+ "850442.SI","板材","L3","230402","1","230400"
57
+ "850443.SI","钢铁管材","L3","230403","0","230400"
58
+ "850401.SI","特钢Ⅲ","L3","230501","1","230500"
59
+ "850521.SI","其他金属新材料","L3","240201","1","240200"
60
+ "850522.SI","磁性材料","L3","240202","1","240200"
61
+ "850551.SI","铝","L3","240301","1","240300"
62
+ "850552.SI","铜","L3","240302","1","240300"
63
+ "850553.SI","铅锌","L3","240303","1","240300"
64
+ "850531.SI","黄金","L3","240401","1","240400"
65
+ "850532.SI","白银","L3","240402","0","240400"
66
+ "850541.SI","稀土","L3","240501","0","240500"
67
+ "850542.SI","钨","L3","240502","0","240500"
68
+ "850544.SI","其他小金属","L3","240504","1","240500"
69
+ "850545.SI","钼","L3","240505","0","240500"
70
+ "850561.SI","钴","L3","240601","0","240600"
71
+ "850562.SI","镍","L3","240602","0","240600"
72
+ "850543.SI","锂","L3","240603","0","240600"
73
+ "850812.SI","分立器件","L3","270102","1","270100"
74
+ "850813.SI","半导体材料","L3","270103","1","270100"
75
+ "850814.SI","数字芯片设计","L3","270104","1","270100"
76
+ "850815.SI","模拟芯片设计","L3","270105","1","270100"
77
+ "850816.SI","集成电路制造","L3","270106","0","270100"
78
+ "850817.SI","集成电路封测","L3","270107","1","270100"
79
+ "850818.SI","半导体设备","L3","270108","1","270100"
80
+ "850822.SI","印制电路板","L3","270202","1","270200"
81
+ "850823.SI","被动元件","L3","270203","1","270200"
82
+ "850831.SI","面板","L3","270301","1","270300"
83
+ "850832.SI","LED","L3","270302","1","270300"
84
+ "850833.SI","光学元件","L3","270303","1","270300"
85
+ "850841.SI","其他电子Ⅲ","L3","270401","1","270400"
86
+ "850853.SI","品牌消费电子","L3","270503","1","270500"
87
+ "850854.SI","消费电子零部件及组装","L3","270504","1","270500"
88
+ "850861.SI","电子化学品Ⅲ","L3","270601","1","270600"
89
+ "850922.SI","车身附件及饰件","L3","280202","1","280200"
90
+ "850923.SI","底盘与发动机系统","L3","280203","1","280200"
91
+ "850924.SI","轮胎轮毂","L3","280204","1","280200"
92
+ "850925.SI","其他汽车零部件","L3","280205","1","280200"
93
+ "850926.SI","汽车电子电气系统","L3","280206","1","280200"
94
+ "850232.SI","汽车经销商","L3","280302","1","280300"
95
+ "850233.SI","汽车综合服务","L3","280303","1","280300"
96
+ "858811.SI","其他运输设备","L3","280401","1","280400"
97
+ "858812.SI","摩托车","L3","280402","1","280400"
98
+ "850951.SI","电动乘用车","L3","280501","0","280500"
99
+ "850952.SI","综合乘用车","L3","280502","1","280500"
100
+ "850912.SI","商用载货车","L3","280601","1","280600"
101
+ "850913.SI","商用载客车","L3","280602","1","280600"
102
+ "851112.SI","空调","L3","330102","1","330100"
103
+ "851116.SI","冰洗","L3","330106","1","330100"
104
+ "851121.SI","彩电","L3","330201","0","330200"
105
+ "851122.SI","其他黑色家电","L3","330202","1","330200"
106
+ "851131.SI","厨房小家电","L3","330301","1","330300"
107
+ "851132.SI","清洁小家电","L3","330302","0","330300"
108
+ "851133.SI","个护小家电","L3","330303","0","330300"
109
+ "851141.SI","厨房电器","L3","330401","1","330400"
110
+ "851142.SI","卫浴电器","L3","330402","0","330400"
111
+ "851151.SI","照明设备Ⅲ","L3","330501","1","330500"
112
+ "851161.SI","家电零部件Ⅲ","L3","330601","1","330600"
113
+ "851171.SI","其他家电Ⅲ","L3","330701","0","330700"
114
+ "851241.SI","肉制品","L3","340401","1","340400"
115
+ "851244.SI","其他食品","L3","340404","0","340400"
116
+ "851246.SI","预加工食品","L3","340406","1","340400"
117
+ "851247.SI","保健品","L3","340407","1","340400"
118
+ "851251.SI","白酒Ⅲ","L3","340501","1","340500"
119
+ "851232.SI","啤酒","L3","340601","1","340600"
120
+ "851233.SI","其他酒类","L3","340602","1","340600"
121
+ "851271.SI","软饮料","L3","340701","1","340700"
122
+ "851243.SI","乳品","L3","340702","1","340700"
123
+ "851281.SI","零食","L3","340801","1","340800"
124
+ "851282.SI","烘焙食品","L3","340802","1","340800"
125
+ "851283.SI","熟食","L3","340803","0","340800"
126
+ "851242.SI","调味发酵品Ⅲ","L3","340901","1","340900"
127
+ "851312.SI","棉纺","L3","350102","1","350100"
128
+ "851314.SI","印染","L3","350104","1","350100"
129
+ "851315.SI","辅料","L3","350105","1","350100"
130
+ "851316.SI","其他纺织","L3","350106","1","350100"
131
+ "851317.SI","纺织鞋类制造","L3","350107","0","350100"
132
+ "851325.SI","鞋帽及其他","L3","350205","1","350200"
133
+ "851326.SI","家纺","L3","350206","1","350200"
134
+ "851328.SI","运动服装","L3","350208","0","350200"
135
+ "851329.SI","非运动服装","L3","350209","1","350200"
136
+ "851331.SI","钟表珠宝","L3","350301","1","350300"
137
+ "851332.SI","多品类奢侈品","L3","350302","0","350300"
138
+ "851333.SI","其他饰品","L3","350303","0","350300"
139
+ "851412.SI","大宗用纸","L3","360102","1","360100"
140
+ "851413.SI","特种纸","L3","360103","1","360100"
141
+ "851422.SI","印刷","L3","360202","1","360200"
142
+ "851423.SI","金属包装","L3","360203","1","360200"
143
+ "851424.SI","塑料包装","L3","360204","1","360200"
144
+ "851425.SI","纸包装","L3","360205","1","360200"
145
+ "851426.SI","综合包装","L3","360206","0","360200"
146
+ "851436.SI","瓷砖地板","L3","360306","1","360300"
147
+ "851437.SI","成品家居","L3","360307","1","360300"
148
+ "851438.SI","定制家居","L3","360308","1","360300"
149
+ "851439.SI","卫浴制品","L3","360309","1","360300"
150
+ "851491.SI","其他家居用品","L3","360311","1","360300"
151
+ "851451.SI","文化用品","L3","360501","0","360500"
152
+ "851452.SI","娱乐用品","L3","360502","1","360500"
153
+ "851511.SI","原料药","L3","370101","1","370100"
154
+ "851512.SI","化学制剂","L3","370102","1","370100"
155
+ "851521.SI","中药Ⅲ","L3","370201","1","370200"
156
+ "851522.SI","血液制品","L3","370302","1","370300"
157
+ "851523.SI","疫苗","L3","370303","1","370300"
158
+ "851524.SI","其他生物制品","L3","370304","1","370300"
159
+ "851542.SI","医药流通","L3","370402","1","370400"
160
+ "851543.SI","线下药店","L3","370403","1","370400"
161
+ "851544.SI","互联网药店","L3","370404","0","370400"
162
+ "851532.SI","医疗设备","L3","370502","1","370500"
163
+ "851533.SI","医疗耗材","L3","370503","1","370500"
164
+ "851534.SI","体外诊断","L3","370504","1","370500"
165
+ "851562.SI","诊断服务","L3","370602","0","370600"
166
+ "851563.SI","医疗研发外包","L3","370603","1","370600"
167
+ "851564.SI","医院","L3","370604","1","370600"
168
+ "851565.SI","其他医疗服务","L3","370605","0","370600"
169
+ "851611.SI","火力发电","L3","410101","1","410100"
170
+ "851612.SI","水力发电","L3","410102","1","410100"
171
+ "851614.SI","热力服务","L3","410104","1","410100"
172
+ "851616.SI","光伏发电","L3","410106","1","410100"
173
+ "851617.SI","风力发电","L3","410107","1","410100"
174
+ "851618.SI","核力发电","L3","410108","0","410100"
175
+ "851619.SI","其他能源发电","L3","410109","0","410100"
176
+ "851610.SI","电能综合服务","L3","410110","1","410100"
177
+ "851631.SI","燃气Ⅲ","L3","410301","1","410300"
178
+ "851782.SI","原材料供应链服务","L3","420802","1","420800"
179
+ "851783.SI","中间产品及消费品供应链服务","L3","420803","1","420800"
180
+ "851784.SI","快递","L3","420804","1","420800"
181
+ "851785.SI","跨境物流","L3","420805","1","420800"
182
+ "851786.SI","仓储物流","L3","420806","1","420800"
183
+ "851787.SI","公路货运","L3","420807","1","420800"
184
+ "851731.SI","高速公路","L3","420901","1","420900"
185
+ "851721.SI","公交","L3","420902","1","420900"
186
+ "851771.SI","铁路运输","L3","420903","1","420900"
187
+ "851741.SI","航空运输","L3","421001","1","421000"
188
+ "851751.SI","机场","L3","421002","0","421000"
189
+ "851761.SI","航运","L3","421101","1","421100"
190
+ "851711.SI","港口","L3","421102","1","421100"
191
+ "851811.SI","住宅开发","L3","430101","1","430100"
192
+ "851812.SI","商业地产","L3","430102","1","430100"
193
+ "851813.SI","产业地产","L3","430103","1","430100"
194
+ "851831.SI","物业管理","L3","430301","1","430300"
195
+ "851832.SI","房产租赁经纪","L3","430302","0","430300"
196
+ "851833.SI","房地产综合服务","L3","430303","0","430300"
197
+ "852021.SI","贸易Ⅲ","L3","450201","1","450200"
198
+ "852031.SI","百货","L3","450301","1","450300"
199
+ "852032.SI","超市","L3","450302","1","450300"
200
+ "852033.SI","多业态零售","L3","450303","1","450300"
201
+ "852034.SI","商业物业经营","L3","450304","1","450300"
202
+ "852041.SI","专业连锁Ⅲ","L3","450401","1","450400"
203
+ "852061.SI","综合电商","L3","450601","0","450600"
204
+ "852062.SI","跨境电商","L3","450602","1","450600"
205
+ "852063.SI","电商服务","L3","450603","1","450600"
206
+ "852071.SI","旅游零售Ⅲ","L3","450701","0","450700"
207
+ "852161.SI","体育Ⅲ","L3","460601","0","460600"
208
+ "852171.SI","本地生活服务Ⅲ","L3","460701","0","460700"
209
+ "852181.SI","人力资源服务","L3","460801","0","460800"
210
+ "852182.SI","检测服务","L3","460802","1","460800"
211
+ "852183.SI","会展服务","L3","460803","1","460800"
212
+ "852184.SI","其他专业服务","L3","460804","0","460800"
213
+ "852121.SI","酒店","L3","460901","1","460900"
214
+ "852141.SI","餐饮","L3","460902","0","460900"
215
+ "859931.SI","博彩","L3","461001","0","461000"
216
+ "852111.SI","人工景区","L3","461002","1","461000"
217
+ "852112.SI","自然景区","L3","461003","1","461000"
218
+ "852131.SI","旅游综合","L3","461004","1","461000"
219
+ "859851.SI","学历教育","L3","461101","0","461100"
220
+ "859852.SI","培训教育","L3","461102","1","461100"
221
+ "859853.SI","教育运营及其他","L3","461103","0","461100"
222
+ "857821.SI","国有大型银行Ⅲ","L3","480201","1","480200"
223
+ "857831.SI","股份制银行Ⅲ","L3","480301","1","480300"
224
+ "857841.SI","城商行Ⅲ","L3","480401","1","480400"
225
+ "857851.SI","农商行Ⅲ","L3","480501","1","480500"
226
+ "857861.SI","其他银行Ⅲ","L3","480601","0","480600"
227
+ "851931.SI","证券Ⅲ","L3","490101","1","490100"
228
+ "851941.SI","保险Ⅲ","L3","490201","1","490200"
229
+ "851922.SI","金融控股","L3","490302","1","490300"
230
+ "851923.SI","期货","L3","490303","0","490300"
231
+ "851924.SI","信托","L3","490304","0","490300"
232
+ "851925.SI","租赁","L3","490305","0","490300"
233
+ "851926.SI","金融信息服务","L3","490306","0","490300"
234
+ "851927.SI","资产管理","L3","490307","1","490300"
235
+ "851928.SI","其他多元金融","L3","490308","0","490300"
236
+ "852311.SI","综合Ⅲ","L3","510101","1","510100"
237
+ "857111.SI","水泥制造","L3","610101","1","610100"
238
+ "857112.SI","水泥制品","L3","610102","1","610100"
239
+ "857121.SI","玻璃制造","L3","610201","1","610200"
240
+ "857122.SI","玻纤制造","L3","610202","1","610200"
241
+ "850615.SI","耐火材料","L3","610301","1","610300"
242
+ "850616.SI","管材","L3","610302","1","610300"
243
+ "850614.SI","其他建材","L3","610303","1","610300"
244
+ "850617.SI","防水材料","L3","610304","0","610300"
245
+ "850618.SI","涂料","L3","610305","0","610300"
246
+ "850623.SI","房屋建设Ⅲ","L3","620101","1","620100"
247
+ "857221.SI","装修装饰Ⅲ","L3","620201","1","620200"
248
+ "857236.SI","基建市政工程","L3","620306","1","620300"
249
+ "857251.SI","园林工程","L3","620307","1","620300"
250
+ "857241.SI","钢结构","L3","620401","1","620400"
251
+ "857242.SI","化学工程","L3","620402","1","620400"
252
+ "857243.SI","国际工程","L3","620403","1","620400"
253
+ "857244.SI","其他专业工程","L3","620404","1","620400"
254
+ "857261.SI","工程咨询服务Ⅲ","L3","620601","1","620600"
255
+ "850741.SI","电机Ⅲ","L3","630101","1","630100"
256
+ "857331.SI","综合电力设备商","L3","630301","0","630300"
257
+ "857334.SI","火电设备","L3","630304","1","630300"
258
+ "857336.SI","其他电源设备Ⅲ","L3","630306","1","630300"
259
+ "857351.SI","硅料硅片","L3","630501","0","630500"
260
+ "857352.SI","光伏电池组件","L3","630502","1","630500"
261
+ "857353.SI","逆变器","L3","630503","0","630500"
262
+ "857354.SI","光伏辅材","L3","630504","1","630500"
263
+ "857355.SI","光伏加工设备","L3","630505","1","630500"
264
+ "857361.SI","风电整机","L3","630601","0","630600"
265
+ "857362.SI","风电零部件","L3","630602","1","630600"
266
+ "857371.SI","锂电池","L3","630701","1","630700"
267
+ "857372.SI","电池化学品","L3","630702","1","630700"
268
+ "857373.SI","锂电专用设备","L3","630703","1","630700"
269
+ "857374.SI","燃料电池","L3","630704","0","630700"
270
+ "857375.SI","蓄电池及其他电池","L3","630705","1","630700"
271
+ "857381.SI","输变电设备","L3","630801","1","630800"
272
+ "857382.SI","配电设备","L3","630802","1","630800"
273
+ "857321.SI","电网自动化设备","L3","630803","1","630800"
274
+ "857323.SI","电工仪器仪表","L3","630804","1","630800"
275
+ "857344.SI","线缆部件及其他","L3","630805","1","630800"
276
+ "850711.SI","机床工具","L3","640101","1","640100"
277
+ "850713.SI","磨具磨料","L3","640103","1","640100"
278
+ "850715.SI","制冷空调设备","L3","640105","1","640100"
279
+ "850716.SI","其他通用设备","L3","640106","1","640100"
280
+ "850731.SI","仪器仪表","L3","640107","1","640100"
281
+ "850751.SI","金属制品","L3","640108","1","640100"
282
+ "850725.SI","能源及重型设备","L3","640203","1","640200"
283
+ "850728.SI","楼宇设备","L3","640204","1","640200"
284
+ "850721.SI","纺织服装设备","L3","640206","1","640200"
285
+ "850723.SI","农用机械","L3","640207","0","640200"
286
+ "850726.SI","印刷包装机械","L3","640208","1","640200"
287
+ "850727.SI","其他专用设备","L3","640209","1","640200"
288
+ "850936.SI","轨交设备Ⅲ","L3","640501","1","640500"
289
+ "850771.SI","工程机械整机","L3","640601","1","640600"
290
+ "850772.SI","工程机械器件","L3","640602","1","640600"
291
+ "850781.SI","机器人","L3","640701","1","640700"
292
+ "850782.SI","工控设备","L3","640702","1","640700"
293
+ "850783.SI","激光设备","L3","640703","1","640700"
294
+ "850784.SI","其他自动化设备","L3","640704","1","640700"
295
+ "857411.SI","航天装备Ⅲ","L3","650101","1","650100"
296
+ "857421.SI","航空装备Ⅲ","L3","650201","1","650200"
297
+ "857431.SI","地面兵装Ⅲ","L3","650301","1","650300"
298
+ "850935.SI","航海装备Ⅲ","L3","650401","1","650400"
299
+ "857451.SI","军工电子Ⅲ","L3","650501","1","650500"
300
+ "850702.SI","安防设备","L3","710102","1","710100"
301
+ "850703.SI","其他计算机设备","L3","710103","1","710100"
302
+ "852226.SI","IT服务Ⅲ","L3","710301","1","710300"
303
+ "851041.SI","垂直应用软件","L3","710401","1","710400"
304
+ "851042.SI","横向通用软件","L3","710402","1","710400"
305
+ "857641.SI","游戏Ⅲ","L3","720401","1","720400"
306
+ "857651.SI","营销代理","L3","720501","1","720500"
307
+ "857652.SI","广告媒体","L3","720502","0","720500"
308
+ "857661.SI","影视动漫制作","L3","720601","1","720600"
309
+ "857662.SI","院线","L3","720602","0","720600"
310
+ "857671.SI","视频媒体","L3","720701","0","720700"
311
+ "857672.SI","音频媒体","L3","720702","0","720700"
312
+ "857673.SI","图片媒体","L3","720703","0","720700"
313
+ "857674.SI","门户网站","L3","720704","1","720700"
314
+ "857675.SI","文字媒体","L3","720705","0","720700"
315
+ "857676.SI","其他数字媒体","L3","720706","0","720700"
316
+ "857681.SI","社交Ⅲ","L3","720801","0","720800"
317
+ "857691.SI","教育出版","L3","720901","1","720900"
318
+ "857692.SI","大众出版","L3","720902","1","720900"
319
+ "857693.SI","其他出版","L3","720903","0","720900"
320
+ "859951.SI","电视广播Ⅲ","L3","721001","1","721000"
321
+ "852212.SI","电信运营商","L3","730102","0","730100"
322
+ "852213.SI","通信工程及服务","L3","730103","1","730100"
323
+ "852214.SI","通信应用增值服务","L3","730104","1","730100"
324
+ "851024.SI","通信网络设备及器件","L3","730204","1","730200"
325
+ "851025.SI","通信线缆及配套","L3","730205","1","730200"
326
+ "851026.SI","通信终端及配件","L3","730206","1","730200"
327
+ "851027.SI","其他通信设备","L3","730207","1","730200"
328
+ "859511.SI","动力煤","L3","740101","1","740100"
329
+ "859512.SI","焦煤","L3","740102","1","740100"
330
+ "859521.SI","焦炭Ⅲ","L3","740201","1","740200"
331
+ "859611.SI","油气开采Ⅲ","L3","750101","0","750100"
332
+ "859621.SI","油田服务","L3","750201","1","750200"
333
+ "859622.SI","油气及炼化工程","L3","750202","1","750200"
334
+ "859631.SI","炼油化工","L3","750301","1","750300"
335
+ "859632.SI","油品石化贸易","L3","750302","1","750300"
336
+ "859633.SI","其他石化","L3","750303","1","750300"
337
+ "859711.SI","大气治理","L3","760101","1","760100"
338
+ "859712.SI","水务及水治理","L3","760102","1","760100"
339
+ "859713.SI","固废治理","L3","760103","1","760100"
340
+ "859714.SI","综合环境治理","L3","760104","1","760100"
341
+ "859721.SI","环保设备Ⅲ","L3","760201","1","760200"
342
+ "859811.SI","生活用纸","L3","770101","1","770100"
343
+ "859812.SI","洗护用品","L3","770102","0","770100"
344
+ "859821.SI","化妆品制造及其他","L3","770201","1","770200"
345
+ "859822.SI","品牌化妆品","L3","770202","1","770200"
346
+ "859831.SI","医美耗材","L3","770301","0","770300"
347
+ "859832.SI","医美服务","L3","770302","0","770300"
app.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ from main import run, add_to_queue, gradio_interface
7
+ import io
8
+ import sys
9
+ import time
10
+ import os
11
+ import pandas as pd
12
+ OPENAI_KEY = None
13
+ css = """#col-container {max-width: 90%; margin-left: auto; margin-right: auto; display: flex; flex-direction: column;}
14
+ #header {text-align: center;}
15
+ #col-chatbox {flex: 1; max-height: min(750px, 100%);}
16
+ #label {font-size: 4em; padding: 0.5em; margin: 0;}
17
+ .scroll-hide {overflow-y: scroll; max-height: 100px;}
18
+ .wrap {max-height: 680px;}
19
+ .message {font-size: 3em;}
20
+ .message-wrap {max-height: min(700px, 100vh);}
21
+ body {
22
+ background-color: #ADD8E6;
23
+ }
24
+ """
25
+
26
+ # plt.rcParams['font.sans-serif'] = ['Arial Unicode MS']
27
+ # plt.rcParams['axes.unicode_minus'] = False
28
+
29
+ plt.rcParams['font.sans-serif'] = ['WenQuanYi Zen Hei', 'Noto Sans CJK']
30
+ plt.rcParams['axes.unicode_minus'] = False
31
+
32
+
33
+ example_stock =['给我画一下可孚医疗2022年年中到今天的股价','北向资金今年的每日流入和累计流入','看一下近三年宁德时代和贵州茅台的pb变化','画一下五粮液和泸州老窖从2019年年初到2022年年中的收益率走势','成都银行近一年的k线图和kdj指标','比较下沪深300,创业板指,中证1000指数今年的收益率','今年上证50所有成分股的收益率是多少']
34
+ example_economic =['中国过去十年的cpi走势是什么','过去五年中国的货币供应量走势,并且打印保存','我想看看现在的新闻或者最新的消息','我想看看中国近十年gdp的走势','预测中国未来12个季度的GDP增速']
35
+ example_fund =['易方达的张坤管理了几个基金','基金经理周海栋名下的所有基金今年的收益率情况','我想看看周海栋管理的华商优势行业的近三年来的的净值曲线','比较下华商优势行业和易方达蓝筹精选这两只基金的近三年的收益率']
36
+ example_company =['介绍下贵州茅台,这公司是干什么的,主营业务是什么','我想比较下工商银行和贵州茅台近十年的净资产回报率','今年一季度上证50的成分股的归母净利润同比增速分别是']
37
+
38
+ class Client:
39
+ def __init__(self) -> None:
40
+ self.OPENAI_KEY = OPENAI_KEY
41
+ self.OPENAI_API_BASED_AZURE = None
42
+ self.OPENAI_ENGINE_AZURE = None
43
+ self.OPENAI_API_KEY_AZURE = None
44
+ self.stop = False # 添加停止标志
45
+
46
+ def set_key(self, openai_key, openai_key_azure, api_base_azure, engine_azure):
47
+ self.OPENAI_KEY = openai_key
48
+ self.OPENAI_API_BASED_AZURE = api_base_azure
49
+ self.OPENAI_API_KEY_AZURE = openai_key_azure
50
+ self.OPENAI_ENGINE_AZURE = engine_azure
51
+ return self.OPENAI_KEY, self.OPENAI_API_KEY_AZURE, self.OPENAI_API_BASED_AZURE, self.OPENAI_ENGINE_AZURE
52
+
53
+
54
+ def run(self, messages):
55
+ if self.OPENAI_KEY == '' and self.OPENAI_API_KEY_AZURE == '':
56
+ yield '', np.zeros((100, 100, 3), dtype=np.uint8), "Please set your OpenAI API key first!!!", pd.DataFrame()
57
+ elif len(self.OPENAI_KEY) >= 0 and not self.OPENAI_KEY.startswith('sk') and self.OPENAI_API_KEY_AZURE == '':
58
+ yield '', np.zeros((100, 100, 3), dtype=np.uint8), "Your openai key is incorrect!!!", pd.DataFrame()
59
+ else:
60
+ # self.stop = False
61
+ gen = gradio_interface(messages, self.OPENAI_KEY, self.OPENAI_API_KEY_AZURE, self.OPENAI_API_BASED_AZURE, self.OPENAI_ENGINE_AZURE)
62
+ while not self.stop: #
63
+ try:
64
+ yield next(gen)
65
+ except StopIteration:
66
+ print("StopIteration")
67
+ break
68
+
69
+ # yield from gradio_interface(messages, self.OPENAI_KEY)
70
+ #return finally_text, img, output, df
71
+
72
+
73
+
74
+
75
+
76
+ with gr.Blocks() as demo:
77
+ state = gr.State(value={"client": Client()})
78
+ def change_textbox(query):
79
+ # 根据不同输入对输出控件进行更新
80
+ return gr.update(lines=2, visible=True, value=query)
81
+ # 图片框显示
82
+
83
+ with gr.Row():
84
+ gr.Markdown(
85
+ """
86
+ # Hello Data-Copilot ! 😀
87
+ A powerful AI system connects humans and data.
88
+ The current version only supports Chinese financial data, in the future we will support for other country data
89
+ """)
90
+
91
+
92
+ if not OPENAI_KEY:
93
+ with gr.Row().style():
94
+ with gr.Column(scale=0.9):
95
+ gr.Markdown(
96
+ """
97
+ You can use gpt35 from openai or from openai-azure.
98
+ """)
99
+ openai_api_key = gr.Textbox(
100
+ show_label=False,
101
+ placeholder="Set your OpenAI API key here and press Submit (e.g. sk-xxx)",
102
+ lines=1,
103
+ type="password"
104
+ ).style(container=False)
105
+
106
+ with gr.Row():
107
+ openai_api_key_azure = gr.Textbox(
108
+ show_label=False,
109
+ placeholder="Set your Azure-OpenAI key",
110
+ lines=1,
111
+ type="password"
112
+ ).style(container=False)
113
+ openai_api_base_azure = gr.Textbox(
114
+ show_label=False,
115
+ placeholder="Azure-OpenAI api_base (e.g. https://zwq0525.openai.azure.com)",
116
+ lines=1,
117
+ type="password"
118
+ ).style(container=False)
119
+ openai_api_engine_azure = gr.Textbox(
120
+ show_label=False,
121
+ placeholder="Azure-OpenAI engine here (e.g. gpt35)",
122
+ lines=1,
123
+ type="password"
124
+ ).style(container=False)
125
+
126
+
127
+ gr.Markdown(
128
+ """
129
+ It is recommended to use the Openai paid API or Azure-OpenAI service, because the free Openai API will be limited by the access speed and 3 Requests per minute (very slow).
130
+ """)
131
+
132
+
133
+ with gr.Column(scale=0.1, min_width=0):
134
+ btn1 = gr.Button("OK").style(height= '100px')
135
+
136
+ with gr.Row():
137
+ with gr.Column(scale=0.9):
138
+ input_text = gr.inputs.Textbox(lines=1, placeholder='Please input your problem...', label='what do you want to find?')
139
+
140
+ with gr.Column(scale=0.1, min_width=0):
141
+ start_btn = gr.Button("Start").style(full_height=True)
142
+ # end_btn = gr.Button("Stop").style(full_height=True)
143
+
144
+
145
+ gr.Markdown(
146
+ """
147
+ # Try these examples ➡️➡️
148
+ """)
149
+ with gr.Row():
150
+
151
+ example_selector1 = gr.Dropdown(choices=example_stock, interactive=True,
152
+ label="查股票 Query stock:", show_label=True)
153
+ example_selector2 = gr.Dropdown(choices=example_economic, interactive=True,
154
+ label="查经济 Query Economy:", show_label=True)
155
+ example_selector3 = gr.Dropdown(choices=example_company, interactive=True,
156
+ label="查公司 Query Company:", show_label=True)
157
+ example_selector4 = gr.Dropdown(choices=example_fund, interactive=True,
158
+ label="查基金 Query Fund:", show_label=True)
159
+
160
+
161
+
162
+ def set_key(state, openai_api_key,openai_api_key_azure, openai_api_base_azure, openai_api_engine_azure):
163
+ return state["client"].set_key(openai_api_key, openai_api_key_azure,openai_api_base_azure, openai_api_engine_azure)
164
+
165
+
166
+ def run(state, chatbot):
167
+ generator = state["client"].run(chatbot)
168
+ for solving_step, img, res, df in generator:
169
+ # if state["client"].stop:
170
+ # print('Stopping generation')
171
+ # break
172
+ yield solving_step, img, res, df
173
+
174
+
175
+ # def stop(state):
176
+ # print('Stop signal received!')
177
+ # state["client"].stop = True
178
+
179
+
180
+
181
+
182
+ with gr.Row():
183
+ with gr.Column(scale=0.3, min_width="500px", max_width="500px", min_height="500px", max_height="500px"):
184
+ Res = gr.Textbox(label="Summary and Result: ")
185
+ with gr.Column(scale=0.7, min_width="500px", max_width="500px", min_height="500px", max_height="500px"):
186
+ solving_step = gr.Textbox(label="Solving Step: ", lines=5)
187
+
188
+
189
+ img = gr.outputs.Image(type='numpy')
190
+ df = gr.outputs.Dataframe(type='pandas')
191
+ with gr.Row():
192
+ gr.Markdown(
193
+ """
194
+ [Tushare](https://tushare.pro/) provides financial data support for our Data-Copilot.
195
+
196
+ [OpenAI](https://openai.com/) provides the powerful Chatgpt model for our Data-Copilot.
197
+ """)
198
+
199
+
200
+ outputs = [solving_step ,img, Res, df]
201
+ #设置change事件
202
+ example_selector1.change(fn = change_textbox, inputs = example_selector1, outputs = input_text)
203
+ example_selector2.change(fn = change_textbox, inputs = example_selector2, outputs = input_text)
204
+ example_selector3.change(fn = change_textbox, inputs = example_selector3, outputs = input_text)
205
+ example_selector4.change(fn = change_textbox, inputs = example_selector4, outputs = input_text)
206
+
207
+
208
+ if not OPENAI_KEY:
209
+ openai_api_key.submit(set_key, [state, openai_api_key, openai_api_key_azure,openai_api_base_azure, openai_api_engine_azure], [openai_api_key, openai_api_key_azure,openai_api_base_azure, openai_api_engine_azure])
210
+ btn1.click(set_key, [state, openai_api_key, openai_api_key_azure,openai_api_base_azure, openai_api_engine_azure], [openai_api_key,openai_api_key_azure, openai_api_base_azure, openai_api_engine_azure])
211
+
212
+ start_btn.click(fn = run, inputs = [state, input_text], outputs=outputs)
213
+ # end_btn.click(stop, state)
214
+
215
+
216
+
217
+ demo.queue()
218
+ demo.launch()
219
+
220
+
data_copilot.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea54a5f7129656e16353f7bded4041dec38c000d43d8ecf5fdd5fd1fb6c16802
3
+ size 188585873
demo1.png ADDED

Git LFS Details

  • SHA256: 285716bd79be5a5c096d461c24346e2a3293463e1b71bf5a47e222406426dd66
  • Pointer size: 132 Bytes
  • Size of remote file: 1.65 MB
flowchart.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Read the source code of main.py, output the figure below, hoping to help friends who want to read the source code, also hoping to correct any errors. 😊
2
+ ![image](https://github.com/zwq2018/Data-Copilot/assets/7806683/14dcafc1-fd5a-4593-964f-ba6119c37d23)
lab_gpt4_call.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import requests
3
+ import openai
4
+ import tiktoken
5
+ import os
6
+ import time
7
+ from functools import wraps
8
+
9
+ import threading
10
+
11
+
12
+ def retry(exception_to_check, tries=3, delay=5, backoff=1):
13
+ """
14
+ Decorator used to automatically retry a failed function. Parameters:
15
+
16
+ exception_to_check: The type of exception to catch.
17
+ tries: Maximum number of retry attempts.
18
+ delay: Waiting time between each retry.
19
+ backoff: Multiplicative factor to increase the waiting time after each retry.
20
+ """
21
+
22
+ def deco_retry(f):
23
+ @wraps(f)
24
+ def f_retry(*args, **kwargs):
25
+ mtries, mdelay = tries, delay
26
+ while mtries > 1:
27
+ try:
28
+ return f(*args, **kwargs)
29
+ except exception_to_check as e:
30
+ print(f"{str(e)}, Retrying in {mdelay} seconds...")
31
+ time.sleep(mdelay)
32
+ mtries -= 1
33
+ mdelay *= backoff
34
+ return f(*args, **kwargs)
35
+
36
+ return f_retry # true decorator
37
+
38
+ return deco_retry
39
+
40
+ def timeout_decorator(timeout):
41
+ class TimeoutException(Exception):
42
+ pass
43
+
44
+ def decorator(func):
45
+ @wraps(func)
46
+ def wrapper(*args, **kwargs):
47
+ result = [TimeoutException('Function call timed out')] # Nonlocal mutable variable
48
+ def target():
49
+ try:
50
+ result[0] = func(*args, **kwargs)
51
+ except Exception as e:
52
+ result[0] = e
53
+
54
+ thread = threading.Thread(target=target)
55
+ thread.start()
56
+ thread.join(timeout)
57
+ if thread.is_alive():
58
+ print(f"Function {func.__name__} timed out, retrying...")
59
+ return wrapper(*args, **kwargs)
60
+ if isinstance(result[0], Exception):
61
+ raise result[0]
62
+ return result[0]
63
+ return wrapper
64
+ return decorator
65
+
66
+
67
+ def send_chat_request(request):
68
+ endpoint = 'http://10.15.82.10:8006/v1/chat/completions'
69
+ model = 'gpt-3.5-turbo'
70
+ # gpt4 gpt4-32k和gpt-3.5-turbo
71
+ headers = {
72
+ 'Content-Type': 'application/json',
73
+ }
74
+ temperature = 0.7
75
+ top_p = 0.95
76
+ frequency_penalty = 0
77
+ presence_penalty = 0
78
+ max_tokens = 8000
79
+ stream = False
80
+ stop = None
81
+ messages = [{"role": "user", "content": request}]
82
+ data = {
83
+ 'model': model,
84
+ 'messages': messages,
85
+ 'temperature': temperature,
86
+ 'top_p': top_p,
87
+ 'frequency_penalty': frequency_penalty,
88
+ 'presence_penalty': presence_penalty,
89
+ 'max_tokens': max_tokens,
90
+ 'stream': stream,
91
+ 'stop': stop,
92
+ }
93
+
94
+ response = requests.post(endpoint, headers=headers, data=json.dumps(data))
95
+
96
+ if response.status_code == 200:
97
+ data = json.loads(response.text)
98
+ data_res = data['choices'][0]['message']
99
+
100
+ return data_res
101
+ else:
102
+ raise Exception(f"Request failed with status code {response.status_code}: {response.text}")
103
+
104
+ def num_tokens_from_string(string: str, encoding_name: str) -> int:
105
+ """Returns the number of tokens in a text string."""
106
+ encoding = tiktoken.get_encoding(encoding_name)
107
+ num_tokens = len(encoding.encode(string))
108
+ print('num_tokens:',num_tokens)
109
+ return num_tokens
110
+
111
+ @timeout_decorator(45)
112
+ def send_chat_request_Azure(query, openai_key, api_base, engine):
113
+ openai.api_type = "azure"
114
+ openai.api_version = "2023-03-15-preview"
115
+
116
+ openai.api_base = api_base
117
+ openai.api_key = openai_key
118
+
119
+
120
+ max_token_num = 8000 - num_tokens_from_string(query,'cl100k_base')
121
+ #
122
+ openai.api_request_timeout = 1 # 设置超时时间为10秒
123
+
124
+ response = openai.ChatCompletion.create(
125
+ engine = engine,
126
+ messages=[{"role": "system", "content": "You are an useful AI assistant that helps people solve the problem step by step."},
127
+ {"role": "user", "content": "" + query}],
128
+ temperature=0.0,
129
+ max_tokens=max_token_num,
130
+ top_p=0.95,
131
+ frequency_penalty=0,
132
+ presence_penalty=0,
133
+ stop=None)
134
+
135
+
136
+
137
+ data_res = response['choices'][0]['message']['content']
138
+ return data_res
139
+ #Note: The openai-python library support for Azure OpenAI is in preview.
140
+
141
+
142
+
143
+ @retry(Exception, tries=10, delay=20, backoff=2)
144
+ @timeout_decorator(45)
145
+ def send_official_call(query, openai_key='', api_base='', engine=''):
146
+ start = time.time()
147
+ # 转换成可阅读的时间
148
+ start = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start))
149
+ print(start)
150
+ openai.api_key = openai_key
151
+
152
+ response = openai.ChatCompletion.create(
153
+ # engine="gpt35",
154
+ model="gpt-3.5-turbo",
155
+ messages = [{"role": "system", "content": "You are an useful AI assistant that helps people solve the problem step by step."},
156
+ {"role": "user", "content": "" + query}],
157
+ #max_tokens=max_token_num,
158
+ temperature=0.1,
159
+ top_p=0.1,
160
+ frequency_penalty=0,
161
+ presence_penalty=0,
162
+ stop=None)
163
+
164
+ data_res = response['choices'][0]['message']['content']
165
+ return data_res
166
+
167
+
168
+
169
+
170
+
171
+
172
+
173
+
main.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 导入tushare
2
+ import tushare as ts
3
+ import matplotlib.pyplot as plt
4
+ import pandas as pd
5
+ import os
6
+ import json
7
+ from matplotlib.ticker import MaxNLocator
8
+ import matplotlib.font_manager as fm
9
+ from lab_gpt4_call import send_chat_request,send_chat_request_Azure,send_official_call
10
+ #import ast
11
+ import re
12
+ from tool import *
13
+ import tiktoken
14
+ import concurrent.futures
15
+ import datetime
16
+ from PIL import Image
17
+ from io import BytesIO
18
+ import queue
19
+ import datetime
20
+ from threading import Thread
21
+ plt.rcParams['font.sans-serif'] = ['Arial Unicode MS']
22
+ plt.rcParams['axes.unicode_minus'] = False
23
+ import openai
24
+
25
+
26
+ # To override the Thread method
27
+ class MyThread(Thread):
28
+
29
+ def __init__(self, target, args):
30
+ super(MyThread, self).__init__()
31
+ self.func = target
32
+ self.args = args
33
+
34
+ def run(self):
35
+ self.result = self.func(*self.args)
36
+
37
+ def get_result(self):
38
+ return self.result
39
+
40
+
41
+
42
+
43
+ def parse_and_exe(call_dict, result_buffer, parallel_step: str='1'):
44
+ """
45
+ Parse the input and call the corresponding function to obtain the result.
46
+ :param call_dict: dict, including arg, func, and output
47
+ :param result_buffer: dict, storing the corresponding intermediate results
48
+ :param parallel_step: int, parallel step
49
+ :return: Returns func(arg) and stores the corresponding result in result_buffer.
50
+ """
51
+ arg_list = call_dict['arg' + parallel_step]
52
+ replace_arg_list = [result_buffer[item][0] if isinstance(item, str) and ('result' in item or 'input' in item) else item for item in arg_list] # 参数
53
+ func_name = call_dict['function' + parallel_step] #
54
+ output = call_dict['output' + parallel_step] #
55
+ desc = call_dict['description' + parallel_step] #
56
+ if func_name == 'loop_rank':
57
+ replace_arg_list[1] = eval(replace_arg_list[1])
58
+ result = eval(func_name)(*replace_arg_list)
59
+ result_buffer[output] = (result, desc) # 'result1': (df1, desc)
60
+ return result_buffer
61
+
62
+ def load_tool_and_prompt(tool_lib, tool_prompt ):
63
+ '''
64
+ Read two JSON files.
65
+ :param tool_lib: Tool description
66
+ :param tool_prompt: Tool prompt
67
+ :return: Flattened prompt
68
+ '''
69
+ #
70
+
71
+ with open(tool_lib, 'r') as f:
72
+ tool_lib = json.load(f)
73
+
74
+ with open(tool_prompt, 'r') as f:
75
+ #
76
+ tool_prompt = json.load(f)
77
+
78
+ for key, value in tool_lib.items():
79
+ tool_prompt["Function Library:"] = tool_prompt["Function Library:"] + key + " " + value+ '\n\n'
80
+
81
+
82
+ prompt_flat = ''
83
+
84
+ for key, value in tool_prompt.items():
85
+ prompt_flat = prompt_flat + key +' '+ value + '\n\n'
86
+
87
+
88
+ return prompt_flat
89
+
90
+ # callback function
91
+ intermediate_results = queue.Queue() # Create a queue to store intermediate results.
92
+
93
+ def add_to_queue(intermediate_result):
94
+ intermediate_results.put(f"After planing, the intermediate result is {intermediate_result}")
95
+
96
+
97
+
98
+ def check_RPM(run_time_list, new_time, max_RPM=1):
99
+ # Check if there are already 3 timestamps in the run_time_list, with a maximum of 3 accesses per minute.
100
+ # False means no rest is needed, True means rest is needed.
101
+ if len(run_time_list) < 3:
102
+ run_time_list.append(new_time)
103
+ return 0
104
+ else:
105
+ if (new_time - run_time_list[0]).seconds < max_RPM:
106
+ # Calculate the required rest time.
107
+ sleep_time = 60 - (new_time - run_time_list[0]).seconds
108
+ print('sleep_time:', sleep_time)
109
+ run_time_list.pop(0)
110
+ run_time_list.append(new_time)
111
+ return sleep_time
112
+ else:
113
+ run_time_list.pop(0)
114
+ run_time_list.append(new_time)
115
+ return 0
116
+
117
+ def run(instruction, add_to_queue=None, send_chat_request_Azure = send_official_call, openai_key = '', api_base='', engine=''):
118
+ output_text = ''
119
+ ################################# Step-1:Task select ###########################################
120
+ current_time = datetime.datetime.now()
121
+ formatted_time = current_time.strftime("%Y-%m-%d")
122
+ # If the time has not exceeded 3 PM, use yesterday's data.
123
+ if current_time.hour < 15:
124
+ formatted_time = (current_time - datetime.timedelta(days=1)).strftime("%Y-%m-%d")
125
+
126
+ print('===============================Intent Detecting===========================================')
127
+ with open('./prompt_lib/prompt_intent_detection.json', 'r') as f:
128
+ prompt_task_dict = json.load(f)
129
+ prompt_intent_detection = ''
130
+ for key, value in prompt_task_dict.items():
131
+ prompt_intent_detection = prompt_intent_detection + key + ": " + value+ '\n\n'
132
+
133
+ prompt_intent_detection = prompt_intent_detection + '\n\n' + 'Instruction:' + '今天的日期是'+ formatted_time +', '+ instruction + ' ###New Instruction: '
134
+ # Record the running time.
135
+ # current_time = datetime.datetime.now()
136
+ # sleep_time = check_RPM(run_time, current_time)
137
+ # if sleep_time > 0:
138
+ # time.sleep(sleep_time)
139
+ response = send_chat_request_Azure(prompt_intent_detection, openai_key=openai_key, api_base=api_base, engine=engine)
140
+
141
+
142
+
143
+
144
+ new_instruction = response
145
+ print('new_instruction:', new_instruction)
146
+ output_text = output_text + '\n======Intent Detecting Stage=====\n\n'
147
+ output_text = output_text + new_instruction +'\n\n'
148
+
149
+ if add_to_queue is not None:
150
+ add_to_queue(output_text)
151
+
152
+ event_happen = True
153
+ print('===============================Task Planing===========================================')
154
+ output_text= output_text + '=====Task Planing Stage=====\n\n'
155
+
156
+ with open('./prompt_lib/prompt_task.json', 'r') as f:
157
+ prompt_task_dict = json.load(f)
158
+ prompt_task = ''
159
+ for key, value in prompt_task_dict.items():
160
+ prompt_task = prompt_task + key + ": " + value+ '\n\n'
161
+
162
+ prompt_task = prompt_task + '\n\n' + 'Instruction:' + new_instruction + ' ###Plan:'
163
+ # current_time = datetime.datetime.now()
164
+ # sleep_time = check_RPM(run_time, current_time)
165
+ # if sleep_time > 0:
166
+ # time.sleep(sleep_time)
167
+
168
+ response = send_chat_request_Azure(prompt_task, openai_key=openai_key,api_base=api_base,engine=engine)
169
+
170
+ task_select = response
171
+ pattern = r"(task\d+=)(\{[^}]*\})"
172
+ matches = re.findall(pattern, task_select)
173
+ task_plan = {}
174
+ for task in matches:
175
+ task_step, task_select = task
176
+ task_select = task_select.replace("'", "\"") # Replace single quotes with double quotes.
177
+ task_select = json.loads(task_select)
178
+ task_name = list(task_select.keys())[0]
179
+ task_instruction = list(task_select.values())[0]
180
+
181
+ task_plan[task_name] = task_instruction
182
+
183
+ # task_plan
184
+ for key, value in task_plan.items():
185
+ print(key, ':', value)
186
+ output_text = output_text + key + ': ' + str(value) + '\n'
187
+
188
+ output_text = output_text +'\n'
189
+ if add_to_queue is not None:
190
+ add_to_queue(output_text)
191
+
192
+
193
+
194
+ ################################# Step-2:Tool select and use ###########################################
195
+ print('===============================Tool select and using Stage===========================================')
196
+ output_text = output_text + '======Tool select and using Stage======\n\n'
197
+ # Read the task_select JSON file name.
198
+ task_name = list(task_plan.keys())[0].split('_task')[0]
199
+ task_instruction = list(task_plan.values())[0]
200
+
201
+ tool_lib = './tool_lib/' + 'tool_' + task_name + '.json'
202
+ tool_prompt = './prompt_lib/' + 'prompt_' + task_name + '.json'
203
+ prompt_flat = load_tool_and_prompt(tool_lib, tool_prompt)
204
+ prompt_flat = prompt_flat + '\n\n' +'Instruction :'+ task_instruction+ ' ###Function Call'
205
+
206
+ #response = "step1={\n \"arg1\": [\"贵州茅台\"],\n \"function1\": \"get_stock_code\",\n \"output1\": \"result1\"\n},step2={\n \"arg1\": [\"result1\",\"20180123\",\"20190313\",\"daily\"],\n \"function1\": \"get_stock_prices_data\",\n \"output1\": \"result2\"\n},step3={\n \"arg1\": [\"result2\",\"close\"],\n \"function1\": \"calculate_stock_index\",\n \"output1\": \"result3\"\n}, ###Output:{\n \"贵州茅台在2018年1月23日到2019年3月13的每日收盘价格的时序表格\": \"result3\",\n}"
207
+ # current_time = datetime.datetime.now()
208
+ # sleep_time = check_RPM(run_time, current_time)
209
+ # if sleep_time > 0:
210
+ # time.sleep(sleep_time)
211
+
212
+ response = send_chat_request_Azure(prompt_flat, openai_key=openai_key,api_base=api_base, engine=engine)
213
+
214
+ #response = "Function Call:step1={\n \"arg1\": [\"五粮液\"],\n \"function1\": \"get_stock_code\",\n \"output1\": \"result1\",\n \"arg2\": [\"泸州老窖\"],\n \"function2\": \"get_stock_code\",\n \"output2\": \"result2\"\n},step2={\n \"arg1\": [\"result1\",\"20190101\",\"20220630\",\"daily\"],\n \"function1\": \"get_stock_prices_data\",\n \"output1\": \"result3\",\n \"arg2\": [\"result2\",\"20190101\",\"20220630\",\"daily\"],\n \"function2\": \"get_stock_prices_data\",\n \"output2\": \"result4\"\n},step3={\n \"arg1\": [\"result3\",\"Cumulative_Earnings_Rate\"],\n \"function1\": \"calculate_stock_index\",\n \"output1\": \"result5\",\n \"arg2\": [\"result4\",\"Cumulative_Earnings_Rate\"],\n \"function2\": \"calculate_stock_index\",\n \"output2\": \"result6\"\n}, ###Output:{\n \"五粮液在2019年1月1日到2022年06月30的每日收盘价格时序表格\": \"result5\",\n \"泸州老窖在2019年1月1日到2022年06月30的每日收盘价格时序表格\": \"result6\"\n}"
215
+ call_steps, _ = response.split('###')
216
+ pattern = r"(step\d+=)(\{[^}]*\})"
217
+ matches = re.findall(pattern, call_steps)
218
+ result_buffer = {} # The stored format is as follows: {'result1': (000001.SH, 'Stock code of China Ping An'), 'result2': (df2, 'Stock data of China Ping An from January to June 2021')}.
219
+ output_buffer = [] # Store the variable names [result5, result6] that will be passed as the final output to the next task.
220
+ # print(task_output)
221
+ #
222
+
223
+ for match in matches:
224
+ step, content = match
225
+ content = content.replace("'", "\"") # Replace single quotes with double quotes.
226
+ print('==================')
227
+ print("\n\nstep:", step)
228
+ print('content:',content)
229
+ call_dict = json.loads(content)
230
+ print('It has parallel steps:', len(call_dict) / 4)
231
+ output_text = output_text + step + ': ' + str(call_dict) + '\n\n'
232
+
233
+
234
+ # Execute the following code in parallel using multiple processes.
235
+ with concurrent.futures.ThreadPoolExecutor() as executor:
236
+ # Submit tasks to thread pool
237
+ futures = {executor.submit(parse_and_exe, call_dict, result_buffer, str(parallel_step))
238
+ for parallel_step in range(1, int(len(call_dict) / 4) + 1)}
239
+
240
+ # Collect results as they become available
241
+ for idx, future in enumerate(concurrent.futures.as_completed(futures)):
242
+ # Handle possible exceptions
243
+ try:
244
+ result = future.result()
245
+ # Print the current parallel step number.
246
+ print('parallel step:', idx+1)
247
+ # print(list(result[1].keys())[0])
248
+ # print(list(result[1].values())[0])
249
+ except Exception as exc:
250
+ print(f'Generated an exception: {exc}')
251
+
252
+ if step == matches[-1][0]:
253
+ # Current task's final step. Save the output of the final step.
254
+ for parallel_step in range(1, int(len(call_dict) / 4) + 1):
255
+ output_buffer.append(call_dict['output' + str(parallel_step)])
256
+ output_text = output_text + '\n'
257
+ if add_to_queue is not None:
258
+ add_to_queue(output_text)
259
+
260
+
261
+
262
+
263
+
264
+ ################################# Step-3:visualization ###########################################
265
+ print('===============================Visualization Stage===========================================')
266
+ output_text = output_text + '======Visualization Stage====\n\n'
267
+ task_name = list(task_plan.keys())[1].split('_task')[0] #visualization_task
268
+ #task_name = 'visualization'
269
+ task_instruction = list(task_plan.values())[1] #''
270
+
271
+
272
+ tool_lib = './tool_lib/' + 'tool_' + task_name + '.json'
273
+ tool_prompt = './prompt_lib/' + 'prompt_' + task_name + '.json'
274
+
275
+ result_buffer_viz={}
276
+ Previous_result = {}
277
+ for output_name in output_buffer:
278
+ rename = 'input'+ str(output_buffer.index(output_name)+1)
279
+ Previous_result[rename] = result_buffer[output_name][1]
280
+ result_buffer_viz[rename] = result_buffer[output_name]
281
+
282
+ prompt_flat = load_tool_and_prompt(tool_lib, tool_prompt)
283
+ prompt_flat = prompt_flat + '\n\n' +'Instruction: '+ task_instruction + ', Previous_result: '+ str(Previous_result) + ' ###Function Call'
284
+
285
+ # current_time = datetime.datetime.now()
286
+ # sleep_time = check_RPM(run_time, current_time)
287
+ # if sleep_time > 0:
288
+ # time.sleep(sleep_time)
289
+
290
+ response = send_chat_request_Azure(prompt_flat, openai_key=openai_key, api_base=api_base, engine=engine)
291
+ call_steps, _ = response.split('###')
292
+ pattern = r"(step\d+=)(\{[^}]*\})"
293
+ matches = re.findall(pattern, call_steps)
294
+ for match in matches:
295
+ step, content = match
296
+ content = content.replace("'", "\"") # Replace single quotes with double quotes.
297
+ print('==================')
298
+ print("\n\nstep:", step)
299
+ print('content:',content)
300
+ call_dict = json.loads(content)
301
+ print('It has parallel steps:', len(call_dict) / 4)
302
+ result_buffer_viz = parse_and_exe(call_dict, result_buffer_viz, parallel_step = '' )
303
+ output_text = output_text + step + ': ' + str(call_dict) + '\n\n'
304
+
305
+ if add_to_queue is not None:
306
+ add_to_queue(output_text)
307
+
308
+ finally_output = list(result_buffer_viz.values()) # plt.Axes
309
+
310
+ #
311
+ df = pd.DataFrame()
312
+ str_out = output_text + 'Finally result: '
313
+ for ax in finally_output:
314
+ if isinstance(ax[0], plt.Axes): # If the output is plt.Axes, display it.
315
+ plt.grid()
316
+ #plt.show()
317
+ str_out = str_out + ax[1]+ ':' + 'plt.Axes' + '\n\n'
318
+ #
319
+ elif isinstance(ax[0], pd.DataFrame):
320
+ df = ax[0]
321
+ str_out = str_out + ax[1]+ ':' + 'pd.DataFrame' + '\n\n'
322
+
323
+ else:
324
+ str_out = str_out + str(ax[1])+ ':' + str(ax[0]) + '\n\n'
325
+
326
+
327
+ #
328
+ print('===============================Summary Stage===========================================')
329
+ output_prompt = "请用第一人称总结一下整个任务规划和解决过程,并且输出结果,用[Task]表示每个规划任务,用\{function\}表示每个任务里调用的函数." + \
330
+ "示例1:###我用将您的问题拆分成两个任务,首先第一个任务[stock_task],我依次获取五粮液和贵州茅台从2013年5月20日到2023年5月20日的净资产回报率roe的时序数据. \n然后第二个任务[visualization_task],我用折线图绘制五粮液和贵州茅台从2013年5月20日到2023年5月20日的净资产回报率,并计算它们的平均值和中位数. \n\n在第一个任务中我分别使用了2个工具函数\{get_stock_code\},\{get_Financial_data_from_time_range\}获取到两只股票的roe数据,在第二个任务里我们使用折线图\{plot_stock_data\}工具函数来绘制他们的roe十年走势,最后并计算了两只股票十年ROE的中位数\{output_median_col\}和均值\{output_mean_col\}.\n\n最后贵州茅台的ROE的均值和中位数是\{\},{},五粮液的ROE的均值和中位数是\{\},\{\}###" + \
331
+ "示例2:###我用将您的问题拆分成两个任务,首先第一个任务[stock_task],我依次获取20230101到20230520这段时间北向资金每日净流入和每日累计流入时序数据,第二个任务是[visualization_task],因此我在同一张图里同时绘制北向资金20230101到20230520的每日净流入柱状图和每日累计流入的折线图 \n\n为了完成第一个任务中我分别使用了2个工具函数\{get_north_south_money\},\{calculate_stock_index\}分别获取到北上资金的每日净流入量和每日的累计净流入量,第二个任务里我们使用折线图\{plot_stock_data\}绘制来两个指标的变化走势.\n\n最后我们给您提供了包含两个指标的折线图和数据表格." + \
332
+ "示例3:###我用将您的问题拆分成两个任务,首先第一个任务[economic_task],我爬取了上市公司贵州茅台和其主营业务介绍信息. \n然后第二个任务[visualization_task],我用表格打印贵州茅台及其相关信息. \n\n在第一个任务中我分别使用了1个工具函数\{get_company_info\} 获取到贵州茅台的公司信息,在第二个任务里我们使用折线图\{print_save_table\}工具函数来输出表格.\n"
333
+ output_result = send_chat_request_Azure(output_prompt + str_out + '###', openai_key=openai_key, api_base=api_base,engine=engine)
334
+ print(output_result)
335
+ buf = BytesIO()
336
+ plt.savefig(buf, format='png')
337
+ buf.seek(0)
338
+ #
339
+ #
340
+ image = Image.open(buf)
341
+
342
+
343
+ return output_text, image, output_result, df
344
+
345
+
346
+ def gradio_interface(query, openai_key, openai_key_azure, api_base,engine):
347
+ # Create a new thread to run the function.
348
+ if openai_key.startswith('sk') and openai_key_azure == '':
349
+ print('send_official_call')
350
+ thread = MyThread(target=run, args=(query, add_to_queue, send_official_call, openai_key))
351
+ elif openai_key =='' and len(openai_key_azure)>0:
352
+ print('send_chat_request_Azure')
353
+ thread = MyThread(target=run, args=(query, add_to_queue, send_chat_request_Azure, openai_key_azure, api_base, engine))
354
+ thread.start()
355
+ placeholder_image = np.zeros((100, 100, 3), dtype=np.uint8) # Create a placeholder image.
356
+ placeholder_dataframe = pd.DataFrame() #
357
+
358
+ # Wait for the result of the calculate function and display the intermediate results simultaneously.
359
+ while thread.is_alive():
360
+ while not intermediate_results.empty():
361
+ yield intermediate_results.get(), placeholder_image, 'Running' , placeholder_dataframe # Use the yield keyword to return intermediate results in real-time
362
+ time.sleep(0.1) # Avoid excessive resource consumption.
363
+
364
+ finally_text, img, output, df = thread.get_result()
365
+ yield finally_text, img, output, df
366
+ # Return the final result.
367
+
368
+
369
+
370
+ instruction = '预测未来中国4个季度的GDP增长率'
371
+
372
+ if __name__ == '__main__':
373
+
374
+ # 初始化pro接口
375
+ #openai_call = send_chat_request_Azure #
376
+ openai_call = send_official_call #
377
+ openai_key = os.getenv("OPENAI_KEY")
378
+
379
+
380
+
381
+ output, image, df , output_result = run(instruction, send_chat_request_Azure = openai_call, openai_key=openai_key, api_base='', engine='')
382
+ print(output_result)
383
+ plt.show()
384
+
385
+
386
+
387
+
388
+
389
+
390
+
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas==1.4.3
2
+ matplotlib==3.5.2
3
+ matplotlib-inline==0.1.3
4
+ numpy==1.22.4
5
+ Pillow==9.1.1
6
+ tushare==1.2.89
7
+ mplfinance==0.12.9b7
8
+ typed-ast==1.5.4
9
+ typer==0.4.0
10
+ typing_extensions==4.5.0
11
+ scikit-learn==1.0
12
+ scipy==1.7.3
13
+ tiktoken==0.4.0
14
+ openai==0.27.0
15
+ requests==2.28.0
16
+ requests-oauthlib==1.3.0
tool.py ADDED
@@ -0,0 +1,1931 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tushare as ts
2
+ import matplotlib.pyplot as plt
3
+ import pandas as pd
4
+ import os
5
+ import random
6
+ from matplotlib.ticker import MaxNLocator
7
+ #from prettytable import PrettyTable
8
+ #from blessed import Terminal
9
+ import time
10
+ from datetime import datetime, timedelta
11
+ import numpy as np
12
+ import mplfinance as mpf
13
+
14
+ from typing import Optional
15
+ import matplotlib.font_manager as fm
16
+ from matplotlib.lines import Line2D
17
+ from typing import Union, Any
18
+ from sklearn.linear_model import LinearRegression
19
+
20
+
21
+ # plt.rcParams['font.sans-serif'] = ['Arial Unicode MS']
22
+ # plt.rcParams['axes.unicode_minus'] = False
23
+
24
+
25
+ font_path = './fonts/SimHei.ttf'
26
+ font_prop = fm.FontProperties(fname=font_path)
27
+
28
+
29
+ tushare_token = os.getenv('TUSHARE_TOKEN')
30
+ tushare_token = 'ssss'
31
+ pro = ts.pro_api(tushare_token)
32
+
33
+ # def last_month_end(date_str:str=''):
34
+ # date_obj = datetime.strptime(date_str, '%Y%m%d')
35
+ # current_month = date_obj.month
36
+ # current_year = date_obj.year
37
+ #
38
+ # if current_month == 1:
39
+ # last_month = 12
40
+ # last_year = current_year - 1
41
+ # else:
42
+ # last_month = current_month - 1
43
+ # last_year = current_year
44
+ #
45
+ # if date_obj.month != (date_obj + timedelta(days=1)).month:
46
+ # last_month_end_date = date_obj
47
+ # else:
48
+ # last_day_of_last_month = (date_obj.replace(day=1) - timedelta(days=1)).day
49
+ # last_month_end_date = datetime(last_year, last_month, last_day_of_last_month)
50
+ #
51
+ # return last_month_end_date.strftime('%Y%m%d')
52
+
53
+
54
+
55
+ def get_last_year_date(date_str: str = '') -> str:
56
+ """
57
+ This function takes a date string in the format YYYYMMDD and returns the date string one year prior to the input date.
58
+
59
+ Args:
60
+ - date_str: string, the input date in the format YYYYMMDD
61
+
62
+ Returns:
63
+ - string, the date one year prior to the input date in the format YYYYMMDD
64
+ """
65
+ dt = datetime.strptime(date_str, '%Y%m%d')
66
+ # To calculate the date one year ago
67
+ one_year_ago = dt - timedelta(days=365)
68
+
69
+ # To format the date as a string
70
+ one_year_ago_str = one_year_ago.strftime('%Y%m%d')
71
+
72
+ return one_year_ago_str
73
+
74
+
75
+ def get_adj_factor(stock_code: str = '', start_date: str = '', end_date: str = '') -> pd.DataFrame:
76
+ # Get stock price adjustment factors. Retrieve the stock price adjustment factors for a single stock's entire historical data or for all stocks on a single trading day.
77
+ # The input includes the stock code, start date, end date, and trading date, all in string format with the date in the YYYYMMDD format
78
+ # The return value is a dataframe containing the stock code, trading date, and adjustment factor
79
+ # ts_code str 股票代码
80
+ # adj_factor float 复权因子
81
+ """
82
+ This function retrieves the adjusted stock prices for a given stock code and date range.
83
+
84
+ Args:
85
+ - stock_code: string, the stock code to retrieve data for
86
+ - start_date: string, the start date in the format YYYYMMDD
87
+ - end_date: string, the end date in the format YYYYMMDD
88
+
89
+ Returns:
90
+ - dataframe, a dataframe containing the stock code, trade date, and adjusted factor
91
+
92
+ This will retrieve the adjusted stock prices for the stock with code '000001.SZ' between the dates '20220101' and '20220501'.
93
+ """
94
+ df = pro.adj_factor(**{
95
+ "ts_code": stock_code,
96
+ "trade_date": "",
97
+ "start_date": start_date,
98
+ "end_date": end_date,
99
+ "limit": "",
100
+ "offset": ""
101
+ }, fields=[
102
+ "ts_code",
103
+ "trade_date",
104
+ "adj_factor"
105
+ ])
106
+
107
+ return df
108
+
109
+ def get_stock_code(stock_name: str) -> str:
110
+ # Retrieve the stock code of a given stock name. If we call get_stock_code('贵州茅台'), it will return '600519.SH'.
111
+
112
+
113
+ df = pd.read_csv('tushare_stock_basic_20230421210721.csv')
114
+ try:
115
+ code = df.loc[df.name==stock_name].ts_code.iloc[0]
116
+ return code
117
+ except:
118
+ return None
119
+
120
+
121
+
122
+
123
+ def get_stock_name_from_code(stock_code: str) -> str:
124
+ """
125
+ Reads a local file to retrieve the stock name from a given stock code.
126
+
127
+ Args:
128
+ - stock_code (str): The code of the stock.
129
+
130
+ Returns:
131
+ - str: The stock name of the given stock code.
132
+ """
133
+ # For example,if we call get_stock_name_from_code('600519.SH'), it will return '贵州茅台'.
134
+
135
+
136
+ df = pd.read_csv('tushare_stock_basic_20230421210721.csv')
137
+ name = df.loc[df.ts_code == stock_code].name.iloc[0]
138
+
139
+ return name
140
+
141
+ def get_stock_prices_data(stock_name: str='', start_date: str='', end_date: str='', freq:str='daily') -> pd.DataFrame:
142
+ """
143
+ Retrieves the daily/weekly/monthly price data for a given stock code during a specific time period. get_stock_prices_data('贵州茅台','20200120','20220222','daily')
144
+
145
+ Args:
146
+ - stock_name (str)
147
+ - start_date (str): The start date in the format 'YYYYMMDD'.
148
+ - end_date (str): The end date in 'YYYYMMDD'.
149
+ - freq (str): The frequency of the price data, can be 'daily', 'weekly', or 'monthly'.
150
+
151
+ Returns:
152
+ - pd.DataFrame: A dataframe that contains the daily/weekly/monthly data. The output columns contain stock_code, trade_date, open, high, low, close, pre_close(昨天收盘价), change(涨跌额), pct_chg(涨跌幅),vol(成交量),amount(成交额)
153
+ """
154
+
155
+ stock_code = get_stock_code(stock_name)
156
+
157
+ if freq == 'daily':
158
+ stock_data = pro.daily(**{
159
+ "ts_code": stock_code,
160
+ "trade_date": '',
161
+ "start_date": start_date,
162
+ "end_date": end_date,
163
+ "offset": "",
164
+ "limit": ""
165
+ }, fields=[
166
+ "ts_code",
167
+ "trade_date",
168
+ "open",
169
+ "high",
170
+ "low",
171
+ "close",
172
+ "pre_close",
173
+ "change",
174
+ "pct_chg",
175
+ "vol",
176
+ "amount"
177
+ ])
178
+
179
+ elif freq == 'weekly':
180
+ stock_data = pro.weekly(**{
181
+ "ts_code": stock_code,
182
+ "trade_date": '',
183
+ "start_date": start_date,
184
+ "end_date": end_date,
185
+ "limit": "",
186
+ "offset": ""
187
+ }, fields=[
188
+ "ts_code",
189
+ "trade_date",
190
+ "close",
191
+ "open",
192
+ "high",
193
+ "low",
194
+ "pre_close",
195
+ "change",
196
+ "pct_chg",
197
+ "vol",
198
+ "amount"
199
+ ])
200
+ elif freq == 'monthly':
201
+ stock_data = pro.monthly(**{
202
+ "ts_code": stock_code,
203
+ "trade_date": '',
204
+ "start_date": start_date,
205
+ "end_date": end_date,
206
+ "limit": "",
207
+ "offset": ""
208
+ }, fields=[
209
+ "ts_code",
210
+ "trade_date",
211
+ "close",
212
+ "open",
213
+ "high",
214
+ "low",
215
+ "pre_close",
216
+ "change",
217
+ "pct_chg",
218
+ "vol",
219
+ "amount"
220
+ ])
221
+
222
+
223
+ adj_f = get_adj_factor(stock_code, start_date, end_date)
224
+ stock_data = pd.merge(stock_data, adj_f, on=['ts_code', 'trade_date'])
225
+ # Multiply the values of open, high, low, and close by their corresponding adjustment factors.
226
+ # To obtain the adjusted close price
227
+ stock_data[['open', 'high', 'low', 'close']] *= stock_data['adj_factor'].values.reshape(-1, 1)
228
+
229
+ #stock_data.rename(columns={'vol': 'volume'}, inplace=True)
230
+ df = pd.read_csv('tushare_stock_basic_20230421210721.csv')
231
+ stock_data_merged = pd.merge(stock_data, df, on='ts_code')
232
+ stock_data_merged.rename(columns={'ts_code': 'stock_code'}, inplace=True)
233
+ stock_data_merged.rename(columns={'name': 'stock_name'}, inplace=True)
234
+ stock_data_merged = stock_data_merged.sort_values(by='trade_date', ascending=True) # To sort the DataFrame by date in ascending order
235
+ return stock_data_merged
236
+
237
+
238
+
239
+ def get_stock_technical_data(stock_name: str, start_date: str, end_date: str) -> pd.DataFrame:
240
+ """
241
+ Retrieves the daily technical data of a stock including macd turnover rate, volume, PE ratio, etc. Those technical indicators are usually plotted as subplots in a k-line chart.
242
+
243
+ Args:
244
+ stock_name (str):
245
+ start_date (str): Start date "YYYYMMDD"
246
+ end_date (str): End date "YYYYMMDD"
247
+
248
+ Returns:
249
+ pd.DataFrame: A DataFrame containing the technical data of the stock,
250
+ including various indicators such as ts_code, trade_date, close, macd_dif, macd_dea, macd, kdj_k, kdj_d, kdj_j, rsi_6, rsi_12, boll_upper, boll_mid, boll_lower, cci, turnover_rate, turnover_rate_f, volume_ratio, pe_ttm(市盈率), pb(市净率), ps_ttm, dv_ttm, total_share, float_share, free_share, total_mv, circ_mv
251
+
252
+ """
253
+
254
+ # Technical factors
255
+ stock_code = get_stock_code(stock_name)
256
+ stock_data1 = pro.stk_factor(**{
257
+ "ts_code": stock_code,
258
+ "start_date": start_date,
259
+ "end_date": end_date,
260
+ "trade_date": '',
261
+ "limit": "",
262
+ "offset": ""
263
+ }, fields=[
264
+ "ts_code",
265
+ "trade_date",
266
+ "close",
267
+ "macd_dif",
268
+ "macd_dea",
269
+ "macd",
270
+ "kdj_k",
271
+ "kdj_d",
272
+ "kdj_j",
273
+ "rsi_6",
274
+ "rsi_12",
275
+ "rsi_24",
276
+ "boll_upper",
277
+ "boll_mid",
278
+ "boll_lower",
279
+ "cci"
280
+ ])
281
+ # Trading factors
282
+ stock_data2 = pro.daily_basic(**{
283
+ "ts_code": stock_code,
284
+ "trade_date": '',
285
+ "start_date": start_date,
286
+ "end_date": end_date,
287
+ "limit": "",
288
+ "offset": ""
289
+ }, fields=[
290
+ "ts_code", #
291
+ "trade_date",
292
+ "turnover_rate",
293
+ "turnover_rate_f",
294
+ "volume_ratio",
295
+ "pe_ttm",
296
+ "pb",
297
+ "ps_ttm",
298
+ "dv_ttm",
299
+ "total_share",
300
+ "float_share",
301
+ "free_share",
302
+ "total_mv",
303
+ "circ_mv"
304
+ ])
305
+
306
+ #
307
+ stock_data = pd.merge(stock_data1, stock_data2, on=['ts_code', 'trade_date'])
308
+ df = pd.read_csv('tushare_stock_basic_20230421210721.csv')
309
+ stock_data_merged = pd.merge(stock_data, df, on='ts_code')
310
+ stock_data_merged = stock_data_merged.sort_values(by='trade_date', ascending=True)
311
+
312
+ stock_data_merged.drop(['symbol'], axis=1, inplace=True)
313
+
314
+ stock_data_merged.rename(columns={'ts_code': 'stock_code'}, inplace=True)
315
+ stock_data_merged.rename(columns={'name': 'stock_name'}, inplace=True)
316
+
317
+ return stock_data_merged
318
+
319
+
320
+ def plot_stock_data(stock_data: pd.DataFrame, ax: Optional[plt.Axes] = None, figure_type: str = 'line', title_name: str ='') -> plt.Axes:
321
+
322
+ """
323
+ This function plots stock data.
324
+
325
+ Args:
326
+ - stock_data: pandas DataFrame, the stock data to plot. The DataFrame should contain three columns:
327
+ - Column 1: trade date in 'YYYYMMDD'
328
+ - Column 2: Stock name or code (string format)
329
+ - Column 3: Index value (numeric format)
330
+ The DataFrame can be time series data or cross-sectional data. If it is time-series data, the first column represents different trade time, the second column represents the same name. For cross-sectional data, the first column is the same, the second column contains different stocks.
331
+
332
+ - ax: matplotlib Axes object, the axes to plot the data on
333
+ - figure_type: the type of figure (either 'line' or 'bar')
334
+ - title_name
335
+
336
+ Returns:
337
+ - matplotlib Axes object, the axes containing the plot
338
+ """
339
+
340
+ index_name = stock_data.columns[2]
341
+ name_list = stock_data.iloc[:,1]
342
+ date_list = stock_data.iloc[:,0]
343
+ if name_list.nunique() == 1 and date_list.nunique() != 1:
344
+ # Time Series Data
345
+ unchanged_var = name_list.iloc[0] # stock name
346
+ x_dim = date_list # tradingdate
347
+ x_name = stock_data.columns[0]
348
+
349
+ elif name_list.nunique() != 1 and date_list.nunique() == 1:
350
+ # Cross-sectional Data
351
+ unchanged_var = date_list.iloc[0] # tradingdate
352
+ x_dim = name_list # stock name
353
+ x_name = stock_data.columns[1]
354
+
355
+ data_size = x_dim.shape[0]
356
+
357
+
358
+
359
+ start_x_dim, end_x_dim = x_dim.iloc[0], x_dim.iloc[-1]
360
+
361
+ start_y = stock_data.iloc[0, 2]
362
+ end_y = stock_data.iloc[-1, 2]
363
+
364
+
365
+ def generate_random_color():
366
+ r = random.randint(0, 255)/ 255.0
367
+ g = random.randint(0, 100)/ 255.0
368
+ b = random.randint(0, 255)/ 255.0
369
+ return (r, g, b)
370
+
371
+ color = generate_random_color()
372
+ if ax is None:
373
+ _, ax = plt.subplots()
374
+
375
+ if figure_type =='line':
376
+ #
377
+
378
+ ax.plot(x_dim, stock_data.iloc[:, 2], label = unchanged_var+'_' + index_name, color=color,linewidth=3)
379
+ #
380
+ plt.scatter(x_dim, stock_data.iloc[:, 2], color=color,s=3) # Add markers to the data points
381
+
382
+ #
383
+ #ax.scatter(x_dim, stock_data.iloc[:, 2],label = unchanged_var+'_' + index_name, color=color, s=3)
384
+ #
385
+
386
+ ax.annotate(unchanged_var + ':' + str(round(start_y, 2)) + ' @' + start_x_dim, xy=(start_x_dim, start_y),
387
+ xytext=(start_x_dim, start_y),
388
+ textcoords='data', fontsize=14,color=color, horizontalalignment='right',fontproperties=font_prop)
389
+
390
+ ax.annotate(unchanged_var + ':' + str(round(end_y, 2)) +' @' + end_x_dim, xy=(end_x_dim, end_y),
391
+ xytext=(end_x_dim, end_y),
392
+ textcoords='data', fontsize=14, color=color, horizontalalignment='left',fontproperties=font_prop)
393
+
394
+
395
+ elif figure_type == 'bar':
396
+ ax.bar(x_dim, stock_data.iloc[:, 2], label = unchanged_var + '_' + index_name, width=0.3, color=color)
397
+ ax.annotate(unchanged_var + ':' + str(round(start_y, 2)) + ' @' + start_x_dim, xy=(start_x_dim, start_y),
398
+ xytext=(start_x_dim, start_y),
399
+ textcoords='data', fontsize=14, color=color, horizontalalignment='right',fontproperties=font_prop)
400
+
401
+ ax.annotate(unchanged_var + ':' + str(round(end_y, 2)) + ' @' + end_x_dim, xy=(end_x_dim, end_y),
402
+ xytext=(end_x_dim, end_y),
403
+ textcoords='data', fontsize=14, color=color, horizontalalignment='left',fontproperties=font_prop)
404
+
405
+ plt.xticks(x_dim,rotation=45) #
406
+ ax.xaxis.set_major_locator(MaxNLocator( integer=True, prune=None, nbins=100)) #
407
+
408
+
409
+ plt.xlabel(x_name, fontproperties=font_prop,fontsize=18)
410
+ plt.ylabel(f'{index_name}', fontproperties=font_prop,fontsize=16)
411
+ ax.set_title(title_name , fontproperties=font_prop,fontsize=16)
412
+ plt.legend(prop=font_prop) # 显示图例
413
+ fig = plt.gcf()
414
+ fig.set_size_inches(18, 12)
415
+
416
+ return ax
417
+
418
+
419
+ def query_fund_Manager(Manager_name: str) -> pd.DataFrame:
420
+ # 代码fund_code,公告日期ann_date,基金经理名字name,性别gender,出生年份birth_year,学历edu,国籍nationality,开始管理日期begin_date,结束日期end_date,简历resume
421
+ """
422
+ Retrieves information about a fund manager.
423
+
424
+ Args:
425
+ Manager_name (str): The name of the fund manager.
426
+
427
+ Returns:
428
+ df (DataFrame): A DataFrame containing the fund manager's information, including the fund codes, announcement dates,
429
+ manager's name, gender, birth year, education, nationality, start and end dates of managing funds,
430
+ and the manager's resume.
431
+ """
432
+
433
+ df = pro.fund_manager(**{
434
+ "ts_code": "",
435
+ "ann_date": "",
436
+ "name": Manager_name,
437
+ "offset": "",
438
+ "limit": ""
439
+ }, fields=[
440
+ "ts_code",
441
+ "ann_date",
442
+ "name",
443
+ "gender",
444
+ "birth_year",
445
+ "edu",
446
+ "nationality",
447
+ "begin_date",
448
+ "end_date",
449
+ "resume"
450
+ ])
451
+ #
452
+ df.rename(columns={'ts_code': 'fund_code'}, inplace=True)
453
+ # To query the fund name based on the fund code and store it in a new column called fund_name, while removing the rows where the fund name is not found
454
+ df['fund_name'] = df['fund_code'].apply(lambda x: query_fund_name_or_code('', x))
455
+ df.dropna(subset=['fund_name'], inplace=True)
456
+ df.rename(columns={'name': 'manager_name'}, inplace=True)
457
+ #
458
+ df_out = df[['fund_name','fund_code','ann_date','manager_name','begin_date','end_date']]
459
+
460
+ return df_out
461
+
462
+
463
+ # def save_stock_prices_to_csv(stock_prices: pd.DataFrame, stock_name: str, file_path: str) -> None:
464
+ #
465
+ # """
466
+ # Saves the price data of a specific stock symbol during a specific time period to a local CSV file.
467
+ #
468
+ # Args:
469
+ # - stock_prices (pd.DataFrame): A pandas dataframe that contains the daily price data for the given stock symbol during the specified time period.
470
+ # - stock_name (str): The name of the stock.
471
+ # - file_path (str): The file path where the CSV file will be saved.
472
+ #
473
+ # Returns:
474
+ # - None: The function only saves the CSV file to the specified file path.
475
+ # """
476
+ # # The function checks if the directory to save the CSV file exists and creates it if it does not exist.
477
+ # # The function then saves the price data of the specified stock symbol during the specified time period to a local CSV file with the name {stock_name}_price_data.csv in the specified file path.
478
+ #
479
+ #
480
+ # if not os.path.exists(file_path):
481
+ # os.makedirs(file_path)
482
+ #
483
+ #
484
+ # file_path = f"{file_path}{stock_name}_stock_prices.csv"
485
+ # stock_prices.to_csv(file_path, index_label='Date')
486
+ # print(f"Stock prices for {stock_name} saved to {file_path}")
487
+
488
+
489
+ def calculate_stock_index(stock_data: pd.DataFrame, index:str='close') -> pd.DataFrame:
490
+ """
491
+ Calculate a specific index of a stock based on its price information.
492
+
493
+ Args:
494
+ stock_data (pd.DataFrame): DataFrame containing the stock's price information.
495
+ index (str, optional): The index to calculate. The available options depend on the column names in the
496
+ input stock price data. Additionally, there are two special indices: 'candle_K' and 'Cumulative_Earnings_Rate'.
497
+
498
+ Returns:
499
+ DataFrame containing the corresponding index data of the stock. In general, it includes three columns: 'trade_date', 'name', and the corresponding index value.
500
+ Besides, if index is 'candle_K', the function returns the DataFrame containing 'trade_date', 'Open', 'High', 'Low', 'Close', 'Volume','name' column.
501
+ If index is a technical index such as 'macd' or a trading index likes 'pe_ttm', the function returns the DataFrame with corresponding columns.
502
+ """
503
+
504
+
505
+ if 'stock_name' not in stock_data.columns and 'index_name' in stock_data.columns:
506
+ stock_data.rename(columns={'index_name': 'stock_name'}, inplace=True)
507
+ #
508
+ index = index.lower()
509
+ if index=='Cumulative_Earnings_Rate' or index =='Cumulative_Earnings_Rate'.lower() :
510
+ stock_data[index] = (1 + stock_data['pct_chg'] / 100.).cumprod() - 1.
511
+ stock_data[index] = stock_data[index] * 100.
512
+ if 'stock_name' in stock_data.columns :
513
+ selected_index = stock_data[['trade_date', 'stock_name', index]].copy()
514
+ #
515
+ if 'fund_name' in stock_data.columns:
516
+ selected_index = stock_data[['trade_date', 'fund_name', index]].copy()
517
+ return selected_index
518
+
519
+ elif index == 'candle_K' or index == 'candle_K'.lower():
520
+ #tech_df = tech_df.drop(['name', 'symbol', 'industry', 'area','market','list_date','ts_code','close'], axis=1)
521
+ # Merge two DataFrames based on the 'trade_date' column.
522
+
523
+ stock_data = stock_data.rename(
524
+ columns={'open': 'Open', 'high': 'High', 'low': 'Low', 'close': 'Close',
525
+ 'vol': 'Volume'})
526
+ selected_index = stock_data[['trade_date', 'Open', 'High', 'Low', 'Close', 'Volume','stock_name']].copy()
527
+ return selected_index
528
+
529
+ elif index =='macd':
530
+ selected_index = stock_data[['trade_date','macd','macd_dea','macd_dif']].copy()
531
+ return selected_index
532
+
533
+ elif index =='rsi':
534
+ selected_index = stock_data[['trade_date','rsi_6','rsi_12']].copy()
535
+ return selected_index
536
+
537
+ elif index =='boll':
538
+ selected_index = stock_data[['trade_date', 'boll_upper', 'boll_lower','boll_mid']].copy()
539
+ return selected_index
540
+
541
+ elif index =='kdj':
542
+ selected_index = stock_data[['trade_date', 'kdj_k', 'kdj_d','kdj_j']].copy()
543
+ return selected_index
544
+
545
+ elif index =='cci':
546
+ selected_index = stock_data[['trade_date', 'cci']].copy()
547
+ return selected_index
548
+
549
+ elif index == '换手率':
550
+ selected_index = stock_data[['trade_date', 'turnover_rate','turnover_rate_f']].copy()
551
+ return selected_index
552
+
553
+ elif index == '市值':
554
+ selected_index = stock_data[['trade_date', 'total_mv','circ_mv']].copy()
555
+ return selected_index
556
+
557
+
558
+ elif index in stock_data.columns:
559
+ stock_data = stock_data
560
+
561
+ if 'stock_name' in stock_data.columns :
562
+ selected_index = stock_data[['trade_date', 'stock_name', index]].copy()
563
+
564
+ if 'fund_name' in stock_data.columns:
565
+ selected_index = stock_data[['trade_date', 'fund_name', index]].copy()
566
+ # Except for candlestick chart and technical indicators, the remaining outputs consist of three columns: date, name, and indicator.
567
+ return selected_index
568
+
569
+
570
+
571
+ def rank_index_cross_section(stock_data: pd.DataFrame, Top_k: int = -1, ascending: bool = False) -> pd.DataFrame:
572
+ """
573
+ Sort the cross-sectional data based on the given index.
574
+
575
+ Args:
576
+ stock_data : DataFrame containing the cross-sectional data. It should have three columns, and the last column represents the variable to be sorted.
577
+ Top_k : The number of data points to retain after sorting. (Default: -1, which retains all data points)
578
+ ascending: Whether to sort the data in ascending order or not. (Default: False)
579
+
580
+ Returns:
581
+ stock_data_selected : DataFrame containing the sorted data. It has the same structure as the input DataFrame.
582
+ """
583
+
584
+ index = stock_data.columns[-1]
585
+ stock_data = stock_data.sort_values(by=index, ascending=ascending)
586
+ #stock_data_selected = stock_data[['trade_date','stock_name', index]].copy()
587
+ stock_data_selected = stock_data[:Top_k]
588
+ stock_data_selected = stock_data_selected.drop_duplicates(subset=['stock_name'], keep='first')
589
+ return stock_data_selected
590
+
591
+
592
+ def get_company_info(stock_name: str='') -> pd.DataFrame:
593
+ # ts_code: str 股票代码, exchange:str 交易所代码SSE上交所 SZSE深交所, chairman:str 法人代表, manager:str 总经理, secretary:str 董秘 # reg_capital:float 注册资本, setup_date:str 注册日期, province:str 所在省份 ,city:str 所在城市
594
+ # introduction:str 公司介绍, website:str 公司主页 , email:str 电子邮件, office:str 办公室 # ann_date: str 公告日期, business_scope:str 经营范围, employees:int 员工人数, main_business:str 主要业务及产品
595
+ """
596
+ This function retrieves company information including stock code, exchange, chairman, manager, secretary,
597
+ registered capital, setup date, province, city, website, email, employees, business scope, main business,
598
+ introduction, office, and announcement date.
599
+
600
+ Args:
601
+ - stock_name (str): The name of the stock.
602
+
603
+ Returns:
604
+ - pd.DataFrame: A DataFrame that contains the company information.
605
+ """
606
+
607
+ stock_code = get_stock_code(stock_name)
608
+ df = pro.stock_company(**{
609
+ "ts_code": stock_code,"exchange": "","status": "", "limit": "","offset": ""
610
+ }, fields=[
611
+ "ts_code","exchange","chairman", "manager","secretary", "reg_capital","setup_date", "province","city",
612
+ "website", "email","employees","business_scope","main_business","introduction","office", "ann_date"
613
+ ])
614
+
615
+
616
+ en_to_cn = {
617
+ 'ts_code': '股票代码',
618
+ 'exchange': '交易所代码',
619
+ 'chairman': '法人代表',
620
+ 'manager': '总经理',
621
+ 'secretary': '董秘',
622
+ 'reg_capital': '注册资本',
623
+ 'setup_date': '注册日期',
624
+ 'province': '所在省份',
625
+ 'city': '所在城市',
626
+ 'introduction': '公司介绍',
627
+ 'website': '公司主页',
628
+ 'email': '电子邮件',
629
+ 'office': '办公室',
630
+ 'ann_date': '公告日期',
631
+ 'business_scope': '经营范围',
632
+ 'employees': '员工人数',
633
+ 'main_business': '主要业务及产品'
634
+ }
635
+
636
+ df.rename(columns=en_to_cn, inplace=True)
637
+ df.insert(0, '股票名称', stock_name)
638
+ # for column in df.columns:
639
+ # print(f"[{column}]: {df[column].values[0]}")
640
+
641
+
642
+ return df
643
+
644
+
645
+
646
+
647
+
648
+ # def get_Financial_data(stock_code: str, report_date: str, financial_index: str = '' ) -> pd.DataFrame:
649
+ # # report_date的格式为"YYYYMMDD",包括"yyyy0331"为一季报,"yyyy0630"为半年报,"yyyy0930"为三季报,"yyyy1231"为年报
650
+ # # index包含: # current_ratio 流动比率 # quick_ratio 速动比率 # netprofit_margin 销售净利率 # grossprofit_margin 销售毛利率 # roe 净��产收益率 # roe_dt 净资产收益率(扣除非经常损益)
651
+ # # roa 总资产报酬率 # debt_to_assets 资产负债率 # roa_yearly 年化总资产净利率 # q_dtprofit 扣除非经常损益后的单季度净利润 # q_eps 每股收益(单季度)
652
+ # # q_netprofit_margin 销售净利率(单季度) # q_gsprofit_margin 销售毛利率(单季度) # basic_eps_yoy 基本每股收益同比增长率(%) # netprofit_yoy 归属母公司股东的净利润同比增长率(%) # q_netprofit_yoy 归属母公司股东的净利润同比增长率(%)(单季度) # q_netprofit_qoq 归属母公司股东的净利润环比增长率(%)(单季度) # equity_yoy 净资产同比增长率
653
+ # """
654
+ # Retrieves financial data for a specific stock within a given date range.
655
+ #
656
+ # Args:
657
+ # stock_code (str): The stock code or symbol of the company for which financial data is requested.
658
+ # report_date (str): The report date in the format "YYYYMMDD" .
659
+ # financial_index (str, optional): The financial indicator to be queried. If not specified, all available financial
660
+ # indicators will be included.
661
+ #
662
+ # Returns:
663
+ # pd.DataFrame: A DataFrame containing the financial data for the specified stock and date range. The DataFrame
664
+ # consists of the following columns: "stock_name",
665
+ # "trade_date" (reporting period), and the requested financial indicator(s).
666
+ #
667
+ # """
668
+ # stock_data = pro.fina_indicator(**{
669
+ # "ts_code": stock_code,
670
+ # "ann_date": "",
671
+ # "start_date": '',
672
+ # "end_date": '',
673
+ # "period": report_date,
674
+ # "update_flag": "1",
675
+ # "limit": "",
676
+ # "offset": ""
677
+ # }, fields=["ts_code","end_date", financial_index])
678
+ #
679
+ # stock_name = get_stock_name_from_code(stock_code)
680
+ # stock_data['stock_name'] = stock_name
681
+ # stock_data = stock_data.sort_values(by='end_date', ascending=True) # 按照日期升序排列
682
+ # # 把end_data列改名为trade_date
683
+ # stock_data.rename(columns={'end_date': 'trade_date'}, inplace=True)
684
+ # stock_financial_data = stock_data[['stock_name', 'trade_date', financial_index]]
685
+ # return stock_financial_data
686
+
687
+
688
+ def get_Financial_data_from_time_range(stock_name:str, start_date:str, end_date:str, financial_index:str='') -> pd.DataFrame:
689
+ # start_date='20190101',end_date='20221231',financial_index='roe', The returned data consists of the ROE values for the entire three-year period from 2019 to 2022.
690
+ # To query quarterly or annual financial report data for a specific moment, "yyyy0331"为一季报,"yyyy0630"为半年报,"yyyy0930"为三季报,"yyyy1231"为年报,例如get_Financial_data_from_time_range("600519.SH", "20190331", "20190331", "roe") means to query the return on equity (ROE) data from the first quarter of 2019,
691
+ # # current_ratio 流动比率 # quick_ratio 速动比率 # netprofit_margin 销售净利率 # grossprofit_margin 销售毛利率 # roe 净资产收益率 # roe_dt 净资产收益率(扣除非经常损益)
692
+ # roa 总资产报酬率 # debt_to_assets 资产负债率 # roa_yearly 年化总资产净利率 # q_dtprofit 扣除非经常损益后的单季度净利润 # q_eps 每股收益(单季度)
693
+ # q_netprofit_margin 销售净利率(单季度) # q_gsprofit_margin 销售毛利率(单季度) # basic_eps_yoy 基本每股收益同比增长率(%) # netprofit_yoy 归属母公司股东的净利润同比增长率(%) # q_netprofit_yoy 归属母公司股东的净利润同比增长率(%)(单季度) # q_netprofit_qoq 归属母公司股东的净利润环比增长率(%)(单季度) # equity_yoy 净资产同比增长率
694
+ """
695
+ Retrieves the financial data for a given stock within a specified date range.
696
+
697
+ Args:
698
+ stock_name (str): The stock code.
699
+ start_date (str): The start date of the data range in the format "YYYYMMDD".
700
+ end_date (str): The end date of the data range in the format "YYYYMMDD".
701
+ financial_index (str, optional): The financial indicator to be queried.
702
+
703
+ Returns:
704
+ pd.DataFrame: A DataFrame containin financial data for the specified stock and date range.
705
+
706
+ """
707
+ stock_code = get_stock_code(stock_name)
708
+ stock_data = pro.fina_indicator(**{
709
+ "ts_code": stock_code,
710
+ "ann_date": "",
711
+ "start_date": start_date,
712
+ "end_date": end_date,
713
+ "period": '',
714
+ "update_flag": "1",
715
+ "limit": "",
716
+ "offset": ""
717
+ }, fields=["ts_code", "end_date", financial_index])
718
+
719
+ #stock_name = get_stock_name_from_code(stock_code)
720
+ stock_data['stock_name'] = stock_name
721
+ stock_data = stock_data.sort_values(by='end_date', ascending=True) # 按照日期升序排列
722
+ # 把end_data列改名为trade_date
723
+ stock_data.rename(columns={'end_date': 'trade_date'}, inplace=True)
724
+ stock_financial_data = stock_data[['stock_name', 'trade_date', financial_index]]
725
+ return stock_financial_data
726
+
727
+
728
+ def get_GDP_data(start_quarter:str='', end_quarter:str='', index:str='gdp_yoy') -> pd.DataFrame:
729
+ # The available indicators for query include the following 9 categories: # gdp GDP累计值(亿元)# gdp_yoy 当季同比增速(%)# pi 第一产业累计值(亿元)# pi_yoy 第一产业同比增速(%)# si 第二产业累计值(亿元)# si_yoy 第二产业同比增速(%)# ti 第三产业累计值(亿元) # ti_yoy 第三产业同比增速(%)
730
+ """
731
+ Retrieves GDP data for the chosen index and specified time period.
732
+
733
+ Args:
734
+ - start_quarter (str): The start quarter of the query, in YYYYMMDD format.
735
+ - end_quarter (str): The end quarter, in YYYYMMDD format.
736
+ - index (str): The specific GDP index to retrieve. Default is `gdp_yoy`.
737
+
738
+ Returns:
739
+ - pd.DataFrame: A pandas DataFrame with three columns: `quarter`, `country`, and the selected `index`.
740
+ """
741
+
742
+ # The output is a DataFrame with three columns:
743
+ # the first column represents the quarter (quarter), the second column represents the country (country), and the third column represents the index (index).
744
+ df = pro.cn_gdp(**{
745
+ "q":'',
746
+ "start_q": start_quarter,
747
+ "end_q": end_quarter,
748
+ "limit": "",
749
+ "offset": ""
750
+ }, fields=[
751
+ "quarter",
752
+ "gdp",
753
+ "gdp_yoy",
754
+ "pi",
755
+ "pi_yoy",
756
+ "si",
757
+ "si_yoy",
758
+ "ti",
759
+ "ti_yoy"
760
+ ])
761
+ df = df.sort_values(by='quarter', ascending=True) #
762
+ df['country'] = 'China'
763
+ df = df[['quarter', 'country', index]].copy()
764
+
765
+
766
+ return df
767
+
768
+ def get_cpi_ppi_currency_supply_data(start_month: str = '', end_month: str = '', type: str = 'cpi', index: str = '') -> pd.DataFrame:
769
+ # The query types (type) include three categories: CPI, PPI, and currency supply. Each type corresponds to different indices.
770
+ # Specifically, CPI has 12 indices, PPI has 30 indices, and currency supply has 9 indices.
771
+ # The output is a DataFrame table with three columns: the first column represents the month (month), the second column represents the country (country), and the third column represents the index (index).
772
+
773
+ # type='cpi',monthly CPI data include the following 12 categories:
774
+ # nt_val 全国当月值 # nt_yoy 全国同比(%)# nt_mom 全国环比(%)# nt_accu 全国累计值# town_val 城市当月值# town_yoy 城市同比(%)# town_mom 城市环比(%)# town_accu 城市累计值# cnt_val 农村当月值# cnt_yoy 农村同比(%)# cnt_mom 农村环比(%)# cnt_accu 农村累计值
775
+
776
+ # type = 'ppi', monthly PPI data include the following 30 categories:
777
+ # ppi_yoy PPI:全部工业品:当月同比
778
+ # ppi_mp_yoy PPI:生产资料:当月同比
779
+ # ppi_mp_qm_yoy PPI:生产资料:采掘业:当月同比
780
+ # ppi_mp_rm_yoy PPI:生产资料:原料业:当月同比
781
+ # ppi_mp_p_yoy PPI:生产资料:加工业:当月同比
782
+ # ppi_cg_yoy PPI:生活资料:当月同比
783
+ # ppi_cg_f_yoy PPI:生活资料:食品类:当月同比
784
+ # ppi_cg_c_yoy PPI:生活资料:衣着类:当月同比
785
+ # ppi_cg_adu_yoy PPI:生活资料:一般日用品类:当月同比
786
+ # ppi_cg_dcg_yoy PPI:生活资料:耐用消费品类:当月同比
787
+ # ppi_mom PPI:全部工业品:环比
788
+ # ppi_mp_mom PPI:生产资料:环比
789
+ # ppi_mp_qm_mom PPI:生产资料:采掘业:环比
790
+ # ppi_mp_rm_mom PPI:生产资料:原料业:环比
791
+ # ppi_mp_p_mom PPI:生产资料:加工业:环比
792
+ # ppi_cg_mom PPI:生活资料:环比
793
+ # ppi_cg_f_mom PPI:生活资料:食品类:环比
794
+ # ppi_cg_c_mom PPI:生活资料:衣着类:环比
795
+ # ppi_cg_adu_mom PPI:生活资料:一般日用品类:环比
796
+ # ppi_cg_dcg_mom PPI:生活资料:耐用消费品类:环比
797
+ # ppi_accu PPI:全部工业品:累计同比
798
+ # ppi_mp_accu PPI:生产资料:累计同比
799
+ # ppi_mp_qm_accu PPI:生产资料:采掘业:累计同比
800
+ # ppi_mp_rm_accu PPI:生产资料:原料业:累计同比
801
+ # ppi_mp_p_accu PPI:生产资料:加工业:累计同比
802
+ # ppi_cg_accu PPI:生活资料:累计同比
803
+ # ppi_cg_f_accu PPI:生活资料:食品类:累计同比
804
+ # ppi_cg_c_accu PPI:生活资料:衣着类:累计同比
805
+ # ppi_cg_adu_accu PPI:生活资料:一般日用品类:累计同比
806
+ # ppi_cg_dcg_accu PPI:生活资料:耐用消费品类:累计同比
807
+
808
+ # type = 'currency_supply', monthly currency supply data include the following 9 categories:
809
+ # m0 M0(亿元)# m0_yoy M0同比(%)# m0_mom M0环比(%)# m1 M1(亿元)# m1_yoy M1同比(%)# m1_mom M1环比(%)# m2 M2(亿元)# m2_yoy M2同比(%)# m2_mom M2环比(%)
810
+
811
+ """
812
+ This function is used to retrieve China's monthly CPI (Consumer Price Index), PPI (Producer Price Index),
813
+ and monetary supply data published by the National Bureau of Statistics,
814
+ and return a DataFrame table containing month, country, and index values.
815
+ The function parameters include start month, end month, query type, and query index.
816
+ For query indexes that are not within the query range, the default index for the corresponding type is returned.
817
+
818
+ Args:
819
+ - start_month (str): start month of the query, in the format of YYYYMMDD.
820
+ - end_month (str):end month in YYYYMMDD
821
+ - type (str): required parameter, query type, including three types: cpi, ppi, and currency_supply.
822
+ - index (str): optional parameter, query index, the specific index depends on the query type.
823
+ If the query index is not within the range, the default index for the corresponding type is returned.
824
+
825
+ Returns:
826
+ - pd.DataFrame: DataFrame type, including three columns: month, country, and index value.
827
+ """
828
+
829
+ if type == 'cpi':
830
+
831
+ df = pro.cn_cpi(**{
832
+ "m": '',
833
+ "start_m": start_month,
834
+ "end_m": end_month,
835
+ "limit": "",
836
+ "offset": ""
837
+ }, fields=[
838
+ "month", "nt_val","nt_yoy", "nt_mom","nt_accu", "town_val", "town_yoy", "town_mom",
839
+ "town_accu", "cnt_val", "cnt_yoy", "cnt_mom", "cnt_accu"])
840
+ # If the index is not within the aforementioned range, the index is set as "nt_yoy".
841
+ if index not in df.columns:
842
+ index = 'nt_yoy'
843
+
844
+
845
+ elif type == 'ppi':
846
+ df = pro.cn_ppi(**{
847
+ "m": '',
848
+ "start_m": start_month,
849
+ "end_m": end_month,
850
+ "limit": "",
851
+ "offset": ""
852
+ }, fields=[
853
+ "month", "ppi_yoy", "ppi_mp_yoy", "ppi_mp_qm_yoy", "ppi_mp_rm_yoy", "ppi_mp_p_yoy", "ppi_cg_yoy",
854
+ "ppi_cg_f_yoy", "ppi_cg_c_yoy", "ppi_cg_adu_yoy", "ppi_cg_dcg_yoy",
855
+ "ppi_mom", "ppi_mp_mom", "ppi_mp_qm_mom", "ppi_mp_rm_mom", "ppi_mp_p_mom", "ppi_cg_mom", "ppi_cg_f_mom",
856
+ "ppi_cg_c_mom", "ppi_cg_adu_mom", "ppi_cg_dcg_mom",
857
+ "ppi_accu", "ppi_mp_accu", "ppi_mp_qm_accu", "ppi_mp_rm_accu", "ppi_mp_p_accu", "ppi_cg_accu",
858
+ "ppi_cg_f_accu", "ppi_cg_c_accu", "ppi_cg_adu_accu", "ppi_cg_dcg_accu"
859
+ ])
860
+ if index not in df.columns:
861
+ index = 'ppi_yoy'
862
+
863
+ elif type == 'currency_supply':
864
+ df = pro.cn_m(**{
865
+ "m": '',
866
+ "start_m": start_month,
867
+ "end_m": end_month,
868
+ "limit": "",
869
+ "offset": ""
870
+ }, fields=[
871
+ "month", "m0", "m0_yoy","m0_mom", "m1",
872
+ "m1_yoy", "m1_mom", "m2", "m2_yoy", "m2_mom"])
873
+ if index not in df.columns:
874
+ index = 'm2_yoy'
875
+
876
+
877
+ df = df.sort_values(by='month', ascending=True) #
878
+ df['country'] = 'China'
879
+ df = df[['month', 'country', index]].copy()
880
+ return df
881
+
882
+ def predict_next_value(df: pd.DataFrame, pred_index: str = 'nt_yoy', pred_num:int = 1. ) -> pd.DataFrame:
883
+ """
884
+ Predict the next n values of a specific column in the DataFrame using linear regression.
885
+
886
+ Parameters:
887
+ df (pandas.DataFrame): The input DataFrame.
888
+ pred_index (str): The name of the column to predict.
889
+ pred_num (int): The number of future values to predict.
890
+
891
+ Returns:
892
+ pandas.DataFrame: The DataFrame with the predicted values appended to the specified column
893
+ and other columns filled as pred+index.
894
+ """
895
+ input_array = df[pred_index].values
896
+
897
+ # Convert the input array into the desired format.
898
+ x = np.array(range(len(input_array))).reshape(-1, 1)
899
+ y = input_array.reshape(-1, 1)
900
+
901
+ # Train a linear regression model.
902
+ model = LinearRegression()
903
+ model.fit(x, y)
904
+
905
+ # Predict the future n values.
906
+ next_indices = np.array(range(len(input_array), len(input_array) + pred_num)).reshape(-1, 1)
907
+ predicted_values = model.predict(next_indices).flatten()
908
+
909
+ for i, value in enumerate(predicted_values, 1):
910
+ row_data = {pred_index: value}
911
+ for other_col in df.columns:
912
+ if other_col != pred_index:
913
+ row_data[other_col] = 'pred' + str(i)
914
+ df = df.append(row_data, ignore_index=True)
915
+
916
+ # Return the updated DataFrame
917
+ return df
918
+
919
+
920
+
921
+
922
+
923
+
924
+ def get_latest_new_from_web(src: str = 'sina') -> pd.DataFrame:
925
+
926
+ # 新浪财经 sina 获取新浪财经实时资讯
927
+ # 同花顺 10jqka 同花顺财经新闻
928
+ # 东方财富 eastmoney 东方财富财经新闻
929
+ # 云财经 yuncaijing 云财经新闻
930
+ """
931
+ Retrieves the latest news data from major news websites, including Sina Finance, 10jqka, Eastmoney, and Yuncaijing.
932
+
933
+ Args:
934
+ src (str): The name of the news website. Default is 'sina'. Optional parameters include: 'sina' for Sina Finance,
935
+ '10jqka' for 10jqka, 'eastmoney' for Eastmoney, and 'yuncaijing' for Yuncaijing.
936
+
937
+ Returns:
938
+ pd.DataFrame: A DataFrame containing the news data, including two columns for date/time and content.
939
+ """
940
+
941
+ df = pro.news(**{
942
+ "start_date": '',
943
+ "end_date": '',
944
+ "src": src,
945
+ "limit": "",
946
+ "offset": ""
947
+ }, fields=[
948
+ "datetime",
949
+ "content",
950
+ ])
951
+ df = df.apply(lambda x: '[' + x.name + ']' + ': ' + x.astype(str))
952
+ return df
953
+
954
+
955
+ # def show_dynamic_table(df: pd.DataFrame) -> None:
956
+ # '''
957
+ # This function displays a dynamic table in the terminal window, where each row of the input DataFrame is shown one by one.
958
+ # Arguments:
959
+ # df: A Pandas DataFrame containing the data to be displayed in the dynamic table.
960
+ #
961
+ # Returns: None. This function does not return anything.
962
+ #
963
+ # '''
964
+ #
965
+ # return df
966
+ # # table = PrettyTable(df.columns.tolist(),align='l')
967
+ #
968
+ # # 将 DataFrame 的数据添加到表格中
969
+ # # for row in df.itertuples(index=False):
970
+ # # table.add_row(row)
971
+ #
972
+ # # 初始化终端
973
+ # # term = Terminal()
974
+ # #
975
+ # # # 在终端窗口中滚动显示表格
976
+ # # with term.fullscreen():
977
+ # # with term.cbreak():
978
+ # # print(term.clear())
979
+ # # with term.location(0, 0):
980
+ # # # 将表格分解为多行,并遍历每一行
981
+ # # lines = str(table).split('\n')
982
+ # # for i, line in enumerate(lines):
983
+ # # with term.location(0, i):
984
+ # # print(line)
985
+ # # time.sleep(1)
986
+ # #
987
+ # # while True:
988
+ # # # 读取输入
989
+ # # key = term.inkey(timeout=0.1)
990
+ # #
991
+ # # # 如果收到q键,则退出
992
+ # # if key.lower() == 'q':
993
+ # # break
994
+
995
+
996
+ def get_index_constituent(index_name: str = '', start_date:str ='', end_date:str ='') -> pd.DataFrame:
997
+ """
998
+ Query the constituent stocks of basic index (中证500) or a specified SW (申万) industry index
999
+
1000
+ args:
1001
+ index_name: the name of the index.
1002
+ start_date: the start date in "YYYYMMDD".
1003
+ end_date: the end date in "YYYYMMDD".
1004
+
1005
+ return:
1006
+ A pandas DataFrame containing the following columns:
1007
+ index_code
1008
+ index_name
1009
+ stock_code: the code of the constituent stock.
1010
+ stock_name: the name of the constituent stock.
1011
+ weight: the weight of the constituent stock.
1012
+ """
1013
+
1014
+ if '申万' in index_name:
1015
+ if '申万一级行业' in index_name:
1016
+ # index_name取后面的名字
1017
+ index_name = index_name[6:]
1018
+ df1 = pd.read_csv('SW2021_industry_L1.csv')
1019
+ index_code = df1[df1['industry_name'] == index_name]['index_code'].iloc[0]
1020
+ elif '申万二级行业' in index_name:
1021
+ index_name = index_name[6:]
1022
+ df1 = pd.read_csv('SW2021_industry_L2.csv')
1023
+ index_code = df1[df1['industry_name'] == index_name]['index_code'].iloc[0]
1024
+ elif '申万三级行业' in index_name:
1025
+ index_name = index_name[6:]
1026
+ df1 = pd.read_csv('SW2021_industry_L3.csv')
1027
+ index_code = df1[df1['industry_name'] == index_name]['index_code'].iloc[0]
1028
+
1029
+ print('The industry code for ', index_name, ' is: ', index_code)
1030
+
1031
+ # 拉取数据
1032
+ df = pro.index_member(**{
1033
+ "index_code": index_code , #'851251.SI'
1034
+ "is_new": "",
1035
+ "ts_code": "",
1036
+ "limit": "",
1037
+ "offset": ""
1038
+ }, fields=[
1039
+ "index_code",
1040
+ "con_code",
1041
+ "in_date",
1042
+ "out_date",
1043
+ "is_new",
1044
+ "index_name",
1045
+ "con_name"
1046
+ ])
1047
+ #
1048
+ # For each stock, filter the start_date and end_date that are between in_date and out_date.
1049
+ df = df[(df['in_date'] <= start_date)]
1050
+ df = df[(df['out_date'] >= end_date) | (df['out_date'].isnull())]
1051
+
1052
+
1053
+
1054
+ df.rename(columns={'con_code': 'stock_code'}, inplace=True)
1055
+
1056
+ df.rename(columns={'con_name': 'stock_name'}, inplace=True)
1057
+ #
1058
+ df['weight'] = np.nan
1059
+
1060
+ df = df[['index_code', "index_name", 'stock_code', 'stock_name','weight']]
1061
+
1062
+ else: # 宽基指数
1063
+ df1 = pro.index_basic(**{
1064
+ "ts_code": "",
1065
+ "market": "",
1066
+ "publisher": "",
1067
+ "category": "",
1068
+ "name": index_name,
1069
+ "limit": "",
1070
+ "offset": ""
1071
+ }, fields=[
1072
+ "ts_code",
1073
+ "name",
1074
+ ])
1075
+
1076
+ index_code = df1["ts_code"][0]
1077
+ print(f'index_code for basic index {index_name} is {index_code}')
1078
+
1079
+
1080
+ # Step 2: Retrieve the constituents of an index based on the index code and given date.
1081
+ df = pro.index_weight(**{
1082
+ "index_code": index_code,
1083
+ "trade_date": '',
1084
+ "start_date": start_date,
1085
+ "end_date": end_date,
1086
+ "limit": "",
1087
+ "offset": ""
1088
+ }, fields=[
1089
+ "index_code",
1090
+ "con_code",
1091
+ "trade_date",
1092
+ "weight"
1093
+ ])
1094
+ # df = df.sort_values(by='trade_date', ascending=True) #
1095
+ df['index_name'] = index_name
1096
+ last_day = df['trade_date'][0]
1097
+ # for the last trading day
1098
+ df = df[df['trade_date'] == last_day]
1099
+ df_stock = pd.read_csv('tushare_stock_basic_20230421210721.csv')
1100
+ # Merge based on the stock code.
1101
+ df = pd.merge(df, df_stock, how='left', left_on='con_code', right_on='ts_code')
1102
+ # df.rename(columns={'name_y': 'name'}, inplace=True)
1103
+ df = df.drop(columns=['symbol', 'area', 'con_code'])
1104
+ df.sort_values(by='weight', ascending=False, inplace=True)
1105
+ df.rename(columns={'name': 'stock_name'}, inplace=True)
1106
+ df.rename(columns={'ts_code': 'stock_code'}, inplace=True)
1107
+ df.dropna(axis=0, how='any', inplace=True)
1108
+ #
1109
+ df = df[['index_code', "index_name", 'stock_code', 'stock_name', 'weight']]
1110
+
1111
+ return df
1112
+
1113
+ # Determine whether the given name is a stock or a fund.,
1114
+ def is_fund(ts_name: str = '') -> bool:
1115
+ # call get_stock_code()和query_fund_name_or_code()
1116
+ if get_stock_code(ts_name) is not None and query_fund_name_or_code(ts_name) is None:
1117
+ return False
1118
+ elif get_stock_code(ts_name) is None and query_fund_name_or_code(ts_name) is not None:
1119
+ return True
1120
+
1121
+
1122
+
1123
+
1124
+ def calculate_earning_between_two_time(stock_name: str = '', start_date: str = '', end_date: str = '', index: str = 'close') -> float:
1125
+ """
1126
+ Calculates the rate of return for a specified stock/fund between two dates.
1127
+
1128
+ Args:
1129
+ stock_name: stock_name or fund_name
1130
+ start_date
1131
+ end_date
1132
+ index (str): The index used to calculate the stock return, including 'open' and 'close'.
1133
+
1134
+ Returns:
1135
+ float: The rate of return for the specified stock between the two dates.
1136
+ """
1137
+ if is_fund(stock_name):
1138
+ fund_code = query_fund_name_or_code(stock_name)
1139
+ stock_data = query_fund_data(fund_code, start_date, end_date)
1140
+ if index =='':
1141
+ index = 'adj_nav'
1142
+ else:
1143
+ stock_data = get_stock_prices_data(stock_name, start_date, end_date,'daily')
1144
+ try:
1145
+ end_price = stock_data.iloc[-1][index]
1146
+ start_price = stock_data.iloc[0][index]
1147
+ earning = cal_dt(end_price, start_price)
1148
+ # earning = round((end_price - start_price) / start_price * 100, 2)
1149
+ except:
1150
+ print(ts_code,start_date,end_date)
1151
+ print('##################### 该股票没有数据 #####################')
1152
+ return None
1153
+ # percent = earning * 100
1154
+ # percent_str = '{:.2f}%'.format(percent)
1155
+
1156
+ return earning
1157
+
1158
+
1159
+ def loop_rank(df: pd.DataFrame, func: callable, *args, **kwargs) -> pd.DataFrame:
1160
+ """
1161
+ It iteratively applies the given function to each row and get a result using function. It then stores the calculated result in 'new_feature' column.
1162
+
1163
+ Args:
1164
+ df: DataFrame with a single column
1165
+ func : The function to be applied to each row: func(row, *args, **kwargs)
1166
+ *args: Additional positional arguments for `func` function.
1167
+ **kwargs: Additional keyword arguments for `func` function.
1168
+
1169
+ Returns:
1170
+ pd.DataFrame: A output DataFrame with three columns: the constant column, input column, and new_feature column.
1171
+ The DataFrame is sorted based on the new_feature column in descending order.
1172
+
1173
+ """
1174
+ df['new_feature'] = None
1175
+ loop_var = df.columns[0]
1176
+ for _, row in df.iterrows():
1177
+ res = None
1178
+ var = row[loop_var] #
1179
+
1180
+ if var is not None:
1181
+ if loop_var == 'stock_name':
1182
+ stock_name = var
1183
+ elif loop_var == 'stock_code':
1184
+ stock_name = get_stock_name_from_code(var)
1185
+ elif loop_var == 'fund_name':
1186
+ stock_name = var
1187
+ elif loop_var == 'fund_code':
1188
+ stock_name = query_fund_name_or_code('',var)
1189
+ time.sleep(0.4)
1190
+ try:
1191
+ res = func(stock_name, *args, **kwargs) #
1192
+ except:
1193
+ raise ValueError('#####################Error for func#####################')
1194
+ # res represents the result obtained for the variable. For example, if the variable is a stock name, res could be the return rate of that stock over a certain period or a specific feature value of that stock. Therefore, res should be a continuous value.
1195
+ # If the format of res is a float, then it can be used directly. However, if res is in DataFrame format, you can retrieve the value corresponding to the index.
1196
+ if isinstance(res, pd.DataFrame) and not res.empty:
1197
+ #
1198
+ try:
1199
+ res = round(res.loc[:,args[-1]][0], 2)
1200
+ df.loc[df[loop_var] == var, 'new_feature'] = res
1201
+ except:
1202
+ raise ValueError('##################### Error ######################')
1203
+ elif isinstance(res, float): #
1204
+ res = res
1205
+ df.loc[df[loop_var] == var, 'new_feature'] = res
1206
+ print(var, res)
1207
+
1208
+
1209
+ # Remove the rows where the new_feature column is empty.
1210
+ df = df.dropna(subset=['new_feature'])
1211
+ stock_data = df.sort_values(by='new_feature', ascending=False)
1212
+ #
1213
+ stock_data.insert(0, 'unchanged', loop_var)
1214
+ stock_data = stock_data.loc[:,[stock_data.columns[0], loop_var, 'new_feature']]
1215
+
1216
+ return stock_data
1217
+
1218
+ def output_mean_median_col(data: pd.DataFrame, col: str = 'new_feature') -> float:
1219
+ # It calculates the mean and median value for the specified column.
1220
+
1221
+ mean = round(data[col].mean(), 2)
1222
+ median = round(data[col].median(), 2)
1223
+ #
1224
+ #print(title, mean)
1225
+ return (mean, median)
1226
+
1227
+
1228
+ # def output_median_col(data: pd.DataFrame, col: str, title_name: str = '') -> float:
1229
+ # # It calculates the median value for the specified column and returns the median as a float value.
1230
+ #
1231
+ # median = round(data[col].median(), 2)
1232
+ # #print(title_name, median)
1233
+ #
1234
+ # return median
1235
+
1236
+
1237
+ def output_weighted_mean_col(data: pd.DataFrame, col: str, weight_col: pd.Series) -> float:
1238
+
1239
+ """
1240
+ Calculates the weighted mean of a column and returns the result as a float.
1241
+
1242
+ Args:
1243
+ data (pd.DataFrame): The input cross-sectional or time-series data containing the feature columns.
1244
+ col (str): The name of the feature column to calculate the weighted mean for.
1245
+ weight_col (pd.Series): The weights used for the calculation, as a pandas Series.
1246
+
1247
+ Returns:
1248
+ float: The weighted mean of the specified feature column.
1249
+ """
1250
+
1251
+ weighted_mean = round(np.average(data[col], weights = weight_col)/100., 2)
1252
+ return weighted_mean
1253
+
1254
+
1255
+
1256
+ def get_index_data(index_name: str = '', start_date: str = '', end_date: str = '', freq: str = 'daily') -> pd.DataFrame:
1257
+ """
1258
+ This function retrieves daily, weekly, or monthly data for a given stock index.
1259
+
1260
+ Arguments:
1261
+ - index_name: Name of the index
1262
+ - start_date: Start date in 'YYYYMMDD'
1263
+ - end_date: End date in 'YYYYMMDD'
1264
+ - freq: Frequency 'daily', 'weekly', or 'monthly'
1265
+
1266
+ Returns:
1267
+ A DataFrame containing the following columns:
1268
+ trade_date, ts_code, close, open, high, low, pre_close: Previous day's closing price, change(涨跌额), pct_chg(涨跌幅), vol(成交量), amount(成交额), name: Index Name
1269
+ """
1270
+ df1 = pro.index_basic(**{
1271
+ "ts_code": "",
1272
+ "market": "",
1273
+ "publisher": "",
1274
+ "category": "",
1275
+ "name": index_name,
1276
+ "limit": "",
1277
+ "offset": ""
1278
+ }, fields=[
1279
+ "ts_code",
1280
+ "name",
1281
+ ])
1282
+
1283
+ index_code = df1["ts_code"][0]
1284
+ print(f'index_code for index {index_name} is {index_code}')
1285
+ #
1286
+ if freq == 'daily':
1287
+ df = pro.index_daily(**{
1288
+ "ts_code": index_code,
1289
+ "trade_date": '',
1290
+ "start_date": start_date,
1291
+ "end_date": end_date,
1292
+ "limit": "",
1293
+ "offset": ""
1294
+ }, fields=[
1295
+ "trade_date",
1296
+ "ts_code",
1297
+ "close",
1298
+ "open",
1299
+ "high",
1300
+ "low",
1301
+ "pre_close",
1302
+ "change",
1303
+ "pct_chg",
1304
+ "vol",
1305
+ "amount"
1306
+ ])
1307
+ elif freq == 'weekly':
1308
+ df = pro.index_weekly(**{
1309
+ "ts_code": index_code,
1310
+ "trade_date": '',
1311
+ "start_date": start_date,
1312
+ "end_date": end_date,
1313
+ "limit": "",
1314
+ "offset": ""
1315
+ }, fields=[
1316
+ "trade_date",
1317
+ "ts_code",
1318
+ "close",
1319
+ "open",
1320
+ "high",
1321
+ "low",
1322
+ "pre_close",
1323
+ "change",
1324
+ "pct_chg",
1325
+ "vol",
1326
+ "amount"
1327
+ ])
1328
+ elif freq == 'monthly':
1329
+ df = pro.index_monthly(**{
1330
+ "ts_code": index_code,
1331
+ "trade_date": '',
1332
+ "start_date": start_date,
1333
+ "end_date": end_date,
1334
+ "limit": "",
1335
+ "offset": ""
1336
+ }, fields=[
1337
+ "trade_date",
1338
+ "ts_code",
1339
+ "close",
1340
+ "open",
1341
+ "high",
1342
+ "low",
1343
+ "pre_close",
1344
+ "change",
1345
+ "pct_chg",
1346
+ "vol",
1347
+ "amount"
1348
+ ])
1349
+
1350
+ df = df.sort_values(by='trade_date', ascending=True) #
1351
+ df['index_name'] = index_name
1352
+ return df
1353
+
1354
+
1355
+
1356
+
1357
+
1358
+ def get_north_south_money(start_date: str = '', end_date: str = '', trade_date: str = '') -> pd.DataFrame:
1359
+ #
1360
+ # trade_date: 交易日期
1361
+ # ggt_ss: 港股通(上海)
1362
+ # ggt_sz: 港股通(深圳)
1363
+ # hgt: 沪股通(亿元)
1364
+ # sgt: 深股通(亿元)
1365
+ # north_money: 北向资金(亿元)= hgt + sgt
1366
+ # south_money: 南向资金(亿元)= ggt_ss + ggt_sz
1367
+ # name: 固定为'A-H',代表A股和H股
1368
+ # accumulate_north_money: 累计北向资金流入
1369
+ # accumulate_south_money: 累计南向资金流入
1370
+
1371
+
1372
+ month_df = pro.moneyflow_hsgt(**{
1373
+ "trade_date": trade_date,
1374
+ "start_date": start_date,
1375
+ "end_date": end_date,
1376
+ "limit": "",
1377
+ "offset": ""
1378
+ }, fields=[
1379
+ "trade_date",
1380
+ "ggt_ss",
1381
+ "ggt_sz",
1382
+ "hgt",
1383
+ "sgt",
1384
+ "north_money",
1385
+ "south_money"
1386
+ ])
1387
+
1388
+ month_df[['ggt_ss','ggt_sz','hgt','sgt','north_money','south_money']] = month_df[['ggt_ss','ggt_sz','hgt','sgt','north_money','south_money']]/100.0
1389
+ month_df = month_df.sort_values(by='trade_date', ascending=True) #
1390
+ month_df['stock_name'] = 'A-H'
1391
+ month_df['accumulate_north_money'] = month_df['north_money'].cumsum()
1392
+ month_df['accumulate_south_money'] = month_df['south_money'].cumsum()
1393
+ return month_df
1394
+
1395
+
1396
+
1397
+ def plot_k_line(stock_data: pd.DataFrame, title: str = '') -> None:
1398
+ """
1399
+ Plots a K-line chart of stock price and volume.
1400
+
1401
+ Args:
1402
+ stock_data : A pandas DataFrame containing the stock price information, in which each row
1403
+ represents a daily record. The DataFrame must contain the 'trade_date','open', 'close', 'high', 'low','volume', 'name' columns, which is used for k-line and volume.
1404
+ 如果dataframe中还含有'macd','kdj', 'rsi', 'cci', 'boll','pe_ttm','turnover_rate'等列,则在k线图下方绘制这些指标的子图.
1405
+ title : The title of the K-line chart.
1406
+
1407
+ Returns:
1408
+ None
1409
+ """
1410
+
1411
+ #
1412
+ stock_data['trade_date'] = pd.to_datetime(stock_data['trade_date'], format='%Y%m%d')
1413
+ stock_data.set_index('trade_date', inplace=True)
1414
+ #
1415
+ custom_style = mpf.make_marketcolors(up='r', down='k', inherit=True)
1416
+ china_style = mpf.make_mpf_style(marketcolors=custom_style)
1417
+
1418
+ # MACD
1419
+ # stock_data['macd1'] = stock_data['Close'].ewm(span=12).mean() - stock_data['Close'].ewm(span=26).mean()
1420
+ # stock_data['macd_signal1'] = stock_data['macd'].ewm(span=9).mean()
1421
+
1422
+ #
1423
+ #mpf.plot(stock_data, type='candle', volume=True, title=title, mav=(5, 10, 20), style = china_style, addplot = macd)
1424
+ add_plot = []
1425
+ # The index column is located after the name column in the last few columns.
1426
+ # Retrieve the column names after the 'name' column.
1427
+ index_list = stock_data.columns[stock_data.columns.get_loc('stock_name')+1:]
1428
+
1429
+ index_df = stock_data[index_list]
1430
+
1431
+ color_list = ['green','blue','red','yellow','black','purple','orange','pink','brown','gray']
1432
+ custom_lines = []
1433
+ for i in range(len(index_list)):
1434
+ # If the column names contain 'boll', set panel to 0. Otherwise, set panel to 2.
1435
+ if 'boll' in index_list[i]:
1436
+ sub_plot = mpf.make_addplot(index_df[index_list[i]], panel=0, ylabel=index_list[i], color=color_list[i], type='line', secondary_y=True)
1437
+ elif index_list[i] =='macd':
1438
+ sub_plot = mpf.make_addplot(index_df[index_list[i]], panel=2, ylabel=index_list[i], color=color_list[i], type='bar', secondary_y=False)
1439
+
1440
+ else:
1441
+ sub_plot = mpf.make_addplot(index_df[index_list[i]], panel=2, ylabel=index_list[i], color=color_list[i], type='line', secondary_y=False)
1442
+
1443
+ custom_line = Line2D([0], [0], color=color_list[i], lw=1, linestyle='dashed')
1444
+
1445
+
1446
+ add_plot.append(sub_plot)
1447
+ custom_lines.append(custom_line)
1448
+
1449
+ mav_colors = ['red', 'green', 'blue']
1450
+
1451
+ fig, axes = mpf.plot(stock_data, type='candle', volume=True, title=title, mav=(5, 10, 20), mavcolors=mav_colors, style=china_style, addplot=add_plot, returnfig=True)
1452
+
1453
+
1454
+ mav_labels = ['5-day MA', '10-day MA', '20-day MA']
1455
+ #
1456
+ legend_lines = [plt.Line2D([0], [0], color=color, lw=2) for color in mav_colors]
1457
+
1458
+ #
1459
+ axes[0].legend(legend_lines, mav_labels)
1460
+
1461
+ if len(index_list) ==1:
1462
+ label = index_list[0]
1463
+ elif len(index_list) > 1:
1464
+ label_list = [i.split('_')[0] for i in index_list]
1465
+ #
1466
+ label = list(set(label_list))[0]
1467
+
1468
+ if len(index_list) >= 1:
1469
+ if 'boll' in label:
1470
+ axes[0].legend(custom_lines, index_list, loc='lower right')
1471
+
1472
+ elif len(index_list) > 1:
1473
+ axes[-2].set_ylabel(label)
1474
+ axes[-2].legend(custom_lines, index_list, loc='lower right')
1475
+
1476
+ #
1477
+ fig.set_size_inches(20, 16)
1478
+ #
1479
+ for ax in axes:
1480
+ ax.grid(True)
1481
+
1482
+ #fig.show()
1483
+ return axes
1484
+
1485
+
1486
+ def cal_dt(num_at_time_2: float = 0.0, num_at_time_1: float = 0.0) -> float:
1487
+ """
1488
+ This function calculates the percentage change of a metric from one time to another.
1489
+
1490
+ Args:
1491
+ - num_at_time_2: the metric value at time 2 (end time)
1492
+ - num_at_time_1: the metric value at time 1 (start time)
1493
+
1494
+ Returns:
1495
+ - float: the percentage change of the metric from time 1 to time 2
1496
+
1497
+ """
1498
+ if num_at_time_1 == 0:
1499
+ num_at_time_1 = 0.0000000001
1500
+ return round((num_at_time_2 - num_at_time_1) / num_at_time_1, 4)
1501
+
1502
+
1503
+ def query_fund_info(fund_code: str = '') -> pd.DataFrame:
1504
+ #
1505
+ # fund_code str Y 基金代码 # fund_name str Y 简称 # management str Y 管理人 # custodian str Y 托管人 # fund_type str Y 投资类型 # found_date str Y 成立日期 # due_date str Y 到期日期 # list_date str Y 上市时间 # issue_date str Y 发行日期 # delist_date str Y 退市日期 # issue_amount float Y 发行份额(亿) # m_fee float Y 管理费 # c_fee float Y 托管费
1506
+ # duration_year float Y 存续期 # p_value float Y 面值 # min_amount float Y 起点金额(万元) # benchmark str Y 业绩比较基准 # status str Y 存续状态D摘牌 I发行 L已上市 # invest_type str Y 投资风格 # type str Y 基金类型 # purc_startdate str Y 日常申购起始日 # redm_startdate str Y 日常赎回起始日 # market str Y E场内O场外
1507
+ """
1508
+ Retrieves information about a fund based on the fund code.
1509
+
1510
+ Args:
1511
+ fund_code (str, optional): Fund code. Defaults to ''.
1512
+
1513
+ Returns:
1514
+ df (DataFrame): A DataFrame containing various information about the fund, including fund code, fund name,
1515
+ management company, custodian company, investment type, establishment date, maturity date,
1516
+ listing date, issuance date, delisting date, issue amount, management fee, custodian fee,
1517
+ fund duration, face value, minimum investment amount, benchmark, fund status, investment style,
1518
+ fund type, start date for daily purchases, start date for daily redemptions, and market type.
1519
+ The column 'ts_code' is renamed to 'fund_code', and 'name' is renamed to 'fund_name' in the DataFrame.
1520
+ """
1521
+ df = pro.fund_basic(**{
1522
+ "ts_code": fund_code,
1523
+ "market": "",
1524
+ "update_flag": "",
1525
+ "offset": "",
1526
+ "limit": "",
1527
+ "status": "",
1528
+ "name": ""
1529
+ }, fields=[
1530
+ "ts_code",
1531
+ "name",
1532
+ "management",
1533
+ "custodian",
1534
+ "fund_type",
1535
+ "found_date",
1536
+ "due_date",
1537
+ "list_date",
1538
+ "issue_date",
1539
+ "delist_date",
1540
+ "issue_amount",
1541
+ "m_fee",
1542
+ "c_fee",
1543
+ "duration_year",
1544
+ "p_value",
1545
+ "min_amount",
1546
+ "benchmark",
1547
+ "status",
1548
+ "invest_type",
1549
+ "type",
1550
+ "purc_startdate",
1551
+ "redm_startdate",
1552
+ "market"
1553
+ ])
1554
+ #
1555
+ df.rename(columns={'ts_code': 'fund_code'}, inplace=True)
1556
+ df.rename(columns={'name': 'fund_name'}, inplace=True)
1557
+ return df
1558
+
1559
+ def query_fund_data(fund_code: str = '', start_date: str = '', end_date: str = '') -> pd.DataFrame:
1560
+ #
1561
+ # ts_code str Y TS代码 # ann_date str Y 公告日期 # nav_date str Y 净值日期 # unit_nav float Y 单位净值 # accum_nav float Y 累计净值
1562
+ # accum_div float Y 累计分红 # net_asset float Y 资产净值 # total_netasset float Y 合计资产净值 # adj_nav float Y 复权单位净值 pct_chg 每日涨跌幅
1563
+ """
1564
+ Retrieves fund data based on the fund code, start date, and end date.
1565
+
1566
+ Args:
1567
+ fund_code (str, optional): Fund code. Defaults to ''.
1568
+ start_date (str, optional): Start date in YYYYMMDD format. Defaults to ''.
1569
+ end_date (str, optional): End date in YYYYMMDD format. Defaults to ''.
1570
+
1571
+ Returns:
1572
+ df (DataFrame): A DataFrame containing fund data, including TS code, announcement date, net asset value date,
1573
+ unit net asset value, accumulated net asset value, accumulated dividends, net asset value,
1574
+ total net asset value, adjusted unit net asset value, and fund name. The 'ts_code' column is renamed
1575
+ to 'fund_code', 'nav_date' is renamed to 'trade_date', and the DataFrame is sorted by the trade date
1576
+ in ascending order. If the fund code does not exist, None is returned.
1577
+ """
1578
+ df = pro.fund_nav(**{
1579
+ "ts_code": fund_code,
1580
+ "nav_date": "",
1581
+ "offset": "",
1582
+ "limit": "",
1583
+ "market": "",
1584
+ "start_date": start_date,
1585
+ "end_date": end_date
1586
+ }, fields=[
1587
+ "ts_code",
1588
+ "ann_date",
1589
+ "nav_date",
1590
+ "unit_nav",
1591
+ "accum_nav",
1592
+ "accum_div",
1593
+ "net_asset",
1594
+ "total_netasset",
1595
+ "adj_nav",
1596
+ "update_flag"
1597
+ ])
1598
+ try:
1599
+ fund_name= query_fund_name_or_code(fund_code=fund_code)
1600
+ df['fund_name'] = fund_name
1601
+ #
1602
+ df.rename(columns={'ts_code': 'fund_code'}, inplace=True)
1603
+ df.rename(columns={'nav_date': 'trade_date'}, inplace=True)
1604
+ df.sort_values(by='trade_date', ascending=True, inplace=True)
1605
+ except:
1606
+ print(fund_code,'基金代码不存在')
1607
+ return None
1608
+ #
1609
+ df['pct_chg'] = df['adj_nav'].pct_change()
1610
+ #
1611
+ df.loc[0, 'pct_chg'] = 0.0
1612
+
1613
+
1614
+ return df
1615
+
1616
+ def query_fund_name_or_code(fund_name: str = '', fund_code: str = '') -> str:
1617
+ #
1618
+ """
1619
+ Retrieves the fund code based on the fund name or Retrieves the fund name based on the fund code.
1620
+
1621
+ Args:
1622
+ fund_name (str, optional): Fund name. Defaults to ''.
1623
+ fund_code (str, optional): Fund code. Defaults to ''.
1624
+
1625
+ Returns:
1626
+ code or name: Fund code if fund_name is provided and fund_code is empty. Fund name if fund_code is provided and fund_name is empty.
1627
+ """
1628
+
1629
+
1630
+ #df = pd.read_csv('./tushare_fund_basic_20230508193747.csv')
1631
+ # Query the fund code based on the fund name.
1632
+ if fund_name != '' and fund_code == '':
1633
+ #
1634
+ df = pd.read_csv('./tushare_fund_basic_all.csv')
1635
+ #
1636
+ # df = pro.fund_basic(**{
1637
+ # "ts_code": "",
1638
+ # "market": "",
1639
+ # "update_flag": "",
1640
+ # "offset": "",
1641
+ # "limit": "",
1642
+ # "status": "",
1643
+ # "name": fund_name
1644
+ # }, fields=[
1645
+ # "ts_code",
1646
+ # "name"
1647
+ # ])
1648
+ try:
1649
+ #
1650
+ code = df[df['name'] == fund_name]['ts_code'].values[0]
1651
+ except:
1652
+ #print(fund_name,'基金名称不存在')
1653
+ return None
1654
+ return code
1655
+ # Query the fund name based on the fund code.
1656
+ if fund_code != '' and fund_name == '':
1657
+ df = pd.read_csv('./tushare_fund_basic_all.csv')
1658
+ try:
1659
+ name = df[df['ts_code'] == fund_code]['name'].values[0]
1660
+ except:
1661
+ #print(fund_code,'基金代码不存在')
1662
+ return None
1663
+ return name
1664
+
1665
+
1666
+
1667
+ def print_save_table(df: pd.DataFrame, title_name: str, save:bool = False ,file_path: str = './output/') -> None:
1668
+ """
1669
+ It prints the dataframe as a formatted table using the PrettyTable library and saves it to a CSV file at the specified file path.
1670
+
1671
+ Args:
1672
+ - df: the dataframe to be printed and saved to a CSV file
1673
+ - title_name: the name of the table to be printed and saved
1674
+ - save: whether to save the table to a CSV file
1675
+ - file_path: the file path where the CSV file should be saved.
1676
+
1677
+ Returns: None
1678
+ """
1679
+
1680
+ # 创建表格table.max_width = 20
1681
+
1682
+ # table = PrettyTable(df.columns.tolist())
1683
+ # table.align = 'l'
1684
+ # table.max_width = 40
1685
+ #
1686
+ # #
1687
+ # for row in df.itertuples(index=False):
1688
+ # table.add_row(row)
1689
+
1690
+ #print(table)
1691
+
1692
+
1693
+ if not os.path.exists(file_path):
1694
+ os.makedirs(file_path)
1695
+
1696
+ if file_path is not None and save == True:
1697
+ file_path = file_path + title_name + '.csv'
1698
+ df.to_csv(file_path, index=False)
1699
+ return df
1700
+
1701
+
1702
+
1703
+ #
1704
+ def merge_indicator_for_same_stock(df1: pd.DataFrame, df2: pd.DataFrame) -> pd.DataFrame:
1705
+ """
1706
+ Merges two DataFrames (two indicators of the same stock) based on common names for same stock. Data from two different stocks cannot be merged
1707
+
1708
+ Args:
1709
+ df1: DataFrame contains some indicators for stock A.
1710
+ df2: DataFrame contains other indicators for stock A.
1711
+
1712
+ Returns:
1713
+ pd.DataFrame: The merged DataFrame contains two different indicators.
1714
+ """
1715
+ if len(set(df1.columns).intersection(set(df2.columns))) > 0:
1716
+ # If there are identical column names, merge the two DataFrames based on the matching column names.
1717
+ #
1718
+ common_cols = list(set(df1.columns).intersection(set(df2.columns)))
1719
+ #
1720
+ df = pd.merge(df1, df2, on=common_cols)
1721
+ return df
1722
+ else:
1723
+ #
1724
+ raise ValueError('The two dataframes have no columns in common.')
1725
+
1726
+ def select_value_by_column(df1:pd.DataFrame, col_name: str = '', row_index: int = -1) -> Union[pd.DataFrame, Any]:
1727
+ """
1728
+ Selects a specific column or a specific value within a DataFrame.
1729
+
1730
+ Args:
1731
+ df1: The input DataFrame.
1732
+ col_name: The name of the column to be selected.
1733
+ row_index: The index of the row to be selected.
1734
+
1735
+ Returns:
1736
+ Union[pd.DataFrame, Any]. row_index=-1: df1[col_name].to_frame() or df1[col_name][row_index]
1737
+ """
1738
+ if row_index == -1:
1739
+ #
1740
+ return df1[col_name].to_frame()
1741
+ else:
1742
+ #
1743
+ return df1[col_name][row_index]
1744
+
1745
+
1746
+
1747
+ if __name__ == "__main__":
1748
+ stock_name='成都银行'
1749
+ stock_name2='五粮液'
1750
+ stock_name3 = '宁德时代'
1751
+ start = '20230104'
1752
+ end = '20230504'
1753
+ fund_name = "华商优势行业" #'易方达蓝筹精选'
1754
+
1755
+ start_quarter = '201001'
1756
+ end_quarter = '202303'
1757
+ title_name ='上证50成分股收益率'
1758
+ ax = None
1759
+ res = is_fund('易方达蓝筹精选')
1760
+ #_, ax = plt.subplots()
1761
+ # code = query_fund_name_or_code('华商优势行业')
1762
+ # ------------step1 数据查询层 获取股票代码
1763
+ # start_last_year = get_last_year_date(start)
1764
+ # end_last_year = get_last_year_date(end)
1765
+ stock_code = get_stock_code(stock_name)
1766
+ # name = get_stock_name_from_code(stock_code)
1767
+ # print(name)
1768
+ # print(stock_code)
1769
+ # stock_code2 = get_stock_code(stock_name2)
1770
+ # stock_code3 = get_stock_code(stock_name3)
1771
+ # stock_technical_data = get_Financial_data(stock_code, start, end)
1772
+ # macrodata = get_ppi_data('', start_quarter, end_quarter, 'ppi_yoy')
1773
+ # index_daily = get_index_data('沪深300',start,end,'daily')
1774
+ # index_daily2 = get_index_data('中证500',start,end,'daily')
1775
+ # index_daily3 = get_index_data('中证1000',start,end,'daily')
1776
+ # index_daily4 = get_index_data('创业板指',start,end,'daily')
1777
+ #stock_data = get_index_constituent('上证50','20230101','20230508')
1778
+ # money = get_north_south_money('20230425', '20230426')
1779
+ # stock_data = get_stock_prices_data(stock_code, start, end)
1780
+
1781
+ # stock_data = get_stock_monthly_prices_data("","", "",'20230331')
1782
+ # stock_data = get_stock_prices_data('', start, end, 'daily')
1783
+ # fund_df = query_fund_Manager('周海栋')
1784
+ #
1785
+ # fund_code = select_value_by_column(fund_df, 'fund_code', -1)
1786
+ # res_earning = loop_rank(fund_code, calculate_earning_between_two_time, start, end, 'adj_nav')
1787
+ # print(res_earning)
1788
+ #fund_code = query_fund_name_or_code(fund_name,'')
1789
+
1790
+
1791
+ #fund_data = query_fund_data(fund_code, start, end)
1792
+ #df_daily = get_daily_trading_data(stock_code,'20200101', '20230526')
1793
+ # stock_data2 = get_stock_prices_data(stock_code2, start, end,'daily')
1794
+ # stock_data3 = get_stock_prices_data(stock_code3, start, end,'daily')
1795
+
1796
+ # dynamic_new = get_latest_new_from_web('sina')
1797
+ #stock_df = get_sw_industry_stock('城商行Ⅱ','L2')
1798
+ # df_macro = get_cpi_ppi_currency_supply_data('200101','202304','cpi','nt_yoy')
1799
+ # df_macro = get_cpi_ppi_currency_supply_data('200101','202304','ppi','ppi_yoy')
1800
+ # df_macro = get_cpi_ppi_currency_supply_data('200101','202304','currency_supply','m2_yoy')
1801
+ # df_gdp = get_GDP_data('2001Q1','2023Q1','gdp_yoy')
1802
+ # df_gdp = predict_next_value(df_gdp, 'gdp_yoy', 4)
1803
+ #company_df = get_company_info('贵州茅台')
1804
+ #print_save_table(company_df, '贵州茅台公司信息')
1805
+ #fin_df = get_Financial_data_from_time_range(stock_code, '20200101', '20230526','roe')
1806
+
1807
+ #tech_df = get_stock_technical_data(stock_code, start, end)
1808
+
1809
+
1810
+
1811
+ # ----------------------------------step2 数据处理层 在截面或者时序数据-------------------------------------------------------
1812
+ # 提取相应指标, 数据处理, 排序,提取,求差,加工..,
1813
+ # fund_info = query_fund_info('005827.OF')
1814
+ # value = select_value_by_column(fund_info, 'fund_name', 0)
1815
+ #fund_index = calculate_stock_index(fund_data,'adj_nav')
1816
+ #stock_index = rank_index_cross_section(stock_data, 'pct_chg', -1, False)
1817
+ #stock_index = calculate_stock_index(stock_data, 'pct_chg')
1818
+ #stock_index_each_day = calculate_stock_index(money, 'north_money')
1819
+ #stock_index = calculate_stock_index(fin_df, 'roe')
1820
+ # stock_index2 = calculate_stock_index(stock_data2, 'Cumulative_Earnings_Rate')
1821
+ # stock_index3 = calculate_stock_index(stock_data3, 'Cumulative_Earnings_Rate')
1822
+ # stock_index4 = calculate_stock_index(index_daily4, 'Cumulative_Earnings_Rate')
1823
+ # stock_index2 = calculate_stock_index(stock_data2, 'Cumulative_Earnings_Rate')
1824
+ #stock_index = calculate_stock_index(stock_data1, 'close')
1825
+ #stock_index2 = calculate_stock_index(tech_df, 'macd')
1826
+ #stock_index1 = calculate_stock_index(stock_data, 'candle_K')
1827
+ #stock_index2 = calculate_stock_index(df_daily, 'pe_ttm')
1828
+ #merge_df = merge_data(stock_index1, stock_index2)
1829
+ #res_earning = loop_rank(stock_data, 'stock_name', calculate_earning_between_two_time, start, end)
1830
+ # index_profit_yoy = loop_rank(stock_data, 'stock_name', get_Financial_data, start, end, 'profit_dedt')
1831
+ # index_profit_yoy = loop_rank(stock_data, 'stock_name', get_Financial_data, start, end, 'netprofit_yoy')
1832
+
1833
+ #res_earning_top_n = rank_index_cross_section(stock_index, 10, False)
1834
+ #index_profit_yoy_last = loop_rank(stock_data, 'stock_name', get_Financial_data, start_last_year, end_last_year, 'profit_dedt')
1835
+ # profit_yoy = calculate_stock_index(stock_technical_data, 'dt_netprofit_yoy')
1836
+ # accumulate_north_month = calculate_stock_index(money, 'accumulate_south_money')
1837
+ # accumulate_north_month = calculate_stock_index(res_earning, 'accumulate_south_money')
1838
+ # stock_code = get_stock_code(stock_name)
1839
+ # fin_df1 = get_Financial_data_from_time_range(stock_code, '20150101', '20230526', 'roa')
1840
+ # fin_df2 = get_Financial_data_from_time_range(stock_code, '20150101', '20230526', 'roa')
1841
+ # ax = plot_stock_data(fin_df1, ax, 'line', title_name)
1842
+ # ax = plot_stock_data(fin_df2, ax, 'line', title_name)
1843
+ #stock_data = get_index_constituent('上证50','20220105', '20230505')
1844
+ # stock_data = get_index_constituent('申万二级行业城商行Ⅱ','20220105', '20220505')
1845
+ # #stock_list = select_value_by_column(stock_data, 'stock_name', -1)
1846
+ #
1847
+ # index_profit_yoy = loop_rank(stock_list, get_Financial_data, start, 'netprofit_yoy')
1848
+ # median = output_median_col(index_profit_yoy, 'new_feature')
1849
+ # ax = plot_stock_data(index_profit_yoy, ax, 'bar', '上证50的最近季度归母净利润同比增长率')
1850
+
1851
+
1852
+
1853
+
1854
+
1855
+
1856
+ # ----------------------------------step3 可视化层:文字,图片,表格等多种模态数据输出-------------------------------------------------------
1857
+ #ax = plot_stock_data(stock_index, ax, 'line', title_name)
1858
+ #ax = plot_stock_data(stock_index_each_day, ax, 'bar', title_name)
1859
+ #print_save_table(fund_info, title_name)
1860
+
1861
+ #_, sum_new = output_mean_sum_col(index_profit_yoy,'new_feature')
1862
+ #_, sum_old = output_mean_sum_col(index_profit_yoy_last,'new_feature')
1863
+
1864
+
1865
+ #print('科创50成分股的最近季度归母净利润同比增长率中位数%:', median)
1866
+ #dt = cal_dt(sum_new, sum_old)
1867
+ #print('上证50成分股的最近季度归母净利润同比增长率:',dt)
1868
+
1869
+ #plot_k_line(merge_df, title_name)
1870
+ # ax = plot_stock_data(index_profit_yoy, ax, 'bar', '上证50成分股的最近季度归母净利润同比增长率')
1871
+ #ax = plot_stock_data(accumulate_north_month, ax, 'line', '2023年1月至4月南向资金累计流向')
1872
+
1873
+ # ax2 = plot_stock_data(stock_index2, ax1, 'line', '贵州茅台VS五粮液近十年收益率对比图')
1874
+ # ax = plot_stock_data(stock_index, ax,'line', title_name)
1875
+ # ax = plot_stock_data(stock_index2, ax,'line', title_name)
1876
+ # ax = plot_stock_data(stock_index3, ax,'line', title_name)
1877
+ # ax = plot_stock_data(stock_index4, ax,'line', title_name)
1878
+
1879
+ #ax = plot_stock_data(df_gdp, ax, 'line','2010-2022年国内每季度gdp增速同比')
1880
+ print_save_table(df_gdp,'GDP预测',True)
1881
+
1882
+ # show_dynamic_table(dynamic_new)
1883
+
1884
+
1885
+ # ax = plot_stock_data(res_earning, None, 'bar', '张坤管理各个基金收益率')
1886
+ # stock_data = get_index_constituent('上证50', '20230101', '20230508')
1887
+ # stock_list = select_value_by_column(stock_data, 'stock_name', -1)
1888
+ # res_earning = loop_rank(stock_list, calculate_earning_between_two_time, start, end)
1889
+ # res_earning_top_n = rank_index_cross_section(res_earnng, 10, False)
1890
+ # ax = plot_stock_data(res_earning_top_n, ax, 'bar', title_name)
1891
+
1892
+ # stock_data = get_index_constituent('上证50', '20230101', '20230508')
1893
+ # stock_list = select_value_by_column(stock_data, 'stock_name', -1)
1894
+ # res_earning = loop_rank(stock_list, calculate_earning_between_two_time, '20230101', '20230508')
1895
+ # res_earning_top_n = rank_index_cross_section(res_earning, 10, False)
1896
+ # ax = plot_stock_data(res_earning_top_n, ax, 'bar', title_name)
1897
+
1898
+ # fund_code = query_fund_name_or_code(fund_name, '')
1899
+ # fund_data = query_fund_data(fund_code, start, end)
1900
+ # fund_index = calculate_stock_index(fund_data, 'adj_nav')
1901
+ # ax = plot_stock_data(fund_index, ax, 'line', title_name)
1902
+ # fund_df = query_fund_Manager('张坤')
1903
+ # fund_code = select_value_by_column(fund_df, 'fund_code', -1)
1904
+ # res_earning = loop_rank(fund_code, calculate_earning_between_two_time, start, end, 'adj_nav')
1905
+ # ax = plot_stock_data(res_earning, None, 'bar', '张坤管理各个基金收益率')
1906
+ # company_df = get_company_info('贵州茅台')
1907
+ # print_save_table(company_df,'gzmt', False)
1908
+
1909
+
1910
+
1911
+
1912
+ if ax is not None:
1913
+ plt.grid()
1914
+ plt.show()
1915
+
1916
+
1917
+
1918
+ # xxx基金经理管理的几只基金中,收益率最高的那只基金的规模是多少----找基金经理search,按收益率排序rank,找到收益率最高的那个select,显示基金信息 show
1919
+ # 食品饮料行业中所有股票近十年涨幅最大的股票的信息----找行业search(行业分类--找到行业代码,根据行业代码找到股票成分), 收益率排序rank,找到涨幅最大的那个select,显示股票信息show
1920
+
1921
+
1922
+
1923
+
1924
+
1925
+
1926
+
1927
+
1928
+
1929
+
1930
+
1931
+
tushare_daily_20230421211129.csv ADDED
The diff for this file is too large to render. See raw diff
 
tushare_fund_basic_20230508193747.csv ADDED
The diff for this file is too large to render. See raw diff
 
tushare_fund_basic_20230516041211.csv ADDED
The diff for this file is too large to render. See raw diff
 
tushare_fund_basic_20230605184116.csv ADDED
The diff for this file is too large to render. See raw diff
 
tushare_fund_basic_20230605184607.csv ADDED
The diff for this file is too large to render. See raw diff
 
tushare_fund_basic_all.csv ADDED
The diff for this file is too large to render. See raw diff
 
tushare_index_basic_20230427223903.csv ADDED
The diff for this file is too large to render. See raw diff
 
tushare_stock_basic_20230421210721.csv ADDED
The diff for this file is too large to render. See raw diff