"""data_loader.py — Load dataset JSON files and build in-memory indexes.""" import json import re import time import os from typing import Dict, List, Tuple from dataclasses import dataclass, field DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data") @dataclass class DataStore: reviews_all: List[dict] = field(default_factory=list) rebuttals_all: List[dict] = field(default_factory=list) # Primary indexes review_by_paper_id: Dict[str, dict] = field(default_factory=dict) rebuttals_by_paper_id: Dict[str, List[dict]] = field(default_factory=dict) rebuttal_by_paper_reviewer: Dict[Tuple[str, str], dict] = field(default_factory=dict) # Filter indexes papers_by_conference: Dict[str, List[str]] = field(default_factory=dict) papers_by_year: Dict[int, List[str]] = field(default_factory=dict) papers_by_conf_year: Dict[Tuple[str, int], List[str]] = field(default_factory=dict) # Parsed metadata parsed_cyt: Dict[str, Tuple[str, int, str]] = field(default_factory=dict) # Sorted unique values for dropdowns conferences: List[str] = field(default_factory=list) years: List[int] = field(default_factory=list) # All paper IDs sorted all_paper_ids: List[str] = field(default_factory=list) load_time: float = 0.0 def parse_conference_year_track(cyt: str) -> Tuple[str, int, str]: """Parse 'NeurIPS 2022 Conference' -> ('NeurIPS', 2022, 'Conference').""" match = re.match(r'^(.+?)\s+(\d{4})\s+(.*)', cyt) if match: return match.group(1).strip(), int(match.group(2)), match.group(3).strip() year_match = re.search(r'(\d{4})', cyt) year = int(year_match.group(1)) if year_match else 0 return cyt, year, "" def load_all_data(data_dir: str = None) -> DataStore: """Load all JSON files and build indexes. Called once at startup.""" if data_dir is None: data_dir = DATA_DIR store = DataStore() start = time.time() # Load reviews for fname in ["REVIEWS_train.json", "REVIEWS_test.json"]: fpath = os.path.join(data_dir, fname) if os.path.exists(fpath): print(f"Loading {fname}...") with open(fpath, "r", encoding="utf-8") as f: data = json.load(f) store.reviews_all.extend(data) print(f" -> {len(data)} papers") # Load rebuttals for fname in ["REBUTTAL_train.json", "REBUTTAL_test.json"]: fpath = os.path.join(data_dir, fname) if os.path.exists(fpath): print(f"Loading {fname}...") with open(fpath, "r", encoding="utf-8") as f: data = json.load(f) store.rebuttals_all.extend(data) print(f" -> {len(data)} conversations") # Build indexes _build_indexes(store) store.load_time = time.time() - start print(f"\nData loaded in {store.load_time:.1f}s") print(f" Papers: {len(store.reviews_all)}") print(f" Reviews: {sum(len(p['reviews']) for p in store.reviews_all)}") print(f" Rebuttals: {len(store.rebuttals_all)}") print(f" Conferences: {len(store.conferences)}") print(f" Years: {store.years}") return store def _build_indexes(store: DataStore): conf_set = set() year_set = set() paper_id_set = set() for paper in store.reviews_all: pid = paper["paper_id"] cyt = paper["conference_year_track"] store.review_by_paper_id[pid] = paper paper_id_set.add(pid) if cyt not in store.parsed_cyt: store.parsed_cyt[cyt] = parse_conference_year_track(cyt) conf, year, track = store.parsed_cyt[cyt] conf_set.add(conf) year_set.add(year) store.papers_by_conference.setdefault(conf, []).append(pid) store.papers_by_year.setdefault(year, []).append(pid) store.papers_by_conf_year.setdefault((conf, year), []).append(pid) for rebuttal in store.rebuttals_all: pid = rebuttal["paper_id"] rid = rebuttal["reviewer_id"] store.rebuttals_by_paper_id.setdefault(pid, []).append(rebuttal) store.rebuttal_by_paper_reviewer[(pid, rid)] = rebuttal store.conferences = sorted(conf_set) store.years = sorted(y for y in year_set if y > 0) store.all_paper_ids = sorted(paper_id_set)