Che237 commited on
Commit
2fecf00
·
verified ·
1 Parent(s): 460801e

Fix validation and normalization logic

Browse files
Files changed (1) hide show
  1. notebooks/01_data_acquisition.ipynb +32 -22
notebooks/01_data_acquisition.ipynb CHANGED
@@ -352,28 +352,32 @@
352
  "source": [
353
  "def validate_dataset(name: str, df: pd.DataFrame) -> Dict[str, Any]:\n",
354
  " \"\"\"Validate dataset quality and return report\"\"\"\n",
 
 
 
 
 
 
355
  " report = {\n",
356
  " \"name\": name,\n",
357
  " \"samples\": len(df),\n",
358
  " \"features\": len(df.columns),\n",
359
  " \"missing_values\": df.isnull().sum().sum(),\n",
360
- " \"missing_pct\": (df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100,\n",
361
  " \"duplicate_rows\": df.duplicated().sum(),\n",
362
  " \"numeric_columns\": len(df.select_dtypes(include=[np.number]).columns),\n",
363
  " \"categorical_columns\": len(df.select_dtypes(include=['object', 'category']).columns),\n",
364
  " \"memory_mb\": df.memory_usage(deep=True).sum() / (1024 * 1024),\n",
365
- " \"has_label\": any(col in df.columns for col in ['label', 'target', 'class', 'is_malicious', 'attack_type']),\n",
366
  " \"valid\": True\n",
367
  " }\n",
368
  " \n",
369
- " # Validation checks\n",
370
  " issues = []\n",
371
- " if report[\"samples\"] < 100:\n",
372
- " issues.append(\"Too few samples (<100)\")\n",
373
- " if report[\"missing_pct\"] > 50:\n",
374
- " issues.append(\"Too many missing values (>50%)\")\n",
375
- " if not report[\"has_label\"]:\n",
376
- " issues.append(\"No label column found\")\n",
377
  " \n",
378
  " report[\"issues\"] = issues\n",
379
  " report[\"valid\"] = len(issues) == 0\n",
@@ -393,7 +397,7 @@
393
  " print(f\"{name:<30} {report['samples']:>10} {report['features']:>10} {report['missing_pct']:>9.2f}% {status:>8}\")\n",
394
  "\n",
395
  "valid_datasets = [r[\"name\"] for r in validation_reports if r[\"valid\"]]\n",
396
- "print(f\"\\n\u2713 {len(valid_datasets)} datasets passed validation\")"
397
  ]
398
  },
399
  {
@@ -418,16 +422,20 @@
418
  " # Standardize column names\n",
419
  " df.columns = df.columns.str.lower().str.replace(' ', '_').str.replace('-', '_')\n",
420
  " \n",
421
- " # Find and standardize label column\n",
422
- " label_columns = ['label', 'target', 'class', 'is_malicious', 'attack_type', 'attack', 'category']\n",
 
 
423
  " for col in label_columns:\n",
424
- " if col in df.columns:\n",
425
- " df = df.rename(columns={col: 'label'})\n",
 
426
  " break\n",
427
  " \n",
428
  " # Handle missing values\n",
429
  " numeric_cols = df.select_dtypes(include=[np.number]).columns\n",
430
- " df[numeric_cols] = df[numeric_cols].fillna(df[numeric_cols].median())\n",
 
431
  " \n",
432
  " categorical_cols = df.select_dtypes(include=['object', 'category']).columns\n",
433
  " for col in categorical_cols:\n",
@@ -443,17 +451,19 @@
443
  " \n",
444
  " return df\n",
445
  "\n",
446
- "# Normalize all valid datasets\n",
447
  "normalized_datasets = {}\n",
448
  "print(\"Normalizing datasets...\")\n",
449
  "\n",
450
- "for name in valid_datasets:\n",
451
- " if name in loaded_datasets:\n",
452
- " df = normalize_dataset(loaded_datasets[name], name)\n",
453
- " normalized_datasets[name] = df\n",
454
- " print(f\" \u2713 {name}: {len(df)} samples after normalization\")\n",
 
 
455
  "\n",
456
- "print(f\"\\n\u2713 Normalized {len(normalized_datasets)} datasets\")"
457
  ]
458
  },
