Spaces:
Sleeping
Sleeping
File size: 6,429 Bytes
4c91838 3977aa0 4c91838 3977aa0 4c91838 3977aa0 4c91838 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
"""
train_dimred_model.py
Trains a dimensionality reduction model (e.g., PCA, t-SNE, UMAP) on a dataset.
It can drop or select specific columns, perform label encoding on any non-numeric columns,
and optionally visualize the reduced data (2D or 3D).
Example Usage:
--------------
python scripts/train_dimred_model.py \
--model_module pca \
--data_path data/raw/breast-cancer-wisconsin-data/data.csv \
--drop_columns "id" \
--select_columns "radius_mean, texture_mean, perimeter_mean, area_mean" \
--visualize
"""
import os
import sys
import argparse
import importlib
import pandas as pd
import numpy as np
import joblib
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
def main(args):
# Move to project root if needed
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
os.chdir(project_root)
sys.path.insert(0, project_root)
# Dynamically import the chosen model module (pca.py, tsne.py, umap.py, etc.)
model_module_path = f"models.unsupervised.dimred.{args.model_module}"
model_module = importlib.import_module(model_module_path)
# Retrieve the estimator from the model file
estimator = model_module.estimator
default_n_components = getattr(model_module, 'default_n_components', 2) # fallback
# Prepare results directory
if args.results_path is None:
# e.g., 'results/PCA_DimRed'
args.results_path = os.path.join('results', f"{estimator.__class__.__name__}_DimRed")
os.makedirs(args.results_path, exist_ok=True)
# Prepare model directory
if args.model_path is None:
# e.g., 'saved_model/PCA_DimRed'
args.model_path = os.path.join('saved_models', f"{estimator.__class__.__name__}_DimRed")
os.makedirs(args.model_path, exist_ok=True)
# Load data from CSV
df = pd.read_csv(args.data_path)
print(f"Data loaded from {args.data_path}, initial shape: {df.shape}")
# Drop empty columns
df = df.dropna(axis='columns', how='all')
print("After dropping empty columns:", df.shape)
# Drop specified columns if any
if args.drop_columns:
drop_cols = [col.strip() for col in args.drop_columns.split(',') if col.strip()]
df = df.drop(columns=drop_cols, errors='ignore')
print(f"Dropped columns: {drop_cols} | New shape: {df.shape}")
# Select specified columns if any
if args.select_columns:
keep_cols = [col.strip() for col in args.select_columns.split(',') if col.strip()]
df = df[keep_cols]
print(f"Selected columns: {keep_cols} | New shape: {df.shape}")
# Label-encode non-numeric columns
for col in df.columns:
if df[col].dtype == 'object':
le = LabelEncoder()
df[col] = le.fit_transform(df[col])
# Impute
imputer = SimpleImputer(strategy='mean') # or 'median'
df_array = imputer.fit_transform(df)
df_imputed = pd.DataFrame(df_array, columns=df.columns)
print("After label-encoding and imputation:", df_imputed.shape)
# Convert DataFrame to numpy array
X = df_imputed.values
print(f"Final data shape after dropping/selecting columns and encoding: {X.shape}")
# Fit-transform the data (typical for dimensionality reduction)
X_transformed = estimator.fit_transform(X)
print(f"Dimensionality reduction done using {args.model_module}. Output shape: {X_transformed.shape}")
# Save the model
model_output_path = os.path.join(args.model_path, "dimred_model.pkl")
joblib.dump(estimator, model_output_path)
print(f"Model saved to {model_output_path}")
# Save the transformed data
transformed_path = os.path.join(args.results_path, "X_transformed.csv")
pd.DataFrame(X_transformed).to_csv(transformed_path, index=False)
print(f"Transformed data saved to {transformed_path}")
# Visualization (only if 2D or 3D)
if args.visualize:
n_dims = X_transformed.shape[1]
if n_dims == 2:
plt.figure(figsize=(6,5))
plt.scatter(X_transformed[:,0], X_transformed[:,1], s=30, alpha=0.7, c='blue')
plt.title(f"{estimator.__class__.__name__} 2D Projection")
plt.xlabel("Component 1")
plt.ylabel("Component 2")
plot_path = os.path.join(args.results_path, "dimred_plot_2D.png")
plt.savefig(plot_path)
plt.show()
print(f"2D plot saved to {plot_path}")
elif n_dims == 3:
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X_transformed[:,0], X_transformed[:,1], X_transformed[:,2], s=30, alpha=0.7, c='blue')
ax.set_title(f"{estimator.__class__.__name__} 3D Projection")
ax.set_xlabel("Component 1")
ax.set_ylabel("Component 2")
ax.set_zlabel("Component 3")
plot_path = os.path.join(args.results_path, "dimred_plot_3D.png")
plt.savefig(plot_path)
plt.show()
print(f"3D plot saved to {plot_path}")
else:
print(f"Visualization only supported for 2D or 3D outputs. Got {n_dims}D, skipping.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a dimensionality reduction model.")
parser.add_argument('--model_module', type=str, required=True,
help='Name of the dimred model module (e.g. pca, tsne, umap).')
parser.add_argument('--data_path', type=str, required=True,
help='Path to the CSV dataset file.')
parser.add_argument('--model_path', type=str, default=None,
help='Where to save the fitted model.')
parser.add_argument('--results_path', type=str, default=None,
help='Directory to store results (transformed data, plots).')
parser.add_argument('--visualize', action='store_true',
help='Plot the transformed data if 2D or 3D.')
parser.add_argument('--drop_columns', type=str, default='',
help='Comma-separated column names to drop from the dataset.')
parser.add_argument('--select_columns', type=str, default='',
help='Comma-separated column names to keep (ignore the rest).')
args = parser.parse_args()
main(args)
|