dralsarrani commited on
Commit
9888a9a
Β·
verified Β·
1 Parent(s): e1759b3
Files changed (1) hide show
  1. rag_pipeline.py +134 -132
rag_pipeline.py CHANGED
@@ -1,132 +1,134 @@
1
- #pip install datasets sentence-transformers chromadb pandas
2
-
3
-
4
- from datasets import load_dataset
5
- from sentence_transformers import SentenceTransformer
6
- import chromadb
7
- import pandas as pd
8
-
9
- # CONFIG
10
-
11
- HF_DATASET_NAME = "dralsarrani/prompt_safety_with_synthetic_labeled"
12
- EMBEDDING_MODEL = "all-MiniLM-L6-v2" # fast, free, good enough
13
- CHROMA_DIR = "./chroma_db" # local folder, created automatically
14
- COLLECTION_NAME = "safety_prompts"
15
- TOP_K = 5 # how many similar prompts to retrieve
16
-
17
-
18
-
19
- # 1 LOAD DATASET
20
-
21
- def load_safety_dataset():
22
- print("Loading dataset from HuggingFace...")
23
- dataset = load_dataset(HF_DATASET_NAME, cache_dir="./hf_cache")
24
- df = dataset["train"].to_pandas()
25
-
26
- # Normalise column names to lowercase
27
- df.columns = [c.lower().strip() for c in df.columns]
28
-
29
- # Keep only rows with valid prompt + label
30
- df = df.dropna(subset=["text", "label"])
31
- df = df[df["label"].isin(["safe", "unsafe"])]
32
- df = df.reset_index(drop=True)
33
-
34
- print(f" Loaded {len(df)} rows | SAFE: {(df.label==0).sum()} UNSAFE: {(df.label==1).sum()}")
35
- return df
36
-
37
-
38
- # 2 BUILD CHROMA VECTOR STORE
39
-
40
- def build_vector_store(df: pd.DataFrame):
41
- print("Building vector store...")
42
- model = SentenceTransformer(EMBEDDING_MODEL)
43
- client = chromadb.PersistentClient(path=CHROMA_DIR)
44
-
45
- # Check if already built β€” skip if so
46
- try:
47
- collection = client.get_collection(COLLECTION_NAME)
48
- if collection.count() > 0:
49
- print(f" Vector store already exists ({collection.count()} vectors). Skipping rebuild.")
50
- return collection, model
51
- except Exception:
52
- pass
53
-
54
- collection = client.create_collection(COLLECTION_NAME)
55
-
56
- prompts = df["text"].tolist()
57
- labels = df["label"].tolist()
58
- ids = [str(i) for i in range(len(prompts))]
59
-
60
- # Embed in batches of 512 to avoid memory issues on large datasets
61
- batch_size = 512
62
- all_embeddings = []
63
- for i in range(0, len(prompts), batch_size):
64
- batch = prompts[i : i + batch_size]
65
- embeddings = model.encode(batch, show_progress_bar=False).tolist()
66
- all_embeddings.extend(embeddings)
67
- print(f" Embedded {min(i + batch_size, len(prompts))}/{len(prompts)}")
68
-
69
- batch_size_chroma = 5000
70
- for i in range(0, len(ids), batch_size_chroma):
71
- batch_ids = ids[i : i + batch_size_chroma]
72
- batch_embeds = all_embeddings[i : i + batch_size_chroma]
73
- batch_docs = prompts[i : i + batch_size_chroma]
74
- batch_metadatas = [{"label": l} for l in labels[i : i + batch_size_chroma]]
75
- collection.add(
76
- ids=batch_ids,
77
- embeddings=batch_embeds,
78
- documents=batch_docs,
79
- metadatas=batch_metadatas
80
- )
81
-
82
-
83
- print(f" Stored {collection.count()} vectors in Chroma")
84
- return collection, model
85
-
86
- # 3 RETRIEVAL FUNCTION
87
-
88
- def retrieve_similar(query: str, collection, model, top_k: int = TOP_K):
89
- """
90
- Given a new prompt, return the top_k most similar prompts
91
- from the dataset with their labels and similarity scores.
92
- """
93
- query_embedding = model.encode([query]).tolist()
94
-
95
- results = collection.query(
96
- query_embeddings = query_embedding,
97
- n_results = top_k,
98
- include = ["documents", "metadatas", "distances"],
99
- )
100
-
101
- similar = []
102
- for doc, meta, dist in zip(
103
- results["documents"][0],
104
- results["metadatas"][0],
105
- results["distances"][0],
106
- ):
107
- similar.append({
108
- "prompt": doc,
109
- "label": meta["label"],
110
- "similarity": round(1 - dist, 3), # cosine distance β†’ similarity
111
- })
112
-
113
- return similar
114
-
115
-
116
- # 4 LOAD EXISTING STORE (skip rebuild if already done)
117
-
118
- def load_vector_store():
119
- """Load an already-built Chroma store without re-embedding."""
120
- model = SentenceTransformer(EMBEDDING_MODEL)
121
- client = chromadb.PersistentClient(path=CHROMA_DIR)
122
-
123
- try:
124
- collection = client.get_collection(COLLECTION_NAME)
125
- print(f"Loaded existing vector store ({collection.count()} vectors)")
126
- except Exception:
127
- print("No existing vector store found β€” building from scratch...")
128
- df = load_safety_dataset()
129
- collection, model = build_vector_store(df)
130
-
131
- return collection, model
132
-
 
 
 
