Commit
·
0bfa212
1
Parent(s):
719f253
fix
Browse files
vision.py
CHANGED
|
@@ -192,7 +192,7 @@ class SiglipVisionModelOutput(ModelOutput):
|
|
| 192 |
|
| 193 |
|
| 194 |
class SiglipVisionEmbeddings(nn.Module):
|
| 195 |
-
def __init__(self, config:
|
| 196 |
super().__init__()
|
| 197 |
self.config = config
|
| 198 |
self.embed_dim = config.hidden_size
|
|
@@ -565,7 +565,7 @@ class SiglipMLP(nn.Module):
|
|
| 565 |
|
| 566 |
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
|
| 567 |
class SiglipEncoderLayer(nn.Module):
|
| 568 |
-
def __init__(self, config:
|
| 569 |
super().__init__()
|
| 570 |
self.embed_dim = config.hidden_size
|
| 571 |
self.self_attn = (
|
|
@@ -1001,7 +1001,7 @@ class SiglipEncoder(nn.Module):
|
|
| 1001 |
|
| 1002 |
|
| 1003 |
class SiglipVisionTransformer(nn.Module):
|
| 1004 |
-
def __init__(self, config:
|
| 1005 |
super().__init__()
|
| 1006 |
self.config = config
|
| 1007 |
embed_dim = config.hidden_size
|
|
@@ -1012,7 +1012,7 @@ class SiglipVisionTransformer(nn.Module):
|
|
| 1012 |
self.head = SiglipMultiheadAttentionPoolingHead(config)
|
| 1013 |
|
| 1014 |
# @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
|
| 1015 |
-
# @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=
|
| 1016 |
def forward(
|
| 1017 |
self,
|
| 1018 |
pixel_values,
|
|
@@ -1058,7 +1058,7 @@ class SiglipVisionTransformer(nn.Module):
|
|
| 1058 |
class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
| 1059 |
"""Multihead Attention Pooling."""
|
| 1060 |
|
| 1061 |
-
def __init__(self, config:
|
| 1062 |
super().__init__()
|
| 1063 |
|
| 1064 |
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
|
@@ -1084,7 +1084,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
|
| 1084 |
# SIGLIP_START_DOCSTRING,
|
| 1085 |
# )
|
| 1086 |
class SiglipVisionModel(nn.Module):
|
| 1087 |
-
def __init__(self, config:
|
| 1088 |
super().__init__()
|
| 1089 |
|
| 1090 |
self.vision_model = SiglipVisionTransformer(config)
|
|
@@ -1096,7 +1096,7 @@ class SiglipVisionModel(nn.Module):
|
|
| 1096 |
# return self.vision_model.embeddings.patch_embedding
|
| 1097 |
|
| 1098 |
# @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
|
| 1099 |
-
# @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=
|
| 1100 |
def forward(
|
| 1101 |
self,
|
| 1102 |
pixel_values,
|
|
|
|
| 192 |
|
| 193 |
|
| 194 |
class SiglipVisionEmbeddings(nn.Module):
|
| 195 |
+
def __init__(self, config: Img2HTMLVisionConfig):
|
| 196 |
super().__init__()
|
| 197 |
self.config = config
|
| 198 |
self.embed_dim = config.hidden_size
|
|
|
|
| 565 |
|
| 566 |
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
|
| 567 |
class SiglipEncoderLayer(nn.Module):
|
| 568 |
+
def __init__(self, config: Img2HTMLVisionConfig):
|
| 569 |
super().__init__()
|
| 570 |
self.embed_dim = config.hidden_size
|
| 571 |
self.self_attn = (
|
|
|
|
| 1001 |
|
| 1002 |
|
| 1003 |
class SiglipVisionTransformer(nn.Module):
|
| 1004 |
+
def __init__(self, config: Img2HTMLVisionConfig):
|
| 1005 |
super().__init__()
|
| 1006 |
self.config = config
|
| 1007 |
embed_dim = config.hidden_size
|
|
|
|
| 1012 |
self.head = SiglipMultiheadAttentionPoolingHead(config)
|
| 1013 |
|
| 1014 |
# @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
|
| 1015 |
+
# @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Img2HTMLVisionConfig)
|
| 1016 |
def forward(
|
| 1017 |
self,
|
| 1018 |
pixel_values,
|
|
|
|
| 1058 |
class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
| 1059 |
"""Multihead Attention Pooling."""
|
| 1060 |
|
| 1061 |
+
def __init__(self, config: Img2HTMLVisionConfig):
|
| 1062 |
super().__init__()
|
| 1063 |
|
| 1064 |
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
|
|
|
| 1084 |
# SIGLIP_START_DOCSTRING,
|
| 1085 |
# )
|
| 1086 |
class SiglipVisionModel(nn.Module):
|
| 1087 |
+
def __init__(self, config: Img2HTMLVisionConfig):
|
| 1088 |
super().__init__()
|
| 1089 |
|
| 1090 |
self.vision_model = SiglipVisionTransformer(config)
|
|
|
|
| 1096 |
# return self.vision_model.embeddings.patch_embedding
|
| 1097 |
|
| 1098 |
# @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
|
| 1099 |
+
# @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Img2HTMLVisionConfig)
|
| 1100 |
def forward(
|
| 1101 |
self,
|
| 1102 |
pixel_values,
|