Spaces:
Runtime error
Runtime error
| """data_loader.py — Load dataset JSON files and build in-memory indexes.""" | |
| import json | |
| import re | |
| import time | |
| import os | |
| from typing import Dict, List, Tuple | |
| from dataclasses import dataclass, field | |
| DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data") | |
| class DataStore: | |
| reviews_all: List[dict] = field(default_factory=list) | |
| rebuttals_all: List[dict] = field(default_factory=list) | |
| # Primary indexes | |
| review_by_paper_id: Dict[str, dict] = field(default_factory=dict) | |
| rebuttals_by_paper_id: Dict[str, List[dict]] = field(default_factory=dict) | |
| rebuttal_by_paper_reviewer: Dict[Tuple[str, str], dict] = field(default_factory=dict) | |
| # Filter indexes | |
| papers_by_conference: Dict[str, List[str]] = field(default_factory=dict) | |
| papers_by_year: Dict[int, List[str]] = field(default_factory=dict) | |
| papers_by_conf_year: Dict[Tuple[str, int], List[str]] = field(default_factory=dict) | |
| # Parsed metadata | |
| parsed_cyt: Dict[str, Tuple[str, int, str]] = field(default_factory=dict) | |
| # Sorted unique values for dropdowns | |
| conferences: List[str] = field(default_factory=list) | |
| years: List[int] = field(default_factory=list) | |
| # All paper IDs sorted | |
| all_paper_ids: List[str] = field(default_factory=list) | |
| load_time: float = 0.0 | |
| def parse_conference_year_track(cyt: str) -> Tuple[str, int, str]: | |
| """Parse 'NeurIPS 2022 Conference' -> ('NeurIPS', 2022, 'Conference').""" | |
| match = re.match(r'^(.+?)\s+(\d{4})\s+(.*)', cyt) | |
| if match: | |
| return match.group(1).strip(), int(match.group(2)), match.group(3).strip() | |
| year_match = re.search(r'(\d{4})', cyt) | |
| year = int(year_match.group(1)) if year_match else 0 | |
| return cyt, year, "" | |
| def load_all_data(data_dir: str = None) -> DataStore: | |
| """Load all JSON files and build indexes. Called once at startup.""" | |
| if data_dir is None: | |
| data_dir = DATA_DIR | |
| store = DataStore() | |
| start = time.time() | |
| # Load reviews | |
| for fname in ["REVIEWS_train.json", "REVIEWS_test.json"]: | |
| fpath = os.path.join(data_dir, fname) | |
| if os.path.exists(fpath): | |
| print(f"Loading {fname}...") | |
| with open(fpath, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| store.reviews_all.extend(data) | |
| print(f" -> {len(data)} papers") | |
| # Load rebuttals | |
| for fname in ["REBUTTAL_train.json", "REBUTTAL_test.json"]: | |
| fpath = os.path.join(data_dir, fname) | |
| if os.path.exists(fpath): | |
| print(f"Loading {fname}...") | |
| with open(fpath, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| store.rebuttals_all.extend(data) | |
| print(f" -> {len(data)} conversations") | |
| # Build indexes | |
| _build_indexes(store) | |
| store.load_time = time.time() - start | |
| print(f"\nData loaded in {store.load_time:.1f}s") | |
| print(f" Papers: {len(store.reviews_all)}") | |
| print(f" Reviews: {sum(len(p['reviews']) for p in store.reviews_all)}") | |
| print(f" Rebuttals: {len(store.rebuttals_all)}") | |
| print(f" Conferences: {len(store.conferences)}") | |
| print(f" Years: {store.years}") | |
| return store | |
| def _build_indexes(store: DataStore): | |
| conf_set = set() | |
| year_set = set() | |
| paper_id_set = set() | |
| for paper in store.reviews_all: | |
| pid = paper["paper_id"] | |
| cyt = paper["conference_year_track"] | |
| store.review_by_paper_id[pid] = paper | |
| paper_id_set.add(pid) | |
| if cyt not in store.parsed_cyt: | |
| store.parsed_cyt[cyt] = parse_conference_year_track(cyt) | |
| conf, year, track = store.parsed_cyt[cyt] | |
| conf_set.add(conf) | |
| year_set.add(year) | |
| store.papers_by_conference.setdefault(conf, []).append(pid) | |
| store.papers_by_year.setdefault(year, []).append(pid) | |
| store.papers_by_conf_year.setdefault((conf, year), []).append(pid) | |
| for rebuttal in store.rebuttals_all: | |
| pid = rebuttal["paper_id"] | |
| rid = rebuttal["reviewer_id"] | |
| store.rebuttals_by_paper_id.setdefault(pid, []).append(rebuttal) | |
| store.rebuttal_by_paper_reviewer[(pid, rid)] = rebuttal | |
| store.conferences = sorted(conf_set) | |
| store.years = sorted(y for y in year_set if y > 0) | |
| store.all_paper_ids = sorted(paper_id_set) | |