CaffeinatedCoding commited on
Commit
7d0fa43
·
verified ·
1 Parent(s): 34df332

Upload folder using huggingface_hub

Browse files
Files changed (8) hide show
  1. api/main.py +90 -4
  2. frontend/app.js +104 -0
  3. frontend/index.html +58 -0
  4. frontend/style.css +171 -1
  5. src/agent_v2.py +7 -1
  6. src/logger.py +64 -0
  7. src/reranker.py +104 -0
  8. src/verify.py +4 -22
api/main.py CHANGED
@@ -5,7 +5,7 @@ V2 agent with conversation memory and 3-pass reasoning.
5
  Port 7860 for HuggingFace Spaces compatibility.
6
  """
7
 
8
- from fastapi import FastAPI, HTTPException
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.staticfiles import StaticFiles
11
  from fastapi.responses import FileResponse
@@ -15,10 +15,14 @@ import time
15
  import os
16
  import sys
17
  import logging
 
 
18
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
 
 
22
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
23
 
24
 
@@ -82,6 +86,9 @@ download_models()
82
  from src.ner import load_ner_model
83
  load_ner_model()
84
 
 
 
 
85
  from src.citation_graph import load_citation_graph
86
  load_citation_graph()
87
 
@@ -119,6 +126,7 @@ class QueryResponse(BaseModel):
119
  num_sources: int
120
  truncated: bool
121
  latency_ms: float
 
122
 
123
 
124
  @app.get("/")
@@ -134,13 +142,14 @@ def health():
134
 
135
 
136
  @app.post("/query", response_model=QueryResponse)
137
- def query(request: QueryRequest):
138
  if not request.query.strip():
139
  raise HTTPException(status_code=400, detail="Query cannot be empty")
140
  if len(request.query) < 10:
141
  raise HTTPException(status_code=400, detail="Query too short — minimum 10 characters")
142
  if len(request.query) > 1000:
143
  raise HTTPException(status_code=400, detail="Query too long — maximum 1000 characters")
 
144
  start = time.time()
145
  try:
146
  if USE_V2:
@@ -148,8 +157,85 @@ def query(request: QueryRequest):
148
  result = _run_query(request.query, session_id)
149
  else:
150
  result = _run_query_v1(request.query)
 
151
  except Exception as e:
152
  logger.error(f"Pipeline error: {e}")
153
  raise HTTPException(status_code=500, detail=f"Pipeline error: {str(e)}")
154
- result["latency_ms"] = round((time.time() - start) * 1000, 2)
155
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  Port 7860 for HuggingFace Spaces compatibility.
6
  """
7
 
8
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.staticfiles import StaticFiles
11
  from fastapi.responses import FileResponse
 
15
  import os
16
  import sys
17
  import logging
18
+ import json
19
+ from collections import Counter
20
 
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
+ from src.logger import log_inference
25
+
26
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
27
 
28
 
 
86
  from src.ner import load_ner_model
87
  load_ner_model()
88
 
89
+ from src.reranker import load_reranker
90
+ load_reranker()
91
+
92
  from src.citation_graph import load_citation_graph
93
  load_citation_graph()
94
 
 
126
  num_sources: int
127
  truncated: bool
128
  latency_ms: float
129
+ session_id: Optional[str] = None
130
 
131
 
132
  @app.get("/")
 
142
 
143
 
144
  @app.post("/query", response_model=QueryResponse)
145
+ def query(request: QueryRequest, background_tasks: BackgroundTasks):
146
  if not request.query.strip():
147
  raise HTTPException(status_code=400, detail="Query cannot be empty")
148
  if len(request.query) < 10:
149
  raise HTTPException(status_code=400, detail="Query too short — minimum 10 characters")
150
  if len(request.query) > 1000:
151
  raise HTTPException(status_code=400, detail="Query too long — maximum 1000 characters")
152
+
153
  start = time.time()
154
  try:
155
  if USE_V2:
 
157
  result = _run_query(request.query, session_id)
158
  else:
159
  result = _run_query_v1(request.query)
160
+ session_id = "v1"
161
  except Exception as e:
162
  logger.error(f"Pipeline error: {e}")
163
  raise HTTPException(status_code=500, detail=f"Pipeline error: {str(e)}")
164
+
165
+ latency_ms = round((time.time() - start) * 1000, 2)
166
+ result["latency_ms"] = latency_ms
167
+ result["session_id"] = session_id
168
+
169
+ # Log inference as background task — non-blocking
170
+ background_tasks.add_task(
171
+ log_inference,
172
+ query=request.query,
173
+ session_id=session_id,
174
+ answer=result.get("answer", ""),
175
+ num_sources=result.get("num_sources", 0),
176
+ verification_status=result.get("verification_status", False),
177
+ entities=result.get("entities", {}),
178
+ latency_ms=latency_ms,
179
+ stage=result.get("analysis", {}).get("stage", ""),
180
+ truncated=result.get("truncated", False),
181
+ out_of_domain=result.get("num_sources", 0) == 0,
182
+ )
183
+
184
+ return result
185
+
186
+
187
+ @app.get("/analytics")
188
+ def analytics():
189
+ """Return aggregated analytics from inference logs."""
190
+ log_path = os.getenv("LOG_PATH", "logs/inference.jsonl")
191
+
192
+ if not os.path.exists(log_path):
193
+ return {
194
+ "total_queries": 0,
195
+ "verified_ratio": 0,
196
+ "avg_latency_ms": 0,
197
+ "out_of_domain_rate": 0,
198
+ "avg_sources": 0,
199
+ "stage_distribution": {},
200
+ "entity_type_frequency": {},
201
+ "recent_latencies": [],
202
+ }
203
+
204
+ records = []
205
+ try:
206
+ with open(log_path, "r", encoding="utf-8") as f:
207
+ for line in f:
208
+ line = line.strip()
209
+ if line:
210
+ try:
211
+ records.append(json.loads(line))
212
+ except Exception:
213
+ continue
214
+ except Exception:
215
+ return {"error": "Could not read logs"}
216
+
217
+ if not records:
218
+ return {"total_queries": 0}
219
+
220
+ total = len(records)
221
+ verified = sum(1 for r in records if r.get("verified", False))
222
+ out_of_domain = sum(1 for r in records if r.get("out_of_domain", False))
223
+ latencies = [r.get("latency_ms", 0) for r in records if r.get("latency_ms")]
224
+ sources = [r.get("num_sources", 0) for r in records]
225
+ stages = Counter(r.get("stage", "unknown") for r in records)
226
+
227
+ all_entity_types = []
228
+ for r in records:
229
+ all_entity_types.extend(r.get("entities_found", []))
230
+ entity_freq = dict(Counter(all_entity_types).most_common(10))
231
+
232
+ return {
233
+ "total_queries": total,
234
+ "verified_ratio": round(verified / total * 100, 1) if total else 0,
235
+ "avg_latency_ms": round(sum(latencies) / len(latencies), 0) if latencies else 0,
236
+ "out_of_domain_rate": round(out_of_domain / total * 100, 1) if total else 0,
237
+ "avg_sources": round(sum(sources) / len(sources), 1) if sources else 0,
238
+ "stage_distribution": dict(stages),
239
+ "entity_type_frequency": entity_freq,
240
+ "recent_latencies": latencies[-20:],
241
+ }
frontend/app.js CHANGED
@@ -371,4 +371,108 @@ function inline(text) {
371
 
372
  function showToast(msg) {
373
  alert(msg);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  }
 
371
 
372
  function showToast(msg) {
373
  alert(msg);
374
+ }
375
+
376
+ // ── Analytics ────────────────────────────────────────────────────────
377
+ async function showAnalytics() {
378
+ showScreen("analytics");
379
+ document.getElementById("topbar-title").textContent = "System Analytics";
380
+ await loadAnalytics();
381
+ }
382
+
383
+ async function loadAnalytics() {
384
+ try {
385
+ const res = await fetch(`${API_BASE}/analytics`);
386
+ const data = await res.json();
387
+
388
+ if (data.total_queries === 0) {
389
+ document.getElementById("stat-total").textContent = "0";
390
+ document.getElementById("stat-verified").textContent = "—";
391
+ document.getElementById("stat-latency").textContent = "—";
392
+ document.getElementById("stat-ood").textContent = "—";
393
+ document.getElementById("stat-sources").textContent = "—";
394
+ document.getElementById("chart-stages").innerHTML = "<p class='no-data'>No queries yet. Start asking questions.</p>";
395
+ document.getElementById("chart-entities").innerHTML = "<p class='no-data'>No entity data yet.</p>";
396
+ document.getElementById("chart-latency").innerHTML = "<p class='no-data'>No latency data yet.</p>";
397
+ return;
398
+ }
399
+
400
+ // Stat cards
401
+ document.getElementById("stat-total").textContent = data.total_queries;
402
+ document.getElementById("stat-verified").textContent = data.verified_ratio + "%";
403
+ document.getElementById("stat-latency").textContent = data.avg_latency_ms + "ms";
404
+ document.getElementById("stat-ood").textContent = data.out_of_domain_rate + "%";
405
+ document.getElementById("stat-sources").textContent = data.avg_sources;
406
+
407
+ // Stage distribution bar chart
408
+ renderBarChart("chart-stages", data.stage_distribution);
409
+
410
+ // Entity frequency bar chart
411
+ renderBarChart("chart-entities", data.entity_type_frequency);
412
+
413
+ // Latency sparkline
414
+ renderSparkline("chart-latency", data.recent_latencies);
415
+
416
+ } catch (err) {
417
+ document.getElementById("chart-stages").innerHTML = "<p class='no-data'>Could not load analytics.</p>";
418
+ }
419
+ }
420
+
421
+ function renderBarChart(containerId, data) {
422
+ const container = document.getElementById(containerId);
423
+ if (!data || Object.keys(data).length === 0) {
424
+ container.innerHTML = "<p class='no-data'>No data yet.</p>";
425
+ return;
426
+ }
427
+
428
+ const max = Math.max(...Object.values(data));
429
+ const html = Object.entries(data)
430
+ .sort((a, b) => b[1] - a[1])
431
+ .map(([label, value]) => `
432
+ <div class="bar-row">
433
+ <span class="bar-label">${escHtml(label)}</span>
434
+ <div class="bar-track">
435
+ <div class="bar-fill" style="width: ${Math.round(value / max * 100)}%"></div>
436
+ </div>
437
+ <span class="bar-value">${value}</span>
438
+ </div>
439
+ `).join("");
440
+
441
+ container.innerHTML = `<div class="bar-chart">${html}</div>`;
442
+ }
443
+
444
+ function renderSparkline(containerId, latencies) {
445
+ const container = document.getElementById(containerId);
446
+ if (!latencies || latencies.length === 0) {
447
+ container.innerHTML = "<p class='no-data'>No data yet.</p>";
448
+ return;
449
+ }
450
+
451
+ const max = Math.max(...latencies);
452
+ const min = Math.min(...latencies);
453
+ const range = max - min || 1;
454
+ const height = 60;
455
+ const width = 300;
456
+ const step = width / (latencies.length - 1 || 1);
457
+
458
+ const points = latencies.map((v, i) => {
459
+ const x = i * step;
460
+ const y = height - ((v - min) / range) * height;
461
+ return `${x},${y}`;
462
+ }).join(" ");
463
+
464
+ container.innerHTML = `
465
+ <svg viewBox="0 0 ${width} ${height}" class="sparkline">
466
+ <polyline points="${points}" fill="none" stroke="var(--accent)" stroke-width="2"/>
467
+ </svg>
468
+ <div class="sparkline-range">
469
+ <span>${Math.round(min)}ms min</span>
470
+ <span>${Math.round(max)}ms max</span>
471
+ </div>
472
+ `;
473
+ }
474
+
475
+ function escHtml(text) {
476
+ const map = { '&': '&amp;', '<': '&lt;', '>': '&gt;', '"': '&quot;', "'": '&#039;' };
477
+ return String(text).replace(/[&<>"']/g, m => map[m]);
478
  }
frontend/index.html CHANGED
@@ -27,6 +27,11 @@
27
  New Research Session
28
  </button>
29
 
 
 
 
 
 
30
  <div class="sidebar-section-label">SESSIONS</div>
31
  <div id="sessions-list" class="sessions-list">
32
  <div class="sessions-empty">No sessions yet</div>
@@ -87,6 +92,59 @@
87
  </div>
88
  </div>
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  <!-- ── SOURCES PANEL ── -->
91
  <div id="sources-panel" class="sources-panel">
92
  <div class="sources-panel-header">
 
27
  New Research Session
28
  </button>
29
 
30
+ <button class="analytics-btn" onclick="showAnalytics()">
31
+ <span class="analytics-icon">📊</span>
32
+ System Analytics
33
+ </button>
34
+
35
  <div class="sidebar-section-label">SESSIONS</div>
36
  <div id="sessions-list" class="sessions-list">
37
  <div class="sessions-empty">No sessions yet</div>
 
92
  </div>
93
  </div>
94
 
95
+ <!-- ── ANALYTICS SCREEN ── -->
96
+ <div id="screen-analytics" class="screen screen-analytics">
97
+ <div class="analytics-inner">
98
+ <div class="analytics-header">
99
+ <h2>System Analytics</h2>
100
+ <p>Live metrics from inference logs</p>
101
+ </div>
102
+
103
+ <div class="analytics-grid">
104
+ <div class="stat-card">
105
+ <div class="stat-value" id="stat-total">—</div>
106
+ <div class="stat-label">Total Queries</div>
107
+ </div>
108
+ <div class="stat-card">
109
+ <div class="stat-value" id="stat-verified">—</div>
110
+ <div class="stat-label">Verified Rate</div>
111
+ </div>
112
+ <div class="stat-card">
113
+ <div class="stat-value" id="stat-latency">—</div>
114
+ <div class="stat-label">Avg Latency</div>
115
+ </div>
116
+ <div class="stat-card">
117
+ <div class="stat-value" id="stat-ood">—</div>
118
+ <div class="stat-label">Out-of-Domain Rate</div>
119
+ </div>
120
+ <div class="stat-card">
121
+ <div class="stat-value" id="stat-sources">—</div>
122
+ <div class="stat-label">Avg Sources / Query</div>
123
+ </div>
124
+ </div>
125
+
126
+ <div class="analytics-charts">
127
+ <div class="chart-card">
128
+ <h3>Stage Distribution</h3>
129
+ <div id="chart-stages" class="chart-container"></div>
130
+ </div>
131
+ <div class="chart-card">
132
+ <h3>Entity Types Extracted</h3>
133
+ <div id="chart-entities" class="chart-container"></div>
134
+ </div>
135
+ <div class="chart-card">
136
+ <h3>Recent Query Latencies (ms)</h3>
137
+ <div id="chart-latency" class="chart-container"></div>
138
+ </div>
139
+ </div>
140
+
141
+ <div class="analytics-footer">
142
+ <button class="refresh-btn" onclick="loadAnalytics()">↻ Refresh</button>
143
+ <span class="analytics-note">Data from current session logs. Resets on container restart.</span>
144
+ </div>
145
+ </div>
146
+ </div>
147
+
148
  <!-- ── SOURCES PANEL ── -->
149
  <div id="sources-panel" class="sources-panel">
150
  <div class="sources-panel-header">
frontend/style.css CHANGED
@@ -750,4 +750,174 @@ body {
750
  margin-bottom: 10px;
751
  }
752
 
753
- .bubble-ai p:last-child { margin-bottom: 0; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750
  margin-bottom: 10px;
751
  }
752
 
753
+ .bubble-ai p:last-child { margin-bottom: 0; }
754
+
755
+ /* ── Analytics ────────────────────────────────────────────── */
756
+ .analytics-btn {
757
+ display: flex;
758
+ align-items: center;
759
+ gap: 8px;
760
+ width: 100%;
761
+ padding: 10px 14px;
762
+ margin-top: 8px;
763
+ background: transparent;
764
+ border: 1px solid var(--border);
765
+ border-radius: 8px;
766
+ color: var(--text-2);
767
+ font-size: 13px;
768
+ cursor: pointer;
769
+ transition: all var(--transition);
770
+ }
771
+ .analytics-btn:hover {
772
+ background: var(--navy-3);
773
+ color: var(--text-1);
774
+ }
775
+
776
+ .screen-analytics {
777
+ padding: 32px;
778
+ overflow-y: auto;
779
+ height: 100%;
780
+ }
781
+ .analytics-inner {
782
+ max-width: 800px;
783
+ margin: 0 auto;
784
+ }
785
+ .analytics-header h2 {
786
+ font-family: 'Cormorant Garamond', serif;
787
+ font-size: 28px;
788
+ margin: 0 0 4px;
789
+ }
790
+ .analytics-header p {
791
+ color: var(--text-2);
792
+ font-size: 14px;
793
+ margin: 0 0 32px;
794
+ }
795
+
796
+ .analytics-grid {
797
+ display: grid;
798
+ grid-template-columns: repeat(auto-fit, minmax(140px, 1fr));
799
+ gap: 16px;
800
+ margin-bottom: 32px;
801
+ }
802
+ .stat-card {
803
+ background: var(--navy-2);
804
+ border: 1px solid var(--border);
805
+ border-radius: 12px;
806
+ padding: 20px 16px;
807
+ text-align: center;
808
+ }
809
+ .stat-value {
810
+ font-size: 28px;
811
+ font-weight: 600;
812
+ color: var(--text-1);
813
+ font-family: 'Cormorant Garamond', serif;
814
+ }
815
+ .stat-label {
816
+ font-size: 11px;
817
+ color: var(--text-3);
818
+ margin-top: 4px;
819
+ text-transform: uppercase;
820
+ letter-spacing: 0.05em;
821
+ }
822
+
823
+ .analytics-charts {
824
+ display: flex;
825
+ flex-direction: column;
826
+ gap: 24px;
827
+ }
828
+ .chart-card {
829
+ background: var(--navy-2);
830
+ border: 1px solid var(--border);
831
+ border-radius: 12px;
832
+ padding: 20px;
833
+ }
834
+ .chart-card h3 {
835
+ font-size: 14px;
836
+ font-weight: 500;
837
+ margin: 0 0 16px;
838
+ color: var(--text-2);
839
+ text-transform: uppercase;
840
+ letter-spacing: 0.05em;
841
+ }
842
+ .chart-container {
843
+ min-height: 60px;
844
+ }
845
+ .no-data {
846
+ color: var(--text-3);
847
+ font-size: 13px;
848
+ text-align: center;
849
+ padding: 16px 0;
850
+ }
851
+
852
+ .bar-chart {
853
+ display: flex;
854
+ flex-direction: column;
855
+ gap: 8px;
856
+ }
857
+ .bar-row {
858
+ display: flex;
859
+ align-items: center;
860
+ gap: 10px;
861
+ font-size: 12px;
862
+ }
863
+ .bar-label {
864
+ width: 100px;
865
+ color: var(--text-3);
866
+ text-align: right;
867
+ flex-shrink: 0;
868
+ }
869
+ .bar-track {
870
+ flex: 1;
871
+ height: 8px;
872
+ background: var(--navy-3);
873
+ border-radius: 4px;
874
+ overflow: hidden;
875
+ }
876
+ .bar-fill {
877
+ height: 100%;
878
+ background: var(--gold);
879
+ border-radius: 4px;
880
+ transition: width 0.4s ease;
881
+ }
882
+ .bar-value {
883
+ width: 30px;
884
+ color: var(--text-1);
885
+ font-weight: 500;
886
+ text-align: right;
887
+ }
888
+
889
+ .sparkline {
890
+ width: 100%;
891
+ height: 60px;
892
+ }
893
+ .sparkline-range {
894
+ display: flex;
895
+ justify-content: space-between;
896
+ font-size: 11px;
897
+ color: var(--text-3);
898
+ margin-top: 4px;
899
+ }
900
+
901
+ .analytics-footer {
902
+ display: flex;
903
+ align-items: center;
904
+ gap: 16px;
905
+ margin-top: 24px;
906
+ }
907
+ .refresh-btn {
908
+ padding: 8px 16px;
909
+ background: var(--navy-3);
910
+ border: 1px solid var(--border);
911
+ border-radius: 8px;
912
+ color: var(--text-1);
913
+ font-size: 13px;
914
+ cursor: pointer;
915
+ transition: background var(--transition);
916
+ }
917
+ .refresh-btn:hover {
918
+ background: var(--navy-4);
919
+ }
920
+ .analytics-note {
921
+ font-size: 12px;
922
+ color: var(--text-3);
923
+ }
src/agent_v2.py CHANGED
@@ -384,7 +384,13 @@ def run_query_v2(user_message: str, session_id: str) -> Dict[str, Any]:
384
 
385
  chunks = []
386
  try:
387
- chunks = retrieve_parallel(search_queries[:3], top_k=5)
 
 
 
 
 
 
388
  # Add precedent chain
389
  from src.citation_graph import get_precedent_chain
390
  retrieved_ids = [c.get("judgment_id", "") for c in chunks]
 
384
 
385
  chunks = []
386
  try:
387
+ # Retrieve more candidates for reranker to work with
388
+ raw_chunks = retrieve_parallel(search_queries[:3], top_k=10)
389
+
390
+ # Rerank candidates by true relevance
391
+ from src.reranker import rerank
392
+ chunks = rerank(user_message, raw_chunks, top_k=5)
393
+
394
  # Add precedent chain
395
  from src.citation_graph import get_precedent_chain
396
  retrieved_ids = [c.get("judgment_id", "") for c in chunks]
src/logger.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference logger.
3
+ Writes one JSON line per query to logs/inference.jsonl.
4
+ Called as FastAPI BackgroundTask — does not block response.
5
+
6
+ WHY two-layer logging?
7
+ HF Spaces containers are ephemeral — local files are wiped on restart.
8
+ Local JSONL is fast for same-session analytics.
9
+ In future, add HF Dataset API push here for durable storage.
10
+ """
11
+
12
+ import json
13
+ import os
14
+ import logging
15
+ from datetime import datetime, timezone
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ LOG_PATH = os.getenv("LOG_PATH", "logs/inference.jsonl")
20
+
21
+
22
+ def ensure_log_dir():
23
+ os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True)
24
+
25
+
26
+ def log_inference(
27
+ query: str,
28
+ session_id: str,
29
+ answer: str,
30
+ num_sources: int,
31
+ verification_status,
32
+ entities: dict,
33
+ latency_ms: float,
34
+ stage: str = "",
35
+ truncated: bool = False,
36
+ out_of_domain: bool = False,
37
+ ):
38
+ """
39
+ Write one inference record to logs/inference.jsonl.
40
+ Called as BackgroundTask in api/main.py.
41
+ Fails silently — never blocks or crashes the main response.
42
+ """
43
+ try:
44
+ ensure_log_dir()
45
+ record = {
46
+ "timestamp": datetime.now(timezone.utc).isoformat(),
47
+ "session_id": session_id,
48
+ "query_length": len(query),
49
+ "query_hash": hash(query) % 100000,
50
+ "num_sources": num_sources,
51
+ "verification_status": str(verification_status),
52
+ "verified": verification_status is True or verification_status == "verified",
53
+ "entities_found": list(entities.keys()) if entities else [],
54
+ "num_entity_types": len(entities) if entities else 0,
55
+ "latency_ms": latency_ms,
56
+ "stage": stage,
57
+ "truncated": truncated,
58
+ "out_of_domain": out_of_domain,
59
+ "answer_length": len(answer),
60
+ }
61
+ with open(LOG_PATH, "a", encoding="utf-8") as f:
62
+ f.write(json.dumps(record) + "\n")
63
+ except Exception as e:
64
+ logger.warning(f"Inference logging failed: {e}")
src/reranker.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cross-encoder reranker.
3
+ Reranks FAISS retrieval results by true query-document relevance.
4
+
5
+ WHY cross-encoder over bi-encoder (MiniLM)?
6
+ MiniLM embeds query and document independently — fast but approximate.
7
+ Cross-encoder sees query+document together — slower but much more accurate.
8
+ Used post-retrieval on top-15 candidates to select best top-5.
9
+
10
+ WHY ms-marco-MiniLM-L-6-v2?
11
+ Trained on MS-MARCO passage ranking — transfers well to legal QA.
12
+ Small enough to load on HF Spaces free tier (~80MB).
13
+ Fast enough for reranking 15 candidates in ~200ms on CPU.
14
+
15
+ Interview answer:
16
+ "I added a cross-encoder reranker post-retrieval to boost precision@5
17
+ by focusing on true relevance rather than embedding similarity alone.
18
+ Legal domain papers show 8-15% precision lift from reranking."
19
+ """
20
+
21
+ import logging
22
+ from typing import List, Dict
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ _reranker = None
27
+ _reranker_loaded = False
28
+
29
+
30
+ def load_reranker():
31
+ """
32
+ Load cross-encoder once at startup.
33
+ Fails gracefully — retrieval works without reranker.
34
+ Call from api/main.py after other models load.
35
+ """
36
+ global _reranker, _reranker_loaded
37
+
38
+ try:
39
+ from sentence_transformers import CrossEncoder
40
+ logger.info("Loading cross-encoder reranker...")
41
+ _reranker = CrossEncoder(
42
+ "cross-encoder/ms-marco-MiniLM-L-6-v2",
43
+ max_length=512
44
+ )
45
+ _reranker_loaded = True
46
+ logger.info("Cross-encoder reranker ready")
47
+ except Exception as e:
48
+ logger.warning(f"Reranker load failed: {e}. Retrieval will use FAISS scores only.")
49
+ _reranker_loaded = False
50
+
51
+
52
+ def rerank(query: str, chunks: List[Dict], top_k: int = 5) -> List[Dict]:
53
+ """
54
+ Rerank chunks by cross-encoder relevance score.
55
+
56
+ Args:
57
+ query: user query string
58
+ chunks: list of retrieved chunks from FAISS
59
+ top_k: number of top chunks to return after reranking
60
+
61
+ Returns:
62
+ top_k chunks sorted by reranker score descending.
63
+ If reranker not loaded, returns original chunks[:top_k].
64
+ """
65
+ if not _reranker_loaded or _reranker is None:
66
+ return chunks[:top_k]
67
+
68
+ if not chunks:
69
+ return []
70
+
71
+ try:
72
+ # Build query-document pairs
73
+ pairs = []
74
+ for chunk in chunks:
75
+ text = (
76
+ chunk.get("expanded_context") or
77
+ chunk.get("chunk_text") or
78
+ chunk.get("text", "")
79
+ )[:512]
80
+ pairs.append([query, text])
81
+
82
+ # Score all pairs
83
+ scores = _reranker.predict(pairs, batch_size=16)
84
+
85
+ # Attach scores and sort
86
+ for chunk, score in zip(chunks, scores):
87
+ chunk["reranker_score"] = float(score)
88
+
89
+ reranked = sorted(chunks, key=lambda x: x.get("reranker_score", 0), reverse=True)
90
+
91
+ logger.info(
92
+ f"Reranked {len(chunks)} chunks → top {top_k}. "
93
+ f"Top score: {reranked[0].get('reranker_score', 0):.3f}"
94
+ )
95
+
96
+ return reranked[:top_k]
97
+
98
+ except Exception as e:
99
+ logger.warning(f"Reranking failed: {e}. Using FAISS order.")
100
+ return chunks[:top_k]
101
+
102
+
103
+ def is_loaded() -> bool:
104
+ return _reranker_loaded
src/verify.py CHANGED
@@ -67,30 +67,12 @@ def _extract_quotes(text: str) -> list:
67
 
68
 
69
  def _get_embedder():
70
- """Get the already-loaded embedder — no double loading."""
71
  try:
72
- from src.retrieval import _embedder as embedder
73
- return embedder
74
- except ImportError:
75
- pass
76
-
77
- try:
78
- from src.embed import _model as embedder
79
- return embedder
80
- except ImportError:
81
- pass
82
-
83
- try:
84
- # Last resort — import from retrieval module globals
85
- import src.retrieval as retrieval_module
86
- if hasattr(retrieval_module, '_embedder'):
87
- return retrieval_module._embedder
88
- if hasattr(retrieval_module, 'embedder'):
89
- return retrieval_module.embedder
90
  except Exception:
91
- pass
92
-
93
- return None
94
 
95
 
96
  def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
 
67
 
68
 
69
  def _get_embedder():
70
+ """Get the already-loaded MiniLM embedder."""
71
  try:
72
+ from src.embed import _model
73
+ return _model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  except Exception:
75
+ return None
 
 
76
 
77
 
78
  def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: