dni138 commited on
Commit
46607f2
·
verified ·
1 Parent(s): 74f31e8

Update README.md

Browse files

Adds a description of what Mozilla AI is doing and provides credit to the appropriate model builder.

Files changed (1) hide show
  1. README.md +69 -3
README.md CHANGED
@@ -4,9 +4,75 @@ pipeline_tag: text-classification
4
  tags:
5
  - model_hub_mixin
6
  - pytorch_model_hub_mixin
 
 
7
  ---
8
 
9
  This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
10
- - Code: [More Information Needed]
11
- - Paper: [More Information Needed]
12
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  tags:
5
  - model_hub_mixin
6
  - pytorch_model_hub_mixin
7
+ base_model:
8
+ - cross-encoder/stsb-roberta-base
9
  ---
10
 
11
  This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
12
+
13
+ For full documentation of this model, please see the official [model card](https://huggingface.co/govtech/stsb-roberta-base-off-topic). They are the ones who built the model.
14
+
15
+ Mozilla AI has made it so you can call the `govtech/stsb-roberta-base-off-topic` using `from_pretrained`. To do this, you'll need to first pull the `CrossEncoderWithMLP` model
16
+ architectuer from their model card and make sure to add `PyTorchModelHubMixin` as an inherited class. See this [article](https://huggingface.co/docs/hub/en/models-uploading#upload-a-pytorch-model-using-huggingfacehub)
17
+
18
+ Then, you can do the following:
19
+
20
+ ```python
21
+ from transformers import AutoModel, AutoTokenizer
22
+ from huggingface_hub import PyTorchModelHubMixin
23
+ import torch.nn as nn
24
+
25
+ class CrossEncoderWithMLP(nn.Module, PyTorchModelHubMixin):
26
+ def __init__(self, base_model, num_labels=2):
27
+ super(CrossEncoderWithMLP, self).__init__()
28
+
29
+ # Existing cross-encoder model
30
+ self.base_model = base_model
31
+ # Hidden size of the base model
32
+ hidden_size = base_model.config.hidden_size
33
+ # MLP layers after combining the cross-encoders
34
+ self.mlp = nn.Sequential(
35
+ nn.Linear(hidden_size, hidden_size // 2), # Input: a single sentence
36
+ nn.ReLU(),
37
+ nn.Linear(hidden_size // 2, hidden_size // 4), # Reduce the size of the layer
38
+ nn.ReLU()
39
+ )
40
+ # Classifier head
41
+ self.classifier = nn.Linear(hidden_size // 4, num_labels)
42
+
43
+ def forward(self, input_ids, attention_mask):
44
+ # Encode the pair of sentences in one pass
45
+ outputs = self.base_model(input_ids, attention_mask)
46
+ pooled_output = outputs.pooler_output
47
+ # Pass the pooled output through mlp layers
48
+ mlp_output = self.mlp(pooled_output)
49
+ # Pass the final MLP output through the classifier
50
+ logits = self.classifier(mlp_output)
51
+ return logits
52
+
53
+ tokenizer = AutoTokenizer.from_pretrained("cross-encoder/stsb-roberta-base")
54
+ base_model = AutoModel.from_pretrained("cross-encoder/stsb-roberta-base")
55
+ off_topic = CrossEncoderWithMLP.from_pretrained("mozilla-ai/stsb-roberta-base-off-topic", base_model=base_model)
56
+
57
+ # Then you can build a predict function that utilizes the tokenizer
58
+
59
+ def predict(model, tokenizer, sentence1, sentence2):
60
+ encoding = tokenizer(
61
+ sentence1,
62
+ sentence2,
63
+ return_tensors="pt",
64
+ truncation=True,
65
+ padding="max_length",
66
+ max_length=max_length,
67
+ return_token_type_ids=False
68
+ )
69
+ input_ids = encoding["input_ids"].to(device)
70
+ attention_mask = encoding["attention_mask"].to(device)
71
+
72
+ with torch.no_grad():
73
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
74
+ probabilities = torch.softmax(outputs, dim=1)
75
+ predicted_label = torch.argmax(probabilities, dim=1).item()
76
+
77
+ return predicted_label, probabilities.cpu().numpy()
78
+ ```