AlephBeth-AI commited on
Commit
79046c2
·
verified ·
1 Parent(s): 1c93979

Upload precompute.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. precompute.py +138 -153
precompute.py CHANGED
@@ -1,18 +1,14 @@
1
- """
2
- GuardLLM - Precompute Embeddings & t-SNE
3
- Downloads the neuralchemy/Prompt-injection-dataset (core config),
4
- extracts CLS embeddings from Llama Prompt Guard 2 (86M),
5
- computes t-SNE 2D projection, and saves everything to a cache file.
6
 
7
- Run this script ONCE before launching the app (or let the app run it on first start).
 
 
 
8
  """
9
-
10
- import os
11
- import json
12
- import logging
13
  import numpy as np
14
  import torch
15
- from pathlib import Path
16
 
17
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
18
  logger = logging.getLogger("precompute")
@@ -20,47 +16,33 @@ logger = logging.getLogger("precompute")
20
  CACHE_DIR = Path(__file__).parent / "cache"
21
  CACHE_FILE = CACHE_DIR / "embeddings_tsne.npz"
22
  META_FILE = CACHE_DIR / "metadata.json"
23
- MODEL_ID = "meta-llama/Llama-Prompt-Guard-2-86M"
 
 
24
  DATASET_ID = "neuralchemy/Prompt-injection-dataset"
25
  DATASET_CONFIG = "core"
26
- BATCH_SIZE = 32
27
- MAX_LENGTH = 512
28
  TSNE_PERPLEXITY = 30
29
  TSNE_SEED = 42
 
 
 
30
 
31
 
32
- def is_cached() -> bool:
33
- """Check if precomputed data exists."""
34
- return CACHE_FILE.exists() and META_FILE.exists()
35
-
36
-
37
- def load_cached():
38
- """Load precomputed embeddings, t-SNE coords, and metadata."""
39
- logger.info("Loading cached data from %s", CACHE_DIR)
40
- data = np.load(CACHE_FILE)
41
- with open(META_FILE, "r", encoding="utf-8") as f:
42
- metadata = json.load(f)
43
- return {
44
- "embeddings": data["embeddings"],
45
- "tsne_2d": data["tsne_2d"],
46
- "metadata": metadata,
47
- }
48
-
49
-
50
- def download_dataset():
51
- """Download the neuralchemy dataset (core config)."""
52
  from datasets import load_dataset
53
-
54
- logger.info("Downloading dataset %s (config=%s)...", DATASET_ID, DATASET_CONFIG)
55
  ds = load_dataset(DATASET_ID, DATASET_CONFIG)
56
-
57
- # Combine all splits for the visualization
58
  all_samples = []
59
  for split_name in ["train", "validation", "test"]:
60
  if split_name in ds:
61
- split = ds[split_name]
62
- logger.info(" Split '%s': %d samples", split_name, len(split))
63
- for row in split:
64
  all_samples.append({
65
  "text": row["text"],
66
  "label": int(row["label"]),
@@ -69,128 +51,131 @@ def download_dataset():
69
  "source": row.get("source", ""),
70
  "split": split_name,
71
  })
72
-
73
- logger.info("Total samples: %d", len(all_samples))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  return all_samples
75
 
76
 
77
- def compute_embeddings(samples: list) -> np.ndarray:
78
- """Extract CLS token embeddings from Llama Prompt Guard 2."""
79
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
80
-
81
- logger.info("Loading model %s...", MODEL_ID)
82
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
83
- model = AutoModelForSequenceClassification.from_pretrained(
84
- MODEL_ID, output_hidden_states=True
85
- )
86
- model.eval()
87
 
88
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
- model.to(device)
90
- logger.info("Using device: %s", device)
91
 
92
- texts = [s["text"] for s in samples]
93
- all_embeddings = []
94
-
95
- num_batches = (len(texts) + BATCH_SIZE - 1) // BATCH_SIZE
96
- for i in range(0, len(texts), BATCH_SIZE):
97
- batch_idx = i // BATCH_SIZE + 1
98
- batch_texts = texts[i : i + BATCH_SIZE]
 
99
 
