File size: 1,248 Bytes
958d6d4
bfe3791
958d6d4
bfe3791
 
958d6d4
 
 
 
bfe3791
958d6d4
 
 
 
 
 
bfe3791
 
958d6d4
 
 
 
 
 
 
 
 
 
bfe3791
958d6d4
 
192668f
 
 
 
958d6d4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import pandas as pd
import chromadb
from chromadb.utils import embedding_functions
import gradio as gr

# ------------------ 1. 加载数据 ------------------
df = pd.read_csv("SMSSpamCollection", sep="\t", names=['label','text'])
df = df.head(10)
chunks = df['text'].tolist()

# ------------------ 2. 创建 Chroma collection(ephemeral) ------------------
client = chromadb.Client()  # ephemeral client,Space 运行时创建
collection = client.create_collection(
    name="sms_demo",
    embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(
        model_name="sentence-transformers/all-MiniLM-L6-v2"
    )
)
collection.add(
    documents=chunks,
    metadatas=[{"source":"sms"}]*len(chunks),
    ids=[str(i) for i in range(len(chunks))]
)

# ------------------ 3. Gradio 搜索界面 ------------------
def semantic_search(query, k=5):
    results = collection.query(query_texts=[query], n_results=k)
    return "\n\n".join(results['documents'][0])

gr.Interface(
    fn=semantic_search,
    inputs=[
        gr.Textbox(label="Enter query", value="win a free lottery ticket"),  # 设置默认查询
        gr.Slider(1, 10, value=5, step=1, label="Top K")
    ],
    outputs=gr.Textbox(label="Results")
).launch()