Spaces:
Runtime error
Runtime error
File size: 4,316 Bytes
aeda197 82ce513 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | """data_loader.py — Load dataset JSON files and build in-memory indexes."""
import json
import re
import time
import os
from typing import Dict, List, Tuple
from dataclasses import dataclass, field
DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
@dataclass
class DataStore:
reviews_all: List[dict] = field(default_factory=list)
rebuttals_all: List[dict] = field(default_factory=list)
# Primary indexes
review_by_paper_id: Dict[str, dict] = field(default_factory=dict)
rebuttals_by_paper_id: Dict[str, List[dict]] = field(default_factory=dict)
rebuttal_by_paper_reviewer: Dict[Tuple[str, str], dict] = field(default_factory=dict)
# Filter indexes
papers_by_conference: Dict[str, List[str]] = field(default_factory=dict)
papers_by_year: Dict[int, List[str]] = field(default_factory=dict)
papers_by_conf_year: Dict[Tuple[str, int], List[str]] = field(default_factory=dict)
# Parsed metadata
parsed_cyt: Dict[str, Tuple[str, int, str]] = field(default_factory=dict)
# Sorted unique values for dropdowns
conferences: List[str] = field(default_factory=list)
years: List[int] = field(default_factory=list)
# All paper IDs sorted
all_paper_ids: List[str] = field(default_factory=list)
load_time: float = 0.0
def parse_conference_year_track(cyt: str) -> Tuple[str, int, str]:
"""Parse 'NeurIPS 2022 Conference' -> ('NeurIPS', 2022, 'Conference')."""
match = re.match(r'^(.+?)\s+(\d{4})\s+(.*)', cyt)
if match:
return match.group(1).strip(), int(match.group(2)), match.group(3).strip()
year_match = re.search(r'(\d{4})', cyt)
year = int(year_match.group(1)) if year_match else 0
return cyt, year, ""
def load_all_data(data_dir: str = None) -> DataStore:
"""Load all JSON files and build indexes. Called once at startup."""
if data_dir is None:
data_dir = DATA_DIR
store = DataStore()
start = time.time()
# Load reviews
for fname in ["REVIEWS_train.json", "REVIEWS_test.json"]:
fpath = os.path.join(data_dir, fname)
if os.path.exists(fpath):
print(f"Loading {fname}...")
with open(fpath, "r", encoding="utf-8") as f:
data = json.load(f)
store.reviews_all.extend(data)
print(f" -> {len(data)} papers")
# Load rebuttals
for fname in ["REBUTTAL_train.json", "REBUTTAL_test.json"]:
fpath = os.path.join(data_dir, fname)
if os.path.exists(fpath):
print(f"Loading {fname}...")
with open(fpath, "r", encoding="utf-8") as f:
data = json.load(f)
store.rebuttals_all.extend(data)
print(f" -> {len(data)} conversations")
# Build indexes
_build_indexes(store)
store.load_time = time.time() - start
print(f"\nData loaded in {store.load_time:.1f}s")
print(f" Papers: {len(store.reviews_all)}")
print(f" Reviews: {sum(len(p['reviews']) for p in store.reviews_all)}")
print(f" Rebuttals: {len(store.rebuttals_all)}")
print(f" Conferences: {len(store.conferences)}")
print(f" Years: {store.years}")
return store
def _build_indexes(store: DataStore):
conf_set = set()
year_set = set()
paper_id_set = set()
for paper in store.reviews_all:
pid = paper["paper_id"]
cyt = paper["conference_year_track"]
store.review_by_paper_id[pid] = paper
paper_id_set.add(pid)
if cyt not in store.parsed_cyt:
store.parsed_cyt[cyt] = parse_conference_year_track(cyt)
conf, year, track = store.parsed_cyt[cyt]
conf_set.add(conf)
year_set.add(year)
store.papers_by_conference.setdefault(conf, []).append(pid)
store.papers_by_year.setdefault(year, []).append(pid)
store.papers_by_conf_year.setdefault((conf, year), []).append(pid)
for rebuttal in store.rebuttals_all:
pid = rebuttal["paper_id"]
rid = rebuttal["reviewer_id"]
store.rebuttals_by_paper_id.setdefault(pid, []).append(rebuttal)
store.rebuttal_by_paper_reviewer[(pid, rid)] = rebuttal
store.conferences = sorted(conf_set)
store.years = sorted(y for y in year_set if y > 0)
store.all_paper_ids = sorted(paper_id_set)
|