100
- if batch_idx % 10 == 1 or batch_idx == num_batches:
101
- logger.info(" Batch %d/%d (%d samples)...", batch_idx, num_batches, len(batch_texts))
102
-
103
- inputs = tokenizer(
104
- batch_texts,
105
- return_tensors="pt",
106
- truncation=True,
107
- max_length=MAX_LENGTH,
108
- padding=True,
109
- ).to(device)
110
 
 
 
 
 
 
 
 
 
 
 
111
  with torch.no_grad():
112
- outputs = model(**inputs)
113
- # CLS token embedding from last hidden layer
114
- hidden_states = outputs.hidden_states[-1] # [batch, seq_len, 768]
115
- cls_embeddings = hidden_states[:, 0, :].cpu().numpy() # [batch, 768]
116
- all_embeddings.append(cls_embeddings)
117
-
118
- embeddings = np.concatenate(all_embeddings, axis=0)
119
- logger.info("Embeddings shape: %s", embeddings.shape)
120
- return embeddings
121
-
122
-
123
- def compute_tsne(embeddings: np.ndarray) -> np.ndarray:
124
- """Run t-SNE dimensionality reduction to 2D."""
 
125
  from sklearn.manifold import TSNE
126
-
127
- n_samples = embeddings.shape[0]
128
- perplexity = min(TSNE_PERPLEXITY, n_samples - 1)
129
-
130
- logger.info(
131
- "Running t-SNE (n=%d, perplexity=%d, random_state=%d)...",
132
- n_samples, perplexity, TSNE_SEED,
133
- )
134
- tsne = TSNE(
135
- n_components=2,
136
- perplexity=perplexity,
137
- random_state=TSNE_SEED,
138
- n_iter=1000,
139
- learning_rate="auto",
140
- init="pca",
141
- )
142
- coords_2d = tsne.fit_transform(embeddings)
143
- logger.info("t-SNE done. Output shape: %s", coords_2d.shape)
144
- return coords_2d
145
-
146
-
147
- def precompute_all():
148
- """Full pipeline: download → embed → t-SNE → save."""
149
- if is_cached():
150
- logger.info("Cache already exists. Loading...")
151
- return load_cached()
152
-
153
- CACHE_DIR.mkdir(parents=True, exist_ok=True)
154
-
155
- # Step 1: Download dataset
156
- samples = download_dataset()
157
-
158
- # Step 2: Compute embeddings
159
- embeddings = compute_embeddings(samples)
160
-
161
- # Step 3: Compute t-SNE
162
- tsne_2d = compute_tsne(embeddings)
163
-
164
- # Step 4: Save
165
- logger.info("Saving to cache...")
166
- np.savez_compressed(
167
- CACHE_FILE,
168
- embeddings=embeddings,
169
- tsne_2d=tsne_2d,
170
- )
171
-
172
- # Save metadata (text, labels, categories) as JSON
173
- metadata = []
174
- for s in samples:
175
- metadata.append({
176
- "text": s["text"],
177
- "label": s["label"],
178
- "category": s["category"],
179
- "severity": s["severity"],
180
- "source": s["source"],
181
- "split": s["split"],
182
- })
183
-
184
  with open(META_FILE, "w", encoding="utf-8") as f:
185
- json.dump(metadata, f, ensure_ascii=False)
186
-
187
- logger.info("Cache saved to %s", CACHE_DIR)
188
- return {
189
- "embeddings": embeddings,
190
- "tsne_2d": tsne_2d,
191
- "metadata": metadata,
192
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
 
195
  if __name__ == "__main__":
196
- precompute_all()
 
1
+ """GuardLLM - Precompute Embeddings & t-SNE (resumable, MiniLM-based).
 
 
 
 
2
 
3
+ Uses sentence-transformers/all-MiniLM-L6-v2 (22M params) to compute
4
+ embeddings for t-SNE visualization. The downstream risk classifier
5
+ (Llama Prompt Guard 2) is *not* loaded here - it is loaded by the
6
+ Gradio app on-demand when a user clicks a point.
7
  """
8
+ import sys, os, json, logging, time
9
+ from pathlib import Path
 
 
10
  import numpy as np
11
  import torch
 
12
 
13
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
14
  logger = logging.getLogger("precompute")
 
16
  CACHE_DIR = Path(__file__).parent / "cache"
17
  CACHE_FILE = CACHE_DIR / "embeddings_tsne.npz"
18
  META_FILE = CACHE_DIR / "metadata.json"
19
+ SAMPLES_FILE = CACHE_DIR / "samples.json"
20
+ EMB_CHUNKS_DIR = CACHE_DIR / "emb_chunks_mini" # NEW folder so old chunks don't collide
21
+ EMB_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
22
  DATASET_ID = "neuralchemy/Prompt-injection-dataset"
23
  DATASET_CONFIG = "core"
24
+ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "32"))
25
+ MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "256"))
26
  TSNE_PERPLEXITY = 30
