Spaces:
Sleeping
Sleeping
Create export_model_artifacts.py
Browse files- export_model_artifacts.py +45 -0
export_model_artifacts.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# export_model_artifacts.py
|
| 2 |
+
import joblib, xgboost as xgb, os
|
| 3 |
+
p = "best_overall_XGBoost.joblib"
|
| 4 |
+
m = joblib.load(p)
|
| 5 |
+
print("Loaded:", type(m))
|
| 6 |
+
|
| 7 |
+
# If pipeline, extract preprocessing and classifier
|
| 8 |
+
preproc_path = None
|
| 9 |
+
json_path = None
|
| 10 |
+
clf = m
|
| 11 |
+
try:
|
| 12 |
+
# sklearn Pipeline -> try to find the xgb step
|
| 13 |
+
from sklearn.pipeline import Pipeline
|
| 14 |
+
if isinstance(m, Pipeline):
|
| 15 |
+
# Save pipeline without the final estimator
|
| 16 |
+
steps = m.steps
|
| 17 |
+
# assume last step is classifier
|
| 18 |
+
*prefix, (last_name, last_obj) = steps
|
| 19 |
+
# Build preprocessor pipeline if any prefix exists
|
| 20 |
+
if prefix:
|
| 21 |
+
from sklearn.pipeline import Pipeline as SKPipeline
|
| 22 |
+
preproc = SKPipeline(prefix)
|
| 23 |
+
preproc_path = "preprocessor.joblib"
|
| 24 |
+
joblib.dump(preproc, preproc_path)
|
| 25 |
+
print("Saved preprocessor to", preproc_path)
|
| 26 |
+
clf = last_obj
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print("Not a Pipeline or failed to extract pipeline:", e)
|
| 29 |
+
|
| 30 |
+
# If clf is XGBClassifier, get booster
|
| 31 |
+
try:
|
| 32 |
+
booster = None
|
| 33 |
+
if hasattr(clf, "get_booster"):
|
| 34 |
+
booster = clf.get_booster()
|
| 35 |
+
elif isinstance(clf, xgb.Booster):
|
| 36 |
+
booster = clf
|
| 37 |
+
else:
|
| 38 |
+
print("Classifier type:", type(clf))
|
| 39 |
+
if booster is not None:
|
| 40 |
+
json_path = "best_overall_XGBoost.json"
|
| 41 |
+
booster.save_model(json_path)
|
| 42 |
+
print("Saved booster JSON to", json_path)
|
| 43 |
+
except Exception as e:
|
| 44 |
+
print("Failed to export booster JSON:", e)
|
| 45 |
+
import traceback; traceback.print_exc()
|