|
|
|
|
|
""" |
|
|
Flexible conditional grouping utilities for subgroup analysis. |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import re |
|
|
import warnings |
|
|
import json |
|
|
import os |
|
|
from typing import List, Dict, Any, Tuple |
|
|
from abc import ABC, abstractmethod |
|
|
|
|
|
warnings.filterwarnings('default') |
|
|
np.seterr(all='warn') |
|
|
|
|
|
|
|
|
class ConditionalGrouper(ABC): |
|
|
|
|
|
def __init__(self, name: str, description: str): |
|
|
self.name = name |
|
|
self.description = description |
|
|
|
|
|
@abstractmethod |
|
|
def compute_values(self, dataset: List[Dict[str, Any]], **kwargs) -> np.ndarray: |
|
|
pass |
|
|
|
|
|
def create_bins(self, values: np.ndarray, method: str = 'quartiles', |
|
|
custom_bins: List[float] = None) -> List[Tuple[float, float]]: |
|
|
finite_values = values[np.isfinite(values)] |
|
|
|
|
|
if len(finite_values) == 0: |
|
|
return [(float(np.min(values)), float(np.max(values)))] |
|
|
|
|
|
if method == 'quartiles': |
|
|
quantiles = [0.0, 0.25, 0.5, 0.75, 1.0] |
|
|
elif method == 'quintiles': |
|
|
quantiles = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] |
|
|
elif method == 'deciles': |
|
|
quantiles = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] |
|
|
elif method == 'tertiles': |
|
|
quantiles = [0.0, 0.33, 0.67, 1.0] |
|
|
elif method == 'median_split': |
|
|
quantiles = [0.0, 0.5, 1.0] |
|
|
elif method == 'custom' and custom_bins: |
|
|
qs = np.array(custom_bins) |
|
|
else: |
|
|
quantiles = [0.0, 0.25, 0.5, 0.75, 1.0] |
|
|
|
|
|
if method != 'custom': |
|
|
qs = np.quantile(finite_values, quantiles) |
|
|
|
|
|
bins = [(float(qs[i]), float(qs[i+1])) for i in range(len(qs)-1)] |
|
|
return bins |
|
|
|
|
|
def get_group_info(self, dataset: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]: |
|
|
values = self.compute_values(dataset, **kwargs) |
|
|
finite_values = values[np.isfinite(values)] |
|
|
|
|
|
return { |
|
|
'name': self.name, |
|
|
'description': self.description, |
|
|
'total_samples': len(values), |
|
|
'valid_samples': len(finite_values), |
|
|
'min_value': float(np.min(finite_values)) if len(finite_values) > 0 else np.nan, |
|
|
'max_value': float(np.max(finite_values)) if len(finite_values) > 0 else np.nan, |
|
|
'mean_value': float(np.mean(finite_values)) if len(finite_values) > 0 else np.nan, |
|
|
'std_value': float(np.std(finite_values)) if len(finite_values) > 0 else np.nan, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def _default_view_csv_path() -> str: |
|
|
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) |
|
|
return os.path.join(repo_root, 'data', 'wiki_scores', 'wikibio_final.csv') |
|
|
|
|
|
GLOBAL_VIEW_METADATA_CSV = _default_view_csv_path() |
|
|
|
|
|
def set_view_metadata_csv(csv_path: str): |
|
|
global GLOBAL_VIEW_METADATA_CSV |
|
|
if isinstance(csv_path, str) and len(csv_path) > 0: |
|
|
GLOBAL_VIEW_METADATA_CSV = csv_path |
|
|
|
|
|
|
|
|
class ViewCountGrouper(ConditionalGrouper): |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__( |
|
|
name="view_count", |
|
|
description="Wikipedia view count (from wikibio_final.csv)" |
|
|
) |
|
|
self._loaded = False |
|
|
self._csv_path = None |
|
|
self._name_to_views = {} |
|
|
self._global_min_count = 0.0 |
|
|
|
|
|
@staticmethod |
|
|
def _parse_name_from_prompt(prompt: str) -> str: |
|
|
if not isinstance(prompt, str): |
|
|
try: |
|
|
prompt = str(prompt) |
|
|
except Exception: |
|
|
return "" |
|
|
txt = prompt.strip() |
|
|
|
|
|
import re |
|
|
m = re.search(r"about\s+(.+?)(?:[\.]|\n|$)", txt, flags=re.IGNORECASE) |
|
|
if m: |
|
|
return m.group(1).strip() |
|
|
|
|
|
if 'about ' in txt: |
|
|
return txt.split('about ', 1)[-1].strip().rstrip('.').strip() |
|
|
return txt |
|
|
|
|
|
def _ensure_loaded(self): |
|
|
|
|
|
if (not self._loaded) or (self._csv_path != GLOBAL_VIEW_METADATA_CSV): |
|
|
try: |
|
|
df = pd.read_csv(GLOBAL_VIEW_METADATA_CSV) |
|
|
name_col = 'Name' if 'Name' in df.columns else None |
|
|
views_col = 'Views' if 'Views' in df.columns else None |
|
|
maxc_col = 'max_counts' if 'max_counts' in df.columns else None |
|
|
mapping = {} |
|
|
values_for_min = [] |
|
|
if name_col and (views_col or maxc_col): |
|
|
for _, row in df.iterrows(): |
|
|
name = str(row[name_col]).strip() |
|
|
v = np.nan |
|
|
|
|
|
if views_col is not None: |
|
|
try: |
|
|
vv = float(row[views_col]) |
|
|
if np.isfinite(vv): |
|
|
v = vv |
|
|
except Exception: |
|
|
pass |
|
|
if (not np.isfinite(v)) and maxc_col is not None: |
|
|
try: |
|
|
mv = float(row[maxc_col]) |
|
|
if np.isfinite(mv): |
|
|
v = mv |
|
|
except Exception: |
|
|
pass |
|
|
mapping[name] = v |
|
|
if np.isfinite(v): |
|
|
values_for_min.append(v) |
|
|
self._name_to_views = mapping |
|
|
self._csv_path = GLOBAL_VIEW_METADATA_CSV |
|
|
|
|
|
self._global_min_count = float(np.min(values_for_min)) if len(values_for_min) > 0 else 0.0 |
|
|
self._loaded = True |
|
|
except Exception: |
|
|
|
|
|
self._name_to_views = {} |
|
|
self._csv_path = GLOBAL_VIEW_METADATA_CSV |
|
|
self._global_min_count = 0.0 |
|
|
self._loaded = True |
|
|
|
|
|
def compute_values(self, dataset: List[Dict[str, Any]], **kwargs) -> np.ndarray: |
|
|
self._ensure_loaded() |
|
|
values = [] |
|
|
for sample in dataset: |
|
|
prompt = sample.get('prompt', '') |
|
|
name = self._parse_name_from_prompt(prompt) |
|
|
|
|
|
val = self._name_to_views.get(name) |
|
|
if val is None: |
|
|
|
|
|
key2 = " ".join(name.split()) |
|
|
val = self._name_to_views.get(key2, np.nan) |
|
|
|
|
|
if val is None or (isinstance(val, float) and not np.isfinite(val)): |
|
|
val = self._global_min_count |
|
|
values.append(float(val)) |
|
|
return np.array(values, dtype=float) |
|
|
|
|
|
|
|
|
class FalseClaimRiskGrouper(ConditionalGrouper): |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__( |
|
|
name="false_claim_risk", |
|
|
description="Text-based false-claim risk index (higher → more risk)" |
|
|
) |
|
|
self.abs_terms = [ |
|
|
'always', 'never', 'guarantee', 'guaranteed', 'cure', 'proven', |
|
|
'will', 'must', 'definitely', 'certainly', 'undoubtedly', 'no doubt' |
|
|
] |
|
|
self.enum_keywords = [ |
|
|
'symptom', 'symptoms', 'signs', 'causes', 'cause', 'types', 'treatments', |
|
|
'treatment', 'risk factors', 'complications', 'side effects', 'prevention' |
|
|
] |
|
|
self.citation_patterns = [ |
|
|
r'according\s+to', r'based\s+on', r'research\s+(?:shows?|indicates?|suggests?)', |
|
|
r'studies?\s+(?:show|indicate|suggest|reveal|demonstrate)', r'\(\d{4}\)', r'\[[\d,\s-]+\]' |
|
|
] |
|
|
self.compiled_cite = [re.compile(p, re.IGNORECASE) for p in self.citation_patterns] |
|
|
|
|
|
@staticmethod |
|
|
def _num_sentences(text: str) -> int: |
|
|
if not text: |
|
|
return 0 |
|
|
return max(1, text.count('.') + text.count('!') + text.count('?') + text.count('\n')) |
|
|
|
|
|
@staticmethod |
|
|
def _listiness(text: str) -> int: |
|
|
if not text: |
|
|
return 0 |
|
|
markers = [',', ';', '\n', '-', '*', '•'] |
|
|
count = sum(text.count(m) for m in markers) |
|
|
|
|
|
count += len(re.findall(r'(?:(?<=\s)|^)(?:\d{1,2}[\.)\]])', text)) |
|
|
return count |
|
|
|
|
|
def _citation_density(self, text: str) -> float: |
|
|
if not text: |
|
|
return 0.0 |
|
|
words = text.split() |
|
|
if not words: |
|
|
return 0.0 |
|
|
matches = 0 |
|
|
low = text.lower() |
|
|
for pat in self.compiled_cite: |
|
|
matches += len(pat.findall(low)) |
|
|
return matches / max(1, len(words)) |
|
|
|
|
|
def _absolute_density(self, text: str) -> float: |
|
|
if not text: |
|
|
return 0.0 |
|
|
words = re.findall(r"\b\w+\b", text.lower()) |
|
|
if not words: |
|
|
return 0.0 |
|
|
abs_cnt = sum(1 for w in words if w in self.abs_terms) |
|
|
return abs_cnt / max(1, len(words)) |
|
|
|
|
|
def _enum_keyword_score(self, prompt: str, response: str) -> float: |
|
|
txt = f"{prompt} {response}".lower() |
|
|
return float(sum(1 for k in self.enum_keywords if k in txt)) |
|
|
|
|
|
def compute_values(self, dataset: List[Dict[str, Any]], **kwargs) -> np.ndarray: |
|
|
vals = [] |
|
|
for sample in dataset: |
|
|
prompt = sample.get('prompt', '') or '' |
|
|
response = sample.get('response', '') or '' |
|
|
resp = str(response) |
|
|
|
|
|
|
|
|
num_words = len(resp.split()) |
|
|
len_norm = min(1.0, num_words / 400.0) |
|
|
sent_norm = min(1.0, self._num_sentences(resp) / 12.0) |
|
|
list_norm = min(1.0, self._listiness(resp) / 40.0) |
|
|
num_density = (sum(ch.isdigit() for ch in resp) / max(1, len(resp))) |
|
|
abs_density = self._absolute_density(resp) |
|
|
cite_density = self._citation_density(resp) |
|
|
enum_score = min(1.0, self._enum_keyword_score(str(prompt), resp) / 4.0) |
|
|
|
|
|
|
|
|
risk = ( |
|
|
0.30 * len_norm + |
|
|
0.15 * sent_norm + |
|
|
0.20 * list_norm + |
|
|
0.10 * num_density + |
|
|
0.15 * abs_density + |
|
|
0.10 * enum_score - |
|
|
0.10 * cite_density |
|
|
) |
|
|
vals.append(float(np.clip(risk, 0.0, 1.0))) |
|
|
return np.array(vals, dtype=float) |
|
|
|
|
|
|
|
|
class MedicalContentGrouper(ConditionalGrouper): |
|
|
def __init__(self): |
|
|
super().__init__( |
|
|
name="medical_content", |
|
|
description="Medical content (Information/Interpretation/Action)" |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _normalize(text: str) -> str: |
|
|
if not isinstance(text, str): |
|
|
try: |
|
|
text = str(text) |
|
|
except Exception: |
|
|
return "" |
|
|
return " ".join(text.strip().lower().split()) |
|
|
|
|
|
def _classify(self, prompt: str) -> int: |
|
|
p = self._normalize(prompt) |
|
|
|
|
|
|
|
|
info_kw = [ |
|
|
"what is", "what are", "definition", "define", "symptom", "signs", "cause", "why", |
|
|
"prognosis", "life expectancy", "effect", "does .* do", "means?", "treatment", "therapy", |
|
|
"disease", "syndrome", "disorder", "cancer", "diabetes", "ards", "tay-sachs", "paget", |
|
|
"thalassemia", "psp", "rosacea", "empyema" |
|
|
] |
|
|
drug_kw = [ |
|
|
"drug", "medication", "medicine", "dose", "dosage", "tablet", "pill", "mg", "patch", |
|
|
"paxlovid", "zoloft", "lexapro", "meloxicam", "naproxen", "fentanyl", "celexa", "restoril", |
|
|
"calcitonin", "latanoprost", "aldactazide", "nicoderm" |
|
|
] |
|
|
symptom_kw = [ |
|
|
"pain", "ache", "swelling", "lump", "dark urine", "dizziness", "lightheaded", "fatigue", |
|
|
"muscle aches", "discharge", "sunburn", "hoarder", "smell" |
|
|
] |
|
|
interpret_kw = [ |
|
|
"what does it mean", "what does .* mean", "when should you worry", "should i worry", |
|
|
] |
|
|
action_kw = [ |
|
|
"should i", "do i need", "is it okay", "can i", "how to", "how do i", "stop", "start", |
|
|
"continue", "switch", "swap", "get tested", "try", "take", "drink", "use" |
|
|
] |
|
|
|
|
|
def contains_any(keys: List[str]) -> bool: |
|
|
for k in keys: |
|
|
if " .* " in k or ".*" in k: |
|
|
import re |
|
|
if re.search(k, p): |
|
|
return True |
|
|
if k in p: |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
if contains_any(action_kw): |
|
|
return 2 |
|
|
|
|
|
|
|
|
if (contains_any(info_kw) or contains_any(drug_kw)) and ("?" in prompt or contains_any(["what", "why", "signs", "symptom", "life expectancy", "treatment"])): |
|
|
return 0 |
|
|
|
|
|
|
|
|
if contains_any(interpret_kw) or contains_any(symptom_kw): |
|
|
return 1 |
|
|
|
|
|
|
|
|
if contains_any(["what", "why"]): |
|
|
return 0 |
|
|
|
|
|
|
|
|
if contains_any(["how to", "how do i"]): |
|
|
return 2 |
|
|
|
|
|
|
|
|
return 1 |
|
|
|
|
|
def compute_values(self, dataset: List[Dict[str, Any]], **kwargs) -> np.ndarray: |
|
|
values = [] |
|
|
for sample in dataset: |
|
|
prompt = sample.get('prompt', '') |
|
|
values.append(self._classify(prompt)) |
|
|
return np.array(values, dtype=float) |
|
|
|
|
|
def create_bins(self, values: np.ndarray, method: str = 'ignored', custom_bins: List[float] = None) -> List[Tuple[float, float]]: |
|
|
return [(-0.5, 0.5), (0.5, 1.5), (1.5, 2.5)] |
|
|
|
|
|
|
|
|
class ExpertQAFieldGrouper(ConditionalGrouper): |
|
|
"""ExpertQA official metadata.field based 3-group classifier |
|
|
|
|
|
- 0: Biology/Medicine (Biology, Chemistry, Psychology, Environmental Science, etc.) |
|
|
- 1: Engineering/Technology (Engineering and Technology, Physics and Astronomy, Architecture, etc.) |
|
|
- 2: Other (All other fields) |
|
|
|
|
|
The mapping is loaded from '/expertqa_prompt_to_field.json' by default. |
|
|
If the file does not exist, all samples are classified as Other(2). |
|
|
The values are integer labels, and create_bins is fixed to discrete intervals. |
|
|
""" |
|
|
|
|
|
def __init__(self, mapping_path: str = "/expertqa_prompt_to_field.json"): |
|
|
super().__init__( |
|
|
name="expertqa_field", |
|
|
description="ExpertQA metadata.field → {Bio/Med, Eng/Tech, Other}" |
|
|
) |
|
|
self.mapping_path = mapping_path |
|
|
self._loaded = False |
|
|
self._prompt_to_field = {} |
|
|
|
|
|
self.bio_med_fields = set([ |
|
|
"Healthcare / Medicine", |
|
|
"Biology", |
|
|
"Chemistry", |
|
|
"Psychology", |
|
|
"Environmental Science", |
|
|
]) |
|
|
self.eng_tech_fields = set([ |
|
|
"Engineering and Technology", |
|
|
"Physics and Astronomy", |
|
|
"Architecture", |
|
|
]) |
|
|
|
|
|
@staticmethod |
|
|
def _normalize(text: str) -> str: |
|
|
if not isinstance(text, str): |
|
|
try: |
|
|
text = str(text) |
|
|
except Exception: |
|
|
return "" |
|
|
return " ".join(text.strip().split()) |
|
|
|
|
|
def _ensure_loaded(self): |
|
|
if self._loaded: |
|
|
return |
|
|
try: |
|
|
if os.path.exists(self.mapping_path): |
|
|
with open(self.mapping_path, "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
self._prompt_to_field = {self._normalize(k): v for k, v in data.items()} |
|
|
else: |
|
|
self._prompt_to_field = {} |
|
|
except Exception: |
|
|
self._prompt_to_field = {} |
|
|
finally: |
|
|
self._loaded = True |
|
|
|
|
|
def _field_to_group(self, field: str) -> int: |
|
|
if not isinstance(field, str): |
|
|
return 2 |
|
|
f = field.strip() |
|
|
if f in self.bio_med_fields: |
|
|
return 0 |
|
|
if f in self.eng_tech_fields: |
|
|
return 1 |
|
|
return 2 |
|
|
|
|
|
def compute_values(self, dataset: List[Dict[str, Any]], **kwargs) -> np.ndarray: |
|
|
self._ensure_loaded() |
|
|
labels = [] |
|
|
for sample in dataset: |
|
|
prompt = sample.get('prompt', '') |
|
|
p_key = self._normalize(prompt) |
|
|
field = self._prompt_to_field.get(p_key) |
|
|
if field is None: |
|
|
q = sample.get('question', '') |
|
|
q_key = self._normalize(q) |
|
|
field = self._prompt_to_field.get(q_key) |
|
|
group_id = self._field_to_group(field) |
|
|
labels.append(float(group_id)) |
|
|
return np.array(labels, dtype=float) |
|
|
|
|
|
def create_bins(self, values: np.ndarray, method: str = 'ignored', custom_bins: List[float] = None) -> List[Tuple[float, float]]: |
|
|
return [(-0.5, 0.5), (0.5, 1.5), (1.5, 2.5)] |
|
|
|
|
|
|
|
|
def get_available_groupers() -> Dict[str, ConditionalGrouper]: |
|
|
return { |
|
|
'view_count': ViewCountGrouper(), |
|
|
'medical_content': MedicalContentGrouper(), |
|
|
'false_claim_risk': FalseClaimRiskGrouper(), |
|
|
} |
|
|
|
|
|
|
|
|
def compute_conditional_coverage_by_grouper( |
|
|
filtered_dataset: List[Dict[str, Any]], |
|
|
grouping_values: np.ndarray, |
|
|
bins: List[Tuple[float, float]] |
|
|
) -> List[float]: |
|
|
"""Calculate conditional coverage by a specific grouper""" |
|
|
|
|
|
def compute_marginal_coverage(sub_dataset: List[Dict[str, Any]]) -> float: |
|
|
"""Calculate marginal coverage from a given subset""" |
|
|
indicators = [] |
|
|
for d in sub_dataset: |
|
|
retained = d.get('filtered_claims', []) |
|
|
has_false = any([not c.get('is_supported', False) for c in retained]) |
|
|
indicators.append(0.0 if has_false else 1.0) |
|
|
return float(np.mean(indicators)) if indicators else 0.0 |
|
|
|
|
|
coverage_results = [] |
|
|
|
|
|
for bin_min, bin_max in bins: |
|
|
mask = [] |
|
|
for i, value in enumerate(grouping_values): |
|
|
if np.isfinite(value): |
|
|
mask.append(bin_min <= value <= bin_max) |
|
|
else: |
|
|
mask.append(False) |
|
|
|
|
|
indices = [i for i, m in enumerate(mask) if m] |
|
|
|
|
|
if not indices: |
|
|
coverage_results.append(np.nan) |
|
|
continue |
|
|
|
|
|
subset = [filtered_dataset[i] for i in indices] |
|
|
coverage = compute_marginal_coverage(subset) |
|
|
coverage_results.append(coverage) |
|
|
|
|
|
return coverage_results |
|
|
|
|
|
|
|
|
def compute_retention_by_grouper( |
|
|
filtered_dataset: List[Dict[str, Any]], |
|
|
grouping_values: np.ndarray, |
|
|
bins: List[Tuple[float, float]] |
|
|
) -> List[Dict[str, Any]]: |
|
|
"""Calculate retention rate by a specific grouper""" |
|
|
|
|
|
retention_results = [] |
|
|
|
|
|
for bin_min, bin_max in bins: |
|
|
mask = [] |
|
|
for i, value in enumerate(grouping_values): |
|
|
if np.isfinite(value): |
|
|
mask.append(bin_min <= value <= bin_max) |
|
|
else: |
|
|
mask.append(False) |
|
|
|
|
|
indices = [i for i, m in enumerate(mask) if m] |
|
|
|
|
|
if not indices: |
|
|
retention_results.append({ |
|
|
'bin': (float(bin_min), float(bin_max)), |
|
|
'samples': 0, |
|
|
'retained': 0, |
|
|
'total': 0, |
|
|
'rate': np.nan, |
|
|
}) |
|
|
continue |
|
|
|
|
|
total_claims = 0 |
|
|
retained_claims = 0 |
|
|
sample_count = len(indices) |
|
|
|
|
|
for idx in indices: |
|
|
d = filtered_dataset[idx] |
|
|
afs = d.get('atomic_facts', []) |
|
|
total_claims += len(afs) |
|
|
retained_claims += len(d.get('filtered_claims', [])) |
|
|
|
|
|
rate = (retained_claims / total_claims) if total_claims > 0 else np.nan |
|
|
|
|
|
retention_results.append({ |
|
|
'bin': (float(bin_min), float(bin_max)), |
|
|
'samples': sample_count, |
|
|
'retained': int(retained_claims), |
|
|
'total': int(total_claims), |
|
|
'rate': float(rate) if not np.isnan(rate) else np.nan, |
|
|
}) |
|
|
|
|
|
return retention_results |
|
|
|