27
  TSNE_SEED = 42
28
+ SAMPLE_SIZE = int(os.environ.get("SAMPLE_SIZE", "0")) or None
29
+ TIME_BUDGET = int(os.environ.get("TIME_BUDGET", "35"))
30
+ STAGE = os.environ.get("STAGE", "auto")
31
 
32
 
33
+ def prepare_samples():
34
+ if SAMPLES_FILE.exists():
35
+ with open(SAMPLES_FILE, "r", encoding="utf-8") as f:
36
+ s = json.load(f)
37
+ logger.info("Loaded existing samples.json (%d samples)", len(s))
38
+ return s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  from datasets import load_dataset
40
+ logger.info("Downloading %s/%s", DATASET_ID, DATASET_CONFIG)
 
41
  ds = load_dataset(DATASET_ID, DATASET_CONFIG)
 
 
42
  all_samples = []
43
  for split_name in ["train", "validation", "test"]:
44
  if split_name in ds:
45
+ for row in ds[split_name]:
 
 
46
  all_samples.append({
47
  "text": row["text"],
48
  "label": int(row["label"]),
 
51
  "source": row.get("source", ""),
52
  "split": split_name,
53
  })
54
+ logger.info("Total %d", len(all_samples))
55
+ if SAMPLE_SIZE and SAMPLE_SIZE < len(all_samples):
56
+ import random
57
+ random.seed(42)
58
+ by_cat = {}
59
+ for s in all_samples:
60
+ by_cat.setdefault(s["category"], []).append(s)
61
+ total = len(all_samples)
62
+ sampled = []
63
+ for cat, items in by_cat.items():
64
+ n = max(1, round(len(items) / total * SAMPLE_SIZE))
65
+ sampled.extend(random.sample(items, min(n, len(items))))
66
+ random.shuffle(sampled)
67
+ all_samples = sampled
68
+ logger.info("Subsampled to %d", len(all_samples))
69
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
70
+ with open(SAMPLES_FILE, "w", encoding="utf-8") as f:
71
+ json.dump(all_samples, f, ensure_ascii=False)
72
  return all_samples
73
 
74
 
75
+ def mean_pool(last_hidden, attention_mask):
76
+ mask = attention_mask.unsqueeze(-1).float()
77
+ s = (last_hidden * mask).sum(dim=1)
78
+ d = mask.sum(dim=1).clamp(min=1e-9)
79
+ return s / d
 
 
 
 
 
80
 
 
 
 
81
 
82
+ def embed_chunked(samples):
83
+ EMB_CHUNKS_DIR.mkdir(parents=True, exist_ok=True)
84
+ num_batches = (len(samples) + BATCH_SIZE - 1) // BATCH_SIZE
85
+ done = {int(p.stem) for p in EMB_CHUNKS_DIR.glob("*.npy")}
86
+ todo = [b for b in range(num_batches) if b not in done]
87
+ logger.info("Batches: total=%d done=%d todo=%d", num_batches, len(done), len(todo))
88
+ if not todo:
89
+ return True
90
 
91
+ from transformers import AutoTokenizer, AutoModel
92
+ logger.info("Loading MiniLM model...")
93
+ t0 = time.time()
94
+ tok = AutoTokenizer.from_pretrained(EMB_MODEL_ID)
95
+ mdl = AutoModel.from_pretrained(EMB_MODEL_ID)
96
+ mdl.eval()
97
+ logger.info("Model loaded in %.1fs", time.time() - t0)
 
 
 
98
 
