File size: 2,582 Bytes
141f1e0
 
 
 
 
bc714de
 
141f1e0
 
bc714de
dcb04e7
bc714de
141f1e0
 
bc714de
 
dcb04e7
bc714de
141f1e0
 
 
 
 
bc714de
141f1e0
 
 
 
 
 
 
 
bc714de
 
141f1e0
bc714de
 
141f1e0
 
 
 
bc714de
 
141f1e0
bc714de
 
141f1e0
dcb04e7
 
bc714de
 
141f1e0
 
bc714de
 
141f1e0
bc714de
dcb04e7
141f1e0
bc714de
141f1e0
bc714de
 
141f1e0
bc714de
141f1e0
bc714de
 
 
 
141f1e0
bc714de
141f1e0
 
bc714de
 
 
 
 
 
 
 
 
 
 
 
 
141f1e0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from io import StringIO

import pandas as pd


def normalize_label(label: any) -> float | None:
    if pd.isna(label):
        return None

    if isinstance(label, (int, float)):
        if label in [0, 1, 0.0, 1.0]:
            return float(label)
        return None

    if isinstance(label, str):
        label_stripped = label.strip()
        if label_stripped in ["0", "1", "0.0", "1.0"]:
            return float(label_stripped)
        return None

    return None


def validate_csv(csv_content: str, true_labels: dict[str, float]) -> tuple[bool, str, pd.DataFrame | None]:
    if not csv_content or not csv_content.strip():
        return False, "CSV content is empty", None

    try:
        df = pd.read_csv(StringIO(csv_content))
    except Exception as e:
        return False, f"Invalid CSV format: {str(e)}", None

    if "id" not in df.columns:
        return False, "CSV must contain 'id' column", None

    if "label" not in df.columns:
        return False, "CSV must contain 'label' column", None

    if df.empty:
        return False, "CSV is empty", None

    if df["id"].isna().any():
        return False, "id column contains missing values", None

    if df["label"].isna().any():
        return False, "label column contains missing values", None

    df["id"] = df["id"].astype(str).str.strip()

    normalized_labels = []
    invalid_labels = []

    for idx, row in df.iterrows():
        id_val = str(row["id"]).strip()
        label = normalize_label(row["label"])

        if label is None:
            invalid_labels.append(f"Row {idx + 1}: invalid label value '{row['label']}' (must be 0, 1, 0.0, or 1.0)")
        else:
            normalized_labels.append(label)

    if invalid_labels:
        return False, "Invalid labels found:\n" + "\n".join(invalid_labels[:5]), None

    df["label"] = normalized_labels

    unknown_ids = []
    for id_val in df["id"]:
        if str(id_val) not in true_labels:
            unknown_ids.append(str(id_val))

    if unknown_ids:
        return (
            False,
            f"Unknown IDs found: {', '.join(unknown_ids[:5])}{'...' if len(unknown_ids) > 5 else ''}",
            None,
        )

    missing_ids = []
    for true_id in true_labels.keys():
        if true_id not in df["id"].values:
            missing_ids.append(true_id)

    if missing_ids:
        return (
            False,
            f"Missing IDs from true labels: {', '.join(missing_ids[:5])}{'...' if len(missing_ids) > 5 else ''} (total: {len(missing_ids)})",
            None,
        )

    return True, "CSV is valid", df