traj-eval / hf_dataset_sync.py
KaushikSid
Force fresh CSV download and strict validation of issue_type values
00d0080
"""Append evaluations to shared HF Dataset as CSV."""
import os
import pandas as pd
from datetime import datetime
import threading
import queue
import tempfile
from huggingface_hub import HfApi, hf_hub_download
# Config
DATASET_REPO = os.getenv("EVAL_DATASET_REPO")
HF_TOKEN = os.getenv("HF_TOKEN")
CSV_FILENAME = "traj_evaluations.csv"
_dataset_sync_enabled = False
_sync_queue = queue.Queue()
_sync_thread = None
_api = None
def init_dataset_sync():
"""Initialize shared dataset sync."""
global _dataset_sync_enabled, _api
if not DATASET_REPO or not HF_TOKEN:
print("📊 Dataset sync disabled (set EVAL_DATASET_REPO + HF_TOKEN to enable)")
return False
try:
_api = HfApi(token=HF_TOKEN)
# Check if CSV exists
try:
hf_hub_download(DATASET_REPO, CSV_FILENAME, repo_type="dataset", token=HF_TOKEN)
print(f"✅ Connected to: {DATASET_REPO}/{CSV_FILENAME}")
except:
print(f"✅ Will create: {DATASET_REPO}/{CSV_FILENAME}")
_dataset_sync_enabled = True
return True
except Exception as e:
print(f"⚠️ Dataset sync failed: {e}")
return False
def _sync_worker():
"""Background worker that syncs to CSV on HF Hub (upserts by trajectory_id)."""
while True:
row = _sync_queue.get()
if row is None:
break
try:
# Download existing CSV or start fresh
try:
csv_path = hf_hub_download(DATASET_REPO, CSV_FILENAME, repo_type="dataset",
token=HF_TOKEN, force_download=True)
df = pd.read_csv(csv_path, keep_default_na=False, na_values=[''])
# Clean up any legacy "nan" strings
df = df.replace(['nan', 'NaN', 'None'], '')
except:
df = pd.DataFrame()
# Upsert: update if exists, append if not
traj_id = row.get("trajectory_id", "")
if len(df) > 0 and "trajectory_id" in df.columns and traj_id in df["trajectory_id"].values:
# Update existing row
idx = df[df["trajectory_id"] == traj_id].index[-1]
for col, val in row.items():
df.at[idx, col] = val
action = "Updated"
else:
# Append new row
df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
action = "Added"
# Upload updated CSV (fillna to ensure no NaN values)
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
df.fillna('').to_csv(f, index=False)
temp_path = f.name
_api.upload_file(
path_or_fileobj=temp_path,
path_in_repo=CSV_FILENAME,
repo_id=DATASET_REPO,
repo_type="dataset",
commit_message=f"{action} eval @ {datetime.now().strftime('%H:%M:%S')}"
)
os.unlink(temp_path)
print(f"✅ {action} in {CSV_FILENAME}")
except Exception as e:
print(f"⚠️ Sync failed: {e}")
finally:
_sync_queue.task_done()
def _ensure_worker():
"""Start background worker if not running."""
global _sync_thread
if _sync_thread is None or not _sync_thread.is_alive():
_sync_thread = threading.Thread(target=_sync_worker, daemon=True)
_sync_thread.start()
def append_to_dataset(evaluation_row):
"""Queue evaluation for async sync."""
if not _dataset_sync_enabled:
return False
_ensure_worker()
_sync_queue.put(evaluation_row)
return True