Safetensors
custom_code
RadZero / configuration_radzero.py
nobottle
add auto map
44f556a
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
class VisionConfig(PretrainedConfig):
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
@staticmethod
def from_exp_config(vision_config: dict):
model_type = vision_config["model_type"]
if model_type in [
"siglip_vision_model",
"clip_vision_model",
"dinov2",
"sam",
"raddino",
]:
config = AutoConfig.from_pretrained(
vision_config["pretrained_name_or_path"]
)
config = config.to_dict()
vision_config.update(config)
elif model_type == "xrayclip":
config = AutoConfig.from_pretrained(
vision_config["pretrained_name_or_path"]
)
config = config.to_dict()
config["model_type"] = "xrayclip"
vision_config.update(config)
elif model_type == "biomedclip":
pass
elif model_type == "m3ae":
pass
else:
raise NotImplementedError()
vision_config = VisionConfig(**vision_config)
return vision_config
class TextConfig(PretrainedConfig):
def __init__(
self,
model_type,
**kwargs,
):
super().__init__(**kwargs)
self.model_type = model_type
@staticmethod
def from_exp_config(
text_config: dict,
):
model_type = text_config["model_type"]
if model_type in [
"siglip_text_model",
"clip_text_model",
"mpnet",
"biomedclip",
"bioclinicalmpbert",
]:
text_config = TextConfig(**text_config)
else:
raise NotImplementedError()
return text_config
class AlignTransformerConfig(PretrainedConfig):
def __init__(
self,
model_type: str = "align_transformer",
projector_config=None,
**kwargs,
):
super().__init__(**kwargs)
self.model_type = model_type
self.projector_config = projector_config
@staticmethod
def from_exp_config(
align_transformer_config: dict,
):
projector_config = align_transformer_config.pop("projector_config", None)
config = Dinov2Config(**align_transformer_config)
config = config.to_dict()
align_transformer_config = AlignTransformerConfig(
**(config | align_transformer_config),
projector_config=projector_config,
)
return align_transformer_config
class CxrAlignConfig(PretrainedConfig):
is_composition = True
def __init__(
self,
vision_config: dict,
text_config: dict,
align_transformer_config: dict,
**kwargs,
):
super().__init__(**kwargs)
# Vision config
self.vision_config = VisionConfig.from_exp_config(vision_config)
# text config
self.text_config = TextConfig.from_exp_config(text_config)
self.align_transformer_config = AlignTransformerConfig.from_exp_config(
align_transformer_config
)
self.kwargs = kwargs