Spaces:
Running
Running
Fix validation and normalization logic
Browse files
notebooks/01_data_acquisition.ipynb
CHANGED
|
@@ -352,28 +352,32 @@
|
|
| 352 |
"source": [
|
| 353 |
"def validate_dataset(name: str, df: pd.DataFrame) -> Dict[str, Any]:\n",
|
| 354 |
" \"\"\"Validate dataset quality and return report\"\"\"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
" report = {\n",
|
| 356 |
" \"name\": name,\n",
|
| 357 |
" \"samples\": len(df),\n",
|
| 358 |
" \"features\": len(df.columns),\n",
|
| 359 |
" \"missing_values\": df.isnull().sum().sum(),\n",
|
| 360 |
-
" \"missing_pct\": (df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100,\n",
|
| 361 |
" \"duplicate_rows\": df.duplicated().sum(),\n",
|
| 362 |
" \"numeric_columns\": len(df.select_dtypes(include=[np.number]).columns),\n",
|
| 363 |
" \"categorical_columns\": len(df.select_dtypes(include=['object', 'category']).columns),\n",
|
| 364 |
" \"memory_mb\": df.memory_usage(deep=True).sum() / (1024 * 1024),\n",
|
| 365 |
-
" \"has_label\":
|
| 366 |
" \"valid\": True\n",
|
| 367 |
" }\n",
|
| 368 |
" \n",
|
| 369 |
-
" #
|
| 370 |
" issues = []\n",
|
| 371 |
-
" if report[\"samples\"] <
|
| 372 |
-
" issues.append(\"Too few samples (<
|
| 373 |
-
" if report[\"missing_pct\"] >
|
| 374 |
-
" issues.append(\"Too many missing values (>
|
| 375 |
-
" if not report[\"has_label\"]:\n",
|
| 376 |
-
" issues.append(\"No label column found\")\n",
|
| 377 |
" \n",
|
| 378 |
" report[\"issues\"] = issues\n",
|
| 379 |
" report[\"valid\"] = len(issues) == 0\n",
|
|
@@ -393,7 +397,7 @@
|
|
| 393 |
" print(f\"{name:<30} {report['samples']:>10} {report['features']:>10} {report['missing_pct']:>9.2f}% {status:>8}\")\n",
|
| 394 |
"\n",
|
| 395 |
"valid_datasets = [r[\"name\"] for r in validation_reports if r[\"valid\"]]\n",
|
| 396 |
-
"print(f\"\\n\u2713 {len(valid_datasets)} datasets passed validation\")"
|
| 397 |
]
|
| 398 |
},
|
| 399 |
{
|
|
@@ -418,16 +422,20 @@
|
|
| 418 |
" # Standardize column names\n",
|
| 419 |
" df.columns = df.columns.str.lower().str.replace(' ', '_').str.replace('-', '_')\n",
|
| 420 |
" \n",
|
| 421 |
-
" # Find and standardize label column\n",
|
| 422 |
-
" label_columns = ['label', 'target', 'class', 'is_malicious', 'attack_type',
|
|
|
|
|
|
|
| 423 |
" for col in label_columns:\n",
|
| 424 |
-
"
|
| 425 |
-
"
|
|
|
|
| 426 |
" break\n",
|
| 427 |
" \n",
|
| 428 |
" # Handle missing values\n",
|
| 429 |
" numeric_cols = df.select_dtypes(include=[np.number]).columns\n",
|
| 430 |
-
"
|
|
|
|
| 431 |
" \n",
|
| 432 |
" categorical_cols = df.select_dtypes(include=['object', 'category']).columns\n",
|
| 433 |
" for col in categorical_cols:\n",
|
|
@@ -443,17 +451,19 @@
|
|
| 443 |
" \n",
|
| 444 |
" return df\n",
|
| 445 |
"\n",
|
| 446 |
-
"# Normalize
|
| 447 |
"normalized_datasets = {}\n",
|
| 448 |
"print(\"Normalizing datasets...\")\n",
|
| 449 |
"\n",
|
| 450 |
-
"for name in
|
| 451 |
-
"
|
| 452 |
-
"
|
| 453 |
-
" normalized_datasets[name] =
|
| 454 |
-
" print(f\" \u2713 {name}: {len(
|
|
|
|
|
|
|
| 455 |
"\n",
|
| 456 |
-
"print(f\"\\n\u2713 Normalized {len(normalized_datasets)} datasets\")"
|
| 457 |
]
|
| 458 |
},
|
| 459 |
{
|
|
|
|
| 352 |
"source": [
|
| 353 |
"def validate_dataset(name: str, df: pd.DataFrame) -> Dict[str, Any]:\n",
|
| 354 |
" \"\"\"Validate dataset quality and return report\"\"\"\n",
|
| 355 |
+
" # Expanded label column detection\n",
|
| 356 |
+
" label_columns = ['label', 'target', 'class', 'is_malicious', 'attack_type', \n",
|
| 357 |
+
" 'attack', 'category', 'malware', 'phishing', 'threat', \n",
|
| 358 |
+
" 'classification', 'type', 'result', 'output', 'y']\n",
|
| 359 |
+
" has_label = any(col.lower() in [c.lower() for c in df.columns] for col in label_columns)\n",
|
| 360 |
+
" \n",
|
| 361 |
" report = {\n",
|
| 362 |
" \"name\": name,\n",
|
| 363 |
" \"samples\": len(df),\n",
|
| 364 |
" \"features\": len(df.columns),\n",
|
| 365 |
" \"missing_values\": df.isnull().sum().sum(),\n",
|
| 366 |
+
" \"missing_pct\": (df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100 if len(df) > 0 else 0,\n",
|
| 367 |
" \"duplicate_rows\": df.duplicated().sum(),\n",
|
| 368 |
" \"numeric_columns\": len(df.select_dtypes(include=[np.number]).columns),\n",
|
| 369 |
" \"categorical_columns\": len(df.select_dtypes(include=['object', 'category']).columns),\n",
|
| 370 |
" \"memory_mb\": df.memory_usage(deep=True).sum() / (1024 * 1024),\n",
|
| 371 |
+
" \"has_label\": has_label,\n",
|
| 372 |
" \"valid\": True\n",
|
| 373 |
" }\n",
|
| 374 |
" \n",
|
| 375 |
+
" # More lenient validation - only fail on critical issues\n",
|
| 376 |
" issues = []\n",
|
| 377 |
+
" if report[\"samples\"] < 10:\n",
|
| 378 |
+
" issues.append(\"Too few samples (<10)\")\n",
|
| 379 |
+
" if report[\"missing_pct\"] > 80:\n",
|
| 380 |
+
" issues.append(\"Too many missing values (>80%)\")\n",
|
|
|
|
|
|
|
| 381 |
" \n",
|
| 382 |
" report[\"issues\"] = issues\n",
|
| 383 |
" report[\"valid\"] = len(issues) == 0\n",
|
|
|
|
| 397 |
" print(f\"{name:<30} {report['samples']:>10} {report['features']:>10} {report['missing_pct']:>9.2f}% {status:>8}\")\n",
|
| 398 |
"\n",
|
| 399 |
"valid_datasets = [r[\"name\"] for r in validation_reports if r[\"valid\"]]\n",
|
| 400 |
+
"print(f\"\\n\u2713 {len(valid_datasets)} datasets passed validation\")\n"
|
| 401 |
]
|
| 402 |
},
|
| 403 |
{
|
|
|
|
| 422 |
" # Standardize column names\n",
|
| 423 |
" df.columns = df.columns.str.lower().str.replace(' ', '_').str.replace('-', '_')\n",
|
| 424 |
" \n",
|
| 425 |
+
" # Find and standardize label column (expanded list)\n",
|
| 426 |
+
" label_columns = ['label', 'target', 'class', 'is_malicious', 'attack_type', \n",
|
| 427 |
+
" 'attack', 'category', 'malware', 'phishing', 'threat',\n",
|
| 428 |
+
" 'classification', 'type', 'result', 'output', 'y']\n",
|
| 429 |
" for col in label_columns:\n",
|
| 430 |
+
" matching_cols = [c for c in df.columns if c.lower() == col.lower()]\n",
|
| 431 |
+
" if matching_cols:\n",
|
| 432 |
+
" df = df.rename(columns={matching_cols[0]: 'label'})\n",
|
| 433 |
" break\n",
|
| 434 |
" \n",
|
| 435 |
" # Handle missing values\n",
|
| 436 |
" numeric_cols = df.select_dtypes(include=[np.number]).columns\n",
|
| 437 |
+
" if len(numeric_cols) > 0:\n",
|
| 438 |
+
" df[numeric_cols] = df[numeric_cols].fillna(df[numeric_cols].median())\n",
|
| 439 |
" \n",
|
| 440 |
" categorical_cols = df.select_dtypes(include=['object', 'category']).columns\n",
|
| 441 |
" for col in categorical_cols:\n",
|
|
|
|
| 451 |
" \n",
|
| 452 |
" return df\n",
|
| 453 |
"\n",
|
| 454 |
+
"# Normalize ALL loaded datasets (not just valid ones)\n",
|
| 455 |
"normalized_datasets = {}\n",
|
| 456 |
"print(\"Normalizing datasets...\")\n",
|
| 457 |
"\n",
|
| 458 |
+
"for name, df in loaded_datasets.items():\n",
|
| 459 |
+
" try:\n",
|
| 460 |
+
" normalized_df = normalize_dataset(df, name)\n",
|
| 461 |
+
" normalized_datasets[name] = normalized_df\n",
|
| 462 |
+
" print(f\" \u2713 {name}: {len(normalized_df)} samples after normalization\")\n",
|
| 463 |
+
" except Exception as e:\n",
|
| 464 |
+
" print(f\" \u26a0 {name}: Error during normalization - {e}\")\n",
|
| 465 |
"\n",
|
| 466 |
+
"print(f\"\\n\u2713 Normalized {len(normalized_datasets)} datasets\")\n"
|
| 467 |
]
|
| 468 |
},
|
| 469 |
{
|