tourism-rf-model / src /train_colab_model.py
Shramik121's picture
Upload model and application files to Hugging Face Space
343148c verified
import pandas as pd
import joblib
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from datasets import load_dataset
import os
dataset = load_dataset("Shramik121/tourism-split-dataset")
data = pd.DataFrame(dataset['train'])
if 'Unnamed: 0' in data.columns:
data = data.drop('Unnamed: 0', axis=1)
num_cols = ['Age', 'DurationOfPitch', 'NumberOfPersonVisiting', 'NumberOfFollowups',
'PreferredPropertyStar', 'NumberOfTrips', 'PitchSatisfactionScore',
'NumberOfChildrenVisiting', 'MonthlyIncome']
cat_cols = ['TypeofContact', 'Occupation', 'Gender', 'ProductPitched',
'MaritalStatus', 'Designation', 'CityTier']
data[num_cols] = data[num_cols].fillna(data[num_cols].median())
data[cat_cols] = data[cat_cols].fillna('Unknown')
if 'CustomerID' in data:
data = data.drop('CustomerID', axis=1)
if 'Gender' in data:
data['Gender'] = data['Gender'].replace('Fe Male', 'Female')
X = data.drop('ProdTaken', axis=1)
y = data['ProdTaken']
preprocessor = ColumnTransformer(
transformers=[
('num', StandardScaler(), num_cols),
('cat', OneHotEncoder(handle_unknown='ignore'), cat_cols)
],
remainder='passthrough'
)
pipeline = Pipeline(steps=[('preprocessor', preprocessor),
('classifier', RandomForestClassifier(random_state=42))])
pipeline.fit(X, y)
dummy_df = pd.DataFrame(columns=X.columns)
preprocessor.fit(dummy_df)
feature_names = []
for name, transformer, cols in preprocessor.transformers_:
if hasattr(transformer, 'get_feature_names_out'):
feature_names.extend(transformer.get_feature_names_out(cols))
else:
feature_names.extend(cols)
columns = feature_names
os.makedirs('/content/models', exist_ok=True)
joblib.dump(columns, '/content/models/columns.joblib')
joblib.dump(pipeline, '/content/models/best_rf_model.joblib')
print("Model and columns saved to /content/models/")