| from transformers import PretrainedConfig |
| from typing import List, Optional, Tuple |
|
|
|
|
| class CXRConfig(PretrainedConfig): |
| model_type = "cxr_basic" |
|
|
| def __init__( |
| self, |
| backbone: str = "tf_efficientnetv2_s", |
| feature_dim: int = 256, |
| seg_dropout: float = 0.1, |
| cls_dropout: float = 0.1, |
| seg_num_classes: int = 4, |
| cls_num_classes: int = 5, |
| in_chans: int = 1, |
| img_size: Tuple[int, int] = (320, 320), |
| decoder_n_blocks: int = 5, |
| decoder_channels: List[int] = [256, 128, 64, 32, 16], |
| encoder_channels: List[int] = [24, 48, 64, 160, 256], |
| decoder_center_block: bool = False, |
| decoder_norm_layer: str = "bn", |
| decoder_attention_type: Optional[str] = None, |
| **kwargs, |
| ): |
| self.backbone = backbone |
| self.feature_dim = feature_dim |
| self.seg_dropout = seg_dropout |
| self.cls_dropout = cls_dropout |
| self.seg_num_classes = seg_num_classes |
| self.cls_num_classes = cls_num_classes |
| self.in_chans = in_chans |
| self.img_size = img_size |
| self.decoder_n_blocks = decoder_n_blocks |
| self.decoder_channels = decoder_channels |
| self.encoder_channels = encoder_channels |
| self.decoder_center_block = decoder_center_block |
| self.decoder_norm_layer = decoder_norm_layer |
| self.decoder_attention_type = decoder_attention_type |
| super().__init__(**kwargs) |
|
|