Spaces:
Build error
Build error
Commit ·
6d4a64c
1
Parent(s): c8405c4
mend indexers
Browse files
app.py
CHANGED
|
@@ -25,54 +25,28 @@ class NewsProcessor:
|
|
| 25 |
self.similarity_threshold = similarity_threshold
|
| 26 |
self.time_threshold = time_threshold
|
| 27 |
|
| 28 |
-
def mean_pooling(self, model_output, attention_mask):
|
| 29 |
-
token_embeddings = model_output[0]
|
| 30 |
-
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 31 |
-
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 32 |
-
|
| 33 |
-
def encode_text(self, text):
|
| 34 |
-
# Convert text to string and handle NaN values
|
| 35 |
-
if pd.isna(text):
|
| 36 |
-
text = ""
|
| 37 |
-
else:
|
| 38 |
-
text = str(text)
|
| 39 |
-
|
| 40 |
-
encoded_input = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
|
| 41 |
-
with torch.no_grad():
|
| 42 |
-
model_output = self.model(**encoded_input)
|
| 43 |
-
sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
|
| 44 |
-
return F.normalize(sentence_embeddings[0], p=2, dim=0).numpy()
|
| 45 |
-
|
| 46 |
-
def is_company_main_subject(self, text: str, companies: List[str]) -> Tuple[bool, str]:
|
| 47 |
-
if pd.isna(text):
|
| 48 |
-
return False, ""
|
| 49 |
-
|
| 50 |
-
text_lower = str(text).lower()
|
| 51 |
-
|
| 52 |
-
for company in companies:
|
| 53 |
-
company_lower = str(company).lower()
|
| 54 |
-
if company_lower in text_lower.split('.')[0]:
|
| 55 |
-
return True, company
|
| 56 |
-
if text_lower.count(company_lower) >= 3:
|
| 57 |
-
return True, company
|
| 58 |
-
doc = self.nlp(text_lower)
|
| 59 |
-
for sent in doc.sents:
|
| 60 |
-
if company_lower in sent.text:
|
| 61 |
-
for token in sent:
|
| 62 |
-
if token.dep_ == 'nsubj' and company_lower in token.text:
|
| 63 |
-
return True, company
|
| 64 |
-
return False, ""
|
| 65 |
-
|
| 66 |
def process_news(self, df: pd.DataFrame, progress_bar=None):
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
df = df.sort_values('datetime')
|
|
|
|
| 69 |
clusters = []
|
| 70 |
processed = set()
|
| 71 |
|
| 72 |
-
for i
|
| 73 |
if i in processed:
|
| 74 |
continue
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
cluster = [i]
|
| 77 |
processed.add(i)
|
| 78 |
text1_embedding = self.encode_text(row1['text'])
|
|
@@ -80,11 +54,16 @@ class NewsProcessor:
|
|
| 80 |
if progress_bar:
|
| 81 |
progress_bar.progress(len(processed) / len(df))
|
| 82 |
progress_bar.text(f'Processing item {len(processed)}/{len(df)}...')
|
| 83 |
-
|
| 84 |
-
|
|
|
|
| 85 |
if j in processed:
|
| 86 |
continue
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
time_diff = pd.to_datetime(row1['datetime']) - pd.to_datetime(row2['datetime'])
|
| 89 |
if abs(time_diff.total_seconds() / 3600) > self.time_threshold:
|
| 90 |
continue
|
|
@@ -95,6 +74,7 @@ class NewsProcessor:
|
|
| 95 |
is_main1, main_company1 = self.is_company_main_subject(row1['text'], row1['company_list'])
|
| 96 |
is_main2, main_company2 = self.is_company_main_subject(row2['text'], row2['company_list'])
|
| 97 |
|
|
|
|
| 98 |
companies_overlap = bool(set(row1['company_list']) & set(row2['company_list']))
|
| 99 |
|
| 100 |
if similarity >= self.similarity_threshold and companies_overlap:
|
|
@@ -105,24 +85,31 @@ class NewsProcessor:
|
|
| 105 |
|
| 106 |
result_data = []
|
| 107 |
for cluster_id, cluster in enumerate(clusters, 1):
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
return pd.DataFrame(result_data)
|
| 128 |
|
|
@@ -184,7 +171,7 @@ def create_download_link(df: pd.DataFrame, filename: str) -> str:
|
|
| 184 |
return f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">Download {filename}</a>'
|
| 185 |
|
| 186 |
def main():
|
| 187 |
-
st.title("
|
| 188 |
st.write("Upload Excel file with columns: company, datetime, text")
|
| 189 |
|
| 190 |
uploaded_file = st.file_uploader("Choose Excel file", type=['xlsx'])
|
|
|
|
| 25 |
self.similarity_threshold = similarity_threshold
|
| 26 |
self.time_threshold = time_threshold
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def process_news(self, df: pd.DataFrame, progress_bar=None):
|
| 29 |
+
# Ensure the DataFrame is not empty
|
| 30 |
+
if df.empty:
|
| 31 |
+
return pd.DataFrame(columns=['cluster_id', 'datetime', 'company', 'main_company', 'text', 'cluster_size'])
|
| 32 |
+
|
| 33 |
+
# Create company_list safely
|
| 34 |
+
df['company_list'] = df['company'].fillna('').str.split(' | ')
|
| 35 |
df = df.sort_values('datetime')
|
| 36 |
+
|
| 37 |
clusters = []
|
| 38 |
processed = set()
|
| 39 |
|
| 40 |
+
for i in tqdm(range(len(df)), total=len(df)):
|
| 41 |
if i in processed:
|
| 42 |
continue
|
| 43 |
+
|
| 44 |
+
row1 = df.iloc[i]
|
| 45 |
+
if pd.isna(row1['text']) or not row1['company_list']:
|
| 46 |
+
processed.add(i)
|
| 47 |
+
clusters.append([i])
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
cluster = [i]
|
| 51 |
processed.add(i)
|
| 52 |
text1_embedding = self.encode_text(row1['text'])
|
|
|
|
| 54 |
if progress_bar:
|
| 55 |
progress_bar.progress(len(processed) / len(df))
|
| 56 |
progress_bar.text(f'Processing item {len(processed)}/{len(df)}...')
|
| 57 |
+
|
| 58 |
+
# Use index-based iteration instead of iterrows
|
| 59 |
+
for j in range(len(df)):
|
| 60 |
if j in processed:
|
| 61 |
continue
|
| 62 |
|
| 63 |
+
row2 = df.iloc[j]
|
| 64 |
+
if pd.isna(row2['text']) or not row2['company_list']:
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
time_diff = pd.to_datetime(row1['datetime']) - pd.to_datetime(row2['datetime'])
|
| 68 |
if abs(time_diff.total_seconds() / 3600) > self.time_threshold:
|
| 69 |
continue
|
|
|
|
| 74 |
is_main1, main_company1 = self.is_company_main_subject(row1['text'], row1['company_list'])
|
| 75 |
is_main2, main_company2 = self.is_company_main_subject(row2['text'], row2['company_list'])
|
| 76 |
|
| 77 |
+
# Safe set operation
|
| 78 |
companies_overlap = bool(set(row1['company_list']) & set(row2['company_list']))
|
| 79 |
|
| 80 |
if similarity >= self.similarity_threshold and companies_overlap:
|
|
|
|
| 85 |
|
| 86 |
result_data = []
|
| 87 |
for cluster_id, cluster in enumerate(clusters, 1):
|
| 88 |
+
try:
|
| 89 |
+
cluster_texts = df.iloc[cluster]
|
| 90 |
+
main_companies = []
|
| 91 |
+
|
| 92 |
+
for _, row in cluster_texts.iterrows():
|
| 93 |
+
if not pd.isna(row['text']) and isinstance(row['company_list'], list):
|
| 94 |
+
is_main, company = self.is_company_main_subject(row['text'], row['company_list'])
|
| 95 |
+
if is_main and company:
|
| 96 |
+
main_companies.append(company)
|
| 97 |
+
|
| 98 |
+
main_company = main_companies[0] if main_companies else "Multiple/Unclear"
|
| 99 |
+
|
| 100 |
+
for idx in cluster:
|
| 101 |
+
row_data = df.iloc[idx]
|
| 102 |
+
result_data.append({
|
| 103 |
+
'cluster_id': cluster_id,
|
| 104 |
+
'datetime': row_data['datetime'],
|
| 105 |
+
'company': ' | '.join(row_data['company_list']) if isinstance(row_data['company_list'], list) else '',
|
| 106 |
+
'main_company': main_company,
|
| 107 |
+
'text': row_data['text'],
|
| 108 |
+
'cluster_size': len(cluster)
|
| 109 |
+
})
|
| 110 |
+
except Exception as e:
|
| 111 |
+
print(f"Error processing cluster {cluster_id}: {str(e)}")
|
| 112 |
+
continue
|
| 113 |
|
| 114 |
return pd.DataFrame(result_data)
|
| 115 |
|
|
|
|
| 171 |
return f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">Download {filename}</a>'
|
| 172 |
|
| 173 |
def main():
|
| 174 |
+
st.title("кластеризуем новости v.1.2")
|
| 175 |
st.write("Upload Excel file with columns: company, datetime, text")
|
| 176 |
|
| 177 |
uploaded_file = st.file_uploader("Choose Excel file", type=['xlsx'])
|