Vu Anh commited on
Commit
ae0d039
·
1 Parent(s): 108511c

Add ruff linter and format code

Browse files

- Add ruff as dev dependency
- Format code according to ruff standards
- All checks passing

Files changed (3) hide show
  1. model.py +18 -13
  2. pyproject.toml +5 -0
  3. train.py +6 -4
model.py CHANGED
@@ -23,15 +23,20 @@ class SonarModel:
23
  random_state : int, default=42
24
  Random state for reproducibility
25
  """
26
- self.pipeline = Pipeline([
27
- ('scaler', StandardScaler()),
28
- ('classifier', RandomForestClassifier(
29
- n_estimators=n_estimators,
30
- max_depth=max_depth,
31
- random_state=random_state,
32
- n_jobs=-1
33
- ))
34
- ])
 
 
 
 
 
35
  self.is_fitted = False
36
 
37
  def fit(self, X, y):
@@ -52,7 +57,7 @@ class SonarModel:
52
  """
53
  self.pipeline.fit(X, y)
54
  self.is_fitted = True
55
- self.classes_ = self.pipeline.named_steps['classifier'].classes_
56
  self.n_features_ = X.shape[1]
57
  return self
58
 
@@ -121,7 +126,7 @@ class SonarModel:
121
  filepath : str
122
  Path to save the model
123
  """
124
- with open(filepath, 'wb') as f:
125
  pickle.dump(self, f)
126
 
127
  @classmethod
@@ -139,7 +144,7 @@ class SonarModel:
139
  model : SonarModel
140
  Loaded model instance
141
  """
142
- with open(filepath, 'rb') as f:
143
  return pickle.load(f)
144
 
145
  def get_feature_importance(self):
@@ -153,4 +158,4 @@ class SonarModel:
153
  """
154
  if not self.is_fitted:
155
  raise ValueError("Model must be fitted to get feature importances")
156
- return self.pipeline.named_steps['classifier'].feature_importances_
 
23
  random_state : int, default=42
24
  Random state for reproducibility
25
  """
26
+ self.pipeline = Pipeline(
27
+ [
28
+ ("scaler", StandardScaler()),
29
+ (
30
+ "classifier",
31
+ RandomForestClassifier(
32
+ n_estimators=n_estimators,
33
+ max_depth=max_depth,
34
+ random_state=random_state,
35
+ n_jobs=-1,
36
+ ),
37
+ ),
38
+ ]
39
+ )
40
  self.is_fitted = False
41
 
42
  def fit(self, X, y):
 
57
  """
58
  self.pipeline.fit(X, y)
59
  self.is_fitted = True
60
+ self.classes_ = self.pipeline.named_steps["classifier"].classes_
61
  self.n_features_ = X.shape[1]
62
  return self
63
 
 
126
  filepath : str
127
  Path to save the model
128
  """
129
+ with open(filepath, "wb") as f:
130
  pickle.dump(self, f)
131
 
132
  @classmethod
 
144
  model : SonarModel
145
  Loaded model instance
146
  """
147
+ with open(filepath, "rb") as f:
148
  return pickle.load(f)
149
 
150
  def get_feature_importance(self):
 
158
  """
159
  if not self.is_fitted:
160
  raise ValueError("Model must be fitted to get feature importances")
161
+ return self.pipeline.named_steps["classifier"].feature_importances_
pyproject.toml CHANGED
@@ -8,3 +8,8 @@ dependencies = [
8
  "scikit-learn>=1.3.0",
9
  "numpy>=1.24.0",
10
  ]
 
 
 
 
 
 
8
  "scikit-learn>=1.3.0",
9
  "numpy>=1.24.0",
10
  ]
11
+
12
+ [dependency-groups]
13
+ dev = [
14
+ "ruff>=0.13.1",
15
+ ]
train.py CHANGED
@@ -25,7 +25,7 @@ def generate_sample_data(n_samples=1000, n_features=60, n_classes=2):
25
  n_clusters_per_class=2,
26
  weights=[0.5, 0.5],
27
  flip_y=0.01,
28
- random_state=42
29
  )
30
  return X, y
31
 
@@ -87,14 +87,16 @@ def main():
87
  y_proba = loaded_model.predict_proba(X_samples)
88
 
89
  for i in range(len(sample_indices)):
90
- print(f" Sample {i+1}:")
91
  print(f" True class: {y_true[i]}")
92
  print(f" Predicted class: {y_pred[i]}")
93
- print(f" Probabilities: Class 0={y_proba[i][0]:.3f}, Class 1={y_proba[i][1]:.3f}")
 
 
94
 
95
  print("\n" + "=" * 50)
96
  print("Training completed successfully!")
97
 
98
 
99
  if __name__ == "__main__":
100
- main()
 
25
  n_clusters_per_class=2,
26
  weights=[0.5, 0.5],
27
  flip_y=0.01,
28
+ random_state=42,
29
  )
30
  return X, y
31
 
 
87
  y_proba = loaded_model.predict_proba(X_samples)
88
 
89
  for i in range(len(sample_indices)):
90
+ print(f" Sample {i + 1}:")
91
  print(f" True class: {y_true[i]}")
92
  print(f" Predicted class: {y_pred[i]}")
93
+ print(
94
+ f" Probabilities: Class 0={y_proba[i][0]:.3f}, Class 1={y_proba[i][1]:.3f}"
95
+ )
96
 
97
  print("\n" + "=" * 50)
98
  print("Training completed successfully!")
99
 
100
 
101
  if __name__ == "__main__":
102
+ main()