Upload folder using huggingface_hub
Browse files- modeling_MMRet_CLIP.py +9 -10
modeling_MMRet_CLIP.py
CHANGED
|
@@ -22,12 +22,12 @@ import torch.utils.checkpoint
|
|
| 22 |
from torch import nn
|
| 23 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 24 |
from PIL import Image
|
| 25 |
-
from
|
| 26 |
-
from
|
| 27 |
-
from
|
| 28 |
-
from
|
| 29 |
-
from
|
| 30 |
-
from
|
| 31 |
ModelOutput,
|
| 32 |
add_code_sample_docstrings,
|
| 33 |
add_start_docstrings,
|
|
@@ -37,7 +37,7 @@ from ...utils import (
|
|
| 37 |
logging,
|
| 38 |
replace_return_docstrings,
|
| 39 |
)
|
| 40 |
-
from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
| 41 |
|
| 42 |
|
| 43 |
if is_flash_attn_2_available():
|
|
@@ -47,11 +47,10 @@ if is_flash_attn_2_available():
|
|
| 47 |
logger = logging.get_logger(__name__)
|
| 48 |
|
| 49 |
# General docstring
|
| 50 |
-
_CONFIG_FOR_DOC = "
|
| 51 |
-
_CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32"
|
| 52 |
|
| 53 |
# Image classification docstring
|
| 54 |
-
_IMAGE_CLASS_CHECKPOINT = "
|
| 55 |
_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
|
| 56 |
|
| 57 |
|
|
|
|
| 22 |
from torch import nn
|
| 23 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 24 |
from PIL import Image
|
| 25 |
+
from transformers.activations import ACT2FN
|
| 26 |
+
from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
|
| 27 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
| 28 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 29 |
+
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
|
| 30 |
+
from transformers.utils import (
|
| 31 |
ModelOutput,
|
| 32 |
add_code_sample_docstrings,
|
| 33 |
add_start_docstrings,
|
|
|
|
| 37 |
logging,
|
| 38 |
replace_return_docstrings,
|
| 39 |
)
|
| 40 |
+
from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
| 41 |
|
| 42 |
|
| 43 |
if is_flash_attn_2_available():
|
|
|
|
| 47 |
logger = logging.get_logger(__name__)
|
| 48 |
|
| 49 |
# General docstring
|
| 50 |
+
_CONFIG_FOR_DOC = "MMRet_CLIP"
|
|
|
|
| 51 |
|
| 52 |
# Image classification docstring
|
| 53 |
+
_IMAGE_CLASS_CHECKPOINT = "JUNJIE/MMRet-large"
|
| 54 |
_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
|
| 55 |
|
| 56 |
|