1
+ from datasets import load_dataset
2
+ from sentence_transformers import SentenceTransformer
3
+ import chromadb
4
+ import pandas as pd
5
+
6
+ # CONFIG
7
+
8
+ HF_DATASET_NAME = "dralsarrani/prompt_safety_with_synthetic_labeled"
9
+ EMBEDDING_MODEL = "all-MiniLM-L6-v2" # fast, free, good enough
10
+ CHROMA_DIR = "./chroma_db" # local folder, created automatically
11
+ COLLECTION_NAME = "safety_prompts"
12
+ TOP_K = 5 # how many similar prompts to retrieve
13
+
14
+
15
+
16
+ # 1 LOAD DATASET
17
+
18
+ def load_safety_dataset():
19
+ print("Loading dataset from HuggingFace...")
20
+ dataset = load_dataset(HF_DATASET_NAME, cache_dir="./hf_cache")
21
+ df = dataset["train"].to_pandas()
22
+
23
+ # Normalise column names to lowercase
24
+ df.columns = [c.lower().strip() for c in df.columns]
25
+
26
+ # Keep only rows with valid prompt + label
27
+ df = df.dropna(subset=["text", "label"])
28
+ df = df[df["label"].isin(["safe", "unsafe"])]
29
+ df = df.reset_index(drop=True)
30
+
31
+ # cap at 50K, balanced between safe/unsafe
32
+
33
+ df = df.groupby("label", group_keys=False).apply(
34
+ lambda x: x.sample(min(len(x), 25_000), random_state=42)
35
+ ).reset_index(drop=True)
36
+
37
+ print(f" Loaded {len(df)} rows | SAFE: {(df.label==0).sum()} UNSAFE: {(df.label==1).sum()}")
38
+ return df
39
+
40
+
41
+ # 2 BUILD CHROMA VECTOR STORE
42
+
43
+ def build_vector_store(df: pd.DataFrame):
44
+ print("Building vector store...")
45
+ model = SentenceTransformer(EMBEDDING_MODEL)
46
+ client = chromadb.PersistentClient(path=CHROMA_DIR)
47
+
48
+ # Check if already built β€” skip if so
49
+ try:
50
+ collection = client.get_collection(COLLECTION_NAME)
51
+ if collection.count() > 0:
52
+ print(f" Vector store already exists ({collection.count()} vectors). Skipping rebuild.")
53
+ return collection, model
54
+ except Exception:
55
+ pass
56
+
57
+ collection = client.create_collection(COLLECTION_NAME)
58
+
59
+ prompts = df["text"].tolist()
60
+ labels = df["label"].tolist()
61
+ ids = [str(i) for i in range(len(prompts))]
62
+
63
+ # Embed in batches of 512 to avoid memory issues on large datasets
64
+ batch_size = 512
65
+ all_embeddings = []
66
+ for i in range(0, len(prompts), batch_size):
67
+ batch = prompts[i : i + batch_size]
68
+ embeddings = model.encode(batch, show_progress_bar=False).tolist()
69
+ all_embeddings.extend(embeddings)
70
+ print(f" Embedded {min(i + batch_size, len(prompts))}/{len(prompts)}")
71
+
72
+ batch_size_chroma = 5000
73
+ for i in range(0, len(ids), batch_size_chroma):
74
+ batch_ids = ids[i : i + batch_size_chroma]
75
+ batch_embeds = all_embeddings[i : i + batch_size_chroma]
76
+ batch_docs = prompts[i : i + batch_size_chroma]
77
+ batch_metadatas = [{"label": l} for l in labels[i : i + batch_size_chroma]]
78
+ collection.add(
79
+ ids=batch_ids,
80
+ embeddings=batch_embeds,
81
+ documents=batch_docs,
82
+ metadatas=batch_metadatas
83
+ )
84
+
85
+
86
+ print(f" Stored {collection.count()} vectors in Chroma")
87
+ return collection, model
88
+
89
+ # 3 RETRIEVAL FUNCTION
90
+
91
+ def retrieve_similar(query: str, collection, model, top_k: int = TOP_K):
92
+ """
93
+ Given a new prompt, return the top_k most similar prompts
94
+ from the dataset with their labels and similarity scores.
95
+ """
96
+ query_embedding = model.encode([query]).tolist()
97
+
98
+ results = collection.query(
99
+ query_embeddings = query_embedding,
100
+ n_results = top_k,
101
+ include = ["documents", "metadatas", "distances"],
102
+ )
103
+
104
+ similar = []
105
+ for doc, meta, dist in zip(
106
+ results["documents"][0],
107
+ results["metadatas"][0],
108
+ results["distances"][0],
109
+ ):
110
+ similar.append({
111
+ "prompt": doc,
112
+ "label": meta["label"],
113
+ "similarity": round(1 - dist, 3), # cosine distance β†’ similarity
114
+ })
115
+
116
+ return similar
117
+
118
+
119
+ # 4 LOAD EXISTING STORE (skip rebuild if already done)
120
+
121
+ def load_vector_store():
122
+ """Load an already-built Chroma store without re-embedding."""
123
+ model = SentenceTransformer(EMBEDDING_MODEL)
124
+ client = chromadb.PersistentClient(path=CHROMA_DIR)
125
+
126
+ try:
127
+ collection = client.get_collection(COLLECTION_NAME)
128
+ print(f"Loaded existing vector store ({collection.count()} vectors)")
129
+ except Exception:
130
+ print("No existing vector store found β€” building from scratch...")
131
+ df = load_safety_dataset()
132
+ collection, model = build_vector_store(df)
133
+
134
+ return collection, model