"""Save session state to local disk and upload JSON + flattened CSV to HuggingFace.""" import csv import json import os import tempfile import uuid from datetime import datetime from pathlib import Path import streamlit as st from huggingface_hub import HfApi from src.data import release_reservation, record_completion @st.cache_resource def _get_hf_api(hf_token: str, output_repo: str) -> HfApi: """Initialise HF API client and ensure the output repo exists.""" api = HfApi(token=hf_token) if hf_token else HfApi() if hf_token: try: api.repo_info(repo_id=output_repo, repo_type="dataset") except Exception as e: if "404" in str(e) or "not found" in str(e).lower(): api.create_repo(repo_id=output_repo, repo_type="dataset", private=True) print(f"[HF] Created output repo: {output_repo}") else: print(f"[HF] Warning checking repo existence: {e}") return api def save_and_upload(state: dict, cfg: dict) -> None: """Write the full JSON to disk, then upload JSON + flattened CSV to HuggingFace.""" output_repo = cfg["output_dataset_repo"] hf_token = cfg.get("hf_token", "") hf_api = _get_hf_api(hf_token, output_repo) worker_id = state.get("prolific_pid") or state.get("user_id", "anonymous") submission_id = state.get("submission_id", str(uuid.uuid4())) safe_worker = "".join(c if c.isalnum() else "_" for c in str(worker_id)) print(f"[SAVE] starting save_and_upload") print(f"[SAVE] output_repo={output_repo}") print(f"[SAVE] hf_token set={bool(hf_token)}") # ── Write JSON ──────────────────────────────────────────────────────────── ann_dir = Path(cfg["annotations_dir"]) / safe_worker ann_dir.mkdir(parents=True, exist_ok=True) json_path = ann_dir / f"{submission_id}.json" with open(json_path, "w") as f: json.dump(state, f, indent=2) print(f"[SAVE] JSON written: {json_path}") uploaded = False if hf_token: try: hf_api.upload_file( path_or_fileobj=str(json_path), path_in_repo=f"json/{safe_worker}/{submission_id}.json", repo_id=output_repo, repo_type="dataset", ) print("[HF] JSON uploaded.") uploaded = True except Exception as e: print(f"[HF] JSON upload error: {e}") if uploaded: # Release reservations so items are immediately available for re-assignment release_reservation(state.get("user_id", ""), cfg) # Record completion locally — updates counts immediately without waiting # for an HF re-scan. Also invalidates the HF cache. record_completion(state.get("user_id", ""), state.get("items", []), cfg) # Auto-pause Prolific study if all items are now covered try: from src.data import all_items_covered, pause_prolific_study if all_items_covered(cfg): pause_prolific_study(cfg) except Exception as e: print(f"[SAVE] Auto-pause check failed: {e}") # ── Write + upload CSV ──────────────────────────────────────────────────── _save_and_upload_csv(state, cfg, hf_api, safe_worker, submission_id) # ── CSV schema ──────────────────────────────────────────────────────────────── _COMMON_HEADER = [ "submission_id", "prolific_pid", "study_id", "session_id", "submission_time", "duration_seconds", "study_type", "model_name", "prompt_personalization", "prompt_detailed_instruction", "pair_selection_seed", "category", # Demographics "age", "gender", "geographic_region", "education_level", "race", "us_citizen", "marital_status", "religion", "religious_attendance", "political_affiliation", "income", "political_views", "household_size", "employment_status", # Background "movies_criteria", "movies_enjoy", "movies_avoid", "groceries_criteria", "groceries_enjoy", "groceries_avoid", # Ratings "pre_rating", "post_rating", "rating_delta", # Conversation "num_turns", "conversation_json", # Reflection "standout_moment", "thinking_change", ] _PREFERENCE_EXTRA_HEADER = [ "pair_index", "pair_id", "product_a_id", "product_a_title", "product_a_price", "product_b_id", "product_b_title", "product_b_price", "familiarity_a", "familiarity_b", ] _LIKELIHOOD_EXTRA_HEADER = [ "item_index", "item_id", "product_title", "product_price", "familiarity", ] def _save_and_upload_csv( state: dict, cfg: dict, hf_api: HfApi, safe_worker: str, submission_id: str ) -> None: study_type = cfg["study_type"] demographics = state.get("demographics", {}) background = state.get("background", {}) items = state.get("items", []) header = _COMMON_HEADER + ( _PREFERENCE_EXTRA_HEADER if study_type == "preference" else _LIKELIHOOD_EXTRA_HEADER ) rows = [] for i, item in enumerate(items): conv = item.get("conversation", {}) refl = item.get("reflection", {}) pre = item.get("pre_rating", "") post = item.get("post_rating", "") delta = (post - pre) if isinstance(pre, int) and isinstance(post, int) else "" pv = item.get("prompt_variant", {}) common = [ submission_id, state.get("prolific_pid", ""), state.get("study_id", ""), state.get("session_id", ""), state.get("meta", {}).get("submission_time", ""), state.get("meta", {}).get("duration_seconds", ""), study_type, item.get("model_name", ""), pv.get("personalization", False), pv.get("detailed_instruction", True), cfg.get("pair_selection_seed", 42), item.get("category", ""), demographics.get("age", ""), demographics.get("gender", ""), demographics.get("geographic_region", ""), demographics.get("education_level", ""), demographics.get("race", ""), demographics.get("us_citizen", ""), demographics.get("marital_status", ""), demographics.get("religion", ""), demographics.get("religious_attendance", ""), demographics.get("political_affiliation",""), demographics.get("income", ""), demographics.get("political_views", ""), demographics.get("household_size", ""), demographics.get("employment_status", ""), background.get("movies_criteria", ""), background.get("movies_enjoy", ""), background.get("movies_avoid", ""), background.get("groceries_criteria", ""), background.get("groceries_enjoy", ""), background.get("groceries_avoid", ""), pre, post, delta, conv.get("num_turns", 0), json.dumps(conv.get("turns", [])), refl.get("standout_moment", ""), refl.get("thinking_change", ""), ] if study_type == "preference": pa, pb = item.get("product_a", {}), item.get("product_b", {}) extra = [ i + 1, item.get("pair_id", ""), pa.get("id", ""), pa.get("title", ""), pa.get("price", ""), pb.get("id", ""), pb.get("title", ""), pb.get("price", ""), item.get("familiarity_a", ""), item.get("familiarity_b", ""), ] else: prod = item.get("product", {}) extra = [ i + 1, item.get("item_id", ""), prod.get("title", ""), prod.get("price", ""), item.get("familiarity", ""), ] rows.append(common + extra) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") unique_tag = uuid.uuid4().hex[:8] repo_path = f"csv/{timestamp}_{safe_worker}_{unique_tag}.csv" with tempfile.NamedTemporaryFile( mode="w", suffix=".csv", delete=False, newline="", encoding="utf-8" ) as tmp: tmp_path = tmp.name writer = csv.writer(tmp) writer.writerow(header) writer.writerows(rows) if cfg.get("hf_token"): try: hf_api.upload_file( path_or_fileobj=tmp_path, path_in_repo=repo_path, repo_id=cfg["output_dataset_repo"], repo_type="dataset", ) print("[HF] CSV uploaded.") except Exception as e: print(f"[HF] CSV upload error: {e}") os.unlink(tmp_path)