jimnoneill commited on
Commit
bc2f6a0
Β·
verified Β·
1 Parent(s): be40856

Updated with real poster data support + poster pass-through gate

Browse files
Files changed (1) hide show
  1. src/pubguard/data.py +49 -12
src/pubguard/data.py CHANGED
@@ -283,8 +283,8 @@ def prepare_doc_type_dataset(
283
  except Exception:
284
  pass
285
 
286
- # ── junk ─────────────────────────────────────────────────────
287
- logger.info("Loading ag_news for junk class...")
288
  try:
289
  ds = load_dataset(
290
  "ag_news",
@@ -292,7 +292,7 @@ def prepare_doc_type_dataset(
292
  )
293
  count = 0
294
  for row in ds:
295
- if count >= n_per_class // 2:
296
  break
297
  text = row.get("text", "")
298
  if len(text) > 30:
@@ -302,16 +302,53 @@ def prepare_doc_type_dataset(
302
  except Exception as e:
303
  logger.warning(f"Could not load ag_news: {e}")
304
 
305
- logger.info("Generating synthetic junk...")
306
- synth_junk = generate_synthetic_junk(n_per_class // 2)
307
- all_samples.extend(synth_junk)
308
- logger.info(f" junk (synthetic): {len(synth_junk)}")
309
-
310
- # ── poster ───────────────────────────────────────────────────
311
- logger.info("Generating synthetic poster data...")
312
- synth_posters = generate_synthetic_posters(n_per_class)
 
 
 
 
 
 
313
  all_samples.extend(synth_posters)
314
- logger.info(f" poster (synthetic): {len(synth_posters)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
  # ── Shuffle and save ─────────────────────────────────────────
317
  random.shuffle(all_samples)
 
283
  except Exception:
284
  pass
285
 
286
+ # ── junk (100% real data β€” ag_news) ────────────────────────────
287
+ logger.info("Loading ag_news for junk class (full β€” no synthetic)...")
288
  try:
289
  ds = load_dataset(
290
  "ag_news",
 
292
  )
293
  count = 0
294
  for row in ds:
295
+ if count >= n_per_class:
296
  break
297
  text = row.get("text", "")
298
  if len(text) > 30:
 
302
  except Exception as e:
303
  logger.warning(f"Could not load ag_news: {e}")
304
 
305
+ # ── poster ────────────────────────────────────────────────────
306
+ # NOTE: Real poster text is nearly identical to paper text in
307
+ # embedding space (both are scientific). PubGuard uses text-only
308
+ # features, so we need SHORT, STRUCTURED poster-style texts that
309
+ # the embedding can distinguish from full papers.
310
+ #
311
+ # Strategy: synthetic poster templates (structured, short) +
312
+ # real poster texts TRUNCATED to first 500 chars (title/authors
313
+ # block, which has distinct formatting from paper introductions).
314
+ logger.info("Loading poster data (structured templates + real poster headers)...")
315
+ poster_count = 0
316
+
317
+ # (a) Synthetic templates β€” provide distinctive poster structure signal
318
+ synth_posters = generate_synthetic_posters(min(n_per_class // 2, 7500))
319
  all_samples.extend(synth_posters)
320
+ poster_count += len(synth_posters)
321
+ logger.info(f" poster (synthetic templates): {len(synth_posters)}")
322
+
323
+ # (b) Real poster header text (first 500 chars only β€” title/authors block)
324
+ real_poster_count = 0
325
+ local_poster_data = Path("/home/joneill/pubverse_brett/poster_sentry/poster_texts_for_pubguard.ndjson")
326
+ if not local_poster_data.exists():
327
+ local_poster_data = Path.cwd().parent / "poster_sentry" / "poster_texts_for_pubguard.ndjson"
328
+
329
+ if local_poster_data.exists():
330
+ logger.info(f" Adding real poster headers from: {local_poster_data}")
331
+ with open(local_poster_data) as f:
332
+ for line in f:
333
+ if real_poster_count >= n_per_class // 2:
334
+ break
335
+ row = json.loads(line)
336
+ if row.get("label") == "poster":
337
+ # Truncate to header region (title, authors, affiliations)
338
+ text = row["text"][:500]
339
+ if len(text) > 50:
340
+ all_samples.append({"text": text, "label": "poster"})
341
+ real_poster_count += 1
342
+ poster_count += real_poster_count
343
+ logger.info(f" poster (real headers, ≀500 chars): {real_poster_count}")
344
+ else:
345
+ # Fill with more synthetic templates if no real data available
346
+ extra = generate_synthetic_posters(n_per_class // 2)
347
+ all_samples.extend(extra)
348
+ poster_count += len(extra)
349
+ logger.info(f" poster (synthetic fallback): {len(extra)}")
350
+
351
+ logger.info(f" poster total: {poster_count}")
352
 
353
  # ── Shuffle and save ─────────────────────────────────────────
354
  random.shuffle(all_samples)