GigaCheck-Classifier-Multi / configuration_gigacheck.py
iitolstykh's picture
Upload 2 files
88b272e verified
raw
history blame contribute delete
767 Bytes
from typing import Dict, Optional, Any
from transformers import MistralConfig
class GigaCheckConfig(MistralConfig):
def __init__(
self,
with_detr: bool = False,
detr_config: Optional[Dict[str, Any]] = None,
freeze_backbone: bool = False,
id2label: Dict[int, str] = None,
num_labels: int = 2,
max_length: int = 1024,
**kwargs
):
super().__init__(**kwargs)
self.with_detr = with_detr
self.detr_config = detr_config
self.freeze_backbone = freeze_backbone
self.id2label = id2label
self.num_labels = num_labels
self.max_length = max_length
if self.id2label:
self.id2label = {int(k): v for k, v in self.id2label.items()}