mihretgold's picture
Upload 2 files
cf97cee verified
"""
User Study: Firestore logging and study configuration.
STORAGE: Google Cloud Firestore (Firebase).
The service account JSON file is loaded from:
1. The file path in FIREBASE_KEY_PATH env var, OR
2. The default path: steering-vision-language-model-firebase-adminsdk-fbsvc-b7d95b30a2.json
(relative to this file's directory), OR
3. The JSON string in FIREBASE_SERVICE_ACCOUNT_JSON env var (for HF Spaces secrets).
Firestore collections:
participants, interaction_log, image_annotations,
method_comparison, survey_responses, final_selections, final_survey
"""
import json
import os
import re
from pathlib import Path
from datetime import datetime, timezone
# ── Firestore initialisation ────────────────────────────────────────────────
_firestore_db = None
_firestore_init_done = False
# Default key file path (same directory as this module)
_DEFAULT_KEY_FILE = Path(__file__).parent / "steering-vision-language-model-firebase-adminsdk-fbsvc-b7d95b30a2.json"
def _init_firestore():
"""Lazy-init Firestore. Called once; sets _firestore_db or raises."""
global _firestore_db, _firestore_init_done
if _firestore_init_done:
return
_firestore_init_done = True
import firebase_admin
from firebase_admin import credentials, firestore
# Already initialised by another module?
if firebase_admin._apps:
_firestore_db = firestore.client()
print("[study_utils] Firestore: reusing existing app")
return
# Option 1: explicit file path from env var
key_path = os.environ.get("FIREBASE_KEY_PATH", "").strip()
if key_path and Path(key_path).exists():
cred = credentials.Certificate(key_path)
firebase_admin.initialize_app(cred)
_firestore_db = firestore.client()
print(f"[study_utils] Firestore: initialised from FIREBASE_KEY_PATH={key_path}")
return
# Option 2: default key file next to this module
if _DEFAULT_KEY_FILE.exists():
cred = credentials.Certificate(str(_DEFAULT_KEY_FILE))
firebase_admin.initialize_app(cred)
_firestore_db = firestore.client()
print(f"[study_utils] Firestore: initialised from {_DEFAULT_KEY_FILE.name}")
return
# Option 3: JSON string in env var (for HF Spaces secrets)
sa_json = os.environ.get("FIREBASE_SERVICE_ACCOUNT_JSON", "").strip()
if sa_json:
info = json.loads(sa_json)
cred = credentials.Certificate(info)
firebase_admin.initialize_app(cred)
_firestore_db = firestore.client()
print("[study_utils] Firestore: initialised from FIREBASE_SERVICE_ACCOUNT_JSON")
return
raise RuntimeError(
"No Firebase credentials found. Place the service account JSON file "
f"at {_DEFAULT_KEY_FILE} or set FIREBASE_SERVICE_ACCOUNT_JSON env var."
)
def _get_db():
"""Return the Firestore client, initialising if needed."""
_init_firestore()
if _firestore_db is None:
raise RuntimeError("Firestore is not initialised.")
return _firestore_db
def firestore_add(collection: str, data: dict) -> None:
"""Add a document to a Firestore collection."""
db = _get_db()
db.collection(collection).add(data)
def firestore_batch_add(items: list[tuple[str, dict]]) -> None:
"""Add many documents efficiently using Firestore batch writes.
Args:
items: list of (collection_name, data_dict) tuples.
Firestore batches support up to 500 ops each;
this function auto-splits into multiple batches.
"""
db = _get_db()
BATCH_LIMIT = 450 # stay under Firestore's 500-op limit
for start in range(0, len(items), BATCH_LIMIT):
batch = db.batch()
for collection, data in items[start:start + BATCH_LIMIT]:
ref = db.collection(collection).document()
batch.set(ref, data)
batch.commit()
def firestore_query_exists(collection: str, field: str, value) -> bool:
"""Return True if at least one document matches field == value."""
db = _get_db()
docs = (db.collection(collection)
.where(field, "==", value)
.limit(1)
.get())
return len(docs) > 0
# ── 22 study queries ────────────────────────────────────────────────────────
STUDY_QUERIES = [
# Stanford Dogs (7)
("a golden retriever", "stanford_dogs"),
("Dog on the beach", "stanford_dogs"),
("Dog looking guilty", "stanford_dogs"),
("friendly looking dog", "stanford_dogs"),
("aggressive looking dog", "stanford_dogs"),
("nervous looking dog", "stanford_dogs"),
("Hyper active dog", "stanford_dogs"),
# Flickr (7)
("a person riding a bicycle", "flickr"),
("A dog playing", "flickr"),
("an exciting action scene", "flickr"),
("a joyful moment", "flickr"),
("A kid having fun", "flickr"),
("peaceful scene", "flickr"),
("a photo with motion", "flickr"),
# CelebA (8)
("wearing eyeglasses", "celeba"),
("a person smiling", "celeba"),
("looking guilty", "celeba"),
("looking happy", "celeba"),
("looking sad", "celeba"),
("looking suspicious", "celeba"),
("looking tired", "celeba"),
("looking confident", "celeba"),
]
NUM_QUERIES = len(STUDY_QUERIES)
MAX_ROUNDS = 3
# ── Helpers ──────────────────────────────────────────────────────────────────
def _iso_ts() -> str:
return datetime.now(timezone.utc).isoformat()
def validate_email(email: str) -> bool:
if not email or not isinstance(email, str):
return False
email = email.strip().lower()
pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
return bool(re.match(pattern, email))
def participant_exists(participant_id: str) -> bool:
"""Return True if participant_id already registered in Firestore."""
return firestore_query_exists("participants", "participant_id", participant_id.strip().lower())
def register_participant(email: str, gender: str, age_range: str) -> tuple:
"""
Write participant to Firestore.
Returns (success: bool, error_message: str).
"""
email = email.strip().lower()
if not validate_email(email):
return False, "Please enter a valid email address."
if participant_exists(email):
return False, "This email is already registered. Use a different one or contact the researchers."
row = {
"participant_id": email,
"email": email,
"gender": (gender or "").strip(),
"age_range": (age_range or "").strip(),
"timestamp": _iso_ts(),
}
firestore_add("participants", row)
return True, ""
def log_interaction(
participant_id: str,
query_id: int,
query_text: str,
method: str,
round_number: int,
attributes_used: str,
retrieved_image_ids: str,
time_elapsed: float,
user_satisfied: int,
) -> None:
firestore_add("interaction_log", {
"participant_id": participant_id,
"query_id": query_id,
"query_text": query_text,
"method": method,
"round_number": round_number,
"attributes_used": attributes_used,
"retrieved_image_ids": retrieved_image_ids,
"time_elapsed": time_elapsed,
"user_satisfied": user_satisfied,
"timestamp": _iso_ts(),
})
def log_image_annotations(
participant_id: str,
query_id: int,
image_id: str,
method: str,
meets_intent: int,
) -> None:
firestore_add("image_annotations", {
"participant_id": participant_id,
"query_id": query_id,
"image_id": image_id,
"method": method,
"meets_intent": meets_intent,
"timestamp": _iso_ts(),
})
def log_method_comparison(
participant_id: str,
query_id: int,
linear_better: str,
) -> None:
firestore_add("method_comparison", {
"participant_id": participant_id,
"query_id": query_id,
"linear_better": linear_better,
"timestamp": _iso_ts(),
})
def log_survey_responses(
participant_id: str,
query_id: int,
alignment_score: int,
agency_score: int,
satisfaction_score: int,
frustration_score: int,
round_satisfied: int,
time_elapsed: float,
) -> None:
firestore_add("survey_responses", {
"participant_id": participant_id,
"query_id": query_id,
"alignment_score": alignment_score,
"agency_score": agency_score,
"satisfaction_score": satisfaction_score,
"frustration_score": frustration_score,
"round_satisfied": round_satisfied,
"time_elapsed": time_elapsed,
"timestamp": _iso_ts(),
})
def log_final_selections(
participant_id: str,
query_id: int,
baseline_final_image_ids: str,
linear_final_image_ids: str,
round_satisfied: int,
time_elapsed: float,
) -> None:
firestore_add("final_selections", {
"participant_id": participant_id,
"query_id": query_id,
"baseline_final_image_ids": baseline_final_image_ids,
"linear_final_image_ids": linear_final_image_ids,
"round_satisfied": round_satisfied,
"time_elapsed": time_elapsed,
"timestamp": _iso_ts(),
})
def log_final_survey(
participant_id: str,
preferred_system: str,
concept_changed: str,
open_feedback: str,
) -> None:
firestore_add("final_survey", {
"participant_id": participant_id,
"preferred_system": preferred_system,
"concept_changed": concept_changed,
"open_feedback": (open_feedback or "").strip(),
"timestamp": _iso_ts(),
})