"""Save session state to local disk and upload JSON + flattened CSV to HuggingFace.""" import csv import json import os import tempfile import uuid from datetime import datetime from pathlib import Path import streamlit as st from huggingface_hub import HfApi from src.data import release_reservation, record_completion @st.cache_resource def _get_hf_api(hf_token: str, output_repo: str) -> HfApi: """Initialise HF API client and ensure the output repo exists.""" api = HfApi(token=hf_token) if hf_token else HfApi() if hf_token: try: api.repo_info(repo_id=output_repo, repo_type="dataset") except Exception as e: if "404" in str(e) or "not found" in str(e).lower(): api.create_repo(repo_id=output_repo, repo_type="dataset", private=True) print(f"[HF] Created output repo: {output_repo}") else: print(f"[HF] Warning checking repo existence: {e}") return api def _hf_study_identity_snapshot(state: dict, cfg: dict) -> dict: """ Compact, non-secret metadata duplicated into every HF JSON export so files are self-describing (especially which YAML model entry each participant saw). """ snap: dict = { "study_type": cfg.get("study_type"), "pair_selection_seed": cfg.get("pair_selection_seed"), "categories": cfg.get("categories"), "pairs_per_user": cfg.get("pairs_per_user"), "min_turns": cfg.get("min_turns"), "max_turns": cfg.get("max_turns"), "prolific_study_id": cfg.get("prolific_study_id"), "output_dataset_repo": cfg.get("output_dataset_repo"), } if cfg.get("study_type") != "model_comparison": return snap snap.update( { "submission_id": state.get("submission_id"), "session_user_id": state.get("user_id"), "comparison_models_canonical": state.get("comparison_models_canonical"), "model_presentation_order_config_indices": state.get( "model_presentation_order_config_indices" ), "model_identity_timeline": state.get("model_identity_timeline"), "comparison_final": state.get("comparison_final"), } ) return snap def save_and_upload(state: dict, cfg: dict) -> None: """Write the full JSON to disk, then upload JSON + flattened CSV to HuggingFace.""" output_repo = cfg["output_dataset_repo"] hf_token = cfg.get("hf_token", "") hf_api = _get_hf_api(hf_token, output_repo) worker_id = state.get("prolific_pid") or state.get("user_id", "anonymous") submission_id = state.get("submission_id", str(uuid.uuid4())) safe_worker = "".join(c if c.isalnum() else "_" for c in str(worker_id)) print(f"[SAVE] starting save_and_upload") print(f"[SAVE] output_repo={output_repo}") print(f"[SAVE] hf_token set={bool(hf_token)}") # ── Write JSON ──────────────────────────────────────────────────────────── ann_dir = Path(cfg["annotations_dir"]) / safe_worker ann_dir.mkdir(parents=True, exist_ok=True) json_path = ann_dir / f"{submission_id}.json" export_state = dict(state) export_state["hf_study_identity"] = _hf_study_identity_snapshot(state, cfg) with open(json_path, "w") as f: json.dump(export_state, f, indent=2) print(f"[SAVE] JSON written: {json_path}") uploaded = False if hf_token: try: hf_api.upload_file( path_or_fileobj=str(json_path), path_in_repo=f"json/{safe_worker}/{submission_id}.json", repo_id=output_repo, repo_type="dataset", ) print("[HF] JSON uploaded.") uploaded = True except Exception as e: print(f"[HF] JSON upload error: {e}") if uploaded: # Release reservations so items are immediately available for re-assignment release_reservation(state.get("user_id", ""), cfg) # Record completion locally — updates counts immediately without waiting # for an HF re-scan. Also invalidates the HF cache. record_completion(state.get("user_id", ""), state.get("items", []), cfg) # Auto-pause Prolific study if all items are now covered try: from src.data import all_items_covered, pause_prolific_study if all_items_covered(cfg): pause_prolific_study(cfg) except Exception as e: print(f"[SAVE] Auto-pause check failed: {e}") # ── Write + upload CSV ──────────────────────────────────────────────────── _save_and_upload_csv(state, cfg, hf_api, safe_worker, submission_id) # ── CSV schema ──────────────────────────────────────────────────────────────── _COMMON_HEADER = [ "submission_id", "prolific_pid", "study_id", "session_id", "submission_time", "duration_seconds", "study_type", "model_name", "prompt_personalization", "prompt_detailed_instruction", "pair_selection_seed", "category", # Demographics "age", "gender", "geographic_region", "education_level", "race", "us_citizen", "marital_status", "religion", "religious_attendance", "political_affiliation", "income", "political_views", "household_size", "employment_status", # Background "movies_criteria", "movies_enjoy", "movies_avoid", "groceries_criteria", "groceries_enjoy", "groceries_avoid", # Ratings "pre_rating", "post_rating", "rating_delta", # Conversation "num_turns", "conversation_json", # Reflection "standout_moment", "thinking_change", ] _PREFERENCE_EXTRA_HEADER = [ "pair_index", "pair_id", "product_a_id", "product_a_title", "product_a_price", "product_b_id", "product_b_title", "product_b_price", "familiarity_a", "familiarity_b", ] _LIKELIHOOD_EXTRA_HEADER = [ "item_index", "item_id", "product_title", "product_price", "familiarity", ] _MODEL_COMPARISON_HEADER = [ "submission_id", "prolific_pid", "study_id", "session_id", "submission_time", "duration_seconds", "study_type", "pair_selection_seed", "comparison_models_canonical_json", "model_presentation_order_config_indices_json", "model_identity_timeline_json", "comparison_round_index", "user_model_label", "variant_key", "config_index", "config_variant_name", "model_name", "sampler_path", "use_demographics", "use_background", "prompt_personalization", "prompt_detailed_instruction", "category", "age", "gender", "geographic_region", "education_level", "race", "us_citizen", "marital_status", "religion", "religious_attendance", "political_affiliation", "income", "political_views", "household_size", "employment_status", "movies_criteria", "movies_enjoy", "movies_avoid", "groceries_criteria", "groceries_enjoy", "groceries_avoid", "pair_id", "product_a_id", "product_a_title", "product_a_price", "product_b_id", "product_b_title", "product_b_price", "familiarity_a", "familiarity_b", "pre_rating", "post_convincing", "post_more_likely_buy_target", "post_trustworthy_natural", "post_stood_out", "num_turns", "conversation_json", "final_preferred_user_label", "final_preferred_config_index", "final_preferred_variant_key", "final_rank_convincing_labels_json", "final_rank_buy_target_labels_json", "final_rank_convincing_config_indices_json", "final_rank_buy_target_config_indices_json", "user_label_to_config_index_json", ] def _save_and_upload_csv( state: dict, cfg: dict, hf_api: HfApi, safe_worker: str, submission_id: str ) -> None: study_type = cfg["study_type"] demographics = state.get("demographics", {}) background = state.get("background", {}) items = state.get("items", []) if study_type == "model_comparison": item = items[0] if items else {} fin = state.get("comparison_final") or {} pref = fin.get("preferred_user_label", "") pref_ci = fin.get("preferred_config_index", "") pref_vk = fin.get("preferred_variant_key", "") rank_c_labels = json.dumps(fin.get("rank_convincing_labels", [])) rank_b_labels = json.dumps(fin.get("rank_buy_target_labels", [])) rank_c_ci = json.dumps(fin.get("rank_convincing_config_indices", [])) rank_b_ci = json.dumps(fin.get("rank_buy_target_config_indices", [])) label_map = json.dumps(fin.get("user_label_to_config_index", {})) canon_j = json.dumps(state.get("comparison_models_canonical", [])) pres_j = json.dumps(state.get("model_presentation_order_config_indices", [])) time_j = json.dumps(state.get("model_identity_timeline", [])) pa, pb = item.get("product_a", {}), item.get("product_b", {}) rows: list[list] = [] for ridx, mr in enumerate(state.get("model_rounds", [])): mconf = mr.get("config", {}) conv = mr.get("conversation", {}) pr = mr.get("post_chat_ratings") or {} pv_eff = { "personalization": mconf.get( "personalization", mconf.get("use_demographics", False) or mconf.get("use_background", False), ), # Not passed to lsp get_seller_system_prompt; optional YAML audit only. "detailed_instruction": mconf.get("detailed_instruction", ""), } rows.append([ submission_id, state.get("prolific_pid", ""), state.get("study_id", ""), state.get("session_id", ""), state.get("meta", {}).get("submission_time", ""), state.get("meta", {}).get("duration_seconds", ""), study_type, cfg.get("pair_selection_seed", 42), canon_j, pres_j, time_j, ridx, mr.get("user_label", ""), mr.get("variant_key", ""), mr.get("config_index", ""), mconf.get("name", ""), mconf.get("model_name", ""), mconf.get("sampler_path", ""), mconf.get("use_demographics", False), mconf.get("use_background", False), pv_eff["personalization"], pv_eff["detailed_instruction"], item.get("category", ""), demographics.get("age", ""), demographics.get("gender", ""), demographics.get("geographic_region", ""), demographics.get("education_level", ""), demographics.get("race", ""), demographics.get("us_citizen", ""), demographics.get("marital_status", ""), demographics.get("religion", ""), demographics.get("religious_attendance", ""), demographics.get("political_affiliation", ""), demographics.get("income", ""), demographics.get("political_views", ""), demographics.get("household_size", ""), demographics.get("employment_status", ""), background.get("movies_criteria", ""), background.get("movies_enjoy", ""), background.get("movies_avoid", ""), background.get("groceries_criteria", ""), background.get("groceries_enjoy", ""), background.get("groceries_avoid", ""), item.get("pair_id", ""), pa.get("id", ""), pa.get("title", ""), pa.get("price", ""), pb.get("id", ""), pb.get("title", ""), pb.get("price", ""), item.get("familiarity_a", ""), item.get("familiarity_b", ""), item.get("pre_rating", ""), pr.get("convincing", ""), pr.get("more_likely_buy_target", ""), pr.get("trustworthy_natural", ""), pr.get("stood_out", ""), conv.get("num_turns", 0), json.dumps(conv.get("turns", [])), pref, pref_ci, pref_vk, rank_c_labels, rank_b_labels, rank_c_ci, rank_b_ci, label_map, ]) header = _MODEL_COMPARISON_HEADER else: header = _COMMON_HEADER + ( _PREFERENCE_EXTRA_HEADER if study_type == "preference" else _LIKELIHOOD_EXTRA_HEADER ) rows = [] for i, item in enumerate(items): conv = item.get("conversation", {}) refl = item.get("reflection", {}) pre = item.get("pre_rating", "") post = item.get("post_rating", "") delta = (post - pre) if isinstance(pre, int) and isinstance(post, int) else "" pv = item.get("prompt_variant", {}) common = [ submission_id, state.get("prolific_pid", ""), state.get("study_id", ""), state.get("session_id", ""), state.get("meta", {}).get("submission_time", ""), state.get("meta", {}).get("duration_seconds", ""), study_type, item.get("model_name", ""), pv.get("personalization", False), pv.get("detailed_instruction", True), cfg.get("pair_selection_seed", 42), item.get("category", ""), demographics.get("age", ""), demographics.get("gender", ""), demographics.get("geographic_region", ""), demographics.get("education_level", ""), demographics.get("race", ""), demographics.get("us_citizen", ""), demographics.get("marital_status", ""), demographics.get("religion", ""), demographics.get("religious_attendance", ""), demographics.get("political_affiliation",""), demographics.get("income", ""), demographics.get("political_views", ""), demographics.get("household_size", ""), demographics.get("employment_status", ""), background.get("movies_criteria", ""), background.get("movies_enjoy", ""), background.get("movies_avoid", ""), background.get("groceries_criteria", ""), background.get("groceries_enjoy", ""), background.get("groceries_avoid", ""), pre, post, delta, conv.get("num_turns", 0), json.dumps(conv.get("turns", [])), refl.get("standout_moment", ""), refl.get("thinking_change", ""), ] if study_type == "preference": pa, pb = item.get("product_a", {}), item.get("product_b", {}) extra = [ i + 1, item.get("pair_id", ""), pa.get("id", ""), pa.get("title", ""), pa.get("price", ""), pb.get("id", ""), pb.get("title", ""), pb.get("price", ""), item.get("familiarity_a", ""), item.get("familiarity_b", ""), ] else: prod = item.get("product", {}) extra = [ i + 1, item.get("item_id", ""), prod.get("title", ""), prod.get("price", ""), item.get("familiarity", ""), ] rows.append(common + extra) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") unique_tag = uuid.uuid4().hex[:8] repo_path = f"csv/{timestamp}_{safe_worker}_{unique_tag}.csv" with tempfile.NamedTemporaryFile( mode="w", suffix=".csv", delete=False, newline="", encoding="utf-8" ) as tmp: tmp_path = tmp.name writer = csv.writer(tmp) writer.writerow(header) writer.writerows(rows) if cfg.get("hf_token"): try: hf_api.upload_file( path_or_fileobj=tmp_path, path_in_repo=repo_path, repo_id=cfg["output_dataset_repo"], repo_type="dataset", ) print("[HF] CSV uploaded.") except Exception as e: print(f"[HF] CSV upload error: {e}") os.unlink(tmp_path)