99
+ texts = [s["text"] for s in samples]
100
+ start = time.time()
101
+ processed = 0
102
+ for b in todo:
103
+ if time.time() - start > TIME_BUDGET:
104
+ logger.info("Time budget reached after %d batches", processed)
105
+ break
106
+ i = b * BATCH_SIZE
107
+ bt = texts[i:i + BATCH_SIZE]
108
+ inputs = tok(bt, return_tensors="pt", truncation=True, max_length=MAX_LENGTH, padding=True)
109
  with torch.no_grad():
110
+ out = mdl(**inputs)
111
+ emb = mean_pool(out.last_hidden_state, inputs["attention_mask"])
112
+ emb = torch.nn.functional.normalize(emb, p=2, dim=1)
113
+ emb = emb.cpu().numpy().astype(np.float32)
114
+ np.save(EMB_CHUNKS_DIR / f"{b}.npy", emb)
115
+ processed += 1
116
+ if processed % 10 == 0 or processed == len(todo):
117
+ logger.info("batch %d/%d (this run=%d elapsed=%.1fs)", b+1, num_batches, processed, time.time()-start)
118
+ remaining = len(todo) - processed
119
+ logger.info("This run: %d batches; remaining: %d", processed, remaining)
120
+ return remaining == 0
121
+
122
+
123
+ def assemble_and_tsne(samples):
124
  from sklearn.manifold import TSNE
125
+ num_batches = (len(samples) + BATCH_SIZE - 1) // BATCH_SIZE
126
+ parts = []
127
+ for b in range(num_batches):
128
+ parts.append(np.load(EMB_CHUNKS_DIR / f"{b}.npy"))
129
+ emb = np.concatenate(parts, axis=0)
130
+ logger.info("Embeddings shape %s", emb.shape)
131
+ n = emb.shape[0]
132
+ perp = min(TSNE_PERPLEXITY, max(5, n - 1))
133
+ logger.info("t-SNE perp=%d...", perp)
134
+ t0 = time.time()
135
+ try:
136
+ tsne = TSNE(n_components=2, perplexity=perp, random_state=TSNE_SEED, max_iter=1000, learning_rate="auto", init="pca")
137
+ except TypeError:
138
+ tsne = TSNE(n_components=2, perplexity=perp, random_state=TSNE_SEED, max_iter=1000, learning_rate="auto", init="pca")
139
+ coords = tsne.fit_transform(emb)
140
+ logger.info("t-SNE done %.1fs", time.time() - t0)
141
+ np.savez_compressed(CACHE_FILE, embeddings=emb, tsne_2d=coords)
142
+ meta = [{"text": s["text"], "label": s["label"], "category": s["category"],
143
+ "severity": s["severity"], "source": s["source"], "split": s["split"]}
144
+ for s in samples]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  with open(META_FILE, "w", encoding="utf-8") as f:
146
+ json.dump(meta, f, ensure_ascii=False)
147
+ logger.info("Cache complete at %s", CACHE_DIR)
148
+
149
+
150
+ def status():
151
+ samples_exists = SAMPLES_FILE.exists()
152
+ n_samples = 0
153
+ if samples_exists:
154
+ with open(SAMPLES_FILE, "r", encoding="utf-8") as f:
155
+ n_samples = len(json.load(f))
156
+ n_done = len(list(EMB_CHUNKS_DIR.glob("*.npy"))) if EMB_CHUNKS_DIR.exists() else 0
157
+ n_batches = (n_samples + BATCH_SIZE - 1) // BATCH_SIZE if n_samples else 0
158
+ cache_done = CACHE_FILE.exists() and META_FILE.exists()
159
+ print(f"samples={n_samples} batches_done={n_done}/{n_batches} final_cache={cache_done}")
160
+
161
+
162
+ def main():
163
+ if STAGE == "status":
164
+ status(); return
165
+ if STAGE in ("download", "auto"):
166
+ samples = prepare_samples()
167
+ if STAGE == "download":
168
+ return
169
+ else:
170
+ with open(SAMPLES_FILE, "r", encoding="utf-8") as f:
171
+ samples = json.load(f)
172
+ if STAGE in ("embed", "auto"):
173
+ all_done = embed_chunked(samples)
174
+ if STAGE == "embed" or not all_done:
175
+ return
176
+ if STAGE in ("tsne", "auto"):
177
+ assemble_and_tsne(samples)
178
 
179
 
180
  if __name__ == "__main__":
181
+ main()