Nicolas Wagner
textual update
dcb04e7
from io import StringIO
import pandas as pd
def normalize_label(label: any) -> float | None:
if pd.isna(label):
return None
if isinstance(label, (int, float)):
if label in [0, 1, 0.0, 1.0]:
return float(label)
return None
if isinstance(label, str):
label_stripped = label.strip()
if label_stripped in ["0", "1", "0.0", "1.0"]:
return float(label_stripped)
return None
return None
def validate_csv(csv_content: str, true_labels: dict[str, float]) -> tuple[bool, str, pd.DataFrame | None]:
if not csv_content or not csv_content.strip():
return False, "CSV content is empty", None
try:
df = pd.read_csv(StringIO(csv_content))
except Exception as e:
return False, f"Invalid CSV format: {str(e)}", None
if "id" not in df.columns:
return False, "CSV must contain 'id' column", None
if "label" not in df.columns:
return False, "CSV must contain 'label' column", None
if df.empty:
return False, "CSV is empty", None
if df["id"].isna().any():
return False, "id column contains missing values", None
if df["label"].isna().any():
return False, "label column contains missing values", None
df["id"] = df["id"].astype(str).str.strip()
normalized_labels = []
invalid_labels = []
for idx, row in df.iterrows():
id_val = str(row["id"]).strip()
label = normalize_label(row["label"])
if label is None:
invalid_labels.append(f"Row {idx + 1}: invalid label value '{row['label']}' (must be 0, 1, 0.0, or 1.0)")
else:
normalized_labels.append(label)
if invalid_labels:
return False, "Invalid labels found:\n" + "\n".join(invalid_labels[:5]), None
df["label"] = normalized_labels
unknown_ids = []
for id_val in df["id"]:
if str(id_val) not in true_labels:
unknown_ids.append(str(id_val))
if unknown_ids:
return (
False,
f"Unknown IDs found: {', '.join(unknown_ids[:5])}{'...' if len(unknown_ids) > 5 else ''}",
None,
)
missing_ids = []
for true_id in true_labels.keys():
if true_id not in df["id"].values:
missing_ids.append(true_id)
if missing_ids:
return (
False,
f"Missing IDs from true labels: {', '.join(missing_ids[:5])}{'...' if len(missing_ids) > 5 else ''} (total: {len(missing_ids)})",
None,
)
return True, "CSV is valid", df