Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| from datetime import datetime | |
| import io | |
| import base64 | |
| from typing import Dict, List, Set, Tuple | |
| from rapidfuzz import fuzz, process | |
| from collections import defaultdict | |
| from tqdm import tqdm | |
| import spacy | |
| import torch.nn.functional as F | |
| class NewsProcessor: | |
| def __init__(self, similarity_threshold=0.75, time_threshold=24): | |
| try: | |
| self.nlp = spacy.load("ru_core_news_sm") | |
| except: | |
| self.nlp = spacy.load("en_core_web_sm") | |
| import pymorphy2 | |
| self.morph = pymorphy2.MorphAnalyzer() | |
| self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/paraphrase-multilingual-mpnet-base-v2') | |
| self.model = AutoModel.from_pretrained('sentence-transformers/paraphrase-multilingual-mpnet-base-v2') | |
| self.similarity_threshold = similarity_threshold | |
| self.time_threshold = time_threshold | |
| def mean_pooling(self, model_output, attention_mask): | |
| token_embeddings = model_output[0] | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| def encode_text(self, text): | |
| # Convert text to string and handle NaN values | |
| if pd.isna(text): | |
| text = "" | |
| else: | |
| text = str(text) | |
| encoded_input = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt') | |
| with torch.no_grad(): | |
| model_output = self.model(**encoded_input) | |
| sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask']) | |
| return F.normalize(sentence_embeddings[0], p=2, dim=0).numpy() | |
| def get_company_variants(self, company_name: str) -> Set[str]: | |
| """Generate morphological variants of company name.""" | |
| if pd.isna(company_name): | |
| return set() | |
| # Clean the company name | |
| name = str(company_name).strip('"\'').strip() | |
| name = name.split(',')[0].strip() # Take first part before comma | |
| variants = set() | |
| variants.add(name.lower()) | |
| # Split into words and get significant parts | |
| words = [w for w in name.split() if len(w) >= 3] | |
| # Generate morphological variants for each significant word | |
| for word in words: | |
| parsed = self.morph.parse(word)[0] | |
| lexeme = parsed.lexeme | |
| variants.update(v.word.lower() for v in lexeme) | |
| # Add combinations of consecutive words | |
| if len(words) > 1: | |
| for i in range(len(words)-1): | |
| variants.add(f"{words[i]} {words[i+1]}".lower()) | |
| return variants | |
| def is_company_main_subject(self, title: str, text: str, company_name: str, threshold_score: float = 0.5) -> Tuple[bool, float]: | |
| """ | |
| Enhanced company subject detection using title and text with Russian language support. | |
| Returns (is_main_subject, relevance_score). | |
| """ | |
| if pd.isna(text) or pd.isna(company_name): | |
| return False, 0.0 | |
| # Ensure we have strings | |
| title = str(title) if not pd.isna(title) else "" | |
| text = str(text) if not pd.isna(text) else "" | |
| # Get company name variants | |
| company_variants = self.get_company_variants(company_name) | |
| if not company_variants: | |
| return False, 0.0 | |
| # Initialize scoring components | |
| title_score = 0.0 | |
| first_para_score = 0.0 | |
| subject_score = 0.0 | |
| frequency_score = 0.0 | |
| # Process title (weight: 0.4) | |
| title_doc = self.nlp(title.lower()) | |
| title_text = title_doc.text | |
| for variant in company_variants: | |
| if variant in title_text: | |
| title_score = 0.4 | |
| # Check if company is subject in title | |
| for token in title_doc: | |
| if variant in token.text and token.dep_ in ['nsubj', 'nsubjpass']: | |
| title_score = 0.4 | |
| break | |
| break | |
| # Process main text | |
| doc = self.nlp(text.lower()) | |
| paragraphs = [p.strip() for p in text.split('\n') if p.strip()] | |
| first_para = paragraphs[0] if paragraphs else "" | |
| # Check first paragraph (weight: 0.2) | |
| for variant in company_variants: | |
| if variant in first_para.lower(): | |
| first_para_score = 0.2 | |
| break | |
| # Analyze subject position and frequency | |
| company_mentions = 0 | |
| subject_mentions = 0 | |
| other_company_indicators = { | |
| 'компания', 'корпорация', 'фирма', 'банк', 'группа', 'холдинг', | |
| 'организация', 'предприятие', 'производитель', 'ао', 'оао', 'пао', 'нк', 'гк', | |
| 'ооо', 'лк', 'фк', 'акб', 'ук', 'зао', 'ак' | |
| } | |
| other_companies = 0 | |
| # Analyze each sentence | |
| for sent in doc.sents: | |
| sent_text = sent.text.lower() | |
| # Count company mentions and subject positions | |
| company_in_sent = False | |
| for variant in company_variants: | |
| if variant in sent_text: | |
| company_mentions += 1 | |
| company_in_sent = True | |
| # Check subject position | |
| for token in sent: | |
| if variant in token.text and token.dep_ in ['nsubj', 'nsubjpass']: | |
| subject_mentions += 1 | |
| # Count other company mentions | |
| if company_in_sent: | |
| continue | |
| for indicator in other_company_indicators: | |
| if indicator in sent_text: | |
| other_companies += 1 | |
| break | |
| # Calculate subject score (weight: 0.2) | |
| subject_score = min(0.2, (subject_mentions / max(1, company_mentions)) * 0.2) | |
| # Calculate frequency score (weight: 0.2) | |
| if company_mentions > 0: | |
| company_ratio = company_mentions / (company_mentions + other_companies + 1) | |
| frequency_score = min(0.2, company_ratio * 0.2) | |
| # Calculate final score | |
| final_score = title_score + first_para_score + subject_score + frequency_score | |
| # Apply penalties | |
| if other_companies > 5: # Too many other companies mentioned | |
| final_score *= 0.5 | |
| # Check if the company is just part of a list | |
| list_indicators = {'среди', 'включая', 'такие как', 'в том числе', 'и другие', 'а также'} | |
| for indicator in list_indicators: | |
| if indicator in text.lower(): | |
| final_score *= 0.7 | |
| return final_score >= threshold_score, final_score | |
| def process_news(self, df: pd.DataFrame, progress_bar=None): | |
| # Ensure the DataFrame is not empty | |
| if df.empty: | |
| return pd.DataFrame(columns=['cluster_id', 'datetime', 'company', 'relevance_score', 'text', 'cluster_size']) | |
| df = df.copy() # Make a copy to preserve original indices | |
| clusters = [] | |
| processed = set() | |
| for idx in df.index: # Iterate over original indices | |
| if idx in processed: | |
| continue | |
| row1 = df.loc[idx] | |
| cluster = [idx] # Store original index | |
| processed.add(idx) | |
| if not pd.isna(row1['text']): | |
| text1_embedding = self.encode_text(row1['text']) | |
| if progress_bar: | |
| progress_bar.progress(len(processed) / len(df)) | |
| for other_idx in df.index: # Iterate over original indices | |
| if other_idx in processed: | |
| continue | |
| row2 = df.loc[other_idx] | |
| if pd.isna(row2['text']): | |
| continue | |
| time_diff = pd.to_datetime(row1['datetime']) - pd.to_datetime(row2['datetime']) | |
| if abs(time_diff.total_seconds() / 3600) > self.time_threshold: | |
| continue | |
| text2_embedding = self.encode_text(row2['text']) | |
| similarity = np.dot(text1_embedding, text2_embedding) | |
| if similarity >= self.similarity_threshold: | |
| cluster.append(other_idx) | |
| processed.add(other_idx) | |
| clusters.append(cluster) | |
| # Create result DataFrame preserving original indices | |
| result_data = [] | |
| for cluster_id, cluster_indices in enumerate(clusters, 1): | |
| cluster_rows = df.loc[cluster_indices] | |
| for idx in cluster_indices: | |
| result_data.append({ | |
| 'cluster_id': cluster_id, | |
| 'datetime': df.loc[idx, 'datetime'], | |
| 'company': df.loc[idx, 'company'], | |
| 'text': df.loc[idx, 'text'], | |
| 'cluster_size': len(cluster_indices) | |
| }) | |
| result_df = pd.DataFrame(result_data, index=sum(clusters, [])) # Use original indices | |
| return result_df | |
| class NewsDeduplicator: | |
| def __init__(self, fuzzy_threshold=85): | |
| self.fuzzy_threshold = fuzzy_threshold | |
| def deduplicate(self, df: pd.DataFrame, progress_bar=None) -> pd.DataFrame: | |
| seen_texts: List[str] = [] | |
| text_to_companies: Dict[str, Set[str]] = defaultdict(set) | |
| indices_to_keep: Set[int] = set() | |
| for idx, row in tqdm(df.iterrows(), total=len(df)): | |
| text = str(row['text']) if not pd.isna(row['text']) else "" | |
| company = str(row['company']) if not pd.isna(row['company']) else "" | |
| if not text: | |
| indices_to_keep.add(idx) | |
| continue | |
| if seen_texts: | |
| result = process.extractOne( | |
| text, | |
| seen_texts, | |
| scorer=fuzz.ratio, | |
| score_cutoff=self.fuzzy_threshold | |
| ) | |
| match = result[0] if result else None | |
| else: | |
| match = None | |
| if match: | |
| text_to_companies[match].add(company) | |
| else: | |
| seen_texts.append(text) | |
| text_to_companies[text].add(company) | |
| indices_to_keep.add(idx) | |
| if progress_bar: | |
| progress_bar.progress((idx + 1) / len(df)) | |
| dedup_df = df.iloc[list(indices_to_keep)].copy() | |
| for idx in indices_to_keep: | |
| text = str(df.iloc[idx]['text']) | |
| companies = sorted(text_to_companies[text]) | |
| dedup_df.at[idx, 'company'] = ' | '.join(companies) | |
| dedup_df.at[idx, 'company_count'] = len(companies) | |
| dedup_df.at[idx, 'duplicate_count'] = len(text_to_companies[text]) | |
| return dedup_df.sort_values('datetime') | |
| def create_download_link(df: pd.DataFrame, filename: str) -> str: | |
| excel_buffer = io.BytesIO() | |
| with pd.ExcelWriter(excel_buffer, engine='openpyxl') as writer: | |
| df.to_excel(writer, index=False) | |
| excel_buffer.seek(0) | |
| b64 = base64.b64encode(excel_buffer.read()).decode() | |
| return f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">Download {filename}</a>' | |
| def main(): | |
| st.title("кластеризуем новости v.1.23 + print debug") | |
| st.write("Upload Excel file with columns: company, datetime, text") | |
| uploaded_file = st.file_uploader("Choose Excel file", type=['xlsx']) | |
| if uploaded_file: | |
| try: | |
| # Read all columns from original sheet | |
| df_original = pd.read_excel(uploaded_file, sheet_name='Публикации') | |
| st.write("Available columns:", df_original.columns.tolist()) | |
| # Create working copy with required columns | |
| df = df_original.copy() | |
| text_column = df_original.columns[6] | |
| title_column = df_original.columns[5] | |
| datetime_column = df_original.columns[3] | |
| company_column = df_original.columns[0] | |
| df = df_original[[company_column, datetime_column, title_column, text_column]].copy() | |
| df.columns = ['company', 'datetime', 'title', 'text'] | |
| st.success(f'Loaded {len(df)} records') | |
| st.dataframe(df.head()) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| fuzzy_threshold = st.slider("Fuzzy Match Threshold", 30, 100, 50) | |
| with col2: | |
| similarity_threshold = st.slider("Similarity Threshold", 0.5, 1.0, 0.75) | |
| time_threshold = st.slider("Time Threshold (hours)", 1, 72, 24) | |
| if st.button("Process News"): | |
| try: | |
| progress_bar = st.progress(0) | |
| # Step 1: Deduplicate | |
| deduplicator = NewsDeduplicator(fuzzy_threshold) | |
| dedup_df = deduplicator.deduplicate(df, progress_bar) | |
| # Preserve all columns from original DataFrame in dedup_df_full | |
| dedup_df_full = df_original.loc[dedup_df.index].copy() | |
| st.write("\nDeduplication Results:") | |
| st.write(f"Original indices: {df.index.tolist()}") | |
| st.write(f"Dedup indices: {dedup_df.index.tolist()}") | |
| st.write(f"Sample from dedup_df:") | |
| st.write(dedup_df[['company', 'text']].head()) | |
| # Step 2: Cluster deduplicated news | |
| processor = NewsProcessor(similarity_threshold, time_threshold) | |
| result_df = processor.process_news(dedup_df, progress_bar) | |
| st.write("\nClustering Results:") | |
| st.write(f"Result df indices: {result_df.index.tolist()}") | |
| # Display cluster information | |
| if len(result_df) > 0: | |
| st.write("\nCluster Details:") | |
| for cluster_id in result_df['cluster_id'].unique(): | |
| cluster_mask = result_df['cluster_id'] == cluster_id | |
| if sum(cluster_mask) > 1: # Only show multi-item clusters | |
| cluster_indices = result_df[cluster_mask].index.tolist() | |
| st.write(f"\nCluster {cluster_id}:") | |
| st.write(f"Indices: {cluster_indices}") | |
| # Show texts for verification | |
| for idx in cluster_indices: | |
| text_length = len(str(dedup_df.loc[idx, 'text'])) | |
| st.write(f"Index {idx} - Length {text_length}:") | |
| st.write(str(dedup_df.loc[idx, 'text'])[:100] + '...') | |
| # Process clusters for deletion | |
| indices_to_delete = set() | |
| if len(result_df) > 0: | |
| for cluster_id in result_df['cluster_id'].unique(): | |
| cluster_mask = result_df['cluster_id'] == cluster_id | |
| if sum(cluster_mask) > 1: | |
| cluster_indices = result_df[cluster_mask].index.tolist() | |
| text_lengths = dedup_df.loc[cluster_indices, 'text'].fillna('').str.len() | |
| longest_text_idx = text_lengths.idxmax() | |
| indices_to_delete.update(set(cluster_indices) - {longest_text_idx}) | |
| st.write("\nDeletion Summary:") | |
| st.write(f"Indices to delete: {sorted(list(indices_to_delete))}") | |
| # Create final DataFrame | |
| declustered_df = dedup_df_full.copy() | |
| if indices_to_delete: | |
| declustered_df = declustered_df.drop(index=list(indices_to_delete)) | |
| st.write(f"Final indices kept: {sorted(declustered_df.index.tolist())}") | |
| # Print statistics | |
| st.success(f""" | |
| Processing results: | |
| - Original rows: {len(df_original)} | |
| - After deduplication: {len(dedup_df_full)} | |
| - Multi-item clusters found: {len(result_df[result_df['cluster_size'] > 1]['cluster_id'].unique()) if len(result_df) > 0 else 0} | |
| - Rows removed from clusters: {len(indices_to_delete)} | |
| - Final rows kept: {len(declustered_df)} | |
| """) | |
| # Download buttons for all results | |
| st.subheader("Download Results") | |
| st.markdown(create_download_link(dedup_df_full, "deduplicated_news.xlsx"), unsafe_allow_html=True) | |
| st.markdown(create_download_link(result_df, "clustered_news.xlsx"), unsafe_allow_html=True) | |
| st.markdown(create_download_link(declustered_df, "declustered_news.xlsx"), unsafe_allow_html=True) | |
| # Show clusters info | |
| if len(result_df) > 0: | |
| st.subheader("Largest Clusters") | |
| largest_clusters = result_df[result_df['cluster_size'] > 1].sort_values( | |
| ['cluster_size', 'cluster_id', 'datetime'], | |
| ascending=[False, True, True] | |
| ) | |
| st.dataframe(largest_clusters) | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| import traceback | |
| st.error(traceback.format_exc()) | |
| finally: | |
| progress_bar.empty() | |
| except Exception as e: | |
| st.error(f"Error reading file: {str(e)}") | |
| import traceback | |
| st.error(traceback.format_exc()) | |
| if __name__ == "__main__": | |
| main() |