jimnoneill commited on
Commit
85de43a
Β·
verified Β·
1 Parent(s): cd3adb9

Upload src/pubguard/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/pubguard/train.py +280 -0
src/pubguard/train.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training pipeline for PubGuard classification heads.
3
+
4
+ Trains lightweight linear classifiers on frozen model2vec embeddings.
5
+ This follows the same paradigm as the openalex-topic-classifier:
6
+ the expensive embedding is pre-computed once, and the classifier
7
+ itself is a single matrix multiply β€” fast to train, fast to infer.
8
+
9
+ Training strategy:
10
+ 1. Load + cache model2vec embeddings for all training data
11
+ 2. For each head, fit a logistic regression (sklearn) with
12
+ class-balanced weights and L2 regularisation
13
+ 3. Export weights as .npz for the numpy-only inference path
14
+ 4. Report per-class precision / recall / F1 on held-out split
15
+
16
+ The entire pipeline trains in <5 minutes on CPU for ~50K samples,
17
+ consistent with your existing toolchain.
18
+ """
19
+
20
+ import json
21
+ import logging
22
+ import time
23
+ from pathlib import Path
24
+ from typing import Dict, List, Optional, Tuple
25
+
26
+ import numpy as np
27
+ from sklearn.linear_model import LogisticRegression
28
+ from sklearn.metrics import classification_report
29
+ from sklearn.model_selection import train_test_split
30
+
31
+ from .config import PubGuardConfig, DOC_TYPE_LABELS, AI_DETECT_LABELS, TOXICITY_LABELS
32
+ from .classifier import LinearHead
33
+ from .text import clean_text, extract_structural_features, N_STRUCTURAL_FEATURES
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ def load_ndjson(path: Path) -> Tuple[List[str], List[str]]:
39
+ """Load NDJSON file β†’ (texts, labels)."""
40
+ texts, labels = [], []
41
+ with open(path) as f:
42
+ for line in f:
43
+ if line.strip():
44
+ row = json.loads(line)
45
+ texts.append(row["text"])
46
+ labels.append(row["label"])
47
+ return texts, labels
48
+
49
+
50
+ def embed_texts(
51
+ texts: List[str],
52
+ config: PubGuardConfig,
53
+ cache_path: Optional[Path] = None,
54
+ ) -> np.ndarray:
55
+ """
56
+ Encode texts with model2vec, L2-normalise, return (N, D) float32.
57
+
58
+ Optionally caches to disk to avoid re-embedding on repeat runs.
59
+ """
60
+ if cache_path and cache_path.exists():
61
+ logger.info(f"Loading cached embeddings from {cache_path}")
62
+ return np.load(cache_path)
63
+
64
+ from model2vec import StaticModel
65
+
66
+ model_path = config.distilled_model_path
67
+ if model_path.exists():
68
+ model = StaticModel.from_pretrained(str(model_path))
69
+ else:
70
+ model = StaticModel.from_pretrained(config.model_name)
71
+ model_path.parent.mkdir(parents=True, exist_ok=True)
72
+ model.save_pretrained(str(model_path))
73
+
74
+ logger.info(f"Embedding {len(texts)} texts...")
75
+ cleaned = [clean_text(t, config.max_text_chars) for t in texts]
76
+ embeddings = model.encode(cleaned, show_progress_bar=True)
77
+
78
+ # L2-normalise
79
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
80
+ norms = np.where(norms == 0, 1, norms)
81
+ embeddings = (embeddings / norms).astype("float32")
82
+
83
+ if cache_path:
84
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
85
+ np.save(cache_path, embeddings)
86
+ logger.info(f"Cached embeddings to {cache_path}")
87
+
88
+ return embeddings
89
+
90
+
91
+ def compute_structural_features(texts: List[str]) -> np.ndarray:
92
+ """Compute structural features for all texts."""
93
+ feats = []
94
+ for t in texts:
95
+ cleaned = clean_text(t)
96
+ feat_dict = extract_structural_features(cleaned)
97
+ feats.append(list(feat_dict.values()))
98
+ return np.array(feats, dtype="float32")
99
+
100
+
101
+ def train_head(
102
+ X_train: np.ndarray,
103
+ y_train: np.ndarray,
104
+ X_test: np.ndarray,
105
+ y_test: np.ndarray,
106
+ labels: List[str],
107
+ head_name: str,
108
+ C: float = 1.0,
109
+ max_iter: int = 1000,
110
+ ) -> LinearHead:
111
+ """
112
+ Train a single linear classification head.
113
+
114
+ Uses sklearn LogisticRegression with:
115
+ - L2 regularisation (C parameter)
116
+ - class_weight='balanced' for imbalanced data
117
+ - lbfgs solver (good for moderate feature counts)
118
+ - multinomial objective even for binary (consistent API)
119
+
120
+ Extracts W and b into a LinearHead for numpy-only inference.
121
+ """
122
+ logger.info(f"\n{'='*60}")
123
+ logger.info(f"Training {head_name} head")
124
+ logger.info(f"{'='*60}")
125
+ logger.info(f" Train: {X_train.shape[0]:,} | Test: {X_test.shape[0]:,}")
126
+ logger.info(f" Features: {X_train.shape[1]} | Classes: {len(labels)}")
127
+
128
+ # Class distribution
129
+ unique, counts = np.unique(y_train, return_counts=True)
130
+ for u, c in zip(unique, counts):
131
+ logger.info(f" {u}: {c:,}")
132
+
133
+ start = time.time()
134
+
135
+ clf = LogisticRegression(
136
+ C=C,
137
+ max_iter=max_iter,
138
+ class_weight="balanced",
139
+ solver="lbfgs",
140
+ n_jobs=-1,
141
+ random_state=42,
142
+ )
143
+ clf.fit(X_train, y_train)
144
+
145
+ elapsed = time.time() - start
146
+ logger.info(f" Trained in {elapsed:.1f}s")
147
+
148
+ # Evaluate
149
+ y_pred = clf.predict(X_test)
150
+ report = classification_report(y_test, y_pred, target_names=labels, digits=4)
151
+ logger.info(f"\n{report}")
152
+
153
+ # Extract weights into LinearHead
154
+ head = LinearHead(labels)
155
+ # sklearn stores coef_ as (n_classes, n_features) for multinomial
156
+ # We want W as (n_features, n_classes) for X @ W + b
157
+ if clf.coef_.shape[0] == 1:
158
+ # Binary case: sklearn only stores one row
159
+ # Expand to full 2-class format
160
+ head.W = np.vstack([-clf.coef_[0], clf.coef_[0]]).T.astype("float32")
161
+ head.b = np.array([-clf.intercept_[0], clf.intercept_[0]], dtype="float32")
162
+ else:
163
+ head.W = clf.coef_.T.astype("float32") # (features, classes)
164
+ head.b = clf.intercept_.astype("float32")
165
+
166
+ # Sanity check: reproduce sklearn predictions
167
+ logits = X_test[:5] @ head.W + head.b
168
+ e = np.exp(logits - logits.max(axis=-1, keepdims=True))
169
+ probs = e / e.sum(axis=-1, keepdims=True)
170
+ np_pred_idx = np.argmax(probs, axis=1)
171
+ sk_pred_idx = clf.predict(X_test[:5]) # returns integer class indices
172
+ assert list(np_pred_idx) == list(int(x) for x in sk_pred_idx), \
173
+ f"Mismatch: {list(np_pred_idx)} vs {list(sk_pred_idx)}"
174
+ logger.info(" βœ“ Numpy inference matches sklearn predictions")
175
+
176
+ return head
177
+
178
+
179
+ def train_all(
180
+ data_dir: Path,
181
+ config: Optional[PubGuardConfig] = None,
182
+ test_size: float = 0.15,
183
+ ):
184
+ """
185
+ Train all three classification heads.
186
+
187
+ Args:
188
+ data_dir: Directory containing the prepared NDJSON files
189
+ config: PubGuard configuration
190
+ test_size: Fraction of data held out for evaluation
191
+ """
192
+ config = config or PubGuardConfig()
193
+ data_dir = Path(data_dir)
194
+ cache_dir = data_dir / "embeddings_cache"
195
+
196
+ logger.info("=" * 60)
197
+ logger.info("PubGuard Training Pipeline")
198
+ logger.info("=" * 60)
199
+ logger.info(f"Data dir: {data_dir}")
200
+ logger.info(f"Models dir: {config.models_dir}")
201
+ start_total = time.time()
202
+
203
+ # ── HEAD 1: doc_type ────────────────────────────────────────
204
+ doc_type_path = data_dir / "doc_type_train.ndjson"
205
+ if doc_type_path.exists():
206
+ texts, labels = load_ndjson(doc_type_path)
207
+ label_to_idx = {l: i for i, l in enumerate(DOC_TYPE_LABELS)}
208
+
209
+ # Embed
210
+ embeddings = embed_texts(
211
+ texts, config,
212
+ cache_path=cache_dir / "doc_type_emb.npy",
213
+ )
214
+
215
+ # Add structural features
216
+ logger.info("Computing structural features...")
217
+ struct = compute_structural_features(texts)
218
+ X = np.concatenate([embeddings, struct], axis=1)
219
+
220
+ y = np.array([label_to_idx.get(l, 0) for l in labels])
221
+
222
+ X_tr, X_te, y_tr, y_te = train_test_split(
223
+ X, y, test_size=test_size, stratify=y, random_state=42
224
+ )
225
+
226
+ head = train_head(X_tr, y_tr, X_te, y_te, DOC_TYPE_LABELS, "doc_type")
227
+ head.save(config.doc_type_head_path)
228
+ logger.info(f"Saved β†’ {config.doc_type_head_path}")
229
+ else:
230
+ logger.warning(f"doc_type data not found: {doc_type_path}")
231
+
232
+ # ── HEAD 2: ai_detect ───────────────────────────────────────
233
+ ai_path = data_dir / "ai_detect_train.ndjson"
234
+ if ai_path.exists():
235
+ texts, labels = load_ndjson(ai_path)
236
+ label_to_idx = {l: i for i, l in enumerate(AI_DETECT_LABELS)}
237
+
238
+ embeddings = embed_texts(
239
+ texts, config,
240
+ cache_path=cache_dir / "ai_detect_emb.npy",
241
+ )
242
+
243
+ y = np.array([label_to_idx.get(l, 0) for l in labels])
244
+
245
+ X_tr, X_te, y_tr, y_te = train_test_split(
246
+ embeddings, y, test_size=test_size, stratify=y, random_state=42
247
+ )
248
+
249
+ head = train_head(X_tr, y_tr, X_te, y_te, AI_DETECT_LABELS, "ai_detect")
250
+ head.save(config.ai_detect_head_path)
251
+ logger.info(f"Saved β†’ {config.ai_detect_head_path}")
252
+ else:
253
+ logger.warning(f"ai_detect data not found: {ai_path}")
254
+
255
+ # ── HEAD 3: toxicity ────────────────────────────────────────
256
+ tox_path = data_dir / "toxicity_train.ndjson"
257
+ if tox_path.exists():
258
+ texts, labels = load_ndjson(tox_path)
259
+ label_to_idx = {l: i for i, l in enumerate(TOXICITY_LABELS)}
260
+
261
+ embeddings = embed_texts(
262
+ texts, config,
263
+ cache_path=cache_dir / "toxicity_emb.npy",
264
+ )
265
+
266
+ y = np.array([label_to_idx.get(l, 0) for l in labels])
267
+
268
+ X_tr, X_te, y_tr, y_te = train_test_split(
269
+ embeddings, y, test_size=test_size, stratify=y, random_state=42
270
+ )
271
+
272
+ head = train_head(X_tr, y_tr, X_te, y_te, TOXICITY_LABELS, "toxicity")
273
+ head.save(config.toxicity_head_path)
274
+ logger.info(f"Saved β†’ {config.toxicity_head_path}")
275
+ else:
276
+ logger.warning(f"toxicity data not found: {tox_path}")
277
+
278
+ elapsed = time.time() - start_total
279
+ logger.info(f"\nTotal training time: {elapsed/60:.1f} minutes")
280
+ logger.info("All heads saved to: " + str(config.models_dir))