Commit ·
b966c37
1
Parent(s): 2a3ee0b
feat: update model config
Browse files- lionguard2.py +4 -4
lionguard2.py
CHANGED
|
@@ -6,7 +6,7 @@ import torch
|
|
| 6 |
import torch.nn as nn
|
| 7 |
from transformers import PretrainedConfig, PreTrainedModel
|
| 8 |
|
| 9 |
-
INPUT_DIMENSION = 3072 # length of
|
| 10 |
|
| 11 |
CATEGORIES = {
|
| 12 |
"binary": ["binary"],
|
|
@@ -56,7 +56,7 @@ class LionGuard2Model(PreTrainedModel):
|
|
| 56 |
|
| 57 |
def __init__(self, config: LionGuard2Config):
|
| 58 |
"""
|
| 59 |
-
LionGuard2 is a localised content moderation model that flags whether text violates the following categories:
|
| 60 |
|
| 61 |
1. `hateful`: Text that discriminates, criticizes, insults, denounces, or dehumanizes a person or group on the basis of a protected identity.
|
| 62 |
|
|
@@ -94,14 +94,14 @@ class LionGuard2Model(PreTrainedModel):
|
|
| 94 |
|
| 95 |
Lastly, there is an additional `binary` category (#7) which flags whether the text is unsafe in general.
|
| 96 |
|
| 97 |
-
The model takes in
|
| 98 |
|
| 99 |
The model outputs the probabilities of each category being true.
|
| 100 |
|
| 101 |
================================
|
| 102 |
|
| 103 |
Args:
|
| 104 |
-
input_dim: The dimension of the input embeddings. This defaults to 3072, which is the dimension of the embeddings from
|
| 105 |
label_names: The names of the labels. This defaults to the keys of the CATEGORIES dictionary. This should not be changed.
|
| 106 |
categories: The categories of the labels. This defaults to the CATEGORIES dictionary. This should not be changed.
|
| 107 |
|
|
|
|
| 6 |
import torch.nn as nn
|
| 7 |
from transformers import PretrainedConfig, PreTrainedModel
|
| 8 |
|
| 9 |
+
INPUT_DIMENSION = 3072 # length of Gemini embeddings
|
| 10 |
|
| 11 |
CATEGORIES = {
|
| 12 |
"binary": ["binary"],
|
|
|
|
| 56 |
|
| 57 |
def __init__(self, config: LionGuard2Config):
|
| 58 |
"""
|
| 59 |
+
LionGuard2.1 is a localised content moderation model that flags whether text violates the following categories:
|
| 60 |
|
| 61 |
1. `hateful`: Text that discriminates, criticizes, insults, denounces, or dehumanizes a person or group on the basis of a protected identity.
|
| 62 |
|
|
|
|
| 94 |
|
| 95 |
Lastly, there is an additional `binary` category (#7) which flags whether the text is unsafe in general.
|
| 96 |
|
| 97 |
+
The model takes in an input text that has been encoded with Gemini's `gemini-embedding-001` model.
|
| 98 |
|
| 99 |
The model outputs the probabilities of each category being true.
|
| 100 |
|
| 101 |
================================
|
| 102 |
|
| 103 |
Args:
|
| 104 |
+
input_dim: The dimension of the input embeddings. This defaults to 3072, which is the dimension of the embeddings from Gemini's `gemini-embedding-001` model. This should not be changed.
|
| 105 |
label_names: The names of the labels. This defaults to the keys of the CATEGORIES dictionary. This should not be changed.
|
| 106 |
categories: The categories of the labels. This defaults to the CATEGORIES dictionary. This should not be changed.
|
| 107 |
|