Spaces:
Running
Running
| """Save session state to local disk and upload JSON + flattened CSV to HuggingFace.""" | |
| import csv | |
| import json | |
| import os | |
| import tempfile | |
| import uuid | |
| from datetime import datetime | |
| from pathlib import Path | |
| import streamlit as st | |
| from huggingface_hub import HfApi | |
| from src.data import release_reservation, record_completion | |
| def _get_hf_api(hf_token: str, output_repo: str) -> HfApi: | |
| """Initialise HF API client and ensure the output repo exists.""" | |
| api = HfApi(token=hf_token) if hf_token else HfApi() | |
| if hf_token: | |
| try: | |
| api.repo_info(repo_id=output_repo, repo_type="dataset") | |
| except Exception as e: | |
| if "404" in str(e) or "not found" in str(e).lower(): | |
| api.create_repo(repo_id=output_repo, repo_type="dataset", private=True) | |
| print(f"[HF] Created output repo: {output_repo}") | |
| else: | |
| print(f"[HF] Warning checking repo existence: {e}") | |
| return api | |
| def save_and_upload(state: dict, cfg: dict) -> None: | |
| """Write the full JSON to disk, then upload JSON + flattened CSV to HuggingFace.""" | |
| output_repo = cfg["output_dataset_repo"] | |
| hf_token = cfg.get("hf_token", "") | |
| hf_api = _get_hf_api(hf_token, output_repo) | |
| worker_id = state.get("prolific_pid") or state.get("user_id", "anonymous") | |
| submission_id = state.get("submission_id", str(uuid.uuid4())) | |
| safe_worker = "".join(c if c.isalnum() else "_" for c in str(worker_id)) | |
| print(f"[SAVE] starting save_and_upload") | |
| print(f"[SAVE] output_repo={output_repo}") | |
| print(f"[SAVE] hf_token set={bool(hf_token)}") | |
| # ── Write JSON ──────────────────────────────────────────────────────────── | |
| ann_dir = Path(cfg["annotations_dir"]) / safe_worker | |
| ann_dir.mkdir(parents=True, exist_ok=True) | |
| json_path = ann_dir / f"{submission_id}.json" | |
| with open(json_path, "w") as f: | |
| json.dump(state, f, indent=2) | |
| print(f"[SAVE] JSON written: {json_path}") | |
| uploaded = False | |
| if hf_token: | |
| try: | |
| hf_api.upload_file( | |
| path_or_fileobj=str(json_path), | |
| path_in_repo=f"json/{safe_worker}/{submission_id}.json", | |
| repo_id=output_repo, | |
| repo_type="dataset", | |
| ) | |
| print("[HF] JSON uploaded.") | |
| uploaded = True | |
| except Exception as e: | |
| print(f"[HF] JSON upload error: {e}") | |
| if uploaded: | |
| # Release reservations so items are immediately available for re-assignment | |
| release_reservation(state.get("user_id", ""), cfg) | |
| # Record completion locally — updates counts immediately without waiting | |
| # for an HF re-scan. Also invalidates the HF cache. | |
| record_completion(state.get("user_id", ""), state.get("items", []), cfg) | |
| # Auto-pause Prolific study if all items are now covered | |
| try: | |
| from src.data import all_items_covered, pause_prolific_study | |
| if all_items_covered(cfg): | |
| pause_prolific_study(cfg) | |
| except Exception as e: | |
| print(f"[SAVE] Auto-pause check failed: {e}") | |
| # ── Write + upload CSV ──────────────────────────────────────────────────── | |
| _save_and_upload_csv(state, cfg, hf_api, safe_worker, submission_id) | |
| # ── CSV schema ──────────────────────────────────────────────────────────────── | |
| _COMMON_HEADER = [ | |
| "submission_id", "prolific_pid", "study_id", "session_id", | |
| "submission_time", "duration_seconds", | |
| "study_type", "model_name", | |
| "prompt_personalization", "prompt_detailed_instruction", | |
| "pair_selection_seed", "category", | |
| # Demographics | |
| "age", "gender", "geographic_region", "education_level", "race", | |
| "us_citizen", "marital_status", "religion", "religious_attendance", | |
| "political_affiliation", "income", "political_views", | |
| "household_size", "employment_status", | |
| # Background | |
| "movies_criteria", "movies_enjoy", "movies_avoid", | |
| "groceries_criteria", "groceries_enjoy", "groceries_avoid", | |
| # Ratings | |
| "pre_rating", "post_rating", "rating_delta", | |
| # Conversation | |
| "num_turns", "conversation_json", | |
| # Reflection | |
| "standout_moment", "thinking_change", | |
| ] | |
| _PREFERENCE_EXTRA_HEADER = [ | |
| "pair_index", "pair_id", | |
| "product_a_id", "product_a_title", "product_a_price", | |
| "product_b_id", "product_b_title", "product_b_price", | |
| "familiarity_a", "familiarity_b", | |
| ] | |
| _LIKELIHOOD_EXTRA_HEADER = [ | |
| "item_index", "item_id", | |
| "product_title", "product_price", | |
| "familiarity", | |
| ] | |
| def _save_and_upload_csv( | |
| state: dict, cfg: dict, hf_api: HfApi, safe_worker: str, submission_id: str | |
| ) -> None: | |
| study_type = cfg["study_type"] | |
| demographics = state.get("demographics", {}) | |
| background = state.get("background", {}) | |
| items = state.get("items", []) | |
| header = _COMMON_HEADER + ( | |
| _PREFERENCE_EXTRA_HEADER if study_type == "preference" | |
| else _LIKELIHOOD_EXTRA_HEADER | |
| ) | |
| rows = [] | |
| for i, item in enumerate(items): | |
| conv = item.get("conversation", {}) | |
| refl = item.get("reflection", {}) | |
| pre = item.get("pre_rating", "") | |
| post = item.get("post_rating", "") | |
| delta = (post - pre) if isinstance(pre, int) and isinstance(post, int) else "" | |
| pv = item.get("prompt_variant", {}) | |
| common = [ | |
| submission_id, | |
| state.get("prolific_pid", ""), | |
| state.get("study_id", ""), | |
| state.get("session_id", ""), | |
| state.get("meta", {}).get("submission_time", ""), | |
| state.get("meta", {}).get("duration_seconds", ""), | |
| study_type, | |
| item.get("model_name", ""), | |
| pv.get("personalization", False), | |
| pv.get("detailed_instruction", True), | |
| cfg.get("pair_selection_seed", 42), | |
| item.get("category", ""), | |
| demographics.get("age", ""), | |
| demographics.get("gender", ""), | |
| demographics.get("geographic_region", ""), | |
| demographics.get("education_level", ""), | |
| demographics.get("race", ""), | |
| demographics.get("us_citizen", ""), | |
| demographics.get("marital_status", ""), | |
| demographics.get("religion", ""), | |
| demographics.get("religious_attendance", ""), | |
| demographics.get("political_affiliation",""), | |
| demographics.get("income", ""), | |
| demographics.get("political_views", ""), | |
| demographics.get("household_size", ""), | |
| demographics.get("employment_status", ""), | |
| background.get("movies_criteria", ""), | |
| background.get("movies_enjoy", ""), | |
| background.get("movies_avoid", ""), | |
| background.get("groceries_criteria", ""), | |
| background.get("groceries_enjoy", ""), | |
| background.get("groceries_avoid", ""), | |
| pre, post, delta, | |
| conv.get("num_turns", 0), | |
| json.dumps(conv.get("turns", [])), | |
| refl.get("standout_moment", ""), | |
| refl.get("thinking_change", ""), | |
| ] | |
| if study_type == "preference": | |
| pa, pb = item.get("product_a", {}), item.get("product_b", {}) | |
| extra = [ | |
| i + 1, | |
| item.get("pair_id", ""), | |
| pa.get("id", ""), pa.get("title", ""), pa.get("price", ""), | |
| pb.get("id", ""), pb.get("title", ""), pb.get("price", ""), | |
| item.get("familiarity_a", ""), | |
| item.get("familiarity_b", ""), | |
| ] | |
| else: | |
| prod = item.get("product", {}) | |
| extra = [ | |
| i + 1, | |
| item.get("item_id", ""), | |
| prod.get("title", ""), prod.get("price", ""), | |
| item.get("familiarity", ""), | |
| ] | |
| rows.append(common + extra) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| unique_tag = uuid.uuid4().hex[:8] | |
| repo_path = f"csv/{timestamp}_{safe_worker}_{unique_tag}.csv" | |
| with tempfile.NamedTemporaryFile( | |
| mode="w", suffix=".csv", delete=False, newline="", encoding="utf-8" | |
| ) as tmp: | |
| tmp_path = tmp.name | |
| writer = csv.writer(tmp) | |
| writer.writerow(header) | |
| writer.writerows(rows) | |
| if cfg.get("hf_token"): | |
| try: | |
| hf_api.upload_file( | |
| path_or_fileobj=tmp_path, | |
| path_in_repo=repo_path, | |
| repo_id=cfg["output_dataset_repo"], | |
| repo_type="dataset", | |
| ) | |
| print("[HF] CSV uploaded.") | |
| except Exception as e: | |
| print(f"[HF] CSV upload error: {e}") | |
| os.unlink(tmp_path) |