Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import json | |
| import io | |
| import os | |
| import random | |
| from collections import defaultdict | |
| from sentence_transformers import SentenceTransformer | |
| import hdbscan | |
| from sklearn.metrics import silhouette_score, davies_bouldin_score | |
| import numpy as np | |
| import umap | |
| from sklearn.preprocessing import MinMaxScaler | |
| # 加载模型,放到全局避免重复加载 | |
| model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| def color_for_label(label): | |
| try: | |
| label_int = int(label) | |
| except: | |
| label_int = -1 | |
| if label_int < 0: | |
| return "rgb(150,150,150)" # 噪声点 | |
| random.seed(label_int + 1000) | |
| return f"rgb({random.randint(50,200)}, {random.randint(50,200)}, {random.randint(50,200)})" | |
| def cluster_sentences(sentences): | |
| embeddings = model.encode(sentences) | |
| clusterer = hdbscan.HDBSCAN(min_cluster_size=2, metric='euclidean') | |
| labels = clusterer.fit_predict(embeddings) | |
| valid_idxs = labels != -1 | |
| if np.sum(valid_idxs) > 1: | |
| silhouette = silhouette_score(embeddings[valid_idxs], labels[valid_idxs]) | |
| db = davies_bouldin_score(embeddings[valid_idxs], labels[valid_idxs]) | |
| else: | |
| silhouette, db = -1, -1 | |
| return labels, embeddings, {"silhouette": silhouette, "db": db} | |
| def generate_force_graph(sentences, labels): | |
| nodes = [] | |
| links = [] | |
| label_map = defaultdict(list) | |
| for i, (s, l) in enumerate(zip(sentences, labels)): | |
| color = color_for_label(l) | |
| nodes.append({"name": s, "symbolSize": 10, "category": int(l) if l >=0 else 0, "itemStyle": {"color": color}}) | |
| label_map[l].append(i) | |
| for group in label_map.values(): | |
| max_edges_per_node = 10 | |
| for i in group: | |
| connected = 0 | |
| for j in group: | |
| if i < j: | |
| links.append({"source": sentences[i], "target": sentences[j]}) | |
| connected += 1 | |
| if connected >= max_edges_per_node: | |
| break | |
| return {"type": "force", "nodes": nodes, "links": links} | |
| def generate_bubble_chart(sentences, labels): | |
| counts = defaultdict(int) | |
| for l in labels: | |
| counts[l] += 1 | |
| data = [{"name": f"簇{l}" if l >=0 else "噪声", "value": v, "itemStyle": {"color": color_for_label(l)}} for l, v in counts.items()] | |
| return {"type": "bubble", "series": [{"type": "scatter", "data": data}]} | |
| def generate_umap_plot(embeddings, labels): | |
| reducer = umap.UMAP(n_components=2, random_state=42) | |
| umap_emb = reducer.fit_transform(embeddings) | |
| scaled = MinMaxScaler().fit_transform(umap_emb) | |
| data = [{"x": float(x), "y": float(y), "label": int(l), "itemStyle": {"color": color_for_label(l)}} for (x, y), l in zip(scaled, labels)] | |
| return {"type": "scatter", "series": [{"data": data}]} | |
| def process(text_input, file_obj): | |
| # 先收集所有句子 | |
| sentences = [] | |
| # 读取txt文件内容 | |
| if file_obj is not None: | |
| try: | |
| # file_obj 是 tempfile.NamedTemporaryFile,直接打开它的 file_obj.name | |
| with open(file_obj.name, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| lines = content.strip().splitlines() | |
| sentences.extend([line.strip() for line in lines if line.strip()]) | |
| except Exception as e: | |
| return f"❌ 文件读取失败: {str(e)}", None, None, None, None, None, None | |
| # 处理文本框输入 | |
| if text_input: | |
| lines = text_input.strip().splitlines() | |
| sentences.extend([line.strip() for line in lines if line.strip()]) | |
| # 去重 | |
| sentences = list(dict.fromkeys(sentences)) | |
| if len(sentences) < 2: | |
| return "⚠️ 请输入至少两个有效句子进行聚类", None, None, None, None, None, None | |
| # 聚类 | |
| labels, embeddings, scores = cluster_sentences(sentences) | |
| # 生成数据 | |
| df = pd.DataFrame({"句子": sentences, "簇ID": labels}) | |
| force_json = generate_force_graph(sentences, labels) | |
| bubble_json = generate_bubble_chart(sentences, labels) | |
| umap_json = generate_umap_plot(embeddings, labels) | |
| csv_data = df.to_csv(index=False, encoding="utf-8-sig") | |
| return ( | |
| f"✅ Silhouette: {scores['silhouette']:.4f}, DB: {scores['db']:.4f}", | |
| df, | |
| json.dumps(force_json, ensure_ascii=False, indent=2), | |
| json.dumps(bubble_json, ensure_ascii=False, indent=2), | |
| json.dumps(umap_json, ensure_ascii=False, indent=2), | |
| csv_data | |
| ) | |
| def csv_download(csv_str): | |
| return io.BytesIO(csv_str.encode("utf-8-sig")) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 中文句子语义聚类 Demo") | |
| with gr.Row(): | |
| text_input = gr.Textbox(label="输入多句子(每行一句)", lines=8) | |
| file_input = gr.File(label="上传文本文件 (.txt)", file_types=['.txt']) | |
| btn = gr.Button("开始聚类") | |
| output_score = gr.Textbox(label="聚类指标", interactive=False) | |
| output_table = gr.Dataframe(headers=["句子", "簇ID"], interactive=False) | |
| output_force = gr.JSON(label="力导图数据") | |
| output_bubble = gr.JSON(label="气泡图数据") | |
| output_umap = gr.JSON(label="UMAP二维数据") | |
| output_csv = gr.File(label="导出CSV") | |
| btn.click( | |
| fn=process, | |
| inputs=[text_input, file_input], | |
| outputs=[output_score, output_table, output_force, output_bubble, output_umap, output_csv] | |
| ) | |
| output_csv.download = csv_download | |
| demo.launch() | |