|
|
import sys
|
|
|
from pathlib import Path
|
|
|
parent_root = Path().resolve().parent.parent
|
|
|
sys.path.append(str(parent_root))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import PretrainedConfig, DonutSwinConfig, GemmaConfig, CONFIG_MAPPING, SiglipVisionConfig
|
|
|
from typing import Tuple, Literal
|
|
|
|
|
|
|
|
|
|
|
|
class PamConfig(PretrainedConfig):
|
|
|
model_type = "pam"
|
|
|
def __init__(
|
|
|
self,
|
|
|
sequence_mapping_layer_type: Literal["linear_projection","bilinear_interpolation"] = "bilinear_interpolation",
|
|
|
student_fmap_dim: Tuple[int,int]=(80,60),
|
|
|
student_embedding_dim: int = 1024,
|
|
|
teacher_fmap_dim: Tuple[int,int] = (64,64),
|
|
|
teacher_embedding_dim: int = 1152,
|
|
|
**kwargs,
|
|
|
):
|
|
|
self.sequence_mapping_layer_type = sequence_mapping_layer_type
|
|
|
self.student_fmap_dim = student_fmap_dim
|
|
|
self.student_embedding_dim = student_embedding_dim
|
|
|
self.teacher_fmap_dim = teacher_fmap_dim
|
|
|
self.teacher_embedding_dim = teacher_embedding_dim
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
|
class SwinPamVisionEncoderConfig(PretrainedConfig):
|
|
|
model_type = "swinpam"
|
|
|
sub_configs = {"encoder_config": DonutSwinConfig, "pam_config": PamConfig}
|
|
|
def __init__(
|
|
|
self,
|
|
|
encoder_config: DonutSwinConfig = None,
|
|
|
pam_config: PamConfig = None,
|
|
|
**kwargs
|
|
|
):
|
|
|
self.encoder_config = encoder_config
|
|
|
self.pam_config = pam_config
|
|
|
|
|
|
if isinstance(self.encoder_config, dict):
|
|
|
encoder_config["model_type"] = (
|
|
|
encoder_config["model_type"] if "model_type" in encoder_config else "donut-swin"
|
|
|
)
|
|
|
if encoder_config["model_type"] == "donut-swin":
|
|
|
self.encoder_config = DonutSwinConfig(**encoder_config)
|
|
|
else:
|
|
|
print(f"Encoder type: {encoder_config['model_type']}")
|
|
|
self.encoder_config = CONFIG_MAPPING[encoder_config["model_type"]](**encoder_config)
|
|
|
|
|
|
'''
|
|
|
elif encoder_config is None:
|
|
|
print("coucou2")
|
|
|
self.encoder_config = DonutSwinConfig()
|
|
|
'''
|
|
|
|
|
|
if isinstance(self.pam_config, dict):
|
|
|
'''
|
|
|
pam_config["model_type"] = (
|
|
|
pam_config["model_type"] if "model_type" in pam_config else "pam"
|
|
|
)
|
|
|
'''
|
|
|
if pam_config["model_type"] == "pam":
|
|
|
self.pam_config = PamConfig(**pam_config)
|
|
|
else:
|
|
|
raise ValueError(f"pam_config['model_type'] should be 'pam', got {pam_config['model_type']}")
|
|
|
'''
|
|
|
elif pam_config is None:
|
|
|
self.pam_config = PamConfig()
|
|
|
'''
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
|
class SiglipPAMVisionEncoderConfig(PretrainedConfig):
|
|
|
model_type = "siglippam"
|
|
|
sub_configs = {"encoder_config": SiglipVisionConfig, "pam_config": PamConfig}
|
|
|
def __init__(
|
|
|
self,
|
|
|
encoder_config: SiglipVisionConfig = None,
|
|
|
pam_config: PamConfig = None,
|
|
|
**kwargs
|
|
|
):
|
|
|
self.encoder_config = encoder_config
|
|
|
self.pam_config = pam_config
|
|
|
|
|
|
if isinstance(self.encoder_config, dict):
|
|
|
encoder_config["model_type"] = (
|
|
|
encoder_config["model_type"] if "model_type" in encoder_config else "siglip_vision_model"
|
|
|
)
|
|
|
if encoder_config["model_type"] == "siglip_vision_model":
|
|
|
self.encoder_config = SiglipVisionConfig(**encoder_config)
|
|
|
else:
|
|
|
raise ValueError(f"Need siglip_model_type, got {encoder_config['model_type']}")
|
|
|
|
|
|
if isinstance(self.pam_config, dict):
|
|
|
if pam_config["model_type"] == "pam":
|
|
|
self.pam_config = PamConfig(**pam_config)
|
|
|
else:
|
|
|
raise ValueError(f"pam_config['model_type'] should be 'pam', got {pam_config['model_type']}")
|
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
|
class DIVEdocConfig(PretrainedConfig):
|
|
|
keys_to_ignore_at_inference = ["past_key_values"]
|
|
|
sub_configs = {"vision_config": SwinPamVisionEncoderConfig, "text_config": GemmaConfig}
|
|
|
model_type = "DIVEdoc"
|
|
|
def __init__(
|
|
|
self,
|
|
|
vision_config=None,
|
|
|
text_config=None,
|
|
|
ignore_index=-100,
|
|
|
image_token_index=256000,
|
|
|
vocab_size=257152,
|
|
|
projection_dim=2048,
|
|
|
hidden_size=2048,
|
|
|
|
|
|
**kwargs,
|
|
|
):
|
|
|
self._ignore_index = ignore_index
|
|
|
self.image_token_index = image_token_index
|
|
|
self._vocab_size = vocab_size
|
|
|
self.projection_dim = projection_dim
|
|
|
self.hidden_size = hidden_size
|
|
|
self.vision_config = vision_config
|
|
|
self.is_encoder_decoder = False
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(self.vision_config, dict):
|
|
|
vision_config["model_type"] = (
|
|
|
vision_config["model_type"] if "model_type" in vision_config else "swinpam"
|
|
|
)
|
|
|
if vision_config["model_type"] == "swinpam":
|
|
|
self.vision_config = SwinPamVisionEncoderConfig(encoder_config=vision_config["encoder_config"],pam_config=vision_config["pam_config"])
|
|
|
elif vision_config["model_type"] == "siglippam":
|
|
|
self.vision_config = SiglipPAMVisionEncoderConfig(encoder_config=vision_config["encoder_config"],pam_config=vision_config["pam_config"])
|
|
|
else:
|
|
|
self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
|
|
elif vision_config is None:
|
|
|
self.vision_config = get_vision_config("swinpam")
|
|
|
|
|
|
self.text_config = text_config
|
|
|
if isinstance(self.text_config, dict):
|
|
|
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma"
|
|
|
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
|
|
elif text_config is None:
|
|
|
self.text_config = CONFIG_MAPPING["gemma"](
|
|
|
hidden_size=2048,
|
|
|
num_hidden_layers=18,
|
|
|
intermediate_size=16384,
|
|
|
num_attention_heads=8,
|
|
|
num_key_value_heads=1,
|
|
|
is_encoder_decoder=False,
|
|
|
vocab_size=vocab_size,
|
|
|
)
|
|
|
|
|
|
self.text_config.num_image_tokens = self.vision_config.pam_config.teacher_fmap_dim[0] *\
|
|
|
self.vision_config.pam_config.teacher_fmap_dim[1]
|
|
|
self.vision_config.projection_dim = projection_dim
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
def to_dict(self):
|
|
|
output = super().to_dict()
|
|
|
output.pop("_ignore_index", None)
|
|
|
return output
|
|
|
|
|
|
def get_siglip_vision_config(image_size=[896,896],num_image_token = 4096,hidden_size = 768):
|
|
|
encoder_config = SiglipVisionConfig(
|
|
|
hidden_size = hidden_size,
|
|
|
image_size = image_size,
|
|
|
intermediate_size = 2860,
|
|
|
model_type = "siglip_vision_model",
|
|
|
num_attention_heads = 8,
|
|
|
num_hidden_layers = 12,
|
|
|
num_image_tokens = num_image_token,
|
|
|
patch_size = 14,
|
|
|
projection_dim = 2048,
|
|
|
projector_hidden_act = "gelu_fast",
|
|
|
torch_dtype = "float32",
|
|
|
vision_use_head = False
|
|
|
)
|
|
|
return encoder_config
|
|
|
|
|
|
def get_swin_vision_config(image_size=[2560,1920],hidden_size = 1024):
|
|
|
encoder_config = DonutSwinConfig(
|
|
|
attention_probs_dropout_prob= 0.0,
|
|
|
depths =[
|
|
|
2,
|
|
|
2,
|
|
|
14,
|
|
|
2
|
|
|
],
|
|
|
drop_path_rate= 0.1,
|
|
|
embed_dim =128,
|
|
|
hidden_act ="gelu",
|
|
|
hidden_dropout_prob = 0.0,
|
|
|
hidden_size = hidden_size,
|
|
|
image_size = image_size,
|
|
|
initializer_range = 0.02,
|
|
|
layer_norm_eps = 1e-05,
|
|
|
mlp_ratio = 4.0,
|
|
|
model_type = "donut-swin",
|
|
|
num_channels = 3,
|
|
|
num_heads =[
|
|
|
4,
|
|
|
8,
|
|
|
16,
|
|
|
32
|
|
|
],
|
|
|
num_layers =4,
|
|
|
patch_size = 4,
|
|
|
path_norm = True,
|
|
|
qkv_bias = True,
|
|
|
use_absolute_embeddings = False,
|
|
|
window_size = 10
|
|
|
)
|
|
|
return encoder_config
|
|
|
|
|
|
def get_vision_config( visual_encoder_type:Literal["swinpam","siglip80m"],
|
|
|
image_size=[2560,1920],
|
|
|
sequence_mapping_layer_type= "bilinear",
|
|
|
student_fmap_dim=(80,60),
|
|
|
student_embedding_dim= 1024,
|
|
|
teacher_fmap_dim= (64,64),
|
|
|
teacher_embedding_dim= 1152):
|
|
|
pam_config = PamConfig(
|
|
|
sequence_mapping_layer_type = sequence_mapping_layer_type,
|
|
|
student_fmap_dim = student_fmap_dim,
|
|
|
student_embedding_dim = student_embedding_dim,
|
|
|
teacher_fmap_dim = teacher_fmap_dim,
|
|
|
teacher_embedding_dim = teacher_embedding_dim)
|
|
|
|
|
|
if visual_encoder_type == "swinpam":
|
|
|
encoder_config = get_swin_vision_config(image_size=image_size,hidden_size = student_embedding_dim)
|
|
|
ve_config = SwinPamVisionEncoderConfig(encoder_config=encoder_config,pam_config=pam_config)
|
|
|
return ve_config
|
|
|
|
|
|
elif visual_encoder_type =="siglip80m":
|
|
|
encoder_config = get_siglip_vision_config(image_size=image_size,num_image_token = (image_size//14)**2, hidden_size = student_embedding_dim)
|
|
|
ve_config = SiglipPAMVisionEncoderConfig(encoder_config=encoder_config,pam_config=pam_config)
|
|
|
return ve_config
|
|
|
else:
|
|
|
raise ValueError(f"Unknown visual encoder type, need 'swinpam' or 'siglip80m, got {visual_encoder_type}.") |