|
|
"""Append evaluations to shared HF Dataset as CSV.""" |
|
|
import os |
|
|
import pandas as pd |
|
|
from datetime import datetime |
|
|
import threading |
|
|
import queue |
|
|
import tempfile |
|
|
from huggingface_hub import HfApi, hf_hub_download |
|
|
|
|
|
|
|
|
DATASET_REPO = os.getenv("EVAL_DATASET_REPO") |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
CSV_FILENAME = "traj_evaluations.csv" |
|
|
|
|
|
_dataset_sync_enabled = False |
|
|
_sync_queue = queue.Queue() |
|
|
_sync_thread = None |
|
|
_api = None |
|
|
|
|
|
|
|
|
def init_dataset_sync(): |
|
|
"""Initialize shared dataset sync.""" |
|
|
global _dataset_sync_enabled, _api |
|
|
|
|
|
if not DATASET_REPO or not HF_TOKEN: |
|
|
print("📊 Dataset sync disabled (set EVAL_DATASET_REPO + HF_TOKEN to enable)") |
|
|
return False |
|
|
|
|
|
try: |
|
|
_api = HfApi(token=HF_TOKEN) |
|
|
|
|
|
try: |
|
|
hf_hub_download(DATASET_REPO, CSV_FILENAME, repo_type="dataset", token=HF_TOKEN) |
|
|
print(f"✅ Connected to: {DATASET_REPO}/{CSV_FILENAME}") |
|
|
except: |
|
|
print(f"✅ Will create: {DATASET_REPO}/{CSV_FILENAME}") |
|
|
|
|
|
_dataset_sync_enabled = True |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"⚠️ Dataset sync failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def _sync_worker(): |
|
|
"""Background worker that syncs to CSV on HF Hub (upserts by trajectory_id).""" |
|
|
while True: |
|
|
row = _sync_queue.get() |
|
|
if row is None: |
|
|
break |
|
|
try: |
|
|
|
|
|
try: |
|
|
csv_path = hf_hub_download(DATASET_REPO, CSV_FILENAME, repo_type="dataset", |
|
|
token=HF_TOKEN, force_download=True) |
|
|
df = pd.read_csv(csv_path, keep_default_na=False, na_values=['']) |
|
|
|
|
|
df = df.replace(['nan', 'NaN', 'None'], '') |
|
|
except: |
|
|
df = pd.DataFrame() |
|
|
|
|
|
|
|
|
traj_id = row.get("trajectory_id", "") |
|
|
if len(df) > 0 and "trajectory_id" in df.columns and traj_id in df["trajectory_id"].values: |
|
|
|
|
|
idx = df[df["trajectory_id"] == traj_id].index[-1] |
|
|
for col, val in row.items(): |
|
|
df.at[idx, col] = val |
|
|
action = "Updated" |
|
|
else: |
|
|
|
|
|
df = pd.concat([df, pd.DataFrame([row])], ignore_index=True) |
|
|
action = "Added" |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f: |
|
|
df.fillna('').to_csv(f, index=False) |
|
|
temp_path = f.name |
|
|
|
|
|
_api.upload_file( |
|
|
path_or_fileobj=temp_path, |
|
|
path_in_repo=CSV_FILENAME, |
|
|
repo_id=DATASET_REPO, |
|
|
repo_type="dataset", |
|
|
commit_message=f"{action} eval @ {datetime.now().strftime('%H:%M:%S')}" |
|
|
) |
|
|
os.unlink(temp_path) |
|
|
print(f"✅ {action} in {CSV_FILENAME}") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Sync failed: {e}") |
|
|
finally: |
|
|
_sync_queue.task_done() |
|
|
|
|
|
|
|
|
def _ensure_worker(): |
|
|
"""Start background worker if not running.""" |
|
|
global _sync_thread |
|
|
if _sync_thread is None or not _sync_thread.is_alive(): |
|
|
_sync_thread = threading.Thread(target=_sync_worker, daemon=True) |
|
|
_sync_thread.start() |
|
|
|
|
|
|
|
|
def append_to_dataset(evaluation_row): |
|
|
"""Queue evaluation for async sync.""" |
|
|
if not _dataset_sync_enabled: |
|
|
return False |
|
|
_ensure_worker() |
|
|
_sync_queue.put(evaluation_row) |
|
|
return True |
|
|
|
|
|
|