File size: 1,099 Bytes
a309487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.pipeline import Pipeline


def train_model(preprocessor, X, y, problem_type, strategy):
    if strategy["model_family"] == "tree_ensemble":
        if problem_type == "classification":
            model = RandomForestClassifier(n_estimators=100)
        else:
            model = RandomForestRegressor(n_estimators=100)
    elif strategy["model_family"] == "linear_or_tree":
        if problem_type == "classification":
            model = LogisticRegression()
        else:
            model = LinearRegression()
    else:
        # Fallback to RandomForest if strategy is not recognized
        if problem_type == "classification":
            model = RandomForestClassifier(n_estimators=100)
        else:
            model = RandomForestRegressor(n_estimators=100)

    clf = Pipeline(steps=[
        ("preprocessor", preprocessor),
        ("model", model)
    ])

    clf.fit(X, y)
    return clf