Spaces:
Sleeping
Sleeping
backend added
Browse files- .dockerignore +18 -0
- Dockerfile +48 -0
- README.md +125 -7
- _wsgi.py +122 -0
- docker-compose.yml +44 -0
- model.py +214 -0
- requirements-base.txt +2 -0
- requirements-test.txt +2 -0
- requirements.txt +3 -0
- test_api.py +62 -0
.dockerignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Exclude everything
|
| 2 |
+
**
|
| 3 |
+
|
| 4 |
+
# Include Dockerfile and docker-compose for reference (optional, decide based on your use case)
|
| 5 |
+
!Dockerfile
|
| 6 |
+
!docker-compose.yml
|
| 7 |
+
|
| 8 |
+
# Include Python application files
|
| 9 |
+
!*.py
|
| 10 |
+
|
| 11 |
+
# Include requirements files
|
| 12 |
+
!requirements*.txt
|
| 13 |
+
|
| 14 |
+
# Include script
|
| 15 |
+
!*.sh
|
| 16 |
+
|
| 17 |
+
# Exclude specific requirements if necessary
|
| 18 |
+
# requirements-test.txt (Uncomment if you decide to exclude this)
|
Dockerfile
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# syntax=docker/dockerfile:1
|
| 2 |
+
ARG PYTHON_VERSION=3.11
|
| 3 |
+
|
| 4 |
+
FROM python:${PYTHON_VERSION}-slim AS python-base
|
| 5 |
+
ARG TEST_ENV
|
| 6 |
+
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 10 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 11 |
+
PORT=${PORT:-9090} \
|
| 12 |
+
PIP_CACHE_DIR=/.cache \
|
| 13 |
+
WORKERS=1 \
|
| 14 |
+
THREADS=8
|
| 15 |
+
|
| 16 |
+
# Update the base OS
|
| 17 |
+
RUN --mount=type=cache,target="/var/cache/apt",sharing=locked \
|
| 18 |
+
--mount=type=cache,target="/var/lib/apt/lists",sharing=locked \
|
| 19 |
+
set -eux; \
|
| 20 |
+
apt-get update; \
|
| 21 |
+
apt-get upgrade -y; \
|
| 22 |
+
apt install --no-install-recommends -y \
|
| 23 |
+
git; \
|
| 24 |
+
apt-get autoremove -y
|
| 25 |
+
|
| 26 |
+
# install base requirements
|
| 27 |
+
COPY requirements-base.txt .
|
| 28 |
+
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
|
| 29 |
+
pip install -r requirements-base.txt
|
| 30 |
+
|
| 31 |
+
# install custom requirements
|
| 32 |
+
COPY requirements.txt .
|
| 33 |
+
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
|
| 34 |
+
pip install -r requirements.txt
|
| 35 |
+
|
| 36 |
+
# install test requirements if needed
|
| 37 |
+
COPY requirements-test.txt .
|
| 38 |
+
# build only when TEST_ENV="true"
|
| 39 |
+
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
|
| 40 |
+
if [ "$TEST_ENV" = "true" ]; then \
|
| 41 |
+
pip install -r requirements-test.txt; \
|
| 42 |
+
fi
|
| 43 |
+
|
| 44 |
+
COPY . .
|
| 45 |
+
|
| 46 |
+
EXPOSE 9090
|
| 47 |
+
|
| 48 |
+
CMD gunicorn --preload --bind :$PORT --workers $WORKERS --threads $THREADS --timeout 0 _wsgi:app
|
README.md
CHANGED
|
@@ -1,10 +1,128 @@
|
|
|
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
|
|
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!--
|
| 2 |
---
|
| 3 |
+
title: Classify text with a BERT model
|
| 4 |
+
type: guide
|
| 5 |
+
tier: all
|
| 6 |
+
order: 35
|
| 7 |
+
hide_menu: true
|
| 8 |
+
hide_frontmatter_title: true
|
| 9 |
+
meta_title: BERT-based text classification
|
| 10 |
+
meta_description: Tutorial on how to use BERT-based text classification with your Label Studio project
|
| 11 |
+
categories:
|
| 12 |
+
- Natural Language Processing
|
| 13 |
+
- Text Classification
|
| 14 |
+
- BERT
|
| 15 |
+
- Hugging Face
|
| 16 |
+
image: "/tutorials/bert.png"
|
| 17 |
---
|
| 18 |
+
-->
|
| 19 |
|
| 20 |
+
# BERT-based text classification
|
| 21 |
+
|
| 22 |
+
The NewModel is a BERT-based text classification model that is designed to work with Label Studio. This model uses the Hugging Face Transformers library to fine-tune a BERT model for text classification. The model is trained on the labeled data from Label Studio and then used to make predictions on new data. With this model connected to Label Studio, you can:
|
| 23 |
+
|
| 24 |
+
- Train a BERT model on your labeled data directly from Label Studio.
|
| 25 |
+
- Use any model for [AutoModelForSequenceClassification](https://huggingface.co/transformers/v3.0.2/model_doc/auto.html#automodelforsequenceclassification) from the Hugging Face model hub.
|
| 26 |
+
- Fine-tune the model on your specific task and use it to make predictions on new data.
|
| 27 |
+
- Automatically download the labeled tasks from Label Studio and prepare the data for training.
|
| 28 |
+
- Customize the training parameters such as learning rate, number of epochs, and weight decay.
|
| 29 |
+
|
| 30 |
+
## Before you begin
|
| 31 |
+
|
| 32 |
+
Before you begin, you must install the [Label Studio ML backend](https://github.com/HumanSignal/label-studio-ml-backend?tab=readme-ov-file#quickstart).
|
| 33 |
+
|
| 34 |
+
This tutorial uses the [`bert_classifier` example](https://github.com/HumanSignal/label-studio-ml-backend/tree/master/label_studio_ml/examples/bert_classifier).
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
## Running with Docker (recommended)
|
| 38 |
+
|
| 39 |
+
1. Start the Machine Learning backend on `http://localhost:9090` with the prebuilt image:
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
docker-compose up
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
2. Validate that backend is running:
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
$ curl http://localhost:9090/
|
| 49 |
+
{"status":"UP"}
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
3. Create a project in Label Studio. Then from the **Model** page in the project settings, [connect the model](https://labelstud.io/guide/ml#Connect-the-model-to-Label-Studio). The default URL is `http://localhost:9090`.
|
| 53 |
+
|
| 54 |
+
> Warning! Note the current limitation of the ML backend: models are loaded dynamically from huggingface.co. You may need the `HF_TOKEN` env variable provided in your environment. Consequently, this may result in a slow response time for the first prediction request. If you are experiencing timeouts on Label Studio side (i.e., no predictions are visible when opening the task), check the logs of the ML backend for any errors, and refresh the page in a few minutes.
|
| 55 |
+
|
| 56 |
+
## Building from source (advanced)
|
| 57 |
+
|
| 58 |
+
To build the ML backend from source, you have to clone the repository and build the Docker image:
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
docker-compose build
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
## Running without Docker (advanced)
|
| 65 |
+
|
| 66 |
+
To run the ML backend without Docker, you have to clone the repository and install all dependencies using pip:
|
| 67 |
+
|
| 68 |
+
```bash
|
| 69 |
+
python -m venv ml-backend
|
| 70 |
+
source ml-backend/bin/activate
|
| 71 |
+
pip install -r requirements.txt
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
Then you can start the ML backend:
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
label-studio-ml start ./dir_with_your_model
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
## Labeling configuration
|
| 82 |
+
|
| 83 |
+
In project `Settings > Labeling Interface > Browse Templates > Natural Language Processing > Text Classification`, you can find the default labeling configuration for text classification in Label Studio. This configuration includes a single `<Choices>` output and a single `<Text>` input.
|
| 84 |
+
Feel free to modify the set of labels in the `<Choices>` tag to match your specific task, for example:
|
| 85 |
+
|
| 86 |
+
```xml
|
| 87 |
+
<View>
|
| 88 |
+
<Text name="text" value="$text" />
|
| 89 |
+
<Choices name="label" toName="text" choice="single" showInLine="true">
|
| 90 |
+
<Choice value="label one" />
|
| 91 |
+
<Choice value="label two" />
|
| 92 |
+
<Choice value="label three" />
|
| 93 |
+
</Choices>
|
| 94 |
+
</View>
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
## Configuration
|
| 99 |
+
|
| 100 |
+
Parameters can be set in `docker-compose.yml` before running the container.
|
| 101 |
+
|
| 102 |
+
The following common parameters are available:
|
| 103 |
+
|
| 104 |
+
- `BASIC_AUTH_USER` - Specify the basic auth user for the model server
|
| 105 |
+
- `BASIC_AUTH_PASS` - Specify the basic auth password for the model server
|
| 106 |
+
- `LOG_LEVEL` - Set the log level for the model server
|
| 107 |
+
- `WORKERS` - Specify the number of workers for the model server
|
| 108 |
+
- `THREADS` - Specify the number of threads for the model server
|
| 109 |
+
- `BASELINE_MODEL_NAME`: The name of the baseline model to use for training. Default is `bert-base-multilingual-cased`.
|
| 110 |
+
|
| 111 |
+
## Training
|
| 112 |
+
|
| 113 |
+
The following parameters are available for training:
|
| 114 |
+
|
| 115 |
+
- `LABEL_STUDIO_HOST` (required): The URL of the Label Studio instance. Default is `http://localhost:8080`.
|
| 116 |
+
- `LABEL_STUDIO_API_KEY` (required): The [API key](https://labelstud.io/guide/user_account#Access-token) for the Label Studio instance.
|
| 117 |
+
- `START_TRAINING_EACH_N_UPDATES`: The number of labeled tasks to download from Label Studio before starting training. Default is 10.
|
| 118 |
+
- `LEARNING_RATE`: The learning rate for the model training. Default is 2e-5.
|
| 119 |
+
- `NUM_TRAIN_EPOCHS`: The number of epochs for model training. Default is 3.
|
| 120 |
+
- `WEIGHT_DECAY`: The weight decay for the model training. Default is 0.01.
|
| 121 |
+
- `FINETUNED_MODEL_NAME`: The name of the fine-tuned model. Default is `finetuned_model`. Checkpoints will be saved under this name.
|
| 122 |
+
|
| 123 |
+
> Note: The `LABEL_STUDIO_API_KEY` is required for training the model. You can find the API key in Label Studio under the [**Account & Settings** page](https://labelstud.io/guide/user_account#Access-token).
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# Customization
|
| 127 |
+
|
| 128 |
+
The ML backend can be customized by adding your own models and logic inside the `./bert_classifier` directory.
|
_wsgi.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import logging.config
|
| 6 |
+
|
| 7 |
+
logging.config.dictConfig({
|
| 8 |
+
"version": 1,
|
| 9 |
+
"disable_existing_loggers": False,
|
| 10 |
+
"formatters": {
|
| 11 |
+
"standard": {
|
| 12 |
+
"format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s"
|
| 13 |
+
}
|
| 14 |
+
},
|
| 15 |
+
"handlers": {
|
| 16 |
+
"console": {
|
| 17 |
+
"class": "logging.StreamHandler",
|
| 18 |
+
"level": os.getenv('LOG_LEVEL'),
|
| 19 |
+
"stream": "ext://sys.stdout",
|
| 20 |
+
"formatter": "standard"
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
"root": {
|
| 24 |
+
"level": os.getenv('LOG_LEVEL'),
|
| 25 |
+
"handlers": [
|
| 26 |
+
"console"
|
| 27 |
+
],
|
| 28 |
+
"propagate": True
|
| 29 |
+
}
|
| 30 |
+
})
|
| 31 |
+
|
| 32 |
+
from label_studio_ml.api import init_app
|
| 33 |
+
from model import BertClassifier
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
_DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json')
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH):
|
| 40 |
+
if not os.path.exists(config_path):
|
| 41 |
+
return dict()
|
| 42 |
+
with open(config_path) as f:
|
| 43 |
+
config = json.load(f)
|
| 44 |
+
assert isinstance(config, dict)
|
| 45 |
+
return config
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
parser = argparse.ArgumentParser(description='Label studio')
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
'-p', '--port', dest='port', type=int, default=9090,
|
| 52 |
+
help='Server port')
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
'--host', dest='host', type=str, default='0.0.0.0',
|
| 55 |
+
help='Server host')
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
'--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+', type=lambda kv: kv.split('='),
|
| 58 |
+
help='Additional LabelStudioMLBase model initialization kwargs')
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
'-d', '--debug', dest='debug', action='store_true',
|
| 61 |
+
help='Switch debug mode')
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
'--log-level', dest='log_level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=None,
|
| 64 |
+
help='Logging level')
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
'--model-dir', dest='model_dir', default=os.path.dirname(__file__),
|
| 67 |
+
help='Directory where models are stored (relative to the project directory)')
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
'--check', dest='check', action='store_true',
|
| 70 |
+
help='Validate model instance before launching server')
|
| 71 |
+
parser.add_argument('--basic-auth-user',
|
| 72 |
+
default=os.environ.get('ML_SERVER_BASIC_AUTH_USER', None),
|
| 73 |
+
help='Basic auth user')
|
| 74 |
+
|
| 75 |
+
parser.add_argument('--basic-auth-pass',
|
| 76 |
+
default=os.environ.get('ML_SERVER_BASIC_AUTH_PASS', None),
|
| 77 |
+
help='Basic auth pass')
|
| 78 |
+
|
| 79 |
+
args = parser.parse_args()
|
| 80 |
+
|
| 81 |
+
# setup logging level
|
| 82 |
+
if args.log_level:
|
| 83 |
+
logging.root.setLevel(args.log_level)
|
| 84 |
+
|
| 85 |
+
def isfloat(value):
|
| 86 |
+
try:
|
| 87 |
+
float(value)
|
| 88 |
+
return True
|
| 89 |
+
except ValueError:
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
def parse_kwargs():
|
| 93 |
+
param = dict()
|
| 94 |
+
for k, v in args.kwargs:
|
| 95 |
+
if v.isdigit():
|
| 96 |
+
param[k] = int(v)
|
| 97 |
+
elif v == 'True' or v == 'true':
|
| 98 |
+
param[k] = True
|
| 99 |
+
elif v == 'False' or v == 'false':
|
| 100 |
+
param[k] = False
|
| 101 |
+
elif isfloat(v):
|
| 102 |
+
param[k] = float(v)
|
| 103 |
+
else:
|
| 104 |
+
param[k] = v
|
| 105 |
+
return param
|
| 106 |
+
|
| 107 |
+
kwargs = get_kwargs_from_config()
|
| 108 |
+
|
| 109 |
+
if args.kwargs:
|
| 110 |
+
kwargs.update(parse_kwargs())
|
| 111 |
+
|
| 112 |
+
if args.check:
|
| 113 |
+
print('Check "' + BertClassifier.__name__ + '" instance creation..')
|
| 114 |
+
model = BertClassifier(**kwargs)
|
| 115 |
+
|
| 116 |
+
app = init_app(model_class=BertClassifier, basic_auth_user=args.basic_auth_user, basic_auth_pass=args.basic_auth_pass)
|
| 117 |
+
|
| 118 |
+
app.run(host=args.host, port=args.port, debug=args.debug)
|
| 119 |
+
|
| 120 |
+
else:
|
| 121 |
+
# for uWSGI use
|
| 122 |
+
app = init_app(model_class=BertClassifier)
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: "3.8"
|
| 2 |
+
|
| 3 |
+
services:
|
| 4 |
+
bert_classifier:
|
| 5 |
+
container_name: bert_classifier
|
| 6 |
+
image: heartexlabs/label-studio-ml-backend:bertclass-master
|
| 7 |
+
build:
|
| 8 |
+
context: .
|
| 9 |
+
args:
|
| 10 |
+
TEST_ENV: ${TEST_ENV}
|
| 11 |
+
environment:
|
| 12 |
+
# If you are using this model for training, you have to connect it to Label Studio
|
| 13 |
+
- LABEL_STUDIO_HOST=http://localhost:8080
|
| 14 |
+
- LABEL_STUDIO_API_KEY=your-api-key
|
| 15 |
+
# Use any model for [AutoModelForSequenceClassification](https://huggingface.co/transformers/v3.0.2/model_doc/auto.html#automodelforsequenceclassification)
|
| 16 |
+
- BASELINE_MODEL_NAME=bert-base-multilingual-cased
|
| 17 |
+
# - BASELINE_MODEL_NAME=google/electra-small-discriminator
|
| 18 |
+
# The model directory for the fine-tuned checkpoints (relative to $MODEL_DIR)
|
| 19 |
+
- FINETUNED_MODEL_NAME=finetuned_model
|
| 20 |
+
# The number of labeled tasks to download from Label Studio before starting training
|
| 21 |
+
- START_TRAINING_EACH_N_UPDATES=10
|
| 22 |
+
# Learning rate
|
| 23 |
+
- LEARNING_RATE=2e-5
|
| 24 |
+
# Number of epochs
|
| 25 |
+
- NUM_TRAIN_EPOCHS=3
|
| 26 |
+
# Weight decay
|
| 27 |
+
- WEIGHT_DECAY=0.01
|
| 28 |
+
# specify these parameters if you want to use basic auth for the model server
|
| 29 |
+
- BASIC_AUTH_USER=
|
| 30 |
+
- BASIC_AUTH_PASS=
|
| 31 |
+
# set the log level for the model server
|
| 32 |
+
- LOG_LEVEL=DEBUG
|
| 33 |
+
# any other parameters that you want to pass to the model server
|
| 34 |
+
- ANY=PARAMETER
|
| 35 |
+
# specify the number of workers and threads for the model server
|
| 36 |
+
- WORKERS=1
|
| 37 |
+
- THREADS=8
|
| 38 |
+
# specify the model directory (likely you don't need to change this)
|
| 39 |
+
- MODEL_DIR=/data/models
|
| 40 |
+
ports:
|
| 41 |
+
- "9090:9090"
|
| 42 |
+
volumes:
|
| 43 |
+
- "./data/server:/data"
|
| 44 |
+
- "./data/.cache:/root/.cache"
|
model.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import logging
|
| 4 |
+
import pathlib
|
| 5 |
+
import label_studio_sdk
|
| 6 |
+
|
| 7 |
+
from typing import List, Dict, Optional
|
| 8 |
+
from label_studio_ml.model import LabelStudioMLBase
|
| 9 |
+
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
|
| 10 |
+
from transformers import pipeline
|
| 11 |
+
from label_studio_sdk.label_interface.objects import PredictionValue
|
| 12 |
+
from label_studio_ml.response import ModelResponse
|
| 13 |
+
from datasets import Dataset
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if torch.cuda.is_available():
|
| 19 |
+
device = torch.device("cuda")
|
| 20 |
+
print('There are %d GPU(s) available.' % torch.cuda.device_count())
|
| 21 |
+
print('We will use the GPU:', torch.cuda.get_device_name(0))
|
| 22 |
+
else:
|
| 23 |
+
print('No GPU available, using the CPU instead.')
|
| 24 |
+
device = torch.device("cpu")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class BertClassifier(LabelStudioMLBase):
|
| 28 |
+
"""
|
| 29 |
+
BERT-based text classification model for Label Studio
|
| 30 |
+
|
| 31 |
+
This model uses the Hugging Face Transformers library to fine-tune a BERT model for text classification.
|
| 32 |
+
Use any model for [AutoModelForSequenceClassification](https://huggingface.co/transformers/v3.0.2/model_doc/auto.html#automodelforsequenceclassification)
|
| 33 |
+
The model is trained on the labeled data from Label Studio and then used to make predictions on new data.
|
| 34 |
+
|
| 35 |
+
Parameters:
|
| 36 |
+
-----------
|
| 37 |
+
LABEL_STUDIO_HOST : str
|
| 38 |
+
The URL of the Label Studio instance
|
| 39 |
+
LABEL_STUDIO_API_KEY : str
|
| 40 |
+
The API key for the Label Studio instance
|
| 41 |
+
START_TRAINING_EACH_N_UPDATES : int
|
| 42 |
+
The number of labeled tasks to download from Label Studio before starting training
|
| 43 |
+
LEARNING_RATE : float
|
| 44 |
+
The learning rate for the model training
|
| 45 |
+
NUM_TRAIN_EPOCHS : int
|
| 46 |
+
The number of epochs for model training
|
| 47 |
+
WEIGHT_DECAY : float
|
| 48 |
+
The weight decay for the model training
|
| 49 |
+
baseline_model_name : str
|
| 50 |
+
The name of the baseline model to use for training
|
| 51 |
+
MODEL_DIR : str
|
| 52 |
+
The directory to save the trained model
|
| 53 |
+
finetuned_model_name : str
|
| 54 |
+
The name of the finetuned model
|
| 55 |
+
"""
|
| 56 |
+
LABEL_STUDIO_HOST = os.getenv('LABEL_STUDIO_HOST', 'http://localhost:8080')
|
| 57 |
+
LABEL_STUDIO_API_KEY = os.getenv('LABEL_STUDIO_API_KEY')
|
| 58 |
+
START_TRAINING_EACH_N_UPDATES = int(os.getenv('START_TRAINING_EACH_N_UPDATES', 10))
|
| 59 |
+
LEARNING_RATE = float(os.getenv('LEARNING_RATE', 2e-5))
|
| 60 |
+
NUM_TRAIN_EPOCHS = int(os.getenv('NUM_TRAIN_EPOCHS', 3))
|
| 61 |
+
WEIGHT_DECAY = float(os.getenv('WEIGHT_DECAY', 0.01))
|
| 62 |
+
baseline_model_name = os.getenv('BASELINE_MODEL_NAME', 'bert-base-multilingual-cased')
|
| 63 |
+
MODEL_DIR = os.getenv('MODEL_DIR', './results')
|
| 64 |
+
finetuned_model_name = os.getenv('FINETUNED_MODEL_NAME', 'finetuned-model')
|
| 65 |
+
_model = None
|
| 66 |
+
|
| 67 |
+
def get_labels(self):
|
| 68 |
+
li = self.label_interface
|
| 69 |
+
from_name, _, _ = li.get_first_tag_occurence('Choices', 'Text')
|
| 70 |
+
tag = li.get_tag(from_name)
|
| 71 |
+
return tag.labels
|
| 72 |
+
|
| 73 |
+
def setup(self):
|
| 74 |
+
self.set("model_version", f'{self.__class__.__name__}-v0.0.1')
|
| 75 |
+
|
| 76 |
+
def _lazy_init(self):
|
| 77 |
+
if not self._model:
|
| 78 |
+
try:
|
| 79 |
+
chk_path = str(pathlib.Path(self.MODEL_DIR) / self.finetuned_model_name)
|
| 80 |
+
self._model = pipeline("text-classification", model=chk_path, tokenizer=chk_path)
|
| 81 |
+
except:
|
| 82 |
+
# if finetuned model is not available, use the baseline model, with the labels from the label_interface
|
| 83 |
+
self._model = pipeline(
|
| 84 |
+
"text-classification",
|
| 85 |
+
model=self.baseline_model_name,
|
| 86 |
+
tokenizer=self.baseline_model_name)
|
| 87 |
+
|
| 88 |
+
labels = self.get_labels()
|
| 89 |
+
self._model.model.config.id2label = {i: label for i, label in enumerate(labels)}
|
| 90 |
+
self._model.model.config.label2id = {label: i for i, label in enumerate(labels)}
|
| 91 |
+
|
| 92 |
+
def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
|
| 93 |
+
""" Write your inference logic here
|
| 94 |
+
:param tasks: [Label Studio tasks in JSON format](https://labelstud.io/guide/task_format.html)
|
| 95 |
+
:param context: [Label Studio context in JSON format](https://labelstud.io/guide/ml_create#Implement-prediction-logic)
|
| 96 |
+
:return predictions: [Predictions array in JSON format](https://labelstud.io/guide/export.html#Label-Studio-JSON-format-of-annotated-tasks)
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
# TODO: this may result in single-time timeout for large models - consider adjusting the timeout on Label Studio side
|
| 100 |
+
self._lazy_init()
|
| 101 |
+
|
| 102 |
+
li = self.label_interface
|
| 103 |
+
from_name, to_name, value = li.get_first_tag_occurence('Choices', 'Text')
|
| 104 |
+
texts = [self.preload_task_data(task, task['data'][value]) for task in tasks]
|
| 105 |
+
|
| 106 |
+
model_predictions = self._model(texts)
|
| 107 |
+
predictions = []
|
| 108 |
+
for prediction in model_predictions:
|
| 109 |
+
logger.debug(f"Prediction: {prediction}")
|
| 110 |
+
region = li.get_tag(from_name).label(prediction['label'])
|
| 111 |
+
pv = PredictionValue(
|
| 112 |
+
score=prediction['score'],
|
| 113 |
+
result=[region],
|
| 114 |
+
model_version=self.get('model_version')
|
| 115 |
+
)
|
| 116 |
+
predictions.append(pv)
|
| 117 |
+
|
| 118 |
+
return ModelResponse(predictions=predictions)
|
| 119 |
+
|
| 120 |
+
def fit(self, event, data, **additional_params):
|
| 121 |
+
"""Download dataset from Label Studio and prepare data for training in BERT"""
|
| 122 |
+
if event not in ('ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING'):
|
| 123 |
+
logger.info(f"Skip training: event {event} is not supported")
|
| 124 |
+
return
|
| 125 |
+
project_id = data['annotation']['project']
|
| 126 |
+
|
| 127 |
+
# dowload annotated tasks from Label Studio
|
| 128 |
+
ls = label_studio_sdk.Client(self.LABEL_STUDIO_HOST, self.LABEL_STUDIO_API_KEY)
|
| 129 |
+
project = ls.get_project(id=project_id)
|
| 130 |
+
tasks = project.get_labeled_tasks()
|
| 131 |
+
|
| 132 |
+
logger.info(f"Downloaded {len(tasks)} labeled tasks from Label Studio")
|
| 133 |
+
logger.debug(f"Tasks: {tasks}")
|
| 134 |
+
if len(tasks) % self.START_TRAINING_EACH_N_UPDATES != 0 and event != 'START_TRAINING':
|
| 135 |
+
# skip training if the number of tasks is not divisible by START_TRAINING_EACH_N_UPDATES
|
| 136 |
+
logger.info(f"Skip training: the number of tasks is not divisible by {self.START_TRAINING_EACH_N_UPDATES}")
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
+
from_name, to_name, value = self.label_interface.get_first_tag_occurence('Choices', 'Text')
|
| 140 |
+
|
| 141 |
+
ds_raw = {
|
| 142 |
+
'id': [],
|
| 143 |
+
'text': [],
|
| 144 |
+
'label': []
|
| 145 |
+
}
|
| 146 |
+
for task in tasks:
|
| 147 |
+
for annotation in task['annotations']:
|
| 148 |
+
if 'result' in annotation:
|
| 149 |
+
for result in annotation['result']:
|
| 150 |
+
if 'choices' in result['value']:
|
| 151 |
+
ds_raw['id'].append(task['id'])
|
| 152 |
+
text = self.preload_task_data(task, task['data'][value])
|
| 153 |
+
ds_raw['text'].append(text)
|
| 154 |
+
ds_raw['label'].append(result['value']['choices'])
|
| 155 |
+
|
| 156 |
+
hf_dataset = Dataset.from_dict(ds_raw)
|
| 157 |
+
logger.debug(f"Dataset: {hf_dataset}")
|
| 158 |
+
|
| 159 |
+
labels = self.get_labels()
|
| 160 |
+
label_to_id = {label: i for i, label in enumerate(labels)}
|
| 161 |
+
id_to_label = {i: label for i, label in enumerate(labels)}
|
| 162 |
+
logger.debug(f"Labels: {labels}")
|
| 163 |
+
|
| 164 |
+
# Preprocess the dataset
|
| 165 |
+
tokenizer = AutoTokenizer.from_pretrained(self.baseline_model_name)
|
| 166 |
+
|
| 167 |
+
def preprocess_function(examples):
|
| 168 |
+
return tokenizer(examples["text"], truncation=True, padding=True)
|
| 169 |
+
|
| 170 |
+
tokenized_datasets = hf_dataset.map(preprocess_function, batched=True)
|
| 171 |
+
logger.debug(f"Tokenized dataset: {tokenized_datasets}")
|
| 172 |
+
|
| 173 |
+
# Convert labels to ids
|
| 174 |
+
def label_to_id_function(examples):
|
| 175 |
+
examples["label"] = [label_to_id[label] for label in examples["label"]]
|
| 176 |
+
return examples
|
| 177 |
+
|
| 178 |
+
tokenized_datasets = tokenized_datasets.map(label_to_id_function)
|
| 179 |
+
|
| 180 |
+
# Load model with custom config
|
| 181 |
+
logger.info(f"Start training the model {self.finetuned_model_name}")
|
| 182 |
+
config = AutoConfig.from_pretrained(self.baseline_model_name, num_labels=len(labels))
|
| 183 |
+
logger.debug(f"Config: {config}")
|
| 184 |
+
model = AutoModelForSequenceClassification.from_pretrained(self.baseline_model_name, config=config)
|
| 185 |
+
model.config.id2label = id_to_label
|
| 186 |
+
model.config.label2id = label_to_id
|
| 187 |
+
logger.debug(f"Model: {model}")
|
| 188 |
+
|
| 189 |
+
# Define training arguments
|
| 190 |
+
training_args = TrainingArguments(
|
| 191 |
+
output_dir=str(pathlib.Path(self.MODEL_DIR) / 'training_output'),
|
| 192 |
+
learning_rate=2e-5,
|
| 193 |
+
evaluation_strategy="no",
|
| 194 |
+
num_train_epochs=3,
|
| 195 |
+
weight_decay=0.01,
|
| 196 |
+
log_level='info'
|
| 197 |
+
)
|
| 198 |
+
logger.debug(f"Training arguments: {training_args}")
|
| 199 |
+
|
| 200 |
+
# Initialize Trainer
|
| 201 |
+
trainer = Trainer(
|
| 202 |
+
model=model,
|
| 203 |
+
args=training_args,
|
| 204 |
+
train_dataset=tokenized_datasets,
|
| 205 |
+
tokenizer=tokenizer,
|
| 206 |
+
)
|
| 207 |
+
logger.debug(f"Trainer: {trainer}")
|
| 208 |
+
|
| 209 |
+
# Train the model
|
| 210 |
+
trainer.train()
|
| 211 |
+
|
| 212 |
+
chk_path = str(pathlib.Path(self.MODEL_DIR) / self.finetuned_model_name)
|
| 213 |
+
logger.info(f"Model is trained and saved as {chk_path}")
|
| 214 |
+
trainer.save_model(chk_path)
|
requirements-base.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gunicorn==22.0.0
|
| 2 |
+
label-studio-ml @ git+https://github.com/HumanSignal/label-studio-ml-backend.git
|
requirements-test.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pytest
|
| 2 |
+
pytest-cov
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers==4.38.0
|
| 2 |
+
datasets==2.18.0
|
| 3 |
+
accelerate==0.28.0
|
test_api.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file contains tests for the API of your model. You can run these tests by installing test requirements:
|
| 3 |
+
|
| 4 |
+
```bash
|
| 5 |
+
pip install -r requirements-test.txt
|
| 6 |
+
```
|
| 7 |
+
Then execute `pytest` in the directory of this file.
|
| 8 |
+
|
| 9 |
+
- Change `NewModel` to the name of the class in your model.py file.
|
| 10 |
+
- Change the `request` and `expected_response` variables to match the input and output of your model.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import pytest
|
| 14 |
+
import json
|
| 15 |
+
from model import BertClassifier
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@pytest.fixture
|
| 19 |
+
def client():
|
| 20 |
+
from _wsgi import init_app
|
| 21 |
+
app = init_app(model_class=BertClassifier)
|
| 22 |
+
app.config['TESTING'] = True
|
| 23 |
+
with app.test_client() as client:
|
| 24 |
+
yield client
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_predict(client):
|
| 28 |
+
request = {
|
| 29 |
+
'tasks': [{
|
| 30 |
+
'data': {
|
| 31 |
+
'text': 'Today is a great day to play football.'
|
| 32 |
+
}
|
| 33 |
+
}],
|
| 34 |
+
# Your labeling configuration here
|
| 35 |
+
'label_config':
|
| 36 |
+
'<View>'
|
| 37 |
+
'<Text name="text" value="$text" />'
|
| 38 |
+
'<Choices name="topic" toName="text" choice="single">'
|
| 39 |
+
'<Choice value="sports" />'
|
| 40 |
+
'<Choice value="politics" />'
|
| 41 |
+
'<Choice value="technology" />'
|
| 42 |
+
'</Choices>'
|
| 43 |
+
'</View>'
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
expected_response_results = [{
|
| 47 |
+
'result': [{
|
| 48 |
+
'from_name': 'topic',
|
| 49 |
+
'to_name': 'text',
|
| 50 |
+
'type': 'choices',
|
| 51 |
+
'value': {'choices': ['sports']}
|
| 52 |
+
}]
|
| 53 |
+
}]
|
| 54 |
+
|
| 55 |
+
response = client.post('/predict', data=json.dumps(request), content_type='application/json')
|
| 56 |
+
assert response.status_code == 200
|
| 57 |
+
response = json.loads(response.data)
|
| 58 |
+
assert len(expected_response_results) == len(response['results'])
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# TODO
|
| 62 |
+
# Implement test_fit()
|