| """Save session state to local disk and upload JSON + flattened CSV to HuggingFace.""" |
| import csv |
| import json |
| import os |
| import tempfile |
| import uuid |
| from datetime import datetime |
| from pathlib import Path |
|
|
| import streamlit as st |
| from huggingface_hub import HfApi |
| from src.data import release_reservation, record_completion |
|
|
|
|
| @st.cache_resource |
| def _get_hf_api(hf_token: str, output_repo: str) -> HfApi: |
| """Initialise HF API client and ensure the output repo exists.""" |
| api = HfApi(token=hf_token) if hf_token else HfApi() |
| if hf_token: |
| try: |
| api.repo_info(repo_id=output_repo, repo_type="dataset") |
| except Exception as e: |
| if "404" in str(e) or "not found" in str(e).lower(): |
| api.create_repo(repo_id=output_repo, repo_type="dataset", private=True) |
| print(f"[HF] Created output repo: {output_repo}") |
| else: |
| print(f"[HF] Warning checking repo existence: {e}") |
| return api |
|
|
|
|
| def _hf_study_identity_snapshot(state: dict, cfg: dict) -> dict: |
| """ |
| Compact, non-secret metadata duplicated into every HF JSON export so files |
| are self-describing (especially which YAML model entry each participant saw). |
| """ |
| snap: dict = { |
| "study_type": cfg.get("study_type"), |
| "pair_selection_seed": cfg.get("pair_selection_seed"), |
| "categories": cfg.get("categories"), |
| "pairs_per_user": cfg.get("pairs_per_user"), |
| "min_turns": cfg.get("min_turns"), |
| "max_turns": cfg.get("max_turns"), |
| "prolific_study_id": cfg.get("prolific_study_id"), |
| "output_dataset_repo": cfg.get("output_dataset_repo"), |
| } |
| if cfg.get("study_type") != "model_comparison": |
| return snap |
| snap.update( |
| { |
| "submission_id": state.get("submission_id"), |
| "session_user_id": state.get("user_id"), |
| "comparison_models_canonical": state.get("comparison_models_canonical"), |
| "model_presentation_order_config_indices": state.get( |
| "model_presentation_order_config_indices" |
| ), |
| "model_identity_timeline": state.get("model_identity_timeline"), |
| "comparison_final": state.get("comparison_final"), |
| } |
| ) |
| return snap |
|
|
|
|
| def save_and_upload(state: dict, cfg: dict) -> None: |
| """Write the full JSON to disk, then upload JSON + flattened CSV to HuggingFace.""" |
| output_repo = cfg["output_dataset_repo"] |
| hf_token = cfg.get("hf_token", "") |
| hf_api = _get_hf_api(hf_token, output_repo) |
|
|
| worker_id = state.get("prolific_pid") or state.get("user_id", "anonymous") |
| submission_id = state.get("submission_id", str(uuid.uuid4())) |
| safe_worker = "".join(c if c.isalnum() else "_" for c in str(worker_id)) |
|
|
| print(f"[SAVE] starting save_and_upload") |
| print(f"[SAVE] output_repo={output_repo}") |
| print(f"[SAVE] hf_token set={bool(hf_token)}") |
|
|
| |
| ann_dir = Path(cfg["annotations_dir"]) / safe_worker |
| ann_dir.mkdir(parents=True, exist_ok=True) |
| json_path = ann_dir / f"{submission_id}.json" |
|
|
| export_state = dict(state) |
| export_state["hf_study_identity"] = _hf_study_identity_snapshot(state, cfg) |
|
|
| with open(json_path, "w") as f: |
| json.dump(export_state, f, indent=2) |
| print(f"[SAVE] JSON written: {json_path}") |
|
|
| uploaded = False |
| if hf_token: |
| try: |
| hf_api.upload_file( |
| path_or_fileobj=str(json_path), |
| path_in_repo=f"json/{safe_worker}/{submission_id}.json", |
| repo_id=output_repo, |
| repo_type="dataset", |
| ) |
| print("[HF] JSON uploaded.") |
| uploaded = True |
| except Exception as e: |
| print(f"[HF] JSON upload error: {e}") |
|
|
| if uploaded: |
| |
| release_reservation(state.get("user_id", ""), cfg) |
| |
| |
| record_completion(state.get("user_id", ""), state.get("items", []), cfg) |
| |
| try: |
| from src.data import all_items_covered, pause_prolific_study |
| if all_items_covered(cfg): |
| pause_prolific_study(cfg) |
| except Exception as e: |
| print(f"[SAVE] Auto-pause check failed: {e}") |
|
|
| |
| _save_and_upload_csv(state, cfg, hf_api, safe_worker, submission_id) |
|
|
|
|
| |
|
|
| _COMMON_HEADER = [ |
| "submission_id", "prolific_pid", "study_id", "session_id", |
| "submission_time", "duration_seconds", |
| "study_type", "model_name", |
| "prompt_personalization", "prompt_detailed_instruction", |
| "pair_selection_seed", "category", |
| |
| "age", "gender", "geographic_region", "education_level", "race", |
| "us_citizen", "marital_status", "religion", "religious_attendance", |
| "political_affiliation", "income", "political_views", |
| "household_size", "employment_status", |
| |
| "movies_criteria", "movies_enjoy", "movies_avoid", |
| "groceries_criteria", "groceries_enjoy", "groceries_avoid", |
| |
| "pre_rating", "post_rating", "rating_delta", |
| |
| "num_turns", "conversation_json", |
| |
| "standout_moment", "thinking_change", |
| ] |
|
|
| _PREFERENCE_EXTRA_HEADER = [ |
| "pair_index", "pair_id", |
| "product_a_id", "product_a_title", "product_a_price", |
| "product_b_id", "product_b_title", "product_b_price", |
| "familiarity_a", "familiarity_b", |
| ] |
|
|
| _LIKELIHOOD_EXTRA_HEADER = [ |
| "item_index", "item_id", |
| "product_title", "product_price", |
| "familiarity", |
| ] |
|
|
| _MODEL_COMPARISON_HEADER = [ |
| "submission_id", "prolific_pid", "study_id", "session_id", |
| "submission_time", "duration_seconds", |
| "study_type", "pair_selection_seed", |
| "comparison_models_canonical_json", |
| "model_presentation_order_config_indices_json", |
| "model_identity_timeline_json", |
| "comparison_round_index", "user_model_label", "variant_key", |
| "config_index", "config_variant_name", |
| "model_name", "sampler_path", |
| "use_demographics", "use_background", "prompt_personalization", "prompt_detailed_instruction", |
| "category", |
| "age", "gender", "geographic_region", "education_level", "race", |
| "us_citizen", "marital_status", "religion", "religious_attendance", |
| "political_affiliation", "income", "political_views", |
| "household_size", "employment_status", |
| "movies_criteria", "movies_enjoy", "movies_avoid", |
| "groceries_criteria", "groceries_enjoy", "groceries_avoid", |
| "pair_id", |
| "product_a_id", "product_a_title", "product_a_price", |
| "product_b_id", "product_b_title", "product_b_price", |
| "familiarity_a", "familiarity_b", "pre_rating", |
| "post_convincing", "post_more_likely_buy_target", "post_trustworthy_natural", |
| "post_stood_out", |
| "num_turns", "conversation_json", |
| "final_preferred_user_label", |
| "final_preferred_config_index", "final_preferred_variant_key", |
| "final_rank_convincing_labels_json", |
| "final_rank_buy_target_labels_json", |
| "final_rank_convincing_config_indices_json", |
| "final_rank_buy_target_config_indices_json", |
| "user_label_to_config_index_json", |
| ] |
|
|
|
|
| def _save_and_upload_csv( |
| state: dict, cfg: dict, hf_api: HfApi, safe_worker: str, submission_id: str |
| ) -> None: |
| study_type = cfg["study_type"] |
| demographics = state.get("demographics", {}) |
| background = state.get("background", {}) |
| items = state.get("items", []) |
|
|
| if study_type == "model_comparison": |
| item = items[0] if items else {} |
| fin = state.get("comparison_final") or {} |
| pref = fin.get("preferred_user_label", "") |
| pref_ci = fin.get("preferred_config_index", "") |
| pref_vk = fin.get("preferred_variant_key", "") |
| rank_c_labels = json.dumps(fin.get("rank_convincing_labels", [])) |
| rank_b_labels = json.dumps(fin.get("rank_buy_target_labels", [])) |
| rank_c_ci = json.dumps(fin.get("rank_convincing_config_indices", [])) |
| rank_b_ci = json.dumps(fin.get("rank_buy_target_config_indices", [])) |
| label_map = json.dumps(fin.get("user_label_to_config_index", {})) |
|
|
| canon_j = json.dumps(state.get("comparison_models_canonical", [])) |
| pres_j = json.dumps(state.get("model_presentation_order_config_indices", [])) |
| time_j = json.dumps(state.get("model_identity_timeline", [])) |
|
|
| pa, pb = item.get("product_a", {}), item.get("product_b", {}) |
|
|
| rows: list[list] = [] |
| for ridx, mr in enumerate(state.get("model_rounds", [])): |
| mconf = mr.get("config", {}) |
| conv = mr.get("conversation", {}) |
| pr = mr.get("post_chat_ratings") or {} |
| pv_eff = { |
| "personalization": mconf.get( |
| "personalization", |
| mconf.get("use_demographics", False) or mconf.get("use_background", False), |
| ), |
| |
| "detailed_instruction": mconf.get("detailed_instruction", ""), |
| } |
| rows.append([ |
| submission_id, |
| state.get("prolific_pid", ""), |
| state.get("study_id", ""), |
| state.get("session_id", ""), |
| state.get("meta", {}).get("submission_time", ""), |
| state.get("meta", {}).get("duration_seconds", ""), |
| study_type, |
| cfg.get("pair_selection_seed", 42), |
| canon_j, |
| pres_j, |
| time_j, |
| ridx, |
| mr.get("user_label", ""), |
| mr.get("variant_key", ""), |
| mr.get("config_index", ""), |
| mconf.get("name", ""), |
| mconf.get("model_name", ""), |
| mconf.get("sampler_path", ""), |
| mconf.get("use_demographics", False), |
| mconf.get("use_background", False), |
| pv_eff["personalization"], |
| pv_eff["detailed_instruction"], |
| item.get("category", ""), |
| demographics.get("age", ""), |
| demographics.get("gender", ""), |
| demographics.get("geographic_region", ""), |
| demographics.get("education_level", ""), |
| demographics.get("race", ""), |
| demographics.get("us_citizen", ""), |
| demographics.get("marital_status", ""), |
| demographics.get("religion", ""), |
| demographics.get("religious_attendance", ""), |
| demographics.get("political_affiliation", ""), |
| demographics.get("income", ""), |
| demographics.get("political_views", ""), |
| demographics.get("household_size", ""), |
| demographics.get("employment_status", ""), |
| background.get("movies_criteria", ""), |
| background.get("movies_enjoy", ""), |
| background.get("movies_avoid", ""), |
| background.get("groceries_criteria", ""), |
| background.get("groceries_enjoy", ""), |
| background.get("groceries_avoid", ""), |
| item.get("pair_id", ""), |
| pa.get("id", ""), pa.get("title", ""), pa.get("price", ""), |
| pb.get("id", ""), pb.get("title", ""), pb.get("price", ""), |
| item.get("familiarity_a", ""), |
| item.get("familiarity_b", ""), |
| item.get("pre_rating", ""), |
| pr.get("convincing", ""), |
| pr.get("more_likely_buy_target", ""), |
| pr.get("trustworthy_natural", ""), |
| pr.get("stood_out", ""), |
| conv.get("num_turns", 0), |
| json.dumps(conv.get("turns", [])), |
| pref, |
| pref_ci, |
| pref_vk, |
| rank_c_labels, |
| rank_b_labels, |
| rank_c_ci, |
| rank_b_ci, |
| label_map, |
| ]) |
|
|
| header = _MODEL_COMPARISON_HEADER |
| else: |
| header = _COMMON_HEADER + ( |
| _PREFERENCE_EXTRA_HEADER if study_type == "preference" |
| else _LIKELIHOOD_EXTRA_HEADER |
| ) |
| rows = [] |
|
|
| for i, item in enumerate(items): |
| conv = item.get("conversation", {}) |
| refl = item.get("reflection", {}) |
| pre = item.get("pre_rating", "") |
| post = item.get("post_rating", "") |
| delta = (post - pre) if isinstance(pre, int) and isinstance(post, int) else "" |
| pv = item.get("prompt_variant", {}) |
|
|
| common = [ |
| submission_id, |
| state.get("prolific_pid", ""), |
| state.get("study_id", ""), |
| state.get("session_id", ""), |
| state.get("meta", {}).get("submission_time", ""), |
| state.get("meta", {}).get("duration_seconds", ""), |
| study_type, |
| item.get("model_name", ""), |
| pv.get("personalization", False), |
| pv.get("detailed_instruction", True), |
| cfg.get("pair_selection_seed", 42), |
| item.get("category", ""), |
| demographics.get("age", ""), |
| demographics.get("gender", ""), |
| demographics.get("geographic_region", ""), |
| demographics.get("education_level", ""), |
| demographics.get("race", ""), |
| demographics.get("us_citizen", ""), |
| demographics.get("marital_status", ""), |
| demographics.get("religion", ""), |
| demographics.get("religious_attendance", ""), |
| demographics.get("political_affiliation",""), |
| demographics.get("income", ""), |
| demographics.get("political_views", ""), |
| demographics.get("household_size", ""), |
| demographics.get("employment_status", ""), |
| background.get("movies_criteria", ""), |
| background.get("movies_enjoy", ""), |
| background.get("movies_avoid", ""), |
| background.get("groceries_criteria", ""), |
| background.get("groceries_enjoy", ""), |
| background.get("groceries_avoid", ""), |
| pre, post, delta, |
| conv.get("num_turns", 0), |
| json.dumps(conv.get("turns", [])), |
| refl.get("standout_moment", ""), |
| refl.get("thinking_change", ""), |
| ] |
|
|
| if study_type == "preference": |
| pa, pb = item.get("product_a", {}), item.get("product_b", {}) |
| extra = [ |
| i + 1, |
| item.get("pair_id", ""), |
| pa.get("id", ""), pa.get("title", ""), pa.get("price", ""), |
| pb.get("id", ""), pb.get("title", ""), pb.get("price", ""), |
| item.get("familiarity_a", ""), |
| item.get("familiarity_b", ""), |
| ] |
| else: |
| prod = item.get("product", {}) |
| extra = [ |
| i + 1, |
| item.get("item_id", ""), |
| prod.get("title", ""), prod.get("price", ""), |
| item.get("familiarity", ""), |
| ] |
|
|
| rows.append(common + extra) |
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| unique_tag = uuid.uuid4().hex[:8] |
| repo_path = f"csv/{timestamp}_{safe_worker}_{unique_tag}.csv" |
|
|
| with tempfile.NamedTemporaryFile( |
| mode="w", suffix=".csv", delete=False, newline="", encoding="utf-8" |
| ) as tmp: |
| tmp_path = tmp.name |
| writer = csv.writer(tmp) |
| writer.writerow(header) |
| writer.writerows(rows) |
|
|
| if cfg.get("hf_token"): |
| try: |
| hf_api.upload_file( |
| path_or_fileobj=tmp_path, |
| path_in_repo=repo_path, |
| repo_id=cfg["output_dataset_repo"], |
| repo_type="dataset", |
| ) |
| print("[HF] CSV uploaded.") |
| except Exception as e: |
| print(f"[HF] CSV upload error: {e}") |
|
|
| os.unlink(tmp_path) |