zkmine's picture
Update app.py
192668f verified
import pandas as pd
import chromadb
from chromadb.utils import embedding_functions
import gradio as gr
# ------------------ 1. 加载数据 ------------------
df = pd.read_csv("SMSSpamCollection", sep="\t", names=['label','text'])
df = df.head(10)
chunks = df['text'].tolist()
# ------------------ 2. 创建 Chroma collection(ephemeral) ------------------
client = chromadb.Client() # ephemeral client,Space 运行时创建
collection = client.create_collection(
name="sms_demo",
embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
)
collection.add(
documents=chunks,
metadatas=[{"source":"sms"}]*len(chunks),
ids=[str(i) for i in range(len(chunks))]
)
# ------------------ 3. Gradio 搜索界面 ------------------
def semantic_search(query, k=5):
results = collection.query(query_texts=[query], n_results=k)
return "\n\n".join(results['documents'][0])
gr.Interface(
fn=semantic_search,
inputs=[
gr.Textbox(label="Enter query", value="win a free lottery ticket"), # 设置默认查询
gr.Slider(1, 10, value=5, step=1, label="Top K")
],
outputs=gr.Textbox(label="Results")
).launch()