pentarosarium commited on
Commit
2412746
·
0 Parent(s):

initial commit

Browse files
Files changed (2) hide show
  1. app.py +185 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import numpy as np
5
+ from huggingface_hub import HfApi, InferenceClient
6
+ from transformers import pipeline
7
+ from datetime import datetime
8
+ import io
9
+ import base64
10
+ from typing import Dict, List, Set, Tuple
11
+ from rapidfuzz import fuzz, process
12
+ from collections import defaultdict
13
+ from tqdm.auto import tqdm
14
+
15
+ # Initialize HuggingFace client with token
16
+ @st.cache_resource
17
+ def get_hf_client():
18
+ token = st.secrets["hf_token"]
19
+ return InferenceClient(token=token)
20
+
21
+ @st.cache_resource
22
+ def get_embeddings_pipeline():
23
+ return pipeline("feature-extraction",
24
+ model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
25
+ token=st.secrets["hf_token"])
26
+
27
+ class NewsProcessor:
28
+ def __init__(self, similarity_threshold=0.75, time_threshold=24):
29
+ self.client = get_hf_client()
30
+ self.embeddings_pipeline = get_embeddings_pipeline()
31
+ self.similarity_threshold = similarity_threshold
32
+ self.time_threshold = time_threshold
33
+
34
+ def encode_text(self, text):
35
+ embeddings = self.embeddings_pipeline(text)
36
+ return np.mean(embeddings[0], axis=0)
37
+
38
+ def process_news(self, df: pd.DataFrame, progress_bar=None) -> pd.DataFrame:
39
+ df['company_list'] = df['company'].str.split(' | ')
40
+ df = df.sort_values('datetime')
41
+
42
+ clusters = []
43
+ processed = set()
44
+ total_items = len(df)
45
+
46
+ for i, row1 in df.iterrows():
47
+ if i in processed:
48
+ continue
49
+
50
+ cluster = [i]
51
+ processed.add(i)
52
+ text1_embedding = self.encode_text(row1['text'])
53
+
54
+ if progress_bar:
55
+ progress_bar.progress(len(processed) / total_items)
56
+
57
+ for j, row2 in df.iterrows():
58
+ if j in processed:
59
+ continue
60
+
61
+ time_diff = abs(pd.to_datetime(row1['datetime']) - pd.to_datetime(row2['datetime']))
62
+ if time_diff.total_seconds() / 3600 > self.time_threshold:
63
+ continue
64
+
65
+ text2_embedding = self.encode_text(row2['text'])
66
+ similarity = np.dot(text1_embedding, text2_embedding)
67
+
68
+ if similarity >= self.similarity_threshold:
69
+ companies_overlap = bool(set(row1['company_list']) & set(row2['company_list']))
70
+ if companies_overlap:
71
+ cluster.append(j)
72
+ processed.add(j)
73
+
74
+ clusters.append(cluster)
75
+
76
+ return self._create_result_df(df, clusters)
77
+
78
+ def _create_result_df(self, df: pd.DataFrame, clusters: List[List[int]]) -> pd.DataFrame:
79
+ result_data = []
80
+ for cluster_id, cluster in enumerate(clusters, 1):
81
+ cluster_texts = df.iloc[cluster]
82
+ for idx in cluster:
83
+ result_data.append({
84
+ 'cluster_id': cluster_id,
85
+ 'datetime': df.iloc[idx]['datetime'],
86
+ 'company': ' | '.join(df.iloc[idx]['company_list']),
87
+ 'text': df.iloc[idx]['text'],
88
+ 'cluster_size': len(cluster)
89
+ })
90
+
91
+ return pd.DataFrame(result_data)
92
+
93
+ class NewsDeduplicator:
94
+ def __init__(self, fuzzy_threshold=85):
95
+ self.fuzzy_threshold = fuzzy_threshold
96
+
97
+ def deduplicate(self, df: pd.DataFrame, progress_bar=None) -> pd.DataFrame:
98
+ seen_texts: List[str] = []
99
+ text_to_companies: Dict[str, Set[str]] = defaultdict(set)
100
+ indices_to_keep: Set[int] = set()
101
+
102
+ for idx, row in df.iterrows():
103
+ text = str(row['text'])
104
+ company = str(row['company'])
105
+
106
+ if seen_texts:
107
+ result = process.extractOne(
108
+ text,
109
+ seen_texts,
110
+ scorer=fuzz.ratio,
111
+ score_cutoff=self.fuzzy_threshold
112
+ )
113
+ match = result[0] if result else None
114
+ else:
115
+ match = None
116
+
117
+ if match:
118
+ text_to_companies[match].add(company)
119
+ else:
120
+ seen_texts.append(text)
121
+ text_to_companies[text].add(company)
122
+ indices_to_keep.add(idx)
123
+
124
+ if progress_bar:
125
+ progress_bar.progress((idx + 1) / len(df))
126
+
127
+ dedup_df = df.iloc[list(indices_to_keep)].copy()
128
+
129
+ for idx in indices_to_keep:
130
+ text = str(df.iloc[idx]['text'])
131
+ companies = sorted(text_to_companies[text])
132
+ dedup_df.at[idx, 'company'] = ' | '.join(companies)
133
+
134
+ return dedup_df.sort_values('datetime')
135
+
136
+ def create_download_link(df: pd.DataFrame, filename: str) -> str:
137
+ excel_buffer = io.BytesIO()
138
+ df.to_excel(excel_buffer, index=False)
139
+ excel_buffer.seek(0)
140
+ b64 = base64.b64encode(excel_buffer.read()).decode()
141
+ return f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">Download {filename}</a>'
142
+
143
+ def main():
144
+ st.title("News Clustering App")
145
+
146
+ st.write("Upload Excel file with columns: company, datetime, text")
147
+
148
+ uploaded_file = st.file_uploader("Choose file", type=['xlsx'])
149
+
150
+ if uploaded_file:
151
+ df = pd.read_excel(uploaded_file)
152
+ st.dataframe(df.head())
153
+
154
+ col1, col2 = st.columns(2)
155
+
156
+ with col1:
157
+ fuzzy_threshold = st.slider("Fuzzy Match Threshold", 30, 100, 50)
158
+
159
+ with col2:
160
+ similarity_threshold = st.slider("Similarity Threshold", 0.5, 1.0, 0.75)
161
+ time_threshold = st.slider("Time Threshold (hours)", 1, 72, 24)
162
+
163
+ if st.button("Process"):
164
+ try:
165
+ progress_bar = st.progress(0)
166
+
167
+ deduplicator = NewsDeduplicator(fuzzy_threshold)
168
+ dedup_df = deduplicator.deduplicate(df, progress_bar)
169
+ st.success(f"Removed {len(df) - len(dedup_df)} duplicates")
170
+
171
+ processor = NewsProcessor(similarity_threshold, time_threshold)
172
+ result_df = processor.process_news(dedup_df, progress_bar)
173
+ st.success(f"Found {result_df['cluster_id'].nunique()} clusters")
174
+
175
+ st.markdown(create_download_link(result_df, "clustered_news.xlsx"), unsafe_allow_html=True)
176
+
177
+ st.dataframe(result_df)
178
+
179
+ except Exception as e:
180
+ st.error(f"Error: {str(e)}")
181
+ finally:
182
+ progress_bar.empty()
183
+
184
+ if __name__ == "__main__":
185
+ main()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ pandas
3
+ numpy
4
+ transformers
5
+ rapidfuzz
6
+ huggingface-hub
7
+ openpyxl