820nam commited on
Commit
2c8cb06
ยท
verified ยท
1 Parent(s): 04a8794

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -144
app.py CHANGED
@@ -2,17 +2,28 @@ import streamlit as st
2
  import requests
3
  import openai
4
  import os
 
5
  from sklearn.feature_extraction.text import TfidfVectorizer
6
  from sklearn.linear_model import LogisticRegression
7
  from sklearn.model_selection import train_test_split, cross_val_score
8
- from sklearn.metrics import accuracy_score
9
  import joblib
10
- from sklearn.model_selection import GridSearchCV
 
 
 
 
11
 
12
  # OpenAI API ํ‚ค ์„ค์ •
13
  openai.api_key = os.getenv("OPENAI_API_KEY")
14
 
15
- # ๋„ค์ด๋ฒ„ ๋‰ด์Šค API๋ฅผ ํ†ตํ•ด ๋‰ด์Šค ๊ธฐ์‚ฌ ๊ฐ€์ ธ์˜ค๊ธฐ
 
 
 
 
 
 
16
  def fetch_naver_news(query, display=5):
17
  client_id = "I_8koTJh3R5l4wLurQbG" # ๋„ค์ด๋ฒ„ ๊ฐœ๋ฐœ์ž ์„ผํ„ฐ์—์„œ ๋ฐœ๊ธ‰๋ฐ›์€ Client ID
18
  client_secret = "W5oWYlAgur" # ๋„ค์ด๋ฒ„ ๊ฐœ๋ฐœ์ž ์„ผํ„ฐ์—์„œ ๋ฐœ๊ธ‰๋ฐ›์€ Client Secret
@@ -26,166 +37,93 @@ def fetch_naver_news(query, display=5):
26
  "query": query,
27
  "display": display,
28
  "start": 1,
29
- "sort": "date", # ์ตœ์‹ ์ˆœ์œผ๋กœ ์ •๋ ฌ
30
  }
31
 
32
  response = requests.get(url, headers=headers, params=params)
33
  if response.status_code == 200:
34
- news_data = response.json()
35
- return news_data['items']
36
  else:
37
  st.error("๋‰ด์Šค ๋ฐ์ดํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๋Š” ๋ฐ ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค.")
38
  return []
39
 
40
- # ๋จธ์‹ ๋Ÿฌ๋‹ ๋ชจ๋ธ ํ•™์Šต ๋ฐ ๊ฐœ์„ 
41
- def train_ml_model():
42
- # ์˜ˆ์‹œ ๋ฐ์ดํ„ฐ
43
- data = [
44
- ("์ง„๋ณด์ ์ธ ์ •๋ถ€ ์ •์ฑ…์„ ๊ฐ•ํ™”ํ•ด์•ผ ํ•œ๋‹ค", "LEFT"),
45
- ("๋ณด์ˆ˜์ ์ธ ๊ฒฝ์ œ ์ •์ฑ…์ด ํ•„์š”ํ•˜๋‹ค", "RIGHT"),
46
- ("์ค‘๋ฆฝ์ ์ธ ์ž…์žฅ์—์„œ ์ƒํ™ฉ์„ ํ‰๊ฐ€ํ•œ๋‹ค", "NEUTRAL")
47
- ]
48
- texts, labels = zip(*data)
49
-
50
- # TF-IDF ๋ฒกํ„ฐํ™”
51
- vectorizer = TfidfVectorizer(max_features=1000)
52
- X = vectorizer.fit_transform(texts)
53
- y = labels
54
-
55
- # ํ›ˆ๋ จ ๋ฐ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ ๋‚˜๋ˆ„๊ธฐ
56
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
57
-
58
- # ๋กœ์ง€์Šคํ‹ฑ ํšŒ๊ท€ ๋ชจ๋ธ ํ•™์Šต
59
- model = LogisticRegression(max_iter=1000, solver='liblinear') # ๋” ๋งŽ์€ ๋ฐ˜๋ณต ํšŸ์ˆ˜์™€ 'liblinear' solver ์‚ฌ์šฉ
60
-
61
- # ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹ (์ •๊ทœํ™” ๊ฐ•๋„ C)
62
- param_grid = {'C': [0.1, 1, 10, 100]}
63
- grid_search = GridSearchCV(model, param_grid, cv=5)
64
- grid_search.fit(X_train, y_train)
65
- best_model = grid_search.best_estimator_
66
-
67
- # ๊ต์ฐจ ๊ฒ€์ฆ์„ ํ†ตํ•œ ํ‰๊ฐ€
68
- cv_scores = cross_val_score(best_model, X, y, cv=5)
69
- st.write(f"๊ต์ฐจ ๊ฒ€์ฆ ํ‰๊ท  ์ •ํ™•๋„: {cv_scores.mean():.2f}")
70
-
71
- # ๋ชจ๋ธ ์„ฑ๋Šฅ ํ‰๊ฐ€
72
- y_pred = best_model.predict(X_test)
73
- accuracy = accuracy_score(y_test, y_pred)
74
- st.write(f"๋ชจ๋ธ ์ •ํ™•๋„: {accuracy:.2f}")
75
 
76
  # ๋ชจ๋ธ ์ €์žฅ
77
- joblib.dump(best_model, 'political_bias_model.pkl')
78
- joblib.dump(vectorizer, 'tfidf_vectorizer.pkl')
79
-
80
- return best_model, vectorizer
81
-
82
- # ๋กœ๋“œ๋œ ๋จธ์‹ ๋Ÿฌ๋‹ ๋ชจ๋ธ๋กœ ์„ฑํ–ฅ ๋ถ„์„
83
- def analyze_article_sentiment_ml(text, model, vectorizer):
84
- X = vectorizer.transform([text])
85
- prediction = model.predict(X)[0]
86
-
87
- if prediction == "LEFT":
88
- return "์ง„๋ณด"
89
- elif prediction == "RIGHT":
90
- return "๋ณด์ˆ˜"
91
- else:
92
- return "์ค‘๋ฆฝ"
93
 
94
  # GPT-4๋ฅผ ์ด์šฉํ•ด ๋ฐ˜๋Œ€ ๊ด€์  ๊ธฐ์‚ฌ ์ƒ์„ฑ
95
  def generate_article_gpt4(prompt):
96
  try:
97
  response = openai.ChatCompletion.create(
98
- model="gpt-4",
99
- messages=[
100
  {"role": "system", "content": "You are a helpful assistant that generates articles."},
101
- {"role": "user", "content": prompt}
102
  ],
103
- max_tokens=1024,
104
- temperature=0.7
105
  )
106
  return response['choices'][0]['message']['content']
107
  except Exception as e:
108
  return f"Error generating text: {e}"
109
 
110
- # ์ •์น˜์  ๊ด€์  ๋น„๊ต ๋ฐ ๋ฐ˜๋Œ€ ๊ด€์  ์ƒ์„ฑ
111
- def analyze_news_political_viewpoint(query, model, vectorizer):
112
- news_items = fetch_naver_news(query)
113
- if not news_items:
114
- return [], {}
115
-
116
- results = []
117
- sentiment_counts = {"์ง„๋ณด": 0, "๋ณด์ˆ˜": 0, "์ค‘๋ฆฝ": 0}
118
-
119
- for item in news_items:
120
- title = item["title"]
121
- description = item["description"]
122
- link = item["link"]
123
- combined_text = f"{title}. {description}"
124
-
125
- sentiment = analyze_article_sentiment_ml(combined_text, model, vectorizer)
126
- sentiment_counts[sentiment] += 1
127
-
128
- opposite_perspective = "๋ณด์ˆ˜์ " if sentiment == "์ง„๋ณด" else "์ง„๋ณด์ "
129
- prompt = f"{combined_text}๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ {opposite_perspective} ๊ด€์ ์˜ ๊ธฐ์‚ฌ๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”."
130
- opposite_article = generate_article_gpt4(prompt)
131
-
132
- results.append({
133
- "์ œ๋ชฉ": title,
134
- "์›๋ณธ ๊ธฐ์‚ฌ": description,
135
- "์„ฑํ–ฅ": sentiment,
136
- "๋Œ€์กฐ ๊ด€์  ๊ธฐ์‚ฌ": opposite_article,
137
- "๋‰ด์Šค ๋งํฌ": link
138
- })
139
-
140
- return results, sentiment_counts
141
-
142
- # ์„ฑํ–ฅ ๋ถ„ํฌ ์‹œ๊ฐํ™”
143
- def visualize_sentiment_distribution(sentiment_counts):
144
- import matplotlib.pyplot as plt
145
- import seaborn as sns
146
-
147
- fig, ax = plt.subplots(figsize=(8, 5))
148
- labels = list(sentiment_counts.keys())
149
- sizes = list(sentiment_counts.values())
150
-
151
- color_palette = sns.color_palette("pastel")[0:len(sizes)]
152
-
153
- ax.bar(labels, sizes, color=color_palette)
154
- ax.set_xlabel('์„ฑํ–ฅ', fontsize=14)
155
- ax.set_ylabel('๊ฑด์ˆ˜', fontsize=14)
156
- ax.set_title('๋‰ด์Šค ์„ฑํ–ฅ ๋ถ„ํฌ', fontsize=16)
157
- st.pyplot(fig)
158
-
159
- # Streamlit ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜
160
- st.title("๐Ÿ“ฐ ์ •์น˜์  ๊ด€์  ๋น„๊ต ๋ถ„์„ ๋„๊ตฌ")
161
- st.markdown("๋‰ด์Šค ๊ธฐ์‚ฌ์˜ ์ •์น˜ ์„ฑํ–ฅ ๋ถ„์„๊ณผ ๋ฐ˜๋Œ€ ๊ด€์  ๊ธฐ์‚ฌ๋ฅผ ์ƒ์„ฑํ•˜์—ฌ ๋น„๊ตํ•ฉ๋‹ˆ๋‹ค.")
162
-
163
- # ๋จธ์‹ ๋Ÿฌ๋‹ ๋ชจ๋ธ ๋กœ๋“œ
164
- if not os.path.exists('political_bias_model.pkl'):
165
- model, vectorizer = train_ml_model()
166
- else:
167
- model = joblib.load('political_bias_model.pkl')
168
- vectorizer = joblib.load('tfidf_vectorizer.pkl')
169
-
170
- # ์‚ฌ์šฉ์ž๋กœ๋ถ€ํ„ฐ ๊ฒ€์ƒ‰์–ด ์ž…๋ ฅ ๋ฐ›๊ธฐ
171
- query = st.text_input("๊ฒ€์ƒ‰ ํ‚ค์›Œ๋“œ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”", value="์ •์น˜")
172
-
173
- # ๋ถ„์„ ์‹œ์ž‘ ๋ฒ„ํŠผ
174
- if st.button("๐Ÿ” ๋ถ„์„ ์‹œ์ž‘"):
175
- with st.spinner("๋ถ„์„ ์ค‘..."):
176
- analysis_results, sentiment_counts = analyze_news_political_viewpoint(query, model, vectorizer)
177
-
178
- if analysis_results:
179
- st.success("๋‰ด์Šค ๋ถ„์„์ด ์™„๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
180
-
181
- for result in analysis_results:
182
- st.subheader(result["์ œ๋ชฉ"])
183
- st.write(f"์„ฑํ–ฅ: {result['์„ฑํ–ฅ']}")
184
- st.write(f"๊ธฐ์‚ฌ: {result['์›๋ณธ ๊ธฐ์‚ฌ']}")
185
- st.write(f"[์›๋ณธ ๊ธฐ์‚ฌ ๋ณด๊ธฐ]({result['๋‰ด์Šค ๋งํฌ']})")
186
- st.write(f"๋Œ€์กฐ ๊ด€์  ๊ธฐ์‚ฌ: {result['๋Œ€์กฐ ๊ด€์  ๊ธฐ์‚ฌ']}")
187
- st.markdown("---")
188
-
189
- visualize_sentiment_distribution(sentiment_counts)
190
- else:
191
- st.warning("๊ฒ€์ƒ‰๋œ ๋‰ด์Šค๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
 
2
  import requests
3
  import openai
4
  import os
5
+ from datasets import load_dataset
6
  from sklearn.feature_extraction.text import TfidfVectorizer
7
  from sklearn.linear_model import LogisticRegression
8
  from sklearn.model_selection import train_test_split, cross_val_score
9
+ from sklearn.metrics import classification_report, accuracy_score
10
  import joblib
11
+ import matplotlib.pyplot as plt
12
+ import seaborn as sns
13
+
14
+ # Streamlit ํŽ˜์ด์ง€ ์„ค์ •
15
+ st.set_page_config(page_title="์ •์น˜์  ์„ฑํ–ฅ ๋ถ„์„", page_icon="๐Ÿ“ฐ", layout="wide")
16
 
17
  # OpenAI API ํ‚ค ์„ค์ •
18
  openai.api_key = os.getenv("OPENAI_API_KEY")
19
 
20
+ # ํ—ˆ๊น…ํŽ˜์ด์Šค ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
21
+ @st.cache_data
22
+ def load_huggingface_data():
23
+ dataset = load_dataset("jacobvs/PoliticalTweets")
24
+ return dataset
25
+
26
+ # ๋„ค์ด๋ฒ„ ๋‰ด์Šค API๋ฅผ ํ†ตํ•ด ๋‰ด์Šค ๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ
27
  def fetch_naver_news(query, display=5):
28
  client_id = "I_8koTJh3R5l4wLurQbG" # ๋„ค์ด๋ฒ„ ๊ฐœ๋ฐœ์ž ์„ผํ„ฐ์—์„œ ๋ฐœ๊ธ‰๋ฐ›์€ Client ID
29
  client_secret = "W5oWYlAgur" # ๋„ค์ด๋ฒ„ ๊ฐœ๋ฐœ์ž ์„ผํ„ฐ์—์„œ ๋ฐœ๊ธ‰๋ฐ›์€ Client Secret
 
37
  "query": query,
38
  "display": display,
39
  "start": 1,
40
+ "sort": "date",
41
  }
42
 
43
  response = requests.get(url, headers=headers, params=params)
44
  if response.status_code == 200:
45
+ return response.json()['items']
 
46
  else:
47
  st.error("๋‰ด์Šค ๋ฐ์ดํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๋Š” ๋ฐ ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค.")
48
  return []
49
 
50
+ # ํ—ˆ๊น…ํŽ˜์ด์Šค ๋ฐ์ดํ„ฐ์™€ ๋„ค์ด๋ฒ„ ๋‰ด์Šค ๋ฐ์ดํ„ฐ๋ฅผ ๊ฒฐํ•ฉ
51
+ def combine_datasets(huggingface_data, naver_data):
52
+ additional_texts = [item['title'] + ". " + item['description'] for item in naver_data]
53
+ additional_labels = ["NEUTRAL"] * len(additional_texts) # ๊ธฐ๋ณธ์ ์œผ๋กœ ์ค‘๋ฆฝ์œผ๋กœ ๋ผ๋ฒจ๋ง
54
+ hf_texts = huggingface_data['train']['text']
55
+ hf_labels = huggingface_data['train']['party']
56
+ return hf_texts + additional_texts, hf_labels + additional_labels
57
+
58
+ # ๋จธ์‹ ๋Ÿฌ๋‹ ๋ชจ๋ธ ํ•™์Šต
59
+ @st.cache_data
60
+ def train_model(X, y):
61
+ vectorizer = TfidfVectorizer(max_features=1000, stop_words="english")
62
+ X_tfidf = vectorizer.fit_transform(X)
63
+ X_train, X_test, y_train, y_test = train_test_split(X_tfidf, y, test_size=0.2, random_state=42)
64
+ model = LogisticRegression(max_iter=1000)
65
+ model.fit(X_train, y_train)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # ๋ชจ๋ธ ์ €์žฅ
68
+ joblib.dump(model, "political_tweets_model.pkl")
69
+ joblib.dump(vectorizer, "tfidf_vectorizer.pkl")
70
+ return model, vectorizer, X_test, y_test
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # GPT-4๋ฅผ ์ด์šฉํ•ด ๋ฐ˜๋Œ€ ๊ด€์  ๊ธฐ์‚ฌ ์ƒ์„ฑ
73
  def generate_article_gpt4(prompt):
74
  try:
75
  response = openai.ChatCompletion.create(
76
+ model="gpt-4",
77
+ messages=[
78
  {"role": "system", "content": "You are a helpful assistant that generates articles."},
79
+ {"role": "user", "content": prompt}
80
  ],
81
+ max_tokens=1024,
82
+ temperature=0.7
83
  )
84
  return response['choices'][0]['message']['content']
85
  except Exception as e:
86
  return f"Error generating text: {e}"
87
 
88
+ # Streamlit ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์‹œ์ž‘
89
+ st.title("๐Ÿ“ฐ ์ •์น˜์  ์„ฑํ–ฅ ๋ถ„์„ ๋ฐ ๋‰ด์Šค ๋น„๊ต ๋„๊ตฌ")
90
+ st.markdown("ํ—ˆ๊น…ํŽ˜์ด์Šค์˜ `PoliticalTweets` ๋ฐ์ดํ„ฐ์…‹๊ณผ ๋„ค์ด๋ฒ„ ๋‰ด์Šค API๋ฅผ ํ™œ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ ์„ฑํ–ฅ์„ ๋ถ„์„ํ•ฉ๋‹ˆ๋‹ค.")
91
+
92
+ # ๋ฐ์ดํ„ฐ ๋กœ๋“œ
93
+ huggingface_data = load_huggingface_data()
94
+ query = st.text_input("๋„ค์ด๋ฒ„ ๋‰ด์Šค์—์„œ ๊ฒ€์ƒ‰ํ•  ํ‚ค์›Œ๋“œ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”", value="์ •์น˜")
95
+ naver_data = fetch_naver_news(query)
96
+
97
+ if st.button("๋ฐ์ดํ„ฐ ๊ฒฐํ•ฉ ๋ฐ ํ•™์Šต"):
98
+ texts, labels = combine_datasets(huggingface_data, naver_data)
99
+ label_mapping = {"Democrat": 0, "Republican": 1, "NEUTRAL": 2}
100
+ y = [label_mapping[label] for label in labels]
101
+ model, vectorizer, X_test, y_test = train_model(texts, y)
102
+
103
+ # ์„ฑ๋Šฅ ํ‰๊ฐ€
104
+ y_pred = model.predict(X_test)
105
+ accuracy = accuracy_score(y_test, y_pred)
106
+ st.write(f"๋ชจ๋ธ ์ •ํ™•๋„: {accuracy:.2f}")
107
+ st.text("๋ถ„๋ฅ˜ ๋ฆฌํฌํŠธ:")
108
+ st.text(classification_report(y_test, y_pred, target_names=list(label_mapping.keys())))
109
+
110
+ # ์‚ฌ์šฉ์ž ์ž…๋ ฅ ๋ฐ ์˜ˆ์ธก
111
+ st.subheader("ํŠธ์œ— ๋˜๋Š” ๋‰ด์Šค ์„ฑํ–ฅ ์˜ˆ์ธก")
112
+ user_input = st.text_area("๋ถ„์„ํ•  ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”", placeholder="์˜ˆ: The government should invest more in public health.")
113
+
114
+ if st.button("์„ฑํ–ฅ ๋ถ„์„"):
115
+ vectorizer = joblib.load("tfidf_vectorizer.pkl")
116
+ model = joblib.load("political_tweets_model.pkl")
117
+ user_tfidf = vectorizer.transform([user_input])
118
+ prediction = model.predict(user_tfidf)[0]
119
+ prediction_label = list(label_mapping.keys())[prediction]
120
+ st.write(f"์˜ˆ์ธก๋œ ์„ฑํ–ฅ: {prediction_label}")
121
+
122
+ # ๋‰ด์Šค ๋ฐ์ดํ„ฐ ์‹œ๊ฐํ™”
123
+ if naver_data:
124
+ st.subheader("๋„ค์ด๋ฒ„ ๋‰ด์Šค ๋ฐ์ดํ„ฐ")
125
+ for item in naver_data:
126
+ st.write(f"์ œ๋ชฉ: {item['title']}")
127
+ st.write(f"๋‚ด์šฉ: {item['description']}")
128
+ st.write(f"[๊ธฐ์‚ฌ ๋งํฌ]({item['link']})")
129
+ st.markdown("---")