ehejin's picture
sync w/ detailed repo
0f4326e
"""Save session state to local disk and upload JSON + flattened CSV to HuggingFace."""
import csv
import json
import os
import tempfile
import uuid
from datetime import datetime
from pathlib import Path
import streamlit as st
from huggingface_hub import HfApi
from src.data import release_reservation, record_completion
@st.cache_resource
def _get_hf_api(hf_token: str, output_repo: str) -> HfApi:
"""Initialise HF API client and ensure the output repo exists."""
api = HfApi(token=hf_token) if hf_token else HfApi()
if hf_token:
try:
api.repo_info(repo_id=output_repo, repo_type="dataset")
except Exception as e:
if "404" in str(e) or "not found" in str(e).lower():
api.create_repo(repo_id=output_repo, repo_type="dataset", private=True)
print(f"[HF] Created output repo: {output_repo}")
else:
print(f"[HF] Warning checking repo existence: {e}")
return api
def save_and_upload(state: dict, cfg: dict) -> None:
"""Write the full JSON to disk, then upload JSON + flattened CSV to HuggingFace."""
output_repo = cfg["output_dataset_repo"]
hf_token = cfg.get("hf_token", "")
hf_api = _get_hf_api(hf_token, output_repo)
worker_id = state.get("prolific_pid") or state.get("user_id", "anonymous")
submission_id = state.get("submission_id", str(uuid.uuid4()))
safe_worker = "".join(c if c.isalnum() else "_" for c in str(worker_id))
print(f"[SAVE] starting save_and_upload")
print(f"[SAVE] output_repo={output_repo}")
print(f"[SAVE] hf_token set={bool(hf_token)}")
# ── Write JSON ────────────────────────────────────────────────────────────
ann_dir = Path(cfg["annotations_dir"]) / safe_worker
ann_dir.mkdir(parents=True, exist_ok=True)
json_path = ann_dir / f"{submission_id}.json"
with open(json_path, "w") as f:
json.dump(state, f, indent=2)
print(f"[SAVE] JSON written: {json_path}")
uploaded = False
if hf_token:
try:
hf_api.upload_file(
path_or_fileobj=str(json_path),
path_in_repo=f"json/{safe_worker}/{submission_id}.json",
repo_id=output_repo,
repo_type="dataset",
)
print("[HF] JSON uploaded.")
uploaded = True
except Exception as e:
print(f"[HF] JSON upload error: {e}")
if uploaded:
# Release reservations so items are immediately available for re-assignment
release_reservation(state.get("user_id", ""), cfg)
# Record completion locally — updates counts immediately without waiting
# for an HF re-scan. Also invalidates the HF cache.
record_completion(state.get("user_id", ""), state.get("items", []), cfg)
# Auto-pause Prolific study if all items are now covered
try:
from src.data import all_items_covered, pause_prolific_study
if all_items_covered(cfg):
pause_prolific_study(cfg)
except Exception as e:
print(f"[SAVE] Auto-pause check failed: {e}")
# ── Write + upload CSV ────────────────────────────────────────────────────
_save_and_upload_csv(state, cfg, hf_api, safe_worker, submission_id)
# ── CSV schema ────────────────────────────────────────────────────────────────
_COMMON_HEADER = [
"submission_id", "prolific_pid", "study_id", "session_id",
"submission_time", "duration_seconds",
"study_type", "model_name",
"prompt_personalization", "prompt_detailed_instruction",
"pair_selection_seed", "category",
# Demographics
"age", "gender", "geographic_region", "education_level", "race",
"us_citizen", "marital_status", "religion", "religious_attendance",
"political_affiliation", "income", "political_views",
"household_size", "employment_status",
# Background
"movies_criteria", "movies_enjoy", "movies_avoid",
"groceries_criteria", "groceries_enjoy", "groceries_avoid",
# Ratings
"pre_rating", "post_rating", "rating_delta",
# Conversation
"num_turns", "conversation_json",
# Reflection
"standout_moment", "thinking_change",
]
_PREFERENCE_EXTRA_HEADER = [
"pair_index", "pair_id",
"product_a_id", "product_a_title", "product_a_price",
"product_b_id", "product_b_title", "product_b_price",
"familiarity_a", "familiarity_b",
]
_LIKELIHOOD_EXTRA_HEADER = [
"item_index", "item_id",
"product_title", "product_price",
"familiarity",
]
def _save_and_upload_csv(
state: dict, cfg: dict, hf_api: HfApi, safe_worker: str, submission_id: str
) -> None:
study_type = cfg["study_type"]
demographics = state.get("demographics", {})
background = state.get("background", {})
items = state.get("items", [])
header = _COMMON_HEADER + (
_PREFERENCE_EXTRA_HEADER if study_type == "preference"
else _LIKELIHOOD_EXTRA_HEADER
)
rows = []
for i, item in enumerate(items):
conv = item.get("conversation", {})
refl = item.get("reflection", {})
pre = item.get("pre_rating", "")
post = item.get("post_rating", "")
delta = (post - pre) if isinstance(pre, int) and isinstance(post, int) else ""
pv = item.get("prompt_variant", {})
common = [
submission_id,
state.get("prolific_pid", ""),
state.get("study_id", ""),
state.get("session_id", ""),
state.get("meta", {}).get("submission_time", ""),
state.get("meta", {}).get("duration_seconds", ""),
study_type,
item.get("model_name", ""),
pv.get("personalization", False),
pv.get("detailed_instruction", True),
cfg.get("pair_selection_seed", 42),
item.get("category", ""),
demographics.get("age", ""),
demographics.get("gender", ""),
demographics.get("geographic_region", ""),
demographics.get("education_level", ""),
demographics.get("race", ""),
demographics.get("us_citizen", ""),
demographics.get("marital_status", ""),
demographics.get("religion", ""),
demographics.get("religious_attendance", ""),
demographics.get("political_affiliation",""),
demographics.get("income", ""),
demographics.get("political_views", ""),
demographics.get("household_size", ""),
demographics.get("employment_status", ""),
background.get("movies_criteria", ""),
background.get("movies_enjoy", ""),
background.get("movies_avoid", ""),
background.get("groceries_criteria", ""),
background.get("groceries_enjoy", ""),
background.get("groceries_avoid", ""),
pre, post, delta,
conv.get("num_turns", 0),
json.dumps(conv.get("turns", [])),
refl.get("standout_moment", ""),
refl.get("thinking_change", ""),
]
if study_type == "preference":
pa, pb = item.get("product_a", {}), item.get("product_b", {})
extra = [
i + 1,
item.get("pair_id", ""),
pa.get("id", ""), pa.get("title", ""), pa.get("price", ""),
pb.get("id", ""), pb.get("title", ""), pb.get("price", ""),
item.get("familiarity_a", ""),
item.get("familiarity_b", ""),
]
else:
prod = item.get("product", {})
extra = [
i + 1,
item.get("item_id", ""),
prod.get("title", ""), prod.get("price", ""),
item.get("familiarity", ""),
]
rows.append(common + extra)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
unique_tag = uuid.uuid4().hex[:8]
repo_path = f"csv/{timestamp}_{safe_worker}_{unique_tag}.csv"
with tempfile.NamedTemporaryFile(
mode="w", suffix=".csv", delete=False, newline="", encoding="utf-8"
) as tmp:
tmp_path = tmp.name
writer = csv.writer(tmp)
writer.writerow(header)
writer.writerows(rows)
if cfg.get("hf_token"):
try:
hf_api.upload_file(
path_or_fileobj=tmp_path,
path_in_repo=repo_path,
repo_id=cfg["output_dataset_repo"],
repo_type="dataset",
)
print("[HF] CSV uploaded.")
except Exception as e:
print(f"[HF] CSV upload error: {e}")
os.unlink(tmp_path)