|
|
--- |
|
|
tags: |
|
|
- model_hub_mixin |
|
|
- pytorch_model_hub_mixin |
|
|
language: |
|
|
- en |
|
|
metrics: |
|
|
- accuracy |
|
|
base_model: |
|
|
- google-bert/bert-base-uncased |
|
|
pipeline_tag: text-classification |
|
|
--- |
|
|
|
|
|
This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration: |
|
|
- Library: [More Information Needed] |
|
|
- Docs: [More Information Needed] |
|
|
|
|
|
The model class looks like the following: |
|
|
```python |
|
|
from transformers import BertModel |
|
|
|
|
|
class BertClassifier(nn.Module, PyTorchModelHubMixin): |
|
|
def __init__(self, dataset: str, num_classes, dropout=0.5): |
|
|
|
|
|
super(BertClassifier, self).__init__() |
|
|
|
|
|
self.model_name = "bert-base-uncased" |
|
|
print(f"Loading BERT model {self.model_name} for {dataset} dataset...") |
|
|
|
|
|
self.bert = BertModel.from_pretrained(self.model_name) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.linear = nn.Linear(768, num_classes) # in features, out features = number of classes |
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
|
|
|
_, pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=False) |
|
|
dropout_output = self.dropout(pooled_output) |
|
|
linear_output = self.linear(dropout_output) |
|
|
final_layer = self.relu(linear_output) |
|
|
|
|
|
return final_layer |
|
|
``` |
|
|
|
|
|
The H&M dataset has 89 classes. Loading in the model looks like this: |
|
|
```python |
|
|
model = BertClassifier.from_pretrained("CDL-RecSys/BERT-uncased-hm-category-classifier") |
|
|
``` |
|
|
|
|
|
Before classifying the input sequence, the phrase needs to be processed with the `BertTokenizer`: |
|
|
```python |
|
|
from transformers import BertModel, BertTokenizer |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
|
input = "Short, sleeveless dress in an airy cotton weave that is open at the back with a tie." |
|
|
|
|
|
texts = self.tokenizer(batch, padding='max_length', max_length = 512, truncation=True,return_tensors="pt") |
|
|
input_ids = texts["input_ids"] |
|
|
attention_mask = texts["attention_mask"] |
|
|
|
|
|
output = model(input_ids, attention_mask) |
|
|
class = output.argmax(dim=1) # should be 13 (Dress) |
|
|
``` |