# Training Aside from [distillation](../../README.md#distillation), `model2vec` also supports training simple classifiers on top of static models, using [pytorch](https://pytorch.org/), [lightning](https://lightning.ai/) and [scikit-learn](https://scikit-learn.org/stable/index.html). We support both single and multi-label classification, which work seamlessly based on the labels you provide. # Installation To train, make sure you install the training extra: ``` pip install model2vec[training] ``` # Quickstart To train a model, simply initialize it using a `StaticModel`, or from a pre-trained model, as follows: ```python from model2vec.distill import distill from model2vec.train import StaticModelForClassification # From a distilled model distilled_model = distill("baai/bge-base-en-v1.5") classifier = StaticModelForClassification.from_static_model(model=distilled_model) # From a pre-trained model: potion is the default classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32m") ``` This creates a very simple classifier: a StaticModel with a single 512-unit hidden layer on top. You can adjust the number of hidden layers and the number units through some parameters on both functions. Note that the default for `from_pretrained` is [potion-base-32m](https://huggingface.co/minishlab/potion-base-32M), our best model to date. This is our recommended path if you're working with general English data. Now that you have created the classifier, let's just train a model. The example below assumes you have the [`datasets`](https://github.com/huggingface/datasets) library installed. ```python import numpy as np from datasets import load_dataset # Load the subj dataset ds = load_dataset("setfit/subj") train = ds["train"] test = ds["test"] s = perf_counter() classifier = classifier.fit(train["text"], train["label"]) print(f"Training took {int(perf_counter() - s)} seconds.") # Training took 81 seconds classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["label"]) print(classification_report) # Achieved 91.0 test accuracy ``` As you can see, we got a pretty nice 91% accuracy, with only 81 seconds of training. The training loop is handled by [`lightning`](https://pypi.org/project/lightning/). By default the training loop splits the data into a train and validation split, with 90% of the data being used for training and 10% for validation. By default, it runs with early stopping on the validation set accuracy, with a patience of 5. Note that this model is as fast as you're used to from us: ```python from time import perf_counter s = perf_counter() classifier.predict(test["text"]) print(f"Took {int((perf_counter() - s) * 1000)} milliseconds for {len(test)} instances on CPU.") # Took 67 milliseconds for 2000 instances on CPU. ``` ## Multi-label classification Multi-label classification is supported out of the box. Just pass a list of lists to the `fit` function (e.g. `[[label1, label2], [label1, label3]]`), and a multi-label classifier will be trained. For example, the following code trains a multi-label classifier on the [go_emotions](https://huggingface.co/datasets/google-research-datasets/go_emotions) dataset: ```python from datasets import load_dataset from model2vec.train import StaticModelForClassification # Initialize a classifier from a pre-trained model classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M") # Load a multi-label dataset ds = load_dataset("google-research-datasets/go_emotions") # Inspect some of the labels print(ds["train"]["labels"][40:50]) # [[0, 15], [15, 18], [16, 27], [27], [7, 13], [10], [20], [27], [27], [27]] # Train the classifier on text (X) and labels (y) classifier.fit(ds["train"]["text"], ds["train"]["labels"]) ``` Then, we can evaluate the classifier: ```python from sklearn import metrics from sklearn.preprocessing import MultiLabelBinarizer classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["labels"], threshold=0.3) print(classification_report) # Accuracy: 0.410 # Precision: 0.527 # Recall: 0.410 # F1: 0.439 ``` The scores are competitive with the popular [roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) model, while our model is orders of magnitude faster. # Persistence You can turn a classifier into a scikit-learn compatible pipeline, as follows: ```python pipeline = classifier.to_pipeline() ``` This pipeline object can be persisted using standard pickle-based methods, such as [joblib](https://joblib.readthedocs.io/en/stable/). This makes it easy to use your model in inferene pipelines (no installing torch!), although `joblib` and `pickle` should not be used to share models outside of your organization. If you want to persist your pipeline to the Hugging Face hub, you can use our built-in functions: ```python pipeline.save_pretrained(path) pipeline.push_to_hub("my_cool/project") ``` Later, you can load these as follows: ```python from model2vec.inference import StaticModelPipeline pipeline = StaticModelPipeline.from_pretrained("my_cool/project") ``` Loading pipelines in this way is _extremely_ fast. It takes only 30ms to load a pipeline from disk. # Bring your own architecture Our training architecture is set up to be extensible, with each task having a specific class. Right now, we only offer `StaticModelForClassification`, but in the future we'll also offer regression, etc. The core functionality of the `StaticModelForClassification` is contained in a couple of functions: * `construct_head`: This function constructs the classifier on top of the staticmodel. For example, if you want to create a model that has LayerNorm, just subclass, and replace this function. This should be the main function to update if you want to change model behavior. * `train_test_split`: governs the train test split before classification. * `prepare_dataset`: Selects the `torch.Dataset` that will be used in the `Dataloader` during training. * `_encode`: The encoding function used in the model. * `fit`: contains all the lightning-related fitting logic. The training of the model is done in a `lighting.LightningModule`, which can be modified but is very basic. # Results We ran extensive benchmarks where we compared our model to several well known architectures. The results can be found in the [training results](https://github.com/MinishLab/model2vec/tree/main/results#training-results) documentation.