王昱
Enhance cluster_analysis.py and embeddings.py: Add new RSS sources, improve logging for processing times, and implement BGE-M3 API for embeddings. Update .DS_Store binary file.
687a4ef
| import json | |
| import os | |
| from storage.azure_table import AzureTableStorage | |
| from config import DOCS_INDEX_NAME | |
| import logging | |
| from dotenv import load_dotenv | |
| import xmltodict | |
| import requests | |
| from utils.text_cleaner import strip_html_tags | |
| from embeddings import get_embeddings_model | |
| from datetime import datetime, timezone, timedelta | |
| from email.utils import parsedate_to_datetime | |
| import numpy as np | |
| from sklearn.cluster import DBSCAN | |
| from collections import defaultdict | |
| from pymilvus import connections, Collection | |
| import time | |
| load_dotenv() | |
| logger = logging.getLogger("backend") | |
| def parse_rss_items(rss_dict: dict, source_name: str) -> list: | |
| """解析不同格式的RSS源""" | |
| logger.info(f"解析RSS源: {source_name}") | |
| if "rdf:RDF" in rss_dict: | |
| items = rss_dict["rdf:RDF"]["item"] | |
| elif "rss" in rss_dict: | |
| items = rss_dict["rss"]["channel"]["item"] | |
| else: | |
| logger.error(f"无法识别的RSS格式: {source_name}") | |
| return [] | |
| # 确保items是列表 | |
| return items if isinstance(items, list) else [items] | |
| def parse_news_date(item: dict) -> datetime: | |
| """解析新闻日期,支持多种格式""" | |
| logger.debug(f"解析新闻日期: {item}") | |
| date_fields = ["pubDate", "published", "dc:date"] | |
| for field in date_fields: | |
| if field in item: | |
| date_str = item[field] | |
| try: | |
| return parsedate_to_datetime(date_str) | |
| except: | |
| try: | |
| return datetime.fromisoformat(date_str.replace("Z", "+00:00")) | |
| except: | |
| try: | |
| return datetime.strptime(date_str, "%A %b %d %Y %H:%M:%S") | |
| except: | |
| continue | |
| return None | |
| def get_news_content(item: dict) -> str: | |
| if "title" in item and item["title"]: | |
| return item["title"] | |
| return None | |
| def get_recent_news_clusters(): | |
| # RSS源列表 | |
| rss_sources = { | |
| "NYTimes": "https://rss.nytimes.com/services/xml/rss/nyt/World.xml", | |
| "CNN": "http://rss.cnn.com/rss/edition_world.rss", | |
| "FoxNews": "https://abcnews.go.com/abcnews/internationalheadlines", | |
| # europe | |
| "BBC": "https://feeds.bbci.co.uk/news/world/rss.xml", | |
| "Reuters": "https://rsshub.app/reuters/world", | |
| "DW": "http://rss.dw-world.de/rdf/rss-en-top", | |
| "Guardian": "https://www.theguardian.com/world/rss", | |
| "SkyNews": "https://feeds.skynews.com/feeds/rss/world.xml", | |
| "TheSun": "https://thesun.my/rss/world", | |
| # asia | |
| "Aljazeera": "http://www.aljazeera.com/xml/rss/all.xml", | |
| "TimesOfIndia": "https://timesofindia.indiatimes.com/rssfeeds/296589292.cms", | |
| "ChannelNewsAsia": "https://www.channelnewsasia.com/api/v1/rss-outbound-feed?_format=xml&category=6311", | |
| # other | |
| "GlobalNews": "https://globalnews.ca/world/feed/", | |
| "SMH": "https://www.smh.com.au/rss/world.xml", | |
| "Capi24": "https://feeds.capi24.com/v1/Search/articles/news24/World/rss", | |
| "IFPNews": "https://ifpnews.com/feed/", | |
| # Russia | |
| "MoscowTimes": "https://themoscowtimes.com/feeds/main.xml", | |
| # China | |
| "SCMP": "https://www.scmp.com/rss/91/feed", | |
| "ChinaNews": "https://www.chinanews.com.cn/rss/world.xml", | |
| "People": "http://www.people.com.cn/rss/world.xml", | |
| # source | |
| "AP News": "https://march42-rsshub.hf.space/apnews/api/apf-topnews", | |
| "CNBC": "https://march42-rsshub.hf.space/cnbc/rss", | |
| "Tass": "https://march42-rsshub.hf.space/tass/world", | |
| "Sputnik News": "https://march42-rsshub.hf.space/sputniknews/world", | |
| "Economist": "https://march42-rsshub.hf.space/economist/latest", | |
| "Straits Times": "https://march42-rsshub.hf.space/straitstimes/world", | |
| "Huanqiu": "https://march42-rsshub.hf.space/huanqiu/news/world", | |
| "Zaobao": "https://march42-rsshub.hf.space/zaobao/znews/world", | |
| } | |
| current_time = datetime.now(timezone.utc) | |
| embeddings_model = get_embeddings_model() | |
| recent_news = [] # 存储24小时内的新闻 | |
| for source_name, url in rss_sources.items(): | |
| try: | |
| start_time = time.time() | |
| logger.info(f"开始处理源: {source_name}") | |
| # 记录请求时间 | |
| request_start = time.time() | |
| response = requests.get(url, timeout=10) | |
| request_time = time.time() - request_start | |
| logger.info(f"{source_name} 请求耗时: {request_time:.2f}秒") | |
| # 记录XML解析时间 | |
| parse_start = time.time() | |
| rss_dict = xmltodict.parse(response.content) | |
| items = parse_rss_items(rss_dict, source_name) | |
| parse_time = time.time() - parse_start | |
| logger.info(f"{source_name} XML解析耗时: {parse_time:.2f}秒") | |
| # 记录embedding处理时间 | |
| embed_start = time.time() | |
| processed_items = 0 | |
| for item in items: | |
| content = get_news_content(item) | |
| if not content: | |
| logger.debug(f"跳过:描述和标题都为空: {source_name}") | |
| continue | |
| pub_date = parse_news_date(item) | |
| if not pub_date: | |
| logger.debug(f"跳过:无法解析日期: {source_name}") | |
| continue | |
| pub_date = pub_date.astimezone(timezone.utc) | |
| if (current_time - pub_date).total_seconds() <= 48 * 3600: | |
| recent_news.append( | |
| { | |
| "source": source_name, | |
| "content": content, | |
| "pub_date": pub_date, | |
| "embedding": embeddings_model.embed_query(content), | |
| } | |
| ) | |
| processed_items += 1 | |
| embed_time = time.time() - embed_start | |
| logger.info( | |
| f"{source_name} Embedding处理耗时: {embed_time:.2f}秒 (处理{processed_items}条新闻)" | |
| ) | |
| logger.info( | |
| f"{source_name} 总处理耗时: {time.time() - start_time:.2f}秒 " | |
| f"[请求: {request_time:.2f}s, 解析: {parse_time:.2f}s, Embedding: {embed_time:.2f}s]" | |
| ) | |
| except Exception as e: | |
| logger.error(f"处理新闻源时出错 {source_name}: {str(e)}") | |
| continue | |
| # 处理embeddings聚类 | |
| if recent_news: | |
| embeddings = np.array([news["embedding"] for news in recent_news]) | |
| clustering = DBSCAN(eps=0.4, min_samples=2, metric="cosine").fit(embeddings) | |
| # 整理聚类结果 | |
| clusters = defaultdict(list) | |
| for idx, label in enumerate(clustering.labels_): | |
| if label != -1: # 排除噪声点 | |
| clusters[label].append(recent_news[idx]) | |
| # 准备JSON输出数据,按聚类大小降序排序 | |
| clusters_output = [] | |
| # 按聚类大小排序 | |
| sorted_clusters = sorted( | |
| clusters.items(), key=lambda x: len(x[1]), reverse=True | |
| ) | |
| for i, (cluster_id, news_list) in enumerate(sorted_clusters): | |
| # 计算聚类中心 | |
| cluster_embeddings = np.array([news["embedding"] for news in news_list]) | |
| cluster_center = np.mean(cluster_embeddings, axis=0) | |
| clusters_output.append( | |
| { | |
| "id": i, # 使用新的顺序索引作为ID | |
| "size": len(news_list), | |
| "center": cluster_center.tolist(), | |
| } | |
| ) | |
| output_data = { | |
| "clusters": clusters_output, | |
| "timestamp": current_time.strftime("%Y-%m-%d %H:%M:%S UTC"), | |
| } | |
| return output_data | |
| async def analyze_news_clusters(): | |
| try: | |
| zilliz_uri = os.getenv("ZILLIZ_CLOUD_URI") | |
| zilliz_token = os.getenv("ZILLIZ_CLOUD_TOKEN") | |
| if not zilliz_uri or not zilliz_token: | |
| logger.error("缺少 Zilliz Cloud 配置信息") | |
| return {"status": "error", "message": "缺少必要的环境变量配置"} | |
| connections.connect( | |
| alias="default", | |
| uri=zilliz_uri, | |
| token=zilliz_token, | |
| ) | |
| logger.info("连接到Zilliz Cloud成功") | |
| collection = Collection(DOCS_INDEX_NAME) | |
| collection.load() | |
| # 获取聚类结果 | |
| clusters_data = get_recent_news_clusters() | |
| if not clusters_data: | |
| logger.warning("没有找到最近24小时内的新闻聚类") | |
| return { | |
| "status": "no_clusters", | |
| "message": "没有找到最近24小时内的新闻聚类", | |
| } | |
| azure_storage = AzureTableStorage() | |
| # 准备所有聚类中心的向量 | |
| all_centers = [] | |
| for cluster in clusters_data["clusters"]: | |
| all_centers.append( | |
| { | |
| "id": cluster["id"], | |
| "size": cluster["size"], | |
| "vector": cluster["center"], | |
| } | |
| ) | |
| # 存储所有聚类结果 | |
| all_cluster_results = [] | |
| timestamp = clusters_data["timestamp"] | |
| logger.info(f"开始处理 {len(all_centers)} 个聚类") | |
| # 为每个聚类中心分别查询 | |
| for cluster in all_centers: | |
| logger.info(f"处理聚类 #{cluster['id']}, 大小: {cluster['size']}") | |
| search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} | |
| current_time = datetime.now(timezone.utc) | |
| two_days_ago = current_time - timedelta(days=2) | |
| time_filter = ( | |
| f'publish_time >= "{two_days_ago.strftime("%Y-%m-%dT%H:%M:%S.000Z")}"' | |
| ) | |
| results = collection.search( | |
| data=[cluster["vector"]], | |
| anns_field="embedding", | |
| param=search_params, | |
| limit=3, | |
| expr=time_filter, # 添加时间过滤 | |
| output_fields=["source"], | |
| ) | |
| # 处理搜索结果 | |
| if results: | |
| cluster_result = {"size": cluster["size"], "articles": []} | |
| for hits in results: | |
| for hit in hits: | |
| similarity = hit.score | |
| if similarity > 0.75: | |
| cluster_result["articles"].append( | |
| { | |
| "url": hit.entity.get("source"), | |
| "similarity": round(similarity, 3), | |
| } | |
| ) | |
| if cluster_result["articles"]: | |
| all_cluster_results.append(cluster_result) | |
| # 如果有有效的聚类结果,则保存到Azure Table | |
| if all_cluster_results: | |
| entity = { | |
| "PartitionKey": timestamp.split("T")[0], | |
| "RowKey": timestamp, | |
| "timestamp": timestamp, | |
| "clusters": json.dumps(all_cluster_results, ensure_ascii=False), | |
| } | |
| azure_storage.store_clusters(entity) | |
| logger.info(f"成功处理 {len(all_cluster_results)} 个聚类") | |
| return { | |
| "status": "success", | |
| "message": f"成功处理 {len(all_cluster_results)} 个聚类", | |
| "data": {"timestamp": timestamp, "clusters": all_cluster_results}, | |
| } | |
| return {"status": "no_results", "message": "没有找到足够相似的新闻聚类"} | |
| except Exception as e: | |
| logger.error(f"分析聚类时出错: {str(e)}") | |
| import traceback | |
| logger.error(f"详细错误信息:\n{traceback.format_exc()}") | |
| return {"status": "error", "message": str(e)} | |
| finally: | |
| connections.disconnect("default") | |
| if __name__ == "__main__": | |
| try: | |
| analyze_news_clusters() | |
| except Exception as e: | |
| logger.error(f"执行失败: {str(e)}") | |