leannetanyt commited on
Commit
b966c37
·
1 Parent(s): 2a3ee0b

feat: update model config

Browse files
Files changed (1) hide show
  1. lionguard2.py +4 -4
lionguard2.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  import torch.nn as nn
7
  from transformers import PretrainedConfig, PreTrainedModel
8
 
9
- INPUT_DIMENSION = 3072 # length of OpenAI embeddings
10
 
11
  CATEGORIES = {
12
  "binary": ["binary"],
@@ -56,7 +56,7 @@ class LionGuard2Model(PreTrainedModel):
56
 
57
  def __init__(self, config: LionGuard2Config):
58
  """
59
- LionGuard2 is a localised content moderation model that flags whether text violates the following categories:
60
 
61
  1. `hateful`: Text that discriminates, criticizes, insults, denounces, or dehumanizes a person or group on the basis of a protected identity.
62
 
@@ -94,14 +94,14 @@ class LionGuard2Model(PreTrainedModel):
94
 
95
  Lastly, there is an additional `binary` category (#7) which flags whether the text is unsafe in general.
96
 
97
- The model takes in as input text, after it has been encoded with OpenAI's `text-embedding-3-small` model.
98
 
99
  The model outputs the probabilities of each category being true.
100
 
101
  ================================
102
 
103
  Args:
104
- input_dim: The dimension of the input embeddings. This defaults to 3072, which is the dimension of the embeddings from OpenAI's `text-embedding-3-small` model. This should not be changed.
105
  label_names: The names of the labels. This defaults to the keys of the CATEGORIES dictionary. This should not be changed.
106
  categories: The categories of the labels. This defaults to the CATEGORIES dictionary. This should not be changed.
107
 
 
6
  import torch.nn as nn
7
  from transformers import PretrainedConfig, PreTrainedModel
8
 
9
+ INPUT_DIMENSION = 3072 # length of Gemini embeddings
10
 
11
  CATEGORIES = {
12
  "binary": ["binary"],
 
56
 
57
  def __init__(self, config: LionGuard2Config):
58
  """
59
+ LionGuard2.1 is a localised content moderation model that flags whether text violates the following categories:
60
 
61
  1. `hateful`: Text that discriminates, criticizes, insults, denounces, or dehumanizes a person or group on the basis of a protected identity.
62
 
 
94
 
95
  Lastly, there is an additional `binary` category (#7) which flags whether the text is unsafe in general.
96
 
97
+ The model takes in an input text that has been encoded with Gemini's `gemini-embedding-001` model.
98
 
99
  The model outputs the probabilities of each category being true.
100
 
101
  ================================
102
 
103
  Args:
104
+ input_dim: The dimension of the input embeddings. This defaults to 3072, which is the dimension of the embeddings from Gemini's `gemini-embedding-001` model. This should not be changed.
105
  label_names: The names of the labels. This defaults to the keys of the CATEGORIES dictionary. This should not be changed.
106
  categories: The categories of the labels. This defaults to the CATEGORIES dictionary. This should not be changed.
107