Spaces:
Running
Running
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| import chromadb | |
| import pandas as pd | |
| # CONFIG | |
| HF_DATASET_NAME = "dralsarrani/prompt_safety_with_synthetic_labeled" | |
| EMBEDDING_MODEL = "all-MiniLM-L6-v2" # fast, free, good enough | |
| CHROMA_DIR = "./chroma_db" # local folder, created automatically | |
| COLLECTION_NAME = "safety_prompts" | |
| TOP_K = 5 # how many similar prompts to retrieve | |
| # 1 LOAD DATASET | |
| def load_safety_dataset(): | |
| print("Loading dataset from HuggingFace...") | |
| dataset = load_dataset(HF_DATASET_NAME, cache_dir="./hf_cache") | |
| df = dataset["train"].to_pandas() | |
| # Normalise column names to lowercase | |
| df.columns = [c.lower().strip() for c in df.columns] | |
| # Keep only rows with valid prompt + label | |
| df = df.dropna(subset=["text", "label"]) | |
| df = df[df["label"].isin(["safe", "unsafe"])] | |
| df = df.reset_index(drop=True) | |
| # cap at 50K, balanced between safe/unsafe | |
| df = df.groupby("label", group_keys=False).apply( | |
| lambda x: x.sample(min(len(x), 25_000), random_state=42) | |
| ).reset_index(drop=True) | |
| print(f" Loaded {len(df)} rows | SAFE: {(df.label==0).sum()} UNSAFE: {(df.label==1).sum()}") | |
| return df | |
| # 2 BUILD CHROMA VECTOR STORE | |
| def build_vector_store(df: pd.DataFrame): | |
| print("Building vector store...") | |
| model = SentenceTransformer(EMBEDDING_MODEL) | |
| client = chromadb.PersistentClient(path=CHROMA_DIR) | |
| # Check if already built β skip if so | |
| try: | |
| collection = client.get_collection(COLLECTION_NAME) | |
| if collection.count() > 0: | |
| print(f" Vector store already exists ({collection.count()} vectors). Skipping rebuild.") | |
| return collection, model | |
| except Exception: | |
| pass | |
| collection = client.create_collection(COLLECTION_NAME) | |
| prompts = df["text"].tolist() | |
| labels = df["label"].tolist() | |
| ids = [str(i) for i in range(len(prompts))] | |
| # Embed in batches of 512 to avoid memory issues on large datasets | |
| batch_size = 512 | |
| all_embeddings = [] | |
| for i in range(0, len(prompts), batch_size): | |
| batch = prompts[i : i + batch_size] | |
| embeddings = model.encode(batch, show_progress_bar=False).tolist() | |
| all_embeddings.extend(embeddings) | |
| print(f" Embedded {min(i + batch_size, len(prompts))}/{len(prompts)}") | |
| batch_size_chroma = 5000 | |
| for i in range(0, len(ids), batch_size_chroma): | |
| batch_ids = ids[i : i + batch_size_chroma] | |
| batch_embeds = all_embeddings[i : i + batch_size_chroma] | |
| batch_docs = prompts[i : i + batch_size_chroma] | |
| batch_metadatas = [{"label": l} for l in labels[i : i + batch_size_chroma]] | |
| collection.add( | |
| ids=batch_ids, | |
| embeddings=batch_embeds, | |
| documents=batch_docs, | |
| metadatas=batch_metadatas | |
| ) | |
| print(f" Stored {collection.count()} vectors in Chroma") | |
| return collection, model | |
| # 3 RETRIEVAL FUNCTION | |
| def retrieve_similar(query: str, collection, model, top_k: int = TOP_K): | |
| """ | |
| Given a new prompt, return the top_k most similar prompts | |
| from the dataset with their labels and similarity scores. | |
| """ | |
| query_embedding = model.encode([query]).tolist() | |
| results = collection.query( | |
| query_embeddings = query_embedding, | |
| n_results = top_k, | |
| include = ["documents", "metadatas", "distances"], | |
| ) | |
| similar = [] | |
| for doc, meta, dist in zip( | |
| results["documents"][0], | |
| results["metadatas"][0], | |
| results["distances"][0], | |
| ): | |
| similar.append({ | |
| "prompt": doc, | |
| "label": meta["label"], | |
| "similarity": round(1 - dist, 3), # cosine distance β similarity | |
| }) | |
| return similar | |
| # 4 LOAD EXISTING STORE (skip rebuild if already done) | |
| def load_vector_store(): | |
| """Load an already-built Chroma store without re-embedding.""" | |
| model = SentenceTransformer(EMBEDDING_MODEL) | |
| client = chromadb.PersistentClient(path=CHROMA_DIR) | |
| try: | |
| collection = client.get_collection(COLLECTION_NAME) | |
| print(f"Loaded existing vector store ({collection.count()} vectors)") | |
| except Exception: | |
| print("No existing vector store found β building from scratch...") | |
| df = load_safety_dataset() | |
| collection, model = build_vector_store(df) | |
| return collection, model |