Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| import gradio as gr | |
| # ------------------ 1. 加载数据 ------------------ | |
| df = pd.read_csv("SMSSpamCollection", sep="\t", names=['label','text']) | |
| df = df.head(10) | |
| chunks = df['text'].tolist() | |
| # ------------------ 2. 创建 Chroma collection(ephemeral) ------------------ | |
| client = chromadb.Client() # ephemeral client,Space 运行时创建 | |
| collection = client.create_collection( | |
| name="sms_demo", | |
| embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| ) | |
| collection.add( | |
| documents=chunks, | |
| metadatas=[{"source":"sms"}]*len(chunks), | |
| ids=[str(i) for i in range(len(chunks))] | |
| ) | |
| # ------------------ 3. Gradio 搜索界面 ------------------ | |
| def semantic_search(query, k=5): | |
| results = collection.query(query_texts=[query], n_results=k) | |
| return "\n\n".join(results['documents'][0]) | |
| gr.Interface( | |
| fn=semantic_search, | |
| inputs=[ | |
| gr.Textbox(label="Enter query", value="win a free lottery ticket"), # 设置默认查询 | |
| gr.Slider(1, 10, value=5, step=1, label="Top K") | |
| ], | |
| outputs=gr.Textbox(label="Results") | |
| ).launch() | |