| from __future__ import annotations | |
| from transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config | |
| class LEGConfig(DebertaV2Config): | |
| model_type = "leg-1.0-guardrail" | |
| def __init__( | |
| self, | |
| base_model_name: str = "", | |
| inference_max_length: int = 512, | |
| prompt_threshold: float = 0.5, | |
| word_threshold: float = 0.5, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.base_model_name = base_model_name | |
| self.inference_max_length = inference_max_length | |
| self.prompt_threshold = prompt_threshold | |
| self.word_threshold = word_threshold | |
| if getattr(self, "num_labels", None) is None: | |
| self.num_labels = 2 | |