{ "nbformat": 4, "nbformat_minor": 5, "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# IT vs Non-IT Job Title Classifier \u2014 `intfloat/e5-base-v2`\n", "\n", "Trains a logistic regression classifier on top of `intfloat/e5-base-v2` embeddings to classify job titles as IT or Non-IT. Exports the classifier head to ONNX for lightweight, runtime-friendly inference.\n", "\n", "**Steps:**\n", "1. Load and split labeled job title data\n", "2. Encode titles with e5-base-v2 (mean pool + L2 normalize)\n", "3. Train logistic regression on embeddings\n", "4. Evaluate on held-out test split\n", "5. Run threshold sweep to inform deployment threshold choice\n", "6. Export classifier to ONNX\n", "\n", "> \u26a0\ufe0f After running the install cell, go to **Runtime \u2192 Restart session**, then run all cells from top." ] }, { "cell_type": "code", "metadata": {}, "source": [ "# \u2500\u2500 Install \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "# Restart the runtime after running this cell\n", "!pip install -q -U transformers \"sentence-transformers[onnx]\" scikit-learn skl2onnx pandas" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "# \u2500\u2500 Load data \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "import pandas as pd\n", "import numpy as np\n", "\n", "DATA_URL = 'https://docs.google.com/spreadsheets/d/e/2PACX-1vQ2s_EXIc36mtdVCi72RY7iO380wMMEhxlhyBUeE71uBCC5fMsRlKpHgafasxcQochvQCBsQF8IuNei/pub?gid=1233103818&single=true&output=csv'\n", "\n", "jobs_df = pd.read_csv(DATA_URL)\n", "jobs_df['text'] = jobs_df['job_title'].fillna('').str.strip()\n", "\n", "train_df = jobs_df[jobs_df['split'] == 'train'].reset_index(drop=True)\n", "test_df = jobs_df[jobs_df['split'] == 'test'].reset_index(drop=True)\n", "\n", "print(f'Train: {len(train_df)} | Test: {len(test_df)}')\n", "print(train_df['label'].value_counts())" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "# \u2500\u2500 Encode with e5-base-v2 \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "from sentence_transformers import SentenceTransformer\n", "\n", "encoder = SentenceTransformer('intfloat/e5-base-v2')\n", "\n", "def encode(texts, batch_size=500):\n", " \"\"\"Encode texts with the e5 query prefix, mean pooling, and L2 normalization.\"\"\"\n", " prefixed = ['query: ' + t for t in texts]\n", " return encoder.encode(\n", " prefixed,\n", " batch_size=batch_size,\n", " normalize_embeddings=True,\n", " show_progress_bar=True,\n", " )\n", "\n", "train_embs = encode(train_df['text'].tolist())\n", "test_embs = encode(test_df['text'].tolist())\n", "print(f'Train: {train_embs.shape} | Test: {test_embs.shape}')" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "# \u2500\u2500 Train logistic regression \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.metrics import classification_report, confusion_matrix, accuracy_score\n", "\n", "y_train = train_df['label'].tolist()\n", "y_test = test_df['label'].tolist()\n", "\n", "clf = LogisticRegression(C=1.0, max_iter=1000, class_weight='balanced')\n", "clf.fit(train_embs, y_train)\n", "preds = clf.predict(test_embs)\n", "\n", "print(f'Accuracy: {accuracy_score(y_test, preds):.4f}')\n", "print()\n", "print(classification_report(y_test, preds, target_names=['Non-IT (0)', 'IT (1)']))\n", "\n", "cm = confusion_matrix(y_test, preds)\n", "print('\u2500\u2500\u2500 Confusion Matrix \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500')\n", "print(f'{\"\":18} Pred IT Pred Non-IT')\n", "print(f' Actual IT {cm[1][1]:<10} {cm[1][0]} \u2190 fn: missed IT')\n", "print(f' Actual Non-IT {cm[0][1]:<10} {cm[0][0]} \u2190 fp: Non-IT incorrectly kept')" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "# \u2500\u2500 Threshold sweep \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "# Use this to inform your deployment threshold choice.\n", "# Lower threshold \u2192 higher IT recall (fewer missed IT jobs, more false positives).\n", "# Higher threshold \u2192 higher IT precision (fewer false positives, more missed IT jobs).\n", "\n", "probas = clf.predict_proba(test_embs)[:, 1]\n", "y_true = np.array(y_test)\n", "\n", "print(f\"{'Threshold':>10} {'Accuracy':>9} {'IT Prec':>8} {'IT Rec':>8} {'NonIT Prec':>10} {'NonIT Rec':>10} {'FN':>6} {'FP':>6}\")\n", "print('\u2500' * 85)\n", "\n", "for t in np.arange(0.20, 0.81, 0.05):\n", " p = (probas >= t).astype(int)\n", " tp_ = ((y_true==1)&(p==1)).sum()\n", " tn_ = ((y_true==0)&(p==0)).sum()\n", " fp_ = ((y_true==0)&(p==1)).sum()\n", " fn_ = ((y_true==1)&(p==0)).sum()\n", " acc_ = (tp_+tn_)/len(y_true)\n", " it_p = tp_/(tp_+fp_) if (tp_+fp_) else 0\n", " it_r = tp_/(tp_+fn_) if (tp_+fn_) else 0\n", " nt_p = tn_/(tn_+fn_) if (tn_+fn_) else 0\n", " nt_r = tn_/(tn_+fp_) if (tn_+fp_) else 0\n", " print(f'{t:>10.2f} {acc_:>9.4f} {it_p:>8.4f} {it_r:>8.4f} {nt_p:>10.4f} {nt_r:>10.4f} {fn_:>6} {fp_:>6}')" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "# \u2500\u2500 Export classifier to ONNX \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "from skl2onnx import convert_sklearn\n", "from skl2onnx.common.data_types import FloatTensorType\n", "\n", "initial_type = [('input', FloatTensorType([None, train_embs.shape[1]]))]\n", "onnx_clf = convert_sklearn(clf, initial_types=initial_type)\n", "\n", "with open('e5_it_classifier.onnx', 'wb') as f:\n", " f.write(onnx_clf.SerializeToString())\n", "\n", "print(f'Saved \u2192 e5_it_classifier.onnx (embedding dim: {train_embs.shape[1]})')" ], "outputs": [], "execution_count": null } ] }