AlephBeth-AI commited on
Commit
f303380
·
verified ·
1 Parent(s): 7335dc8

Upload precompute.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. precompute.py +196 -0
precompute.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GuardLLM - Precompute Embeddings & t-SNE
3
+ Downloads the neuralchemy/Prompt-injection-dataset (core config),
4
+ extracts CLS embeddings from Llama Prompt Guard 2 (86M),
5
+ computes t-SNE 2D projection, and saves everything to a cache file.
6
+
7
+ Run this script ONCE before launching the app (or let the app run it on first start).
8
+ """
9
+
10
+ import os
11
+ import json
12
+ import logging
13
+ import numpy as np
14
+ import torch
15
+ from pathlib import Path
16
+
17
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
18
+ logger = logging.getLogger("precompute")
19
+
20
+ CACHE_DIR = Path(__file__).parent / "cache"
21
+ CACHE_FILE = CACHE_DIR / "embeddings_tsne.npz"
22
+ META_FILE = CACHE_DIR / "metadata.json"
23
+ MODEL_ID = "meta-llama/Llama-Prompt-Guard-2-86M"
24
+ DATASET_ID = "neuralchemy/Prompt-injection-dataset"
25
+ DATASET_CONFIG = "core"
26
+ BATCH_SIZE = 32
27
+ MAX_LENGTH = 512
28
+ TSNE_PERPLEXITY = 30
29
+ TSNE_SEED = 42
30
+
31
+
32
+ def is_cached() -> bool:
33
+ """Check if precomputed data exists."""
34
+ return CACHE_FILE.exists() and META_FILE.exists()
35
+
36
+
37
+ def load_cached():
38
+ """Load precomputed embeddings, t-SNE coords, and metadata."""
39
+ logger.info("Loading cached data from %s", CACHE_DIR)
40
+ data = np.load(CACHE_FILE)
41
+ with open(META_FILE, "r", encoding="utf-8") as f:
42
+ metadata = json.load(f)
43
+ return {
44
+ "embeddings": data["embeddings"],
45
+ "tsne_2d": data["tsne_2d"],
46
+ "metadata": metadata,
47
+ }
48
+
49
+
50
+ def download_dataset():
51
+ """Download the neuralchemy dataset (core config)."""
52
+ from datasets import load_dataset
53
+
54
+ logger.info("Downloading dataset %s (config=%s)...", DATASET_ID, DATASET_CONFIG)
55
+ ds = load_dataset(DATASET_ID, DATASET_CONFIG)
56
+
57
+ # Combine all splits for the visualization
58
+ all_samples = []
59
+ for split_name in ["train", "validation", "test"]:
60
+ if split_name in ds:
61
+ split = ds[split_name]
62
+ logger.info(" Split '%s': %d samples", split_name, len(split))
63
+ for row in split:
64
+ all_samples.append({
65
+ "text": row["text"],
66
+ "label": int(row["label"]),
67
+ "category": row.get("category", "unknown"),
68
+ "severity": row.get("severity", ""),
69
+ "source": row.get("source", ""),
70
+ "split": split_name,
71
+ })
72
+
73
+ logger.info("Total samples: %d", len(all_samples))
74
+ return all_samples
75
+
76
+
77
+ def compute_embeddings(samples: list) -> np.ndarray:
78
+ """Extract CLS token embeddings from Llama Prompt Guard 2."""
79
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
80
+
81
+ logger.info("Loading model %s...", MODEL_ID)
82
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
83
+ model = AutoModelForSequenceClassification.from_pretrained(
84
+ MODEL_ID, output_hidden_states=True
85
+ )
86
+ model.eval()
87
+
88
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
+ model.to(device)
90
+ logger.info("Using device: %s", device)
91
+
92
+ texts = [s["text"] for s in samples]
93
+ all_embeddings = []
94
+
95
+ num_batches = (len(texts) + BATCH_SIZE - 1) // BATCH_SIZE
96
+ for i in range(0, len(texts), BATCH_SIZE):
97
+ batch_idx = i // BATCH_SIZE + 1
98
+ batch_texts = texts[i : i + BATCH_SIZE]
99
+
100
+ if batch_idx % 10 == 1 or batch_idx == num_batches:
101
+ logger.info(" Batch %d/%d (%d samples)...", batch_idx, num_batches, len(batch_texts))
102
+
103
+ inputs = tokenizer(
104
+ batch_texts,
105
+ return_tensors="pt",
106
+ truncation=True,
107
+ max_length=MAX_LENGTH,
108
+ padding=True,
109
+ ).to(device)
110
+
111
+ with torch.no_grad():
112
+ outputs = model(**inputs)
113
+ # CLS token embedding from last hidden layer
114
+ hidden_states = outputs.hidden_states[-1] # [batch, seq_len, 768]
115
+ cls_embeddings = hidden_states[:, 0, :].cpu().numpy() # [batch, 768]
116
+ all_embeddings.append(cls_embeddings)
117
+
118
+ embeddings = np.concatenate(all_embeddings, axis=0)
119
+ logger.info("Embeddings shape: %s", embeddings.shape)
120
+ return embeddings
121
+
122
+
123
+ def compute_tsne(embeddings: np.ndarray) -> np.ndarray:
124
+ """Run t-SNE dimensionality reduction to 2D."""
125
+ from sklearn.manifold import TSNE
126
+
127
+ n_samples = embeddings.shape[0]
128
+ perplexity = min(TSNE_PERPLEXITY, n_samples - 1)
129
+
130
+ logger.info(
131
+ "Running t-SNE (n=%d, perplexity=%d, random_state=%d)...",
132
+ n_samples, perplexity, TSNE_SEED,
133
+ )
134
+ tsne = TSNE(
135
+ n_components=2,
136
+ perplexity=perplexity,
137
+ random_state=TSNE_SEED,
138
+ n_iter=1000,
139
+ learning_rate="auto",
140
+ init="pca",
141
+ )
142
+ coords_2d = tsne.fit_transform(embeddings)
143
+ logger.info("t-SNE done. Output shape: %s", coords_2d.shape)
144
+ return coords_2d
145
+
146
+
147
+ def precompute_all():
148
+ """Full pipeline: download → embed → t-SNE → save."""
149
+ if is_cached():
150
+ logger.info("Cache already exists. Loading...")
151
+ return load_cached()
152
+
153
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
154
+
155
+ # Step 1: Download dataset
156
+ samples = download_dataset()
157
+
158
+ # Step 2: Compute embeddings
159
+ embeddings = compute_embeddings(samples)
160
+
161
+ # Step 3: Compute t-SNE
162
+ tsne_2d = compute_tsne(embeddings)
163
+
164
+ # Step 4: Save
165
+ logger.info("Saving to cache...")
166
+ np.savez_compressed(
167
+ CACHE_FILE,
168
+ embeddings=embeddings,
169
+ tsne_2d=tsne_2d,
170
+ )
171
+
172
+ # Save metadata (text, labels, categories) as JSON
173
+ metadata = []
174
+ for s in samples:
175
+ metadata.append({
176
+ "text": s["text"],
177
+ "label": s["label"],
178
+ "category": s["category"],
179
+ "severity": s["severity"],
180
+ "source": s["source"],
181
+ "split": s["split"],
182
+ })
183
+
184
+ with open(META_FILE, "w", encoding="utf-8") as f:
185
+ json.dump(metadata, f, ensure_ascii=False)
186
+
187
+ logger.info("Cache saved to %s", CACHE_DIR)
188
+ return {
189
+ "embeddings": embeddings,
190
+ "tsne_2d": tsne_2d,
191
+ "metadata": metadata,
192
+ }
193
+
194
+
195
+ if __name__ == "__main__":
196
+ precompute_all()