ACA050's picture
Upload 79 files
a309487 verified
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.pipeline import Pipeline
def train_model(preprocessor, X, y, problem_type, strategy):
if strategy["model_family"] == "tree_ensemble":
if problem_type == "classification":
model = RandomForestClassifier(n_estimators=100)
else:
model = RandomForestRegressor(n_estimators=100)
elif strategy["model_family"] == "linear_or_tree":
if problem_type == "classification":
model = LogisticRegression()
else:
model = LinearRegression()
else:
# Fallback to RandomForest if strategy is not recognized
if problem_type == "classification":
model = RandomForestClassifier(n_estimators=100)
else:
model = RandomForestRegressor(n_estimators=100)
clf = Pipeline(steps=[
("preprocessor", preprocessor),
("model", model)
])
clf.fit(X, y)
return clf