review_annotation / app /data_loader.py
nuocuhz's picture
Remove dataset references, pin gradio==4.44.1
aeda197
"""data_loader.py — Load dataset JSON files and build in-memory indexes."""
import json
import re
import time
import os
from typing import Dict, List, Tuple
from dataclasses import dataclass, field
DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
@dataclass
class DataStore:
reviews_all: List[dict] = field(default_factory=list)
rebuttals_all: List[dict] = field(default_factory=list)
# Primary indexes
review_by_paper_id: Dict[str, dict] = field(default_factory=dict)
rebuttals_by_paper_id: Dict[str, List[dict]] = field(default_factory=dict)
rebuttal_by_paper_reviewer: Dict[Tuple[str, str], dict] = field(default_factory=dict)
# Filter indexes
papers_by_conference: Dict[str, List[str]] = field(default_factory=dict)
papers_by_year: Dict[int, List[str]] = field(default_factory=dict)
papers_by_conf_year: Dict[Tuple[str, int], List[str]] = field(default_factory=dict)
# Parsed metadata
parsed_cyt: Dict[str, Tuple[str, int, str]] = field(default_factory=dict)
# Sorted unique values for dropdowns
conferences: List[str] = field(default_factory=list)
years: List[int] = field(default_factory=list)
# All paper IDs sorted
all_paper_ids: List[str] = field(default_factory=list)
load_time: float = 0.0
def parse_conference_year_track(cyt: str) -> Tuple[str, int, str]:
"""Parse 'NeurIPS 2022 Conference' -> ('NeurIPS', 2022, 'Conference')."""
match = re.match(r'^(.+?)\s+(\d{4})\s+(.*)', cyt)
if match:
return match.group(1).strip(), int(match.group(2)), match.group(3).strip()
year_match = re.search(r'(\d{4})', cyt)
year = int(year_match.group(1)) if year_match else 0
return cyt, year, ""
def load_all_data(data_dir: str = None) -> DataStore:
"""Load all JSON files and build indexes. Called once at startup."""
if data_dir is None:
data_dir = DATA_DIR
store = DataStore()
start = time.time()
# Load reviews
for fname in ["REVIEWS_train.json", "REVIEWS_test.json"]:
fpath = os.path.join(data_dir, fname)
if os.path.exists(fpath):
print(f"Loading {fname}...")
with open(fpath, "r", encoding="utf-8") as f:
data = json.load(f)
store.reviews_all.extend(data)
print(f" -> {len(data)} papers")
# Load rebuttals
for fname in ["REBUTTAL_train.json", "REBUTTAL_test.json"]:
fpath = os.path.join(data_dir, fname)
if os.path.exists(fpath):
print(f"Loading {fname}...")
with open(fpath, "r", encoding="utf-8") as f:
data = json.load(f)
store.rebuttals_all.extend(data)
print(f" -> {len(data)} conversations")
# Build indexes
_build_indexes(store)
store.load_time = time.time() - start
print(f"\nData loaded in {store.load_time:.1f}s")
print(f" Papers: {len(store.reviews_all)}")
print(f" Reviews: {sum(len(p['reviews']) for p in store.reviews_all)}")
print(f" Rebuttals: {len(store.rebuttals_all)}")
print(f" Conferences: {len(store.conferences)}")
print(f" Years: {store.years}")
return store
def _build_indexes(store: DataStore):
conf_set = set()
year_set = set()
paper_id_set = set()
for paper in store.reviews_all:
pid = paper["paper_id"]
cyt = paper["conference_year_track"]
store.review_by_paper_id[pid] = paper
paper_id_set.add(pid)
if cyt not in store.parsed_cyt:
store.parsed_cyt[cyt] = parse_conference_year_track(cyt)
conf, year, track = store.parsed_cyt[cyt]
conf_set.add(conf)
year_set.add(year)
store.papers_by_conference.setdefault(conf, []).append(pid)
store.papers_by_year.setdefault(year, []).append(pid)
store.papers_by_conf_year.setdefault((conf, year), []).append(pid)
for rebuttal in store.rebuttals_all:
pid = rebuttal["paper_id"]
rid = rebuttal["reviewer_id"]
store.rebuttals_by_paper_id.setdefault(pid, []).append(rebuttal)
store.rebuttal_by_paper_reviewer[(pid, rid)] = rebuttal
store.conferences = sorted(conf_set)
store.years = sorted(y for y in year_set if y > 0)
store.all_paper_ids = sorted(paper_id_set)