Eric Xu commited on
Commit
d94e0d2
·
unverified ·
1 Parent(s): 92cf501

Use Nemotron personas when available, add dataset setup UI

Browse files

Web interface now prefers census-grounded Nemotron personas (1M dataset)
over LLM-generated ones. Checks common paths on startup; if not found,
shows a setup panel where user provides a path — loads existing data or
downloads from HuggingFace (~2GB).

- Add /api/nemotron/setup endpoint (load or download to given path)
- Add /api/config nemotron_available field
- Cohort generation uses stratified sampling from Nemotron when available
- Progress log shows data source (census-grounded vs LLM-generated)

Files changed (2) hide show
  1. web/app.py +113 -16
  2. web/static/index.html +70 -1
web/app.py CHANGED
@@ -44,6 +44,8 @@ from bias_audit import (
44
  reframe_entity, add_authority_signals, reorder_entity,
45
  run_paired_evaluation, analyze_probe, generate_report, HUMAN_BASELINES,
46
  )
 
 
47
 
48
  app = FastAPI(title="SGO — Semantic Gradient Optimization")
49
  app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
@@ -51,6 +53,49 @@ app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), na
51
  # In-memory store for active sessions
52
  sessions: dict = {}
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  def get_client():
56
  return OpenAI(
@@ -102,14 +147,43 @@ async def index():
102
 
103
  @app.get("/api/config")
104
  async def get_config():
105
- """Return current LLM config (model name, whether API key is set)."""
 
106
  return {
107
  "model": get_model(),
108
  "has_api_key": bool(os.getenv("LLM_API_KEY")),
109
  "base_url": os.getenv("LLM_BASE_URL", ""),
 
 
110
  }
111
 
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  @app.post("/api/session")
114
  async def create_session(entity: EntityInput):
115
  """Create a new evaluation session with an entity."""
@@ -182,22 +256,42 @@ Be concrete and relevant — no generic segments."""
182
 
183
  @app.post("/api/cohort/generate")
184
  async def generate_cohort_endpoint(config: CohortConfig):
185
- """Generate an LLM cohort and attach to a new session."""
186
  sid = uuid.uuid4().hex[:12]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
- client = get_client()
189
- model = get_model()
190
- all_personas = []
191
-
192
- with concurrent.futures.ThreadPoolExecutor(max_workers=config.parallel) as pool:
193
- futs = {
194
- pool.submit(generate_segment, client, model,
195
- seg["label"], seg["count"], config.description): seg
196
- for seg in config.segments
197
- }
198
- for fut in concurrent.futures.as_completed(futs):
199
- personas = fut.result()
200
- all_personas.extend(personas)
201
 
202
  for i, p in enumerate(all_personas):
203
  p["user_id"] = i
@@ -211,7 +305,10 @@ async def generate_cohort_endpoint(config: CohortConfig):
211
  "created": datetime.now().isoformat(),
212
  }
213
 
214
- return {"session_id": sid, "cohort_size": len(all_personas), "cohort": all_personas}
 
 
 
215
 
216
 
217
  @app.post("/api/cohort/upload/{sid}")
 
44
  reframe_entity, add_authority_signals, reorder_entity,
45
  run_paired_evaluation, analyze_probe, generate_report, HUMAN_BASELINES,
46
  )
47
+ from persona_loader import load_personas, filter_personas, to_profile
48
+ from stratified_sampler import stratified_sample, age_bracket, make_occupation_fn
49
 
50
  app = FastAPI(title="SGO — Semantic Gradient Optimization")
51
  app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
 
53
  # In-memory store for active sessions
54
  sessions: dict = {}
55
 
56
+ # Nemotron dataset — loaded once if available
57
+ _nemotron_ds = None
58
+ _nemotron_checked = False
59
+
60
+ NEMOTRON_SEARCH_PATHS = [
61
+ PROJECT_ROOT / "data" / "nemotron",
62
+ Path.home() / "data" / "nvidia" / "Nemotron-Personas-USA",
63
+ Path.home() / "data" / "nemotron",
64
+ Path(os.getenv("NEMOTRON_DATA_DIR", "/nonexistent")),
65
+ ]
66
+
67
+
68
+ def find_nemotron_path():
69
+ """Find Nemotron dataset on disk. Returns path or None."""
70
+ for path in NEMOTRON_SEARCH_PATHS:
71
+ if (path / "dataset_info.json").exists():
72
+ return path
73
+ return None
74
+
75
+
76
+ def get_nemotron(data_dir=None):
77
+ """Load Nemotron dataset. Returns None if not found."""
78
+ global _nemotron_ds, _nemotron_checked
79
+ if data_dir:
80
+ # Explicit path — reset cache
81
+ _nemotron_checked = False
82
+ _nemotron_ds = None
83
+ NEMOTRON_SEARCH_PATHS.insert(0, Path(data_dir))
84
+
85
+ if _nemotron_checked:
86
+ return _nemotron_ds
87
+
88
+ _nemotron_checked = True
89
+ path = find_nemotron_path()
90
+ if path:
91
+ try:
92
+ _nemotron_ds = load_personas(data_dir=path)
93
+ print(f"Nemotron loaded: {len(_nemotron_ds)} personas from {path}")
94
+ return _nemotron_ds
95
+ except Exception as e:
96
+ print(f"Failed to load Nemotron from {path}: {e}")
97
+ return None
98
+
99
 
100
  def get_client():
101
  return OpenAI(
 
147
 
148
  @app.get("/api/config")
149
  async def get_config():
150
+ """Return current LLM config and Nemotron status."""
151
+ nem_path = find_nemotron_path()
152
  return {
153
  "model": get_model(),
154
  "has_api_key": bool(os.getenv("LLM_API_KEY")),
155
  "base_url": os.getenv("LLM_BASE_URL", ""),
156
+ "nemotron_path": str(nem_path) if nem_path else None,
157
+ "nemotron_available": nem_path is not None,
158
  }
159
 
160
 
161
+ class NemotronPathInput(BaseModel):
162
+ path: str
163
+
164
+
165
+ @app.post("/api/nemotron/setup")
166
+ async def setup_nemotron(input: NemotronPathInput):
167
+ """Point to existing Nemotron data, or download it to the given path."""
168
+ p = Path(input.path).expanduser().resolve()
169
+
170
+ if (p / "dataset_info.json").exists():
171
+ # Already there — just load it
172
+ ds = get_nemotron(data_dir=str(p))
173
+ if ds is None:
174
+ raise HTTPException(500, "Failed to load dataset")
175
+ return {"status": "loaded", "path": str(p), "count": len(ds)}
176
+
177
+ # Not there — download to this path
178
+ from setup_data import setup
179
+ try:
180
+ ds = setup(data_dir=p)
181
+ get_nemotron(data_dir=str(p))
182
+ return {"status": "downloaded", "path": str(p), "count": len(ds)}
183
+ except Exception as e:
184
+ raise HTTPException(500, f"Download failed: {e}")
185
+
186
+
187
  @app.post("/api/session")
188
  async def create_session(entity: EntityInput):
189
  """Create a new evaluation session with an entity."""
 
256
 
257
  @app.post("/api/cohort/generate")
258
  async def generate_cohort_endpoint(config: CohortConfig):
259
+ """Generate a cohort from Nemotron if available, else LLM-generated."""
260
  sid = uuid.uuid4().hex[:12]
261
+ total = sum(s.get("count", 8) for s in config.segments)
262
+
263
+ ds = get_nemotron()
264
+ if ds is not None:
265
+ # Use census-grounded Nemotron personas
266
+ filtered = filter_personas(ds, {}, limit=max(total * 20, 2000))
267
+ profiles = [to_profile(row, i) for i, row in enumerate(filtered)]
268
+
269
+ dim_fns = [
270
+ lambda p: age_bracket(p.get("age", 30)),
271
+ lambda p: p.get("marital_status", "unknown"),
272
+ lambda p: p.get("education_level", "") or "unknown",
273
+ ]
274
+ diversity_fn = lambda p: p.get("occupation", "unknown") or "unknown"
275
+
276
+ all_personas = stratified_sample(profiles, dim_fns, total=total,
277
+ diversity_fn=diversity_fn)
278
+ source = "nemotron"
279
+ else:
280
+ # Fallback: LLM-generated
281
+ client = get_client()
282
+ model = get_model()
283
+ all_personas = []
284
 
285
+ with concurrent.futures.ThreadPoolExecutor(max_workers=config.parallel) as pool:
286
+ futs = {
287
+ pool.submit(generate_segment, client, model,
288
+ seg["label"], seg["count"], config.description): seg
289
+ for seg in config.segments
290
+ }
291
+ for fut in concurrent.futures.as_completed(futs):
292
+ personas = fut.result()
293
+ all_personas.extend(personas)
294
+ source = "llm-generated"
 
 
 
295
 
296
  for i, p in enumerate(all_personas):
297
  p["user_id"] = i
 
305
  "created": datetime.now().isoformat(),
306
  }
307
 
308
+ return {
309
+ "session_id": sid, "cohort_size": len(all_personas),
310
+ "cohort": all_personas, "source": source,
311
+ }
312
 
313
 
314
  @app.post("/api/cohort/upload/{sid}")
web/static/index.html CHANGED
@@ -308,8 +308,28 @@
308
  <h1>Semantic Gradient Optimization</h1>
309
  <p>Evaluate anything against a synthetic panel. Find what to change first.</p>
310
  <div id="configBadge" class="config-badge">checking...</div>
 
311
  </header>
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  <!-- STEP 1: Entity + Evaluate (one click) -->
314
  <div class="step active" id="step1">
315
  <div class="step-header">
@@ -547,6 +567,7 @@ let evalResultsData = null;
547
  async function init() {
548
  const resp = await fetch('/api/config');
549
  const cfg = await resp.json();
 
550
  const badge = document.getElementById('configBadge');
551
  if (cfg.has_api_key) {
552
  badge.textContent = cfg.model;
@@ -556,10 +577,57 @@ async function init() {
556
  badge.className = 'config-badge warn';
557
  }
558
 
 
 
 
 
 
 
 
 
 
 
 
 
559
  addChange('', '');
560
  addChange('', '');
561
  }
562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
  // ── Templates ──
564
 
565
  function loadTemplate(name) {
@@ -669,7 +737,8 @@ async function runFullPipeline() {
669
  body: JSON.stringify(cohortData.cohort),
670
  });
671
 
672
- logStep(`${cohortData.cohort_size} evaluators generated`, 'pos');
 
673
  document.getElementById('pipelineProgressBar').style.width = '35%';
674
 
675
  // Phase 4: Evaluate via SSE
 
308
  <h1>Semantic Gradient Optimization</h1>
309
  <p>Evaluate anything against a synthetic panel. Find what to change first.</p>
310
  <div id="configBadge" class="config-badge">checking...</div>
311
+ <div id="nemotronBadge" class="config-badge" style="margin-left:8px">checking...</div>
312
  </header>
313
 
314
+ <!-- Nemotron setup (shown if not available) -->
315
+ <div class="step hidden" id="nemotronSetup" style="border-color:var(--yellow)">
316
+ <div class="step-header">
317
+ <div class="step-num" style="border-color:var(--yellow);color:var(--yellow)">!</div>
318
+ <div class="step-title">Persona dataset not found</div>
319
+ </div>
320
+ <p class="step-desc">SGO uses 1M census-grounded personas for realistic evaluations. Provide a path to the dataset or download it (~2GB).</p>
321
+ <div class="field">
322
+ <label>Dataset path</label>
323
+ <input type="text" id="nemotronPath" placeholder="">
324
+ </div>
325
+ <div class="btn-row">
326
+ <button onclick="setupNemotron()">Load or download</button>
327
+ </div>
328
+ <div id="nemotronStatus" class="hidden mt-16">
329
+ <div class="progress-text" id="nemotronStatusText"></div>
330
+ </div>
331
+ </div>
332
+
333
  <!-- STEP 1: Entity + Evaluate (one click) -->
334
  <div class="step active" id="step1">
335
  <div class="step-header">
 
567
  async function init() {
568
  const resp = await fetch('/api/config');
569
  const cfg = await resp.json();
570
+
571
  const badge = document.getElementById('configBadge');
572
  if (cfg.has_api_key) {
573
  badge.textContent = cfg.model;
 
577
  badge.className = 'config-badge warn';
578
  }
579
 
580
+ const nemBadge = document.getElementById('nemotronBadge');
581
+ if (cfg.nemotron_available) {
582
+ nemBadge.textContent = 'Nemotron 1M';
583
+ nemBadge.className = 'config-badge ok';
584
+ } else {
585
+ nemBadge.textContent = 'No persona dataset';
586
+ nemBadge.className = 'config-badge warn';
587
+ document.getElementById('nemotronSetup').classList.remove('hidden');
588
+ // Default path: project's data dir
589
+ document.getElementById('nemotronPath').value = cfg.base_url ? '' : 'data/nemotron';
590
+ }
591
+
592
  addChange('', '');
593
  addChange('', '');
594
  }
595
 
596
+ async function setupNemotron() {
597
+ const path = document.getElementById('nemotronPath').value.trim();
598
+ if (!path) return alert('Please enter a path.');
599
+
600
+ const status = document.getElementById('nemotronStatus');
601
+ const text = document.getElementById('nemotronStatusText');
602
+ status.classList.remove('hidden');
603
+ text.textContent = 'Loading dataset (or downloading if not found — ~2GB, may take a few minutes)...';
604
+
605
+ try {
606
+ const resp = await fetch('/api/nemotron/setup', {
607
+ method: 'POST',
608
+ headers: {'Content-Type': 'application/json'},
609
+ body: JSON.stringify({path}),
610
+ });
611
+ const data = await resp.json();
612
+ if (!resp.ok) throw new Error(data.detail || 'Failed');
613
+
614
+ text.textContent = `${data.status === 'downloaded' ? 'Downloaded' : 'Loaded'}: ${data.count.toLocaleString()} personas`;
615
+ text.style.color = 'var(--green)';
616
+
617
+ const nemBadge = document.getElementById('nemotronBadge');
618
+ nemBadge.textContent = 'Nemotron 1M';
619
+ nemBadge.className = 'config-badge ok';
620
+
621
+ // Hide setup after a moment
622
+ setTimeout(() => {
623
+ document.getElementById('nemotronSetup').classList.add('hidden');
624
+ }, 2000);
625
+ } catch (e) {
626
+ text.textContent = `Error: ${e.message}`;
627
+ text.style.color = 'var(--red)';
628
+ }
629
+ }
630
+
631
  // ── Templates ──
632
 
633
  function loadTemplate(name) {
 
737
  body: JSON.stringify(cohortData.cohort),
738
  });
739
 
740
+ const src = cohortData.source === 'nemotron' ? 'census-grounded (Nemotron)' : 'LLM-generated';
741
+ logStep(`${cohortData.cohort_size} evaluators ready — ${src}`, 'pos');
742
  document.getElementById('pipelineProgressBar').style.width = '35%';
743
 
744
  // Phase 4: Evaluate via SSE