File size: 3,335 Bytes
4333430 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
class VisionConfig(PretrainedConfig):
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
@staticmethod
def from_exp_config(vision_config: dict):
model_type = vision_config["model_type"]
if model_type in [
"siglip_vision_model",
"clip_vision_model",
"dinov2",
"sam",
"raddino",
]:
config = AutoConfig.from_pretrained(
vision_config["pretrained_name_or_path"]
)
config = config.to_dict()
vision_config.update(config)
elif model_type == "xrayclip":
config = AutoConfig.from_pretrained(
vision_config["pretrained_name_or_path"]
)
config = config.to_dict()
config["model_type"] = "xrayclip"
vision_config.update(config)
elif model_type == "biomedclip":
pass
elif model_type == "m3ae":
pass
else:
raise NotImplementedError()
vision_config = VisionConfig(**vision_config)
return vision_config
class TextConfig(PretrainedConfig):
def __init__(
self,
model_type,
**kwargs,
):
super().__init__(**kwargs)
self.model_type = model_type
@staticmethod
def from_exp_config(
text_config: dict,
):
model_type = text_config["model_type"]
if model_type in [
"siglip_text_model",
"clip_text_model",
"mpnet",
"biomedclip",
"bioclinicalmpbert",
]:
text_config = TextConfig(**text_config)
else:
raise NotImplementedError()
return text_config
class AlignTransformerConfig(PretrainedConfig):
def __init__(
self,
model_type: str = "align_transformer",
projector_config=None,
**kwargs,
):
super().__init__(**kwargs)
self.model_type = model_type
self.projector_config = projector_config
@staticmethod
def from_exp_config(
align_transformer_config: dict,
):
projector_config = align_transformer_config.pop("projector_config", None)
config = Dinov2Config(**align_transformer_config)
config = config.to_dict()
align_transformer_config = AlignTransformerConfig(
**(config | align_transformer_config),
projector_config=projector_config,
)
return align_transformer_config
class CxrAlignConfig(PretrainedConfig):
is_composition = True
def __init__(
self,
vision_config: dict,
text_config: dict,
align_transformer_config: dict,
**kwargs,
):
super().__init__(**kwargs)
# Vision config
self.vision_config = VisionConfig.from_exp_config(vision_config)
# text config
self.text_config = TextConfig.from_exp_config(text_config)
self.align_transformer_config = AlignTransformerConfig.from_exp_config(
align_transformer_config
)
self.kwargs = kwargs
|