KYTHY commited on
Commit
5665b46
·
verified ·
1 Parent(s): 1cc85d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -66
app.py CHANGED
@@ -9,7 +9,7 @@ import plotly.graph_objects as go
9
  from plotly.subplots import make_subplots
10
  import yfinance as yf
11
  import torch
12
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
13
 
14
  # --------------------------
15
  # CONFIG
@@ -29,21 +29,19 @@ def load_finbert():
29
  tokenizer, model = load_finbert()
30
 
31
  # --------------------------
32
- # โหลด Zero-shot classifier สำหรับธีมข่าว
33
- # --------------------------
34
- @st.cache_resource
35
- def load_theme_classifier():
36
- return pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
37
-
38
- theme_classifier = load_theme_classifier()
39
- candidate_labels = ["Stock Movement", "Earnings", "M&A", "Regulation", "Product Launch", "Market Analysis"]
40
-
41
- # --------------------------
42
- # โหลด Pegasus สำหรับสรุปข่าว
43
  # --------------------------
44
  @st.cache_resource
45
  def load_summarizer():
46
- return pipeline("summarization", model="Nerdward/financial-summarization-pegasus-finetuned-pytorch-model")
 
 
 
 
 
 
 
 
47
 
48
  summarizer = load_summarizer()
49
 
@@ -72,49 +70,12 @@ def analyze_text(text):
72
  score = (-1 * probs[0]) + (0 * probs[1]) + (1 * probs[2])
73
  return float(score)
74
 
75
- def summarize_texts(news_texts):
76
- """สรุปข่าวแต่ละข่าว 1 พารากราฟ พร้อม progress bar"""
77
- summaries = []
78
- progress_text = st.empty()
79
- progress_bar = st.progress(0)
80
- total = len(news_texts)
81
-
82
- for i, text in enumerate(news_texts):
83
- if not text.strip():
84
- summaries.append("")
85
- else:
86
- try:
87
- summary = summarizer(text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
88
- summaries.append(summary)
89
- except:
90
- summaries.append(text)
91
- progress_text.text(f"กำลังสรุปข่าว {i+1}/{total}")
92
- progress_bar.progress((i+1)/total)
93
- progress_bar.empty()
94
- progress_text.empty()
95
- return summaries
96
-
97
- def summarize_themes(news_texts):
98
- """สรุปธีมข่าวแต่ละข่าว พร้อม progress bar"""
99
- themes = []
100
- progress_text = st.empty()
101
- progress_bar = st.progress(0)
102
- total = len(news_texts)
103
-
104
- for i, text in enumerate(news_texts):
105
- if not text.strip():
106
- themes.append("Unknown")
107
- else:
108
- try:
109
- result = theme_classifier(text, candidate_labels)
110
- themes.append(result["labels"][0])
111
- except:
112
- themes.append("Unknown")
113
- progress_text.text(f"กำลังสรุปธีมข่าว {i+1}/{total}")
114
- progress_bar.progress((i+1)/total)
115
- progress_bar.empty()
116
- progress_text.empty()
117
- return themes
118
 
119
  # --------------------------
120
  # แปลงชื่อ/ตัวย่อ → (Company Name, Symbol)
@@ -123,6 +84,7 @@ def resolve_company_symbol(keyword: str):
123
  keyword = keyword.strip()
124
  ticker = None
125
  name = None
 
126
  try:
127
  data = yf.Ticker(keyword)
128
  info = data.info
@@ -138,10 +100,12 @@ def resolve_company_symbol(keyword: str):
138
  name = q.get("longname", q.get("shortname", keyword))
139
  except:
140
  pass
 
141
  if not ticker:
142
  ticker = keyword.upper()
143
  if not name:
144
  name = keyword.capitalize()
 
145
  return name, ticker
146
 
147
  # --------------------------
@@ -152,6 +116,7 @@ def fetch_financial_news(keyword):
152
  company, symbol = resolve_company_symbol(keyword)
153
  to_date = datetime.now().strftime('%Y-%m-%d')
154
  from_date = (datetime.now() - timedelta(days=7)).strftime('%Y-%m-%d')
 
155
  query_keyword = f"({company} OR {symbol}) finance stock"
156
 
157
  all_articles = []
@@ -198,14 +163,18 @@ def fetch_stock_price(symbol, start_date, end_date):
198
  start_str = (start_date - timedelta(days=2)).strftime('%Y-%m-%d')
199
  end_str = (end_date + timedelta(days=1)).strftime('%Y-%m-%d')
200
  df = yf.download(symbol, start=start_str, end=end_str, interval="1d")
 
201
  if df.empty:
202
  st.warning("ไม่พบข้อมูลราคาหุ้น")
203
  return pd.DataFrame()
 
204
  df = df.reset_index()
205
  df_subset = df[['Date', 'Close']]
206
  df_subset.columns = ['date', 'price']
207
  df_subset["date"] = pd.to_datetime(df_subset["date"].dt.date)
 
208
  return df_subset
 
209
  except Exception as e:
210
  st.warning(f"ดึงราคาหุ้นล้มเหลว: {e}")
211
  return pd.DataFrame()
@@ -238,14 +207,6 @@ def main():
238
  news_df["sentiment"] = news_df["text"].apply(analyze_text)
239
  news_df["date"] = pd.to_datetime(news_df["date"])
240
 
241
- # สรุปข่าวเป็น 1 พารากราฟ พร้อม progress bar
242
- st.info("กำลังสรุปเนื้อหาข่าว...")
243
- news_df["text"] = summarize_texts(news_df["text"].tolist())
244
-
245
- # สรุปธีมข่าวพร้อม progress bar
246
- st.info("กำลังสรุปธีมข่าว...")
247
- news_df["theme"] = summarize_themes(news_df["text"].tolist())
248
-
249
  # Metrics
250
  avg_sentiment = news_df["sentiment"].mean()
251
  pos_pct = (news_df["sentiment"] > 0.1).mean() * 100
@@ -256,9 +217,114 @@ def main():
256
  col2.metric("ข่าวเชิงบวก", f"{pos_pct:.1f}%")
257
  col3.metric("ข่าวเชิงลบ", f"{neg_pct:.1f}%")
258
 
259
- # แสดงรายการข่าว
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  st.subheader("📰 รายการข่าวทั้งหมด")
261
- st.dataframe(news_df[["date", "source", "text", "sentiment", "theme", "url"]], use_container_width=True)
262
 
263
  # ---------------------------------------------------------
264
  # RUN APP
 
9
  from plotly.subplots import make_subplots
10
  import yfinance as yf
11
  import torch
12
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, pipeline
13
 
14
  # --------------------------
15
  # CONFIG
 
29
  tokenizer, model = load_finbert()
30
 
31
  # --------------------------
32
+ # โหลด Pegasus summarizer (slow tokenizer)
 
 
 
 
 
 
 
 
 
 
33
  # --------------------------
34
  @st.cache_resource
35
  def load_summarizer():
36
+ tokenizer_sum = AutoTokenizer.from_pretrained(
37
+ "Nerdward/financial-summarization-pegasus-finetuned-pytorch-model",
38
+ use_fast=False # ใช้ slow tokenizer
39
+ )
40
+ model_sum = AutoModelForSeq2SeqLM.from_pretrained(
41
+ "Nerdward/financial-summarization-pegasus-finetuned-pytorch-model"
42
+ )
43
+ summarizer = pipeline("summarization", model=model_sum, tokenizer=tokenizer_sum)
44
+ return summarizer
45
 
46
  summarizer = load_summarizer()
47
 
 
70
  score = (-1 * probs[0]) + (0 * probs[1]) + (1 * probs[2])
71
  return float(score)
72
 
73
+ def summarize_article(text):
74
+ """สรุปข่าวเป็น 1 พารากราฟ"""
75
+ if not text.strip():
76
+ return ""
77
+ summary_list = summarizer(text, max_length=150, min_length=50, do_sample=False)
78
+ return summary_list[0]['summary_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  # --------------------------
81
  # แปลงชื่อ/ตัวย่อ → (Company Name, Symbol)
 
84
  keyword = keyword.strip()
85
  ticker = None
86
  name = None
87
+
88
  try:
89
  data = yf.Ticker(keyword)
90
  info = data.info
 
100
  name = q.get("longname", q.get("shortname", keyword))
101
  except:
102
  pass
103
+
104
  if not ticker:
105
  ticker = keyword.upper()
106
  if not name:
107
  name = keyword.capitalize()
108
+
109
  return name, ticker
110
 
111
  # --------------------------
 
116
  company, symbol = resolve_company_symbol(keyword)
117
  to_date = datetime.now().strftime('%Y-%m-%d')
118
  from_date = (datetime.now() - timedelta(days=7)).strftime('%Y-%m-%d')
119
+
120
  query_keyword = f"({company} OR {symbol}) finance stock"
121
 
122
  all_articles = []
 
163
  start_str = (start_date - timedelta(days=2)).strftime('%Y-%m-%d')
164
  end_str = (end_date + timedelta(days=1)).strftime('%Y-%m-%d')
165
  df = yf.download(symbol, start=start_str, end=end_str, interval="1d")
166
+
167
  if df.empty:
168
  st.warning("ไม่พบข้อมูลราคาหุ้น")
169
  return pd.DataFrame()
170
+
171
  df = df.reset_index()
172
  df_subset = df[['Date', 'Close']]
173
  df_subset.columns = ['date', 'price']
174
  df_subset["date"] = pd.to_datetime(df_subset["date"].dt.date)
175
+
176
  return df_subset
177
+
178
  except Exception as e:
179
  st.warning(f"ดึงราคาหุ้นล้มเหลว: {e}")
180
  return pd.DataFrame()
 
207
  news_df["sentiment"] = news_df["text"].apply(analyze_text)
208
  news_df["date"] = pd.to_datetime(news_df["date"])
209
 
 
 
 
 
 
 
 
 
210
  # Metrics
211
  avg_sentiment = news_df["sentiment"].mean()
212
  pos_pct = (news_df["sentiment"] > 0.1).mean() * 100
 
217
  col2.metric("ข่าวเชิงบวก", f"{pos_pct:.1f}%")
218
  col3.metric("ข่าวเชิงลบ", f"{neg_pct:.1f}%")
219
 
220
+ # สรุปข่าว 1 พารากราฟ
221
+ st.info("กำลังสรุปข่าวเป็น 1 พารากราฟต่อข่าว...")
222
+ news_df["text"] = news_df["text"].apply(lambda x: summarize_article(x) if x.strip() else "")
223
+
224
+ # ---------------------------------------------------------
225
+ # ส่วนกราฟ Sentiment & Price (เหมือนเดิม)
226
+ # ---------------------------------------------------------
227
+ st.subheader("📈 แนวโน้มอารมณ์ของข่าว & ราคาหุ้น")
228
+
229
+ news_df["date_day"] = pd.to_datetime(news_df["date"].dt.date)
230
+
231
+ def sentiment_type(score):
232
+ if score > 0.1:
233
+ return "positive"
234
+ if score < -0.1:
235
+ return "negative"
236
+ return "neutral"
237
+
238
+ news_df["sentiment_type"] = news_df["sentiment"].apply(sentiment_type)
239
+
240
+ daily_avg = news_df.groupby("date_day")["sentiment"].mean().reset_index(name="avg_sentiment")
241
+ daily_counts = news_df.groupby(["date_day", "sentiment_type"]).size().unstack(fill_value=0).reset_index()
242
+
243
+ df_sorted = pd.merge(daily_avg, daily_counts, on="date_day").sort_values("date_day")
244
+
245
+ if len(df_sorted) < 2:
246
+ st.warning("ข้อมูลไม่พอสร้างแนวโน้ม")
247
+ st.dataframe(news_df)
248
+ return
249
+
250
+ # ดึงราคาหุ้น
251
+ _, symbol = resolve_company_symbol(keyword)
252
+ min_date, max_date = df_sorted["date_day"].min(), df_sorted["date_day"].max()
253
+ st.info(f"กำลังดึงราคาหุ้น {symbol} ...")
254
+ stock_df = fetch_stock_price(symbol, min_date, max_date)
255
+
256
+ plot_data = pd.merge(df_sorted, stock_df, left_on="date_day", right_on="date", how="left")
257
+
258
+ # Correlation
259
+ correlation = plot_data['price'].corr(plot_data['avg_sentiment'])
260
+ corr_text = "ไม่มีความสัมพันธ์"
261
+ if correlation > 0.5:
262
+ corr_text = "มีความสัมพันธ์ในทิศทางเดียวกัน"
263
+ elif correlation < -0.5:
264
+ corr_text = "มีความสัมพันธ์ในทิศทางตรงข้าม"
265
+ st.metric("วิเคราะห์ความสัมพันธ์ระหว่างอารมณ์ของข่าวกับราคาหุ้น (Correlation)", corr_text, f"{correlation:.2f}")
266
+
267
+ # Forecast Sentiment
268
+ plot_data["timestamp"] = (plot_data["date_day"] - plot_data["date_day"].min()).dt.days
269
+ train_data = plot_data.dropna(subset=['avg_sentiment'])
270
+
271
+ if len(train_data) >= 2:
272
+ model_lr = LinearRegression()
273
+ model_lr.fit(train_data[["timestamp"]], train_data["avg_sentiment"])
274
+
275
+ future_days = 7
276
+ future_timestamps = np.arange(
277
+ plot_data["timestamp"].max() + 1,
278
+ plot_data["timestamp"].max() + future_days + 1
279
+ )
280
+ future_dates = [plot_data["date_day"].max() + timedelta(days=i) for i in range(1, future_days + 1)]
281
+ future_preds = model_lr.predict(future_timestamps.reshape(-1, 1))
282
+
283
+ # Plot
284
+ fig = make_subplots(rows=2, cols=1, specs=[[{"secondary_y": True}], [{}]],
285
+ row_heights=[0.7, 0.3], vertical_spacing=0.1,
286
+ shared_xaxes=True)
287
+
288
+ # ราคาหุ้น
289
+ fig.add_trace(go.Scatter(x=plot_data["date_day"], y=plot_data["price"], name=f"{symbol} Price",
290
+ mode="lines+markers", line=dict(color="orange")), row=1, col=1)
291
+ # Sentiment จริง
292
+ fig.add_trace(go.Scatter(x=plot_data["date_day"], y=plot_data["avg_sentiment"], name="Actual Sentiment",
293
+ mode="lines+markers", line=dict(color="blue")), row=1, col=1, secondary_y=True)
294
+ # Sentiment พยากรณ์
295
+ if "future_preds" in locals():
296
+ fig.add_trace(go.Scatter(x=future_dates, y=future_preds, name="Predicted Sentiment",
297
+ mode="lines+markers", line=dict(color="#05a0fa", dash="dash")), row=1, col=1, secondary_y=True)
298
+ # เส้นเชื่อม Actual -> Predicted
299
+ last_actual_date = plot_data["date_day"].max()
300
+ last_actual_value = plot_data["avg_sentiment"].iloc[-1]
301
+ first_pred_date = future_dates[0]
302
+ first_pred_value = future_preds[0]
303
+ fig.add_trace(go.Scatter(x=[last_actual_date, first_pred_date],
304
+ y=[last_actual_value, first_pred_value],
305
+ mode="lines",
306
+ line=dict(color="#05a0fa", dash="dot"),
307
+ name="Connector Actual→Predicted"), row=1, col=1, secondary_y=True)
308
+
309
+ # จำนวนข่าว
310
+ for col in ["neutral", "negative", "positive"]:
311
+ if col not in plot_data.columns:
312
+ plot_data[col] = 0
313
+ fig.add_trace(go.Bar(x=plot_data["date_day"], y=plot_data["neutral"], name="Neutral",
314
+ marker_color='rgba(128, 128, 128, 0.7)'), row=2, col=1)
315
+ fig.add_trace(go.Bar(x=plot_data["date_day"], y=plot_data["negative"], name="Negative",
316
+ marker_color='rgba(255, 0, 0, 0.7)'), row=2, col=1)
317
+ fig.add_trace(go.Bar(x=plot_data["date_day"], y=plot_data["positive"], name="Positive",
318
+ marker_color='rgba(0, 128, 0, 0.7)'), row=2, col=1)
319
+
320
+ fig.update_layout(title=f"แนวโน้มอารมณ์ของข่าว + ราคาหุ้น ({symbol})",
321
+ barmode="stack", height=650, hovermode="x unified", template="plotly_white")
322
+
323
+ st.plotly_chart(fig, use_container_width=True)
324
+
325
+ # แสดงรายการข่าวทั้งหมด (text เป็นสรุปแล้ว)
326
  st.subheader("📰 รายการข่าวทั้งหมด")
327
+ st.dataframe(news_df[["date", "source", "text", "sentiment", "url"]], use_container_width=True)
328
 
329
  # ---------------------------------------------------------
330
  # RUN APP