Spaces:
Running
Running
Commit
·
f7c1877
1
Parent(s):
73cc4bb
app.py
CHANGED
|
@@ -194,24 +194,24 @@ def update_stock(category, stock):
|
|
| 194 |
stock_plot: gr.update(value=None)
|
| 195 |
}
|
| 196 |
|
| 197 |
-
def predict_stock(category, stock, stock_item, features, model_type):
|
| 198 |
if not all([category, stock, stock_item]):
|
| 199 |
-
return gr.update(value=None)
|
| 200 |
|
| 201 |
try:
|
| 202 |
url = next((item['網址'] for item in category_dict.get(category, [])
|
| 203 |
if item['類股'] == stock), None)
|
| 204 |
if not url:
|
| 205 |
-
return gr.update(value=None)
|
| 206 |
|
| 207 |
stock_items = get_stock_items(url)
|
| 208 |
stock_code = stock_items.get(stock_item, "")
|
| 209 |
|
| 210 |
if not stock_code:
|
| 211 |
-
return gr.update(value=None)
|
| 212 |
|
| 213 |
# 下載股票數據
|
| 214 |
-
df = yf.download(stock_code, period=
|
| 215 |
if df.empty:
|
| 216 |
raise ValueError("無法獲取股票數據")
|
| 217 |
|
|
@@ -248,13 +248,20 @@ def predict_stock(category, stock, stock_item, features, model_type):
|
|
| 248 |
colors = ['#FF9999', '#66B2FF']
|
| 249 |
labels = ['預測開盤價', '預測收盤價']
|
| 250 |
|
| 251 |
-
for i, (label, color) in enumerate(
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
|
| 260 |
ax.set_xlabel('日期', labelpad=10)
|
|
@@ -263,11 +270,11 @@ def predict_stock(category, stock, stock_item, features, model_type):
|
|
| 263 |
ax.grid(True, linestyle='--', alpha=0.7)
|
| 264 |
|
| 265 |
plt.tight_layout()
|
| 266 |
-
return gr.update(value=fig)
|
| 267 |
|
| 268 |
except Exception as e:
|
| 269 |
logging.error(f"預測過程發生錯誤: {str(e)}")
|
| 270 |
-
return gr.update(value=None)
|
| 271 |
|
| 272 |
# 初始化
|
| 273 |
setup_font()
|
|
@@ -330,7 +337,7 @@ with gr.Blocks() as demo:
|
|
| 330 |
|
| 331 |
predict_button.click(
|
| 332 |
predict_stock,
|
| 333 |
-
inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, features_checkboxes, model_type_dropdown],
|
| 334 |
outputs=[stock_plot, status_textbox]
|
| 335 |
)
|
| 336 |
|
|
|
|
| 194 |
stock_plot: gr.update(value=None)
|
| 195 |
}
|
| 196 |
|
| 197 |
+
def predict_stock(category, stock, stock_item, period, features, model_type):
|
| 198 |
if not all([category, stock, stock_item]):
|
| 199 |
+
return gr.update(value=None), "請選擇完整的選項"
|
| 200 |
|
| 201 |
try:
|
| 202 |
url = next((item['網址'] for item in category_dict.get(category, [])
|
| 203 |
if item['類股'] == stock), None)
|
| 204 |
if not url:
|
| 205 |
+
return gr.update(value=None), "無法找到該類股的網址"
|
| 206 |
|
| 207 |
stock_items = get_stock_items(url)
|
| 208 |
stock_code = stock_items.get(stock_item, "")
|
| 209 |
|
| 210 |
if not stock_code:
|
| 211 |
+
return gr.update(value=None), "無法找到該股票的代碼"
|
| 212 |
|
| 213 |
# 下載股票數據
|
| 214 |
+
df = yf.download(stock_code, period=period)
|
| 215 |
if df.empty:
|
| 216 |
raise ValueError("無法獲取股票數據")
|
| 217 |
|
|
|
|
| 248 |
colors = ['#FF9999', '#66B2FF']
|
| 249 |
labels = ['預測開盤價', '預測收盤價']
|
| 250 |
|
| 251 |
+
for i, (label, color) in enumerate(labels):
|
| 252 |
+
if model_type == "Prophet":
|
| 253 |
+
ax.plot(date_labels, all_predictions, label='預測收盤價', marker='o', color=colors[1], linewidth=2)
|
| 254 |
+
for j, value in enumerate(all_predictions):
|
| 255 |
+
ax.annotate(f'{value:.2f}', (date_labels[j], value),
|
| 256 |
+
textcoords="offset points", xytext=(0,10),
|
| 257 |
+
ha='center', va='bottom')
|
| 258 |
+
break
|
| 259 |
+
else:
|
| 260 |
+
ax.plot(date_labels, all_predictions[:, i], label=label, marker='o', color=color, linewidth=2)
|
| 261 |
+
for j, value in enumerate(all_predictions[:, i]):
|
| 262 |
+
ax.annotate(f'{value:.2f}', (date_labels[j], value),
|
| 263 |
+
textcoords="offset points", xytext=(0,10),
|
| 264 |
+
ha='center', va='bottom')
|
| 265 |
|
| 266 |
ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
|
| 267 |
ax.set_xlabel('日期', labelpad=10)
|
|
|
|
| 270 |
ax.grid(True, linestyle='--', alpha=0.7)
|
| 271 |
|
| 272 |
plt.tight_layout()
|
| 273 |
+
return gr.update(value=fig), "預測成功"
|
| 274 |
|
| 275 |
except Exception as e:
|
| 276 |
logging.error(f"預測過程發生錯誤: {str(e)}")
|
| 277 |
+
return gr.update(value=None), f"預測過程發生錯誤: {str(e)}"
|
| 278 |
|
| 279 |
# 初始化
|
| 280 |
setup_font()
|
|
|
|
| 337 |
|
| 338 |
predict_button.click(
|
| 339 |
predict_stock,
|
| 340 |
+
inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, period_dropdown, features_checkboxes, model_type_dropdown],
|
| 341 |
outputs=[stock_plot, status_textbox]
|
| 342 |
)
|
| 343 |
|