| | import logging |
| | import pickle |
| |
|
| | import matplotlib.pyplot as plt |
| | import polars as pl |
| | import seaborn as sns |
| | from numpy.typing import NDArray |
| | from sklearn.metrics import auc, confusion_matrix, roc_curve |
| | from sklearn.svm import SVC |
| |
|
| | from utils.paths import DATA, IMGS, MODEL |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| |
|
| |
|
| | def save_roc_curve(clf, X: NDArray, y: NDArray): |
| | probs = clf.predict_proba(X)[:, 1] |
| | fpr, tpr, thresholds = roc_curve(y, probs) |
| | roc_auc = auc(fpr, tpr) |
| |
|
| | plt.figure(figsize=(6, 5)) |
| | plt.plot( |
| | fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})" |
| | ) |
| | plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--") |
| | plt.xlim([0.0, 1.0]) |
| | plt.ylim([0.0, 1.05]) |
| | plt.xlabel("False Positive Rate") |
| | plt.ylabel("True Positive Rate") |
| | plt.title("Receiver Operating Characteristic (ROC)") |
| | plt.legend(loc="lower right") |
| | plt.tight_layout() |
| | plt.savefig(IMGS / "roc_curve.png") |
| | plt.close() |
| |
|
| |
|
| | def save_confusion_matrix(y: NDArray, pred: NDArray): |
| | plt.figure(figsize=(5, 4)) |
| | sns.heatmap( |
| | confusion_matrix(y, pred), |
| | annot=True, |
| | fmt="d", |
| | cmap="Blues", |
| | xticklabels=["Not Relevant", "Relevant"], |
| | yticklabels=["Not Relevant", "Relevant"], |
| | ) |
| | plt.xlabel("Predicted") |
| | plt.ylabel("Actual") |
| | plt.title("Confusion Matrix") |
| | plt.tight_layout() |
| | plt.savefig(IMGS / "confusion_matrix.png") |
| | plt.close() |
| |
|
| |
|
| | def main() -> None: |
| | train_df = pl.read_parquet(DATA / "train.parquet") |
| | clf = SVC(kernel="poly", probability=True) |
| | clf.fit( |
| | train_df.get_column("embeds").to_numpy(), |
| | train_df.get_column("is_news").to_numpy(), |
| | ) |
| | with open(MODEL / "model.pickle", "wb") as f: |
| | pickle.dump(clf, f) |
| |
|
| | eval_df = pl.read_parquet(DATA / "eval.parquet") |
| | eval_X = eval_df.get_column("embeds").to_numpy() |
| | eval_y = eval_df.get_column("is_news").to_numpy() |
| | eval_pred = clf.predict(eval_X) |
| | save_confusion_matrix(eval_y, eval_pred) |
| | save_roc_curve(clf, eval_X, eval_y) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|