Prompt48 commited on
Commit
b34449a
·
verified ·
1 Parent(s): 2c1b758

Upload edit\Qwen3-TTS-test\.venv\Lib\site-packages\sklearn\ensemble\tests\test_base.py with huggingface_hub

Browse files
edit//Qwen3-TTS-test//.venv//Lib//site-packages//sklearn//ensemble//tests//test_base.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Testing for the base module (sklearn.ensemble.base).
3
+ """
4
+
5
+ # Authors: The scikit-learn developers
6
+ # SPDX-License-Identifier: BSD-3-Clause
7
+
8
+ from collections import OrderedDict
9
+
10
+ import numpy as np
11
+
12
+ from sklearn.datasets import load_iris
13
+ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
14
+ from sklearn.ensemble import BaggingClassifier
15
+ from sklearn.ensemble._base import _set_random_states
16
+ from sklearn.feature_selection import SelectFromModel
17
+ from sklearn.linear_model import Perceptron
18
+ from sklearn.pipeline import Pipeline
19
+
20
+
21
+ def test_base():
22
+ # Check BaseEnsemble methods.
23
+ ensemble = BaggingClassifier(
24
+ estimator=Perceptron(random_state=None), n_estimators=3
25
+ )
26
+
27
+ iris = load_iris()
28
+ ensemble.fit(iris.data, iris.target)
29
+ ensemble.estimators_ = [] # empty the list and create estimators manually
30
+
31
+ ensemble._make_estimator()
32
+ random_state = np.random.RandomState(3)
33
+ ensemble._make_estimator(random_state=random_state)
34
+ ensemble._make_estimator(random_state=random_state)
35
+ ensemble._make_estimator(append=False)
36
+
37
+ assert 3 == len(ensemble)
38
+ assert 3 == len(ensemble.estimators_)
39
+
40
+ assert isinstance(ensemble[0], Perceptron)
41
+ assert ensemble[0].random_state is None
42
+ assert isinstance(ensemble[1].random_state, int)
43
+ assert isinstance(ensemble[2].random_state, int)
44
+ assert ensemble[1].random_state != ensemble[2].random_state
45
+
46
+ np_int_ensemble = BaggingClassifier(
47
+ estimator=Perceptron(), n_estimators=np.int32(3)
48
+ )
49
+ np_int_ensemble.fit(iris.data, iris.target)
50
+
51
+
52
+ def test_set_random_states():
53
+ # Linear Discriminant Analysis doesn't have random state: smoke test
54
+ _set_random_states(LinearDiscriminantAnalysis(), random_state=17)
55
+
56
+ clf1 = Perceptron(random_state=None)
57
+ assert clf1.random_state is None
58
+ # check random_state is None still sets
59
+ _set_random_states(clf1, None)
60
+ assert isinstance(clf1.random_state, int)
61
+
62
+ # check random_state fixes results in consistent initialisation
63
+ _set_random_states(clf1, 3)
64
+ assert isinstance(clf1.random_state, int)
65
+ clf2 = Perceptron(random_state=None)
66
+ _set_random_states(clf2, 3)
67
+ assert clf1.random_state == clf2.random_state
68
+
69
+ # nested random_state
70
+
71
+ def make_steps():
72
+ return [
73
+ ("sel", SelectFromModel(Perceptron(random_state=None))),
74
+ ("clf", Perceptron(random_state=None)),
75
+ ]
76
+
77
+ est1 = Pipeline(make_steps())
78
+ _set_random_states(est1, 3)
79
+ assert isinstance(est1.steps[0][1].estimator.random_state, int)
80
+ assert isinstance(est1.steps[1][1].random_state, int)
81
+ assert (
82
+ est1.get_params()["sel__estimator__random_state"]
83
+ != est1.get_params()["clf__random_state"]
84
+ )
85
+
86
+ # ensure multiple random_state parameters are invariant to get_params()
87
+ # iteration order
88
+
89
+ class AlphaParamPipeline(Pipeline):
90
+ def get_params(self, *args, **kwargs):
91
+ params = Pipeline.get_params(self, *args, **kwargs).items()
92
+ return OrderedDict(sorted(params))
93
+
94
+ class RevParamPipeline(Pipeline):
95
+ def get_params(self, *args, **kwargs):
96
+ params = Pipeline.get_params(self, *args, **kwargs).items()
97
+ return OrderedDict(sorted(params, reverse=True))
98
+
99
+ for cls in [AlphaParamPipeline, RevParamPipeline]:
100
+ est2 = cls(make_steps())
101
+ _set_random_states(est2, 3)
102
+ assert (
103
+ est1.get_params()["sel__estimator__random_state"]
104
+ == est2.get_params()["sel__estimator__random_state"]
105
+ )
106
+ assert (
107
+ est1.get_params()["clf__random_state"]
108
+ == est2.get_params()["clf__random_state"]
109
+ )