459
  {
 
352
  "source": [
353
  "def validate_dataset(name: str, df: pd.DataFrame) -> Dict[str, Any]:\n",
354
  " \"\"\"Validate dataset quality and return report\"\"\"\n",
355
+ " # Expanded label column detection\n",
356
+ " label_columns = ['label', 'target', 'class', 'is_malicious', 'attack_type', \n",
357
+ " 'attack', 'category', 'malware', 'phishing', 'threat', \n",
358
+ " 'classification', 'type', 'result', 'output', 'y']\n",
359
+ " has_label = any(col.lower() in [c.lower() for c in df.columns] for col in label_columns)\n",
360
+ " \n",
361
  " report = {\n",
362
  " \"name\": name,\n",
363
  " \"samples\": len(df),\n",
364
  " \"features\": len(df.columns),\n",
365
  " \"missing_values\": df.isnull().sum().sum(),\n",
366
+ " \"missing_pct\": (df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100 if len(df) > 0 else 0,\n",
367
  " \"duplicate_rows\": df.duplicated().sum(),\n",
368
  " \"numeric_columns\": len(df.select_dtypes(include=[np.number]).columns),\n",
369
  " \"categorical_columns\": len(df.select_dtypes(include=['object', 'category']).columns),\n",
370
  " \"memory_mb\": df.memory_usage(deep=True).sum() / (1024 * 1024),\n",
371
+ " \"has_label\": has_label,\n",
372
  " \"valid\": True\n",
373
  " }\n",
374
  " \n",
375
+ " # More lenient validation - only fail on critical issues\n",
376
  " issues = []\n",
377
+ " if report[\"samples\"] < 10:\n",
378
+ " issues.append(\"Too few samples (<10)\")\n",
379
+ " if report[\"missing_pct\"] > 80:\n",
380
+ " issues.append(\"Too many missing values (>80%)\")\n",
 
 
381
  " \n",
382
  " report[\"issues\"] = issues\n",
383
  " report[\"valid\"] = len(issues) == 0\n",
 
397
  " print(f\"{name:<30} {report['samples']:>10} {report['features']:>10} {report['missing_pct']:>9.2f}% {status:>8}\")\n",
398
  "\n",
399
  "valid_datasets = [r[\"name\"] for r in validation_reports if r[\"valid\"]]\n",
400
+ "print(f\"\\n\u2713 {len(valid_datasets)} datasets passed validation\")\n"
401
  ]
402
  },
403
  {
 
422
  " # Standardize column names\n",
423
  " df.columns = df.columns.str.lower().str.replace(' ', '_').str.replace('-', '_')\n",
424
  " \n",
425
+ " # Find and standardize label column (expanded list)\n",
426
+ " label_columns = ['label', 'target', 'class', 'is_malicious', 'attack_type', \n",
427
+ " 'attack', 'category', 'malware', 'phishing', 'threat',\n",
428
+ " 'classification', 'type', 'result', 'output', 'y']\n",
429
  " for col in label_columns:\n",
430
+ " matching_cols = [c for c in df.columns if c.lower() == col.lower()]\n",
431
+ " if matching_cols:\n",
432
+ " df = df.rename(columns={matching_cols[0]: 'label'})\n",
433
  " break\n",
434
  " \n",
435
  " # Handle missing values\n",
436
  " numeric_cols = df.select_dtypes(include=[np.number]).columns\n",
437
+ " if len(numeric_cols) > 0:\n",
438
+ " df[numeric_cols] = df[numeric_cols].fillna(df[numeric_cols].median())\n",
439
  " \n",
440
  " categorical_cols = df.select_dtypes(include=['object', 'category']).columns\n",
441
  " for col in categorical_cols:\n",
 
451
  " \n",
452
  " return df\n",
453
  "\n",
454
+ "# Normalize ALL loaded datasets (not just valid ones)\n",
455
  "normalized_datasets = {}\n",
456
  "print(\"Normalizing datasets...\")\n",
457
  "\n",
458
+ "for name, df in loaded_datasets.items():\n",
459
+ " try:\n",
460
+ " normalized_df = normalize_dataset(df, name)\n",
461
+ " normalized_datasets[name] = normalized_df\n",
462
+ " print(f\" \u2713 {name}: {len(normalized_df)} samples after normalization\")\n",
463
+ " except Exception as e:\n",
464
+ " print(f\" \u26a0 {name}: Error during normalization - {e}\")\n",
465
  "\n",
466
+ "print(f\"\\n\u2713 Normalized {len(normalized_datasets)} datasets\")\n"
467
  ]
468
  },
469
  {