Spaces:
Sleeping
Sleeping
File size: 1,248 Bytes
958d6d4 bfe3791 958d6d4 bfe3791 958d6d4 bfe3791 958d6d4 bfe3791 958d6d4 bfe3791 958d6d4 192668f 958d6d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import pandas as pd
import chromadb
from chromadb.utils import embedding_functions
import gradio as gr
# ------------------ 1. 加载数据 ------------------
df = pd.read_csv("SMSSpamCollection", sep="\t", names=['label','text'])
df = df.head(10)
chunks = df['text'].tolist()
# ------------------ 2. 创建 Chroma collection(ephemeral) ------------------
client = chromadb.Client() # ephemeral client,Space 运行时创建
collection = client.create_collection(
name="sms_demo",
embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
)
collection.add(
documents=chunks,
metadatas=[{"source":"sms"}]*len(chunks),
ids=[str(i) for i in range(len(chunks))]
)
# ------------------ 3. Gradio 搜索界面 ------------------
def semantic_search(query, k=5):
results = collection.query(query_texts=[query], n_results=k)
return "\n\n".join(results['documents'][0])
gr.Interface(
fn=semantic_search,
inputs=[
gr.Textbox(label="Enter query", value="win a free lottery ticket"), # 设置默认查询
gr.Slider(1, 10, value=5, step=1, label="Top K")
],
outputs=gr.Textbox(label="Results")
).launch()
|