Nicolas Wagner
update for correct metric and label
bc714de
import os
from huggingface_hub import snapshot_download
from src.envs import TOKEN, TRUE_LABELS_PATH, TRUE_LABELS_REPO
def load_true_labels() -> dict[str, float]:
os.makedirs(TRUE_LABELS_PATH, exist_ok=True)
try:
snapshot_download(
repo_id=TRUE_LABELS_REPO,
local_dir=TRUE_LABELS_PATH,
repo_type="dataset",
tqdm_class=None,
etag_timeout=30,
token=TOKEN,
)
except Exception as e:
print(f"Warning: Could not download true labels: {e}")
return {}
labels = {}
import pandas as pd
for root, _, files in os.walk(TRUE_LABELS_PATH):
for file in files:
if file == "true_label.csv":
filepath = os.path.join(root, file)
try:
df = pd.read_csv(filepath)
if "id" in df.columns and "label" in df.columns:
for _, row in df.iterrows():
label_val = float(row["label"])
if label_val in [0.0, 1.0]:
labels[str(row["id"])] = label_val
except Exception as e:
print(f"Error loading true_label.csv: {e}")
continue
return labels