File size: 1,305 Bytes
141f1e0
 
 
 
 
 
 
bc714de
141f1e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc714de
 
141f1e0
 
bc714de
141f1e0
 
bc714de
 
141f1e0
bc714de
 
 
 
 
141f1e0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os

from huggingface_hub import snapshot_download

from src.envs import TOKEN, TRUE_LABELS_PATH, TRUE_LABELS_REPO


def load_true_labels() -> dict[str, float]:
    os.makedirs(TRUE_LABELS_PATH, exist_ok=True)

    try:
        snapshot_download(
            repo_id=TRUE_LABELS_REPO,
            local_dir=TRUE_LABELS_PATH,
            repo_type="dataset",
            tqdm_class=None,
            etag_timeout=30,
            token=TOKEN,
        )
    except Exception as e:
        print(f"Warning: Could not download true labels: {e}")
        return {}

    labels = {}

    import pandas as pd

    for root, _, files in os.walk(TRUE_LABELS_PATH):
        for file in files:
            if file == "true_label.csv":
                filepath = os.path.join(root, file)
                try:
                    df = pd.read_csv(filepath)
                    if "id" in df.columns and "label" in df.columns:
                        for _, row in df.iterrows():
                            label_val = float(row["label"])
                            if label_val in [0.0, 1.0]:
                                labels[str(row["id"])] = label_val
                except Exception as e:
                    print(f"Error loading true_label.csv: {e}")
                    continue

    return labels