MukulRay commited on
Commit
2138877
·
1 Parent(s): ad651e3

Phase 2.4: add edge reliability scorer — three-signal formula, foundational paper test passes

Browse files
Files changed (1) hide show
  1. src/reliability.py +240 -0
src/reliability.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Edge Reliability Scoring for RECON v2.
3
+
4
+ Computes a three-signal reliability score for each retrieved paper:
5
+
6
+ edge_reliability = (citation_centrality × 0.4)
7
+ + (recency_signal × 0.3)
8
+ + (content_coherence × 0.3)
9
+
10
+ Signals:
11
+ - citation_centrality: normalized cited_by_count from OpenAlex (or S2 fallback)
12
+ High centrality = foundational paper = high reliability regardless of age
13
+ - recency_signal: linear decay max(0, 1 - age/20) — same as RECON v1
14
+ Now one of three inputs, not the whole score
15
+ - content_coherence: LLM check — does this paper's abstract still represent
16
+ current scientific understanding? Batched for all papers in one LLM call.
17
+
18
+ Dominant signal labels (for explainability in synthesizer output):
19
+ FOUNDATIONAL: reliability >= 0.70 AND centrality >= 0.6
20
+ CURRENT: reliability >= 0.60 AND recency >= 0.7
21
+ DECLINING: reliability 0.35–0.60
22
+ SUPERSEDED: reliability < 0.35
23
+ """
24
+
25
+ import math
26
+ import logging
27
+ import os
28
+ from dataclasses import dataclass
29
+ from typing import Optional
30
+ import json
31
+ import re
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ CURRENT_YEAR = 2026
36
+
37
+ # Reliability thresholds
38
+ THRESHOLD_FOUNDATIONAL_RELIABILITY = 0.70
39
+ THRESHOLD_FOUNDATIONAL_CENTRALITY = 0.60
40
+ THRESHOLD_CURRENT_RELIABILITY = 0.60
41
+ THRESHOLD_CURRENT_RECENCY = 0.70
42
+ THRESHOLD_DECLINING_LOW = 0.35
43
+
44
+ # Signal weights
45
+ W_CENTRALITY = 0.4
46
+ W_RECENCY = 0.3
47
+ W_COHERENCE = 0.3
48
+
49
+
50
+ @dataclass
51
+ class ReliabilityScore:
52
+ score: float # [0, 1] composite reliability
53
+ centrality: float # normalized citation centrality
54
+ recency: float # linear decay recency signal
55
+ coherence: float # LLM content coherence [0, 1]
56
+ dominant_signal: str # FOUNDATIONAL / CURRENT / DECLINING / SUPERSEDED
57
+ reason: str # one-line human-readable explanation
58
+
59
+
60
+ def _compute_centrality(citation_count: int, doi: str = "") -> float:
61
+ """
62
+ Normalized citation centrality.
63
+ Uses OpenAlex cited_by_count if DOI available, else falls back to S2 count.
64
+ Formula: min(1.0, log1p(count) / log1p(10000))
65
+ """
66
+ from src.openalex_utils import get_citation_centrality
67
+ return get_citation_centrality(doi=doi, citation_count=citation_count)
68
+
69
+
70
+ def _compute_recency(year: Optional[int]) -> float:
71
+ """Linear decay: max(0, 1 - age/20). Age 0 = 1.0, age 20+ = 0.0."""
72
+ if not year or year <= 0:
73
+ return 0.0
74
+ age = CURRENT_YEAR - year
75
+ return max(0.0, 1.0 - age / 20.0)
76
+
77
+
78
+ def _compute_coherence_batch(papers: list, query: str) -> list[float]:
79
+ """
80
+ LLM batch coherence check for all papers at once.
81
+
82
+ For each paper, asks: does this paper's abstract still represent
83
+ current scientific understanding on this topic?
84
+
85
+ Returns a list of float scores [0, 1] in the same order as input papers.
86
+ Falls back to recency-based heuristic if LLM call fails.
87
+
88
+ Batched: one LLM call for all papers, not one per paper.
89
+ """
90
+ if not papers:
91
+ return []
92
+
93
+ # Build batch prompt
94
+ paper_summaries = []
95
+ for i, p in enumerate(papers):
96
+ abstract_snippet = (p.abstract or "")[:300]
97
+ paper_summaries.append(
98
+ f"Paper {i+1}: [{p.year}] {p.title}\n"
99
+ f"Abstract: {abstract_snippet}"
100
+ )
101
+
102
+ papers_text = "\n\n".join(paper_summaries)
103
+
104
+ system_prompt = """You are a scientific literature analyst assessing whether papers represent current scientific understanding.
105
+
106
+ For each paper provided, assign a content_coherence score from 0.0 to 1.0:
107
+ - 1.0: Paper's central claims are still the consensus view, no major challenges
108
+ - 0.7: Paper is foundational and still cited, but some aspects have been refined
109
+ - 0.5: Paper's claims are actively debated; newer work challenges some findings
110
+ - 0.3: Paper's central claims have been substantially superseded by newer work
111
+ - 0.1: Paper is clearly outdated; its claims contradict current consensus
112
+
113
+ Respond ONLY with a JSON array of objects, one per paper, in the same order:
114
+ [{"paper_index": 1, "coherence": 0.8, "reason": "one sentence"}, ...]
115
+
116
+ Be concise. No other text."""
117
+
118
+ user_prompt = f"""Research query context: {query[:200]}
119
+
120
+ Papers to assess:
121
+ {papers_text}
122
+
123
+ Return ONLY the JSON array."""
124
+
125
+ try:
126
+ from langchain_groq import ChatGroq
127
+ from langchain_core.messages import SystemMessage, HumanMessage
128
+
129
+ llm = ChatGroq(
130
+ model="llama-3.3-70b-versatile",
131
+ temperature=0.1,
132
+ api_key=os.environ.get("GROQ_API_KEY"),
133
+ )
134
+ response = llm.invoke([
135
+ SystemMessage(content=system_prompt),
136
+ HumanMessage(content=user_prompt),
137
+ ])
138
+ raw = response.content.strip()
139
+
140
+ # Extract JSON array
141
+ match = re.search(r"\[.*\]", raw, re.DOTALL)
142
+ if match:
143
+ data = json.loads(match.group())
144
+ scores = [0.5] * len(papers) # default
145
+ for item in data:
146
+ idx = int(item.get("paper_index", 0)) - 1 # 1-indexed in prompt
147
+ if 0 <= idx < len(papers):
148
+ scores[idx] = float(item.get("coherence", 0.5))
149
+ return scores
150
+
151
+ except Exception as e:
152
+ logger.warning(f"Coherence batch LLM call failed: {e}")
153
+
154
+ # Fallback: use recency as coherence proxy
155
+ return [_compute_recency(p.year) for p in papers]
156
+
157
+
158
+ def _dominant_signal(score: float, centrality: float, recency: float) -> str:
159
+ """Classify dominant signal for explainability."""
160
+ if score >= THRESHOLD_FOUNDATIONAL_RELIABILITY and centrality >= THRESHOLD_FOUNDATIONAL_CENTRALITY:
161
+ return "FOUNDATIONAL"
162
+ elif score >= THRESHOLD_CURRENT_RELIABILITY and recency >= THRESHOLD_CURRENT_RECENCY:
163
+ return "CURRENT"
164
+ elif score >= THRESHOLD_DECLINING_LOW:
165
+ return "DECLINING"
166
+ else:
167
+ return "SUPERSEDED"
168
+
169
+
170
+ def _build_reason(dominant: str, centrality: float, recency: float,
171
+ coherence: float, year: Optional[int]) -> str:
172
+ """One-line reason string for the trust summary."""
173
+ age = (CURRENT_YEAR - year) if year else None
174
+ age_str = f"{age}yr old" if age is not None else "unknown age"
175
+
176
+ if dominant == "FOUNDATIONAL":
177
+ return f"High citation centrality ({centrality:.2f}), {age_str} — foundational work still current"
178
+ elif dominant == "CURRENT":
179
+ return f"Recent ({age_str}), coherence={coherence:.2f} — aligns with current consensus"
180
+ elif dominant == "DECLINING":
181
+ return f"Mixed signals: centrality={centrality:.2f}, recency={recency:.2f}, coherence={coherence:.2f}"
182
+ else:
183
+ return f"Low reliability: {age_str}, centrality={centrality:.2f}, coherence={coherence:.2f} — likely superseded"
184
+
185
+
186
+ def score_papers(papers: list, query: str, use_llm: bool = True) -> dict[str, ReliabilityScore]:
187
+ """
188
+ Main entry point. Scores all papers and returns a dict of paper_id -> ReliabilityScore.
189
+
190
+ Args:
191
+ papers: list of Paper objects
192
+ query: the original research query (for coherence context)
193
+ use_llm: if False, skips coherence LLM call (uses recency as fallback)
194
+ Set False during eval to save Groq API calls.
195
+
196
+ Returns:
197
+ dict mapping paper_id -> ReliabilityScore
198
+ """
199
+ if not papers:
200
+ return {}
201
+
202
+ # Step 1: Centrality (OpenAlex DOI lookup if available, else S2 count)
203
+ centralities = []
204
+ for p in papers:
205
+ c = _compute_centrality(
206
+ citation_count=getattr(p, "citation_count", 0) or 0,
207
+ doi=getattr(p, "doi", "") or "",
208
+ )
209
+ centralities.append(c)
210
+
211
+ # Step 2: Recency
212
+ recencies = [_compute_recency(getattr(p, "year", None)) for p in papers]
213
+
214
+ # Step 3: Coherence (batched LLM call)
215
+ if use_llm:
216
+ coherences = _compute_coherence_batch(papers, query)
217
+ else:
218
+ coherences = [_compute_recency(getattr(p, "year", None)) for p in papers]
219
+
220
+ # Step 4: Composite score and labeling
221
+ results = {}
222
+ for i, p in enumerate(papers):
223
+ c = centralities[i]
224
+ r = recencies[i]
225
+ co = coherences[i] if i < len(coherences) else r
226
+
227
+ score = W_CENTRALITY * c + W_RECENCY * r + W_COHERENCE * co
228
+ dominant = _dominant_signal(score, c, r)
229
+ reason = _build_reason(dominant, c, r, co, getattr(p, "year", None))
230
+
231
+ results[p.paper_id] = ReliabilityScore(
232
+ score=round(score, 4),
233
+ centrality=round(c, 4),
234
+ recency=round(r, 4),
235
+ coherence=round(co, 4),
236
+ dominant_signal=dominant,
237
+ reason=reason,
238
+ )
239
+
240
+ return results