ehejin's picture
new jr user study
b58a0fc
"""Save session state to local disk and upload JSON + flattened CSV to HuggingFace."""
import csv
import json
import os
import tempfile
import uuid
from datetime import datetime
from pathlib import Path
import streamlit as st
from huggingface_hub import HfApi
from src.data import release_reservation, record_completion
@st.cache_resource
def _get_hf_api(hf_token: str, output_repo: str) -> HfApi:
"""Initialise HF API client and ensure the output repo exists."""
api = HfApi(token=hf_token) if hf_token else HfApi()
if hf_token:
try:
api.repo_info(repo_id=output_repo, repo_type="dataset")
except Exception as e:
if "404" in str(e) or "not found" in str(e).lower():
api.create_repo(repo_id=output_repo, repo_type="dataset", private=True)
print(f"[HF] Created output repo: {output_repo}")
else:
print(f"[HF] Warning checking repo existence: {e}")
return api
def _hf_study_identity_snapshot(state: dict, cfg: dict) -> dict:
"""
Compact, non-secret metadata duplicated into every HF JSON export so files
are self-describing (especially which YAML model entry each participant saw).
"""
snap: dict = {
"study_type": cfg.get("study_type"),
"pair_selection_seed": cfg.get("pair_selection_seed"),
"categories": cfg.get("categories"),
"pairs_per_user": cfg.get("pairs_per_user"),
"min_turns": cfg.get("min_turns"),
"max_turns": cfg.get("max_turns"),
"prolific_study_id": cfg.get("prolific_study_id"),
"output_dataset_repo": cfg.get("output_dataset_repo"),
}
if cfg.get("study_type") != "model_comparison":
return snap
snap.update(
{
"submission_id": state.get("submission_id"),
"session_user_id": state.get("user_id"),
"comparison_models_canonical": state.get("comparison_models_canonical"),
"model_presentation_order_config_indices": state.get(
"model_presentation_order_config_indices"
),
"model_identity_timeline": state.get("model_identity_timeline"),
"comparison_final": state.get("comparison_final"),
}
)
return snap
def save_and_upload(state: dict, cfg: dict) -> None:
"""Write the full JSON to disk, then upload JSON + flattened CSV to HuggingFace."""
output_repo = cfg["output_dataset_repo"]
hf_token = cfg.get("hf_token", "")
hf_api = _get_hf_api(hf_token, output_repo)
worker_id = state.get("prolific_pid") or state.get("user_id", "anonymous")
submission_id = state.get("submission_id", str(uuid.uuid4()))
safe_worker = "".join(c if c.isalnum() else "_" for c in str(worker_id))
print(f"[SAVE] starting save_and_upload")
print(f"[SAVE] output_repo={output_repo}")
print(f"[SAVE] hf_token set={bool(hf_token)}")
# ── Write JSON ────────────────────────────────────────────────────────────
ann_dir = Path(cfg["annotations_dir"]) / safe_worker
ann_dir.mkdir(parents=True, exist_ok=True)
json_path = ann_dir / f"{submission_id}.json"
export_state = dict(state)
export_state["hf_study_identity"] = _hf_study_identity_snapshot(state, cfg)
with open(json_path, "w") as f:
json.dump(export_state, f, indent=2)
print(f"[SAVE] JSON written: {json_path}")
uploaded = False
if hf_token:
try:
hf_api.upload_file(
path_or_fileobj=str(json_path),
path_in_repo=f"json/{safe_worker}/{submission_id}.json",
repo_id=output_repo,
repo_type="dataset",
)
print("[HF] JSON uploaded.")
uploaded = True
except Exception as e:
print(f"[HF] JSON upload error: {e}")
if uploaded:
# Release reservations so items are immediately available for re-assignment
release_reservation(state.get("user_id", ""), cfg)
# Record completion locally — updates counts immediately without waiting
# for an HF re-scan. Also invalidates the HF cache.
record_completion(state.get("user_id", ""), state.get("items", []), cfg)
# Auto-pause Prolific study if all items are now covered
try:
from src.data import all_items_covered, pause_prolific_study
if all_items_covered(cfg):
pause_prolific_study(cfg)
except Exception as e:
print(f"[SAVE] Auto-pause check failed: {e}")
# ── Write + upload CSV ────────────────────────────────────────────────────
_save_and_upload_csv(state, cfg, hf_api, safe_worker, submission_id)
# ── CSV schema ────────────────────────────────────────────────────────────────
_COMMON_HEADER = [
"submission_id", "prolific_pid", "study_id", "session_id",
"submission_time", "duration_seconds",
"study_type", "model_name",
"prompt_personalization", "prompt_detailed_instruction",
"pair_selection_seed", "category",
# Demographics
"age", "gender", "geographic_region", "education_level", "race",
"us_citizen", "marital_status", "religion", "religious_attendance",
"political_affiliation", "income", "political_views",
"household_size", "employment_status",
# Background
"movies_criteria", "movies_enjoy", "movies_avoid",
"groceries_criteria", "groceries_enjoy", "groceries_avoid",
# Ratings
"pre_rating", "post_rating", "rating_delta",
# Conversation
"num_turns", "conversation_json",
# Reflection
"standout_moment", "thinking_change",
]
_PREFERENCE_EXTRA_HEADER = [
"pair_index", "pair_id",
"product_a_id", "product_a_title", "product_a_price",
"product_b_id", "product_b_title", "product_b_price",
"familiarity_a", "familiarity_b",
]
_LIKELIHOOD_EXTRA_HEADER = [
"item_index", "item_id",
"product_title", "product_price",
"familiarity",
]
_MODEL_COMPARISON_HEADER = [
"submission_id", "prolific_pid", "study_id", "session_id",
"submission_time", "duration_seconds",
"study_type", "pair_selection_seed",
"comparison_models_canonical_json",
"model_presentation_order_config_indices_json",
"model_identity_timeline_json",
"comparison_round_index", "user_model_label", "variant_key",
"config_index", "config_variant_name",
"model_name", "sampler_path",
"use_demographics", "use_background", "prompt_personalization", "prompt_detailed_instruction",
"category",
"age", "gender", "geographic_region", "education_level", "race",
"us_citizen", "marital_status", "religion", "religious_attendance",
"political_affiliation", "income", "political_views",
"household_size", "employment_status",
"movies_criteria", "movies_enjoy", "movies_avoid",
"groceries_criteria", "groceries_enjoy", "groceries_avoid",
"pair_id",
"product_a_id", "product_a_title", "product_a_price",
"product_b_id", "product_b_title", "product_b_price",
"familiarity_a", "familiarity_b", "pre_rating",
"post_convincing", "post_more_likely_buy_target", "post_trustworthy_natural",
"post_stood_out",
"num_turns", "conversation_json",
"final_preferred_user_label",
"final_preferred_config_index", "final_preferred_variant_key",
"final_rank_convincing_labels_json",
"final_rank_buy_target_labels_json",
"final_rank_convincing_config_indices_json",
"final_rank_buy_target_config_indices_json",
"user_label_to_config_index_json",
]
def _save_and_upload_csv(
state: dict, cfg: dict, hf_api: HfApi, safe_worker: str, submission_id: str
) -> None:
study_type = cfg["study_type"]
demographics = state.get("demographics", {})
background = state.get("background", {})
items = state.get("items", [])
if study_type == "model_comparison":
item = items[0] if items else {}
fin = state.get("comparison_final") or {}
pref = fin.get("preferred_user_label", "")
pref_ci = fin.get("preferred_config_index", "")
pref_vk = fin.get("preferred_variant_key", "")
rank_c_labels = json.dumps(fin.get("rank_convincing_labels", []))
rank_b_labels = json.dumps(fin.get("rank_buy_target_labels", []))
rank_c_ci = json.dumps(fin.get("rank_convincing_config_indices", []))
rank_b_ci = json.dumps(fin.get("rank_buy_target_config_indices", []))
label_map = json.dumps(fin.get("user_label_to_config_index", {}))
canon_j = json.dumps(state.get("comparison_models_canonical", []))
pres_j = json.dumps(state.get("model_presentation_order_config_indices", []))
time_j = json.dumps(state.get("model_identity_timeline", []))
pa, pb = item.get("product_a", {}), item.get("product_b", {})
rows: list[list] = []
for ridx, mr in enumerate(state.get("model_rounds", [])):
mconf = mr.get("config", {})
conv = mr.get("conversation", {})
pr = mr.get("post_chat_ratings") or {}
pv_eff = {
"personalization": mconf.get(
"personalization",
mconf.get("use_demographics", False) or mconf.get("use_background", False),
),
# Not passed to lsp get_seller_system_prompt; optional YAML audit only.
"detailed_instruction": mconf.get("detailed_instruction", ""),
}
rows.append([
submission_id,
state.get("prolific_pid", ""),
state.get("study_id", ""),
state.get("session_id", ""),
state.get("meta", {}).get("submission_time", ""),
state.get("meta", {}).get("duration_seconds", ""),
study_type,
cfg.get("pair_selection_seed", 42),
canon_j,
pres_j,
time_j,
ridx,
mr.get("user_label", ""),
mr.get("variant_key", ""),
mr.get("config_index", ""),
mconf.get("name", ""),
mconf.get("model_name", ""),
mconf.get("sampler_path", ""),
mconf.get("use_demographics", False),
mconf.get("use_background", False),
pv_eff["personalization"],
pv_eff["detailed_instruction"],
item.get("category", ""),
demographics.get("age", ""),
demographics.get("gender", ""),
demographics.get("geographic_region", ""),
demographics.get("education_level", ""),
demographics.get("race", ""),
demographics.get("us_citizen", ""),
demographics.get("marital_status", ""),
demographics.get("religion", ""),
demographics.get("religious_attendance", ""),
demographics.get("political_affiliation", ""),
demographics.get("income", ""),
demographics.get("political_views", ""),
demographics.get("household_size", ""),
demographics.get("employment_status", ""),
background.get("movies_criteria", ""),
background.get("movies_enjoy", ""),
background.get("movies_avoid", ""),
background.get("groceries_criteria", ""),
background.get("groceries_enjoy", ""),
background.get("groceries_avoid", ""),
item.get("pair_id", ""),
pa.get("id", ""), pa.get("title", ""), pa.get("price", ""),
pb.get("id", ""), pb.get("title", ""), pb.get("price", ""),
item.get("familiarity_a", ""),
item.get("familiarity_b", ""),
item.get("pre_rating", ""),
pr.get("convincing", ""),
pr.get("more_likely_buy_target", ""),
pr.get("trustworthy_natural", ""),
pr.get("stood_out", ""),
conv.get("num_turns", 0),
json.dumps(conv.get("turns", [])),
pref,
pref_ci,
pref_vk,
rank_c_labels,
rank_b_labels,
rank_c_ci,
rank_b_ci,
label_map,
])
header = _MODEL_COMPARISON_HEADER
else:
header = _COMMON_HEADER + (
_PREFERENCE_EXTRA_HEADER if study_type == "preference"
else _LIKELIHOOD_EXTRA_HEADER
)
rows = []
for i, item in enumerate(items):
conv = item.get("conversation", {})
refl = item.get("reflection", {})
pre = item.get("pre_rating", "")
post = item.get("post_rating", "")
delta = (post - pre) if isinstance(pre, int) and isinstance(post, int) else ""
pv = item.get("prompt_variant", {})
common = [
submission_id,
state.get("prolific_pid", ""),
state.get("study_id", ""),
state.get("session_id", ""),
state.get("meta", {}).get("submission_time", ""),
state.get("meta", {}).get("duration_seconds", ""),
study_type,
item.get("model_name", ""),
pv.get("personalization", False),
pv.get("detailed_instruction", True),
cfg.get("pair_selection_seed", 42),
item.get("category", ""),
demographics.get("age", ""),
demographics.get("gender", ""),
demographics.get("geographic_region", ""),
demographics.get("education_level", ""),
demographics.get("race", ""),
demographics.get("us_citizen", ""),
demographics.get("marital_status", ""),
demographics.get("religion", ""),
demographics.get("religious_attendance", ""),
demographics.get("political_affiliation",""),
demographics.get("income", ""),
demographics.get("political_views", ""),
demographics.get("household_size", ""),
demographics.get("employment_status", ""),
background.get("movies_criteria", ""),
background.get("movies_enjoy", ""),
background.get("movies_avoid", ""),
background.get("groceries_criteria", ""),
background.get("groceries_enjoy", ""),
background.get("groceries_avoid", ""),
pre, post, delta,
conv.get("num_turns", 0),
json.dumps(conv.get("turns", [])),
refl.get("standout_moment", ""),
refl.get("thinking_change", ""),
]
if study_type == "preference":
pa, pb = item.get("product_a", {}), item.get("product_b", {})
extra = [
i + 1,
item.get("pair_id", ""),
pa.get("id", ""), pa.get("title", ""), pa.get("price", ""),
pb.get("id", ""), pb.get("title", ""), pb.get("price", ""),
item.get("familiarity_a", ""),
item.get("familiarity_b", ""),
]
else:
prod = item.get("product", {})
extra = [
i + 1,
item.get("item_id", ""),
prod.get("title", ""), prod.get("price", ""),
item.get("familiarity", ""),
]
rows.append(common + extra)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
unique_tag = uuid.uuid4().hex[:8]
repo_path = f"csv/{timestamp}_{safe_worker}_{unique_tag}.csv"
with tempfile.NamedTemporaryFile(
mode="w", suffix=".csv", delete=False, newline="", encoding="utf-8"
) as tmp:
tmp_path = tmp.name
writer = csv.writer(tmp)
writer.writerow(header)
writer.writerows(rows)
if cfg.get("hf_token"):
try:
hf_api.upload_file(
path_or_fileobj=tmp_path,
path_in_repo=repo_path,
repo_id=cfg["output_dataset_repo"],
repo_type="dataset",
)
print("[HF] CSV uploaded.")
except Exception as e:
print(f"[HF] CSV upload error: {e}")
os.unlink(tmp_path)