Upload model
Browse files- README.md +199 -0
- config.json +33 -0
- configuration_graph_clip.py +67 -0
- model.safetensors +3 -0
- modeling_graph_clip.py +278 -0
README.md
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags: []
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
config.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha": 0.5,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"GraphCLIPModel"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_graph_clip.GraphCLIPConfig",
|
| 8 |
+
"AutoModel": "modeling_graph_clip.GraphCLIPModel"
|
| 9 |
+
},
|
| 10 |
+
"graph_config": {
|
| 11 |
+
"dropout": 0.2,
|
| 12 |
+
"embedding_dim": 512,
|
| 13 |
+
"ffn_embedding_dim": 512,
|
| 14 |
+
"hidden_size": 512,
|
| 15 |
+
"model_type": "graphormer",
|
| 16 |
+
"num_hidden_layers": 6
|
| 17 |
+
},
|
| 18 |
+
"graph_pair_type": "image",
|
| 19 |
+
"initializer_factor": 1.0,
|
| 20 |
+
"logit_scale_init_value": 2.6592,
|
| 21 |
+
"model_type": "graph_clip",
|
| 22 |
+
"pretrained_graphormer_hub_id": "helena-balabin/pretrained_graphormer_vg_image_graphs",
|
| 23 |
+
"pretrained_model_name_or_path": "openai/clip-vit-base-patch32",
|
| 24 |
+
"projection_dim": 512,
|
| 25 |
+
"text_config": {
|
| 26 |
+
"model_type": "clip_text_model"
|
| 27 |
+
},
|
| 28 |
+
"torch_dtype": "float32",
|
| 29 |
+
"transformers_version": "4.45.2",
|
| 30 |
+
"vision_config": {
|
| 31 |
+
"model_type": "clip_vision_model"
|
| 32 |
+
}
|
| 33 |
+
}
|
configuration_graph_clip.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration class for the custom Graph-based CLIP model incorporating Image, Text, and Graph inputs."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Union
|
| 4 |
+
|
| 5 |
+
from transformers import CLIPConfig
|
| 6 |
+
from transformers.models.deprecated.graphormer.configuration_graphormer import GraphormerConfig
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class GraphCLIPConfig(CLIPConfig):
|
| 10 |
+
r"""
|
| 11 |
+
Configuration for GraphCLIPModel, which extends CLIP with a Graphormer encoder.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
graph_config (`Union[dict, GraphormerConfig]`):
|
| 15 |
+
Configuration (or dict) for the Graphormer graph encoder.
|
| 16 |
+
graph_pair_type (`str`, *optional*, defaults to `"text"`):
|
| 17 |
+
Which modality to pair against the graph in contrastive loss.
|
| 18 |
+
One of `"text"` or `"image"`.
|
| 19 |
+
pretrained_model_name_or_path (`str`, *optional*):
|
| 20 |
+
If set, vision & text heads will be loaded from this CLIP checkpoint.
|
| 21 |
+
pretrained_graphormer_hub_id (`str`, *optional*):
|
| 22 |
+
If set, the Graphormer will be loaded from this HuggingFace Hub model ID.
|
| 23 |
+
alpha (`float`, *optional*, defaults to 0.5):
|
| 24 |
+
Weight for combining image-text and graph-pair contrastive losses.
|
| 25 |
+
**kwargs:
|
| 26 |
+
All remaining kwargs will be passed to the base `CLIPConfig` (e.g., `projection_dim`,
|
| 27 |
+
`vision_layers`, `text_layers`, etc.).
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
model_type = "graph_clip"
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
graph_config: Union[dict, GraphormerConfig] = GraphormerConfig(
|
| 35 |
+
hidden_size=512,
|
| 36 |
+
embedding_dim=512,
|
| 37 |
+
ffn_embedding_dim=512,
|
| 38 |
+
num_hidden_layers=6,
|
| 39 |
+
dropout=0.1,
|
| 40 |
+
),
|
| 41 |
+
graph_pair_type: str = "text",
|
| 42 |
+
pretrained_model_name_or_path: Optional[str] = None,
|
| 43 |
+
pretrained_graphormer_hub_id: Optional[str] = None,
|
| 44 |
+
alpha: float = 0.5,
|
| 45 |
+
**kwargs,
|
| 46 |
+
):
|
| 47 |
+
super().__init__(**kwargs)
|
| 48 |
+
|
| 49 |
+
# build or assign the graph encoder config
|
| 50 |
+
if isinstance(graph_config, dict):
|
| 51 |
+
self.graph_config = GraphormerConfig(**graph_config)
|
| 52 |
+
else:
|
| 53 |
+
self.graph_config = graph_config
|
| 54 |
+
|
| 55 |
+
# which modality to pair the graph with
|
| 56 |
+
if graph_pair_type not in ("text", "image"):
|
| 57 |
+
raise ValueError("`graph_pair_type` must be either 'text' or 'image'")
|
| 58 |
+
self.graph_pair_type = graph_pair_type
|
| 59 |
+
|
| 60 |
+
# if provided, load CLIP vision/text from this checkpoint
|
| 61 |
+
self.pretrained_model_name_or_path = pretrained_model_name_or_path
|
| 62 |
+
|
| 63 |
+
# if provided, load pretrained Graphormer from this HuggingFace Hub model ID
|
| 64 |
+
self.pretrained_graphormer_hub_id = pretrained_graphormer_hub_id
|
| 65 |
+
|
| 66 |
+
# alpha for the contrastive loss
|
| 67 |
+
self.alpha = alpha
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:949597cdaf0150ce353b0306c1be1287537b06534a57c47d1bc655f2f8f2e74e
|
| 3 |
+
size 663769988
|
modeling_graph_clip.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contrastive Learning-Based Graph, Image, and Text Model."""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from transformers import TrainerCallback
|
| 9 |
+
from transformers import CLIPModel, CLIPTextModel, CLIPVisionModel, GraphormerModel
|
| 10 |
+
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, BaseModelOutputWithPooling, ModelOutput
|
| 11 |
+
from transformers.models.clip.modeling_clip import clip_loss
|
| 12 |
+
|
| 13 |
+
from nsd_compositionality.models.graph_clip_model.configuration_graph_clip import GraphCLIPConfig
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class LossLoggingCallback(TrainerCallback):
|
| 17 |
+
def on_log(self, args, state, control, logs=None, model=None, **kwargs): # type: ignore
|
| 18 |
+
"""Log losses from the model during training.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
args: Training arguments.
|
| 22 |
+
state: Training state.
|
| 23 |
+
control: Control object for training.
|
| 24 |
+
logs: Dictionary of logs to update.
|
| 25 |
+
model: The model being trained.
|
| 26 |
+
"""
|
| 27 |
+
if logs is None or model is None:
|
| 28 |
+
return
|
| 29 |
+
add = {}
|
| 30 |
+
if getattr(model, "last_loss_image_text", None) is not None:
|
| 31 |
+
add["loss_image_text"] = model.last_loss_image_text
|
| 32 |
+
if getattr(model, "last_loss_graph_pair", None) is not None:
|
| 33 |
+
add["loss_graph_pair"] = model.last_loss_graph_pair
|
| 34 |
+
if add:
|
| 35 |
+
logs.update(add)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class GraphCLIPOutput(ModelOutput):
|
| 40 |
+
"""
|
| 41 |
+
Custom output class for GraphCLIPModel.
|
| 42 |
+
|
| 43 |
+
Attributes:
|
| 44 |
+
loss (torch.FloatTensor, optional): Loss value if return_loss is True.
|
| 45 |
+
logits_image_text (torch.FloatTensor): Logits for image-text pairs.
|
| 46 |
+
logits_graph_pair (torch.FloatTensor): Logits for graph-text or graph-image pairs.
|
| 47 |
+
image_embeds (torch.FloatTensor): Image embeddings.
|
| 48 |
+
graph_embeds (torch.FloatTensor): Graph embeddings.
|
| 49 |
+
text_embeds (torch.FloatTensor): Text embeddings.
|
| 50 |
+
vision_model_output (BaseModelOutputWithPooling): Output from the vision model.
|
| 51 |
+
text_model_output (BaseModelOutputWithPooling): Output from the text model.
|
| 52 |
+
graph_model_output (BaseModelOutputWithNoAttention): Output from the graph model.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
loss: Optional[torch.FloatTensor] = None
|
| 56 |
+
logits_image_text: torch.FloatTensor = None
|
| 57 |
+
logits_graph_pair: torch.FloatTensor = None
|
| 58 |
+
image_embeds: torch.FloatTensor = None
|
| 59 |
+
graph_embeds: torch.FloatTensor = None
|
| 60 |
+
text_embeds: torch.FloatTensor = None
|
| 61 |
+
vision_model_output: BaseModelOutputWithPooling = None
|
| 62 |
+
text_model_output: BaseModelOutputWithPooling = None
|
| 63 |
+
graph_model_output: BaseModelOutputWithNoAttention = None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class GraphCLIPModel(CLIPModel):
|
| 67 |
+
config_class = GraphCLIPConfig
|
| 68 |
+
|
| 69 |
+
def __init__(self, config: GraphCLIPConfig):
|
| 70 |
+
# Specify configs
|
| 71 |
+
super().__init__(config)
|
| 72 |
+
graph_config = config.graph_config
|
| 73 |
+
self.alpha = getattr(config, "alpha", 0.5)
|
| 74 |
+
|
| 75 |
+
# If "pretrained_model_name_or_path" is in config, load the pretrained vision and text models
|
| 76 |
+
if config.pretrained_model_name_or_path:
|
| 77 |
+
self.vision_model = CLIPVisionModel.from_pretrained(
|
| 78 |
+
config.pretrained_model_name_or_path,
|
| 79 |
+
).vision_model
|
| 80 |
+
self.text_model = CLIPTextModel.from_pretrained(
|
| 81 |
+
config.pretrained_model_name_or_path,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Initialize Graphormer model - load pretrained if specified
|
| 85 |
+
if config.pretrained_graphormer_hub_id:
|
| 86 |
+
self.graph_model = GraphormerModel.from_pretrained(config.pretrained_graphormer_hub_id)
|
| 87 |
+
else:
|
| 88 |
+
self.graph_model = GraphormerModel._from_config(graph_config)
|
| 89 |
+
|
| 90 |
+
# Projection layer for graph embeddings
|
| 91 |
+
self.graph_projection = nn.Linear(graph_config.hidden_size, config.projection_dim, bias=False)
|
| 92 |
+
|
| 93 |
+
# Determine the graph pair type (either "text" or "image")
|
| 94 |
+
self.graph_pair_type = config.graph_pair_type # Should be "text" or "image"
|
| 95 |
+
|
| 96 |
+
# For logging component losses
|
| 97 |
+
self.last_loss_image_text: Optional[torch.tensor] = None
|
| 98 |
+
self.last_loss_graph_pair: Optional[torch.tensor] = None
|
| 99 |
+
|
| 100 |
+
def forward(
|
| 101 |
+
self,
|
| 102 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 103 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 104 |
+
graph_input: Optional[dict] = None,
|
| 105 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 106 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 107 |
+
return_loss: Optional[bool] = True,
|
| 108 |
+
output_attentions: Optional[bool] = None,
|
| 109 |
+
output_hidden_states: Optional[bool] = None,
|
| 110 |
+
return_dict: Optional[bool] = None,
|
| 111 |
+
**kwargs, # noqa
|
| 112 |
+
) -> Union[Tuple, GraphCLIPOutput]:
|
| 113 |
+
"""
|
| 114 |
+
Forward pass of GraphCLIP Model with three modalities: image, graph, and text.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
input_ids (torch.LongTensor): Tokenized text input IDs.
|
| 118 |
+
pixel_values (torch.FloatTensor): Batch of images.
|
| 119 |
+
graph_input (dict, optional): Dictionary of inputs for the Graphormer encoder.
|
| 120 |
+
attention_mask (torch.LongTensor, optional): Attention mask for the text encoder.
|
| 121 |
+
position_ids (torch.LongTensor, optional): Position IDs for text encoder.
|
| 122 |
+
return_loss (bool, optional): Whether to compute the contrastive loss, default is True.
|
| 123 |
+
output_attentions (bool, optional): Whether to output attentions.
|
| 124 |
+
output_hidden_states (bool, optional): Whether to output hidden states.
|
| 125 |
+
return_dict (bool, optional): Whether to return a ModelOutput object.
|
| 126 |
+
**kwargs: Additional keyword arguments.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
GraphCLIPOutput: Custom output object containing logits and embeddings.
|
| 130 |
+
"""
|
| 131 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 132 |
+
|
| 133 |
+
# Process images through the CLIP vision encoder
|
| 134 |
+
vision_outputs = self.vision_model(
|
| 135 |
+
pixel_values=pixel_values,
|
| 136 |
+
output_attentions=output_attentions,
|
| 137 |
+
output_hidden_states=output_hidden_states,
|
| 138 |
+
return_dict=return_dict,
|
| 139 |
+
)
|
| 140 |
+
image_embeds = vision_outputs[1] # Pooled output
|
| 141 |
+
image_embeds = self.visual_projection(image_embeds)
|
| 142 |
+
|
| 143 |
+
# Process text input through CLIP text encoder
|
| 144 |
+
text_outputs = self.text_model(
|
| 145 |
+
input_ids=input_ids,
|
| 146 |
+
attention_mask=attention_mask,
|
| 147 |
+
position_ids=position_ids,
|
| 148 |
+
output_attentions=output_attentions,
|
| 149 |
+
output_hidden_states=output_hidden_states,
|
| 150 |
+
return_dict=return_dict,
|
| 151 |
+
)
|
| 152 |
+
text_embeds = text_outputs[1] # Pooled output
|
| 153 |
+
text_embeds = self.text_projection(text_embeds)
|
| 154 |
+
|
| 155 |
+
# Process graph input through Graphormer (if provided)
|
| 156 |
+
graph_outputs = None
|
| 157 |
+
graph_embeds = None
|
| 158 |
+
if graph_input is not None:
|
| 159 |
+
graph_outputs = self.graph_model(
|
| 160 |
+
**graph_input,
|
| 161 |
+
)
|
| 162 |
+
# Use the special graph token for graph representation
|
| 163 |
+
graph_embeds = graph_outputs.last_hidden_state[:, 0, :]
|
| 164 |
+
graph_embeds = self.graph_projection(graph_embeds)
|
| 165 |
+
|
| 166 |
+
# Normalize the projected features
|
| 167 |
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 168 |
+
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 169 |
+
if graph_embeds is not None:
|
| 170 |
+
graph_embeds = graph_embeds / graph_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 171 |
+
|
| 172 |
+
# Compute scaled cosine similarity logits
|
| 173 |
+
logit_scale = self.logit_scale.exp()
|
| 174 |
+
logits_image_text = logit_scale * torch.matmul(image_embeds, text_embeds.t())
|
| 175 |
+
|
| 176 |
+
# Compute graph pair logits based on the specified pair type (if graph input is provided)
|
| 177 |
+
logits_graph_pair = None
|
| 178 |
+
if graph_embeds is not None:
|
| 179 |
+
if self.graph_pair_type == "text":
|
| 180 |
+
logits_graph_pair = logit_scale * torch.matmul(graph_embeds, text_embeds.t())
|
| 181 |
+
elif self.graph_pair_type == "image":
|
| 182 |
+
logits_graph_pair = logit_scale * torch.matmul(graph_embeds, image_embeds.t())
|
| 183 |
+
else:
|
| 184 |
+
raise ValueError("Invalid graph_pair_type. Must be 'text' or 'image'.")
|
| 185 |
+
|
| 186 |
+
loss = None
|
| 187 |
+
if return_loss:
|
| 188 |
+
# Compute contrastive loss for the specified pairs
|
| 189 |
+
loss_image_text = clip_loss(logits_image_text)
|
| 190 |
+
# Store for logging
|
| 191 |
+
try:
|
| 192 |
+
self.last_loss_image_text = loss_image_text.detach().mean()
|
| 193 |
+
except Exception:
|
| 194 |
+
self.last_loss_image_text = None
|
| 195 |
+
|
| 196 |
+
if logits_graph_pair is not None:
|
| 197 |
+
loss_graph_pair = clip_loss(logits_graph_pair)
|
| 198 |
+
try:
|
| 199 |
+
self.last_loss_graph_pair = loss_graph_pair.detach().mean()
|
| 200 |
+
except Exception:
|
| 201 |
+
self.last_loss_graph_pair = None
|
| 202 |
+
loss = (1.0 - self.alpha) * loss_image_text + self.alpha * loss_graph_pair
|
| 203 |
+
else:
|
| 204 |
+
self.last_loss_graph_pair = None
|
| 205 |
+
loss = loss_image_text
|
| 206 |
+
|
| 207 |
+
if not return_dict:
|
| 208 |
+
output = (
|
| 209 |
+
logits_image_text,
|
| 210 |
+
logits_graph_pair,
|
| 211 |
+
image_embeds,
|
| 212 |
+
graph_embeds,
|
| 213 |
+
text_embeds,
|
| 214 |
+
vision_outputs,
|
| 215 |
+
text_outputs,
|
| 216 |
+
graph_outputs,
|
| 217 |
+
)
|
| 218 |
+
return ((loss,) + output) if loss is not None else output
|
| 219 |
+
|
| 220 |
+
return GraphCLIPOutput(
|
| 221 |
+
loss=loss,
|
| 222 |
+
logits_image_text=logits_image_text,
|
| 223 |
+
logits_graph_pair=logits_graph_pair,
|
| 224 |
+
image_embeds=image_embeds,
|
| 225 |
+
graph_embeds=graph_embeds,
|
| 226 |
+
text_embeds=text_embeds,
|
| 227 |
+
vision_model_output=vision_outputs,
|
| 228 |
+
text_model_output=text_outputs,
|
| 229 |
+
graph_model_output=graph_outputs,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
def freeze_layers(self, freeze_vision: bool = False, freeze_text: bool = False, freeze_graph: bool = False):
|
| 233 |
+
"""
|
| 234 |
+
Freeze or unfreeze layers of the vision, text, and graph backbones.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
freeze_vision (bool): Whether to freeze the vision backbone.
|
| 238 |
+
freeze_text (bool): Whether to freeze the text backbone.
|
| 239 |
+
freeze_graph (bool): Whether to freeze the graph backbone.
|
| 240 |
+
"""
|
| 241 |
+
if freeze_vision:
|
| 242 |
+
for param in self.vision_model.parameters():
|
| 243 |
+
param.requires_grad = False
|
| 244 |
+
|
| 245 |
+
if freeze_text:
|
| 246 |
+
for param in self.text_model.parameters():
|
| 247 |
+
param.requires_grad = False
|
| 248 |
+
|
| 249 |
+
if freeze_graph:
|
| 250 |
+
for param in self.graph_model.parameters():
|
| 251 |
+
param.requires_grad = False
|
| 252 |
+
|
| 253 |
+
def unfreeze_partial_layers(self, model_part: str, num_layers: int):
|
| 254 |
+
"""
|
| 255 |
+
Unfreeze the last `num_layers` of a specific model part.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
model_part (str): The part of the model to unfreeze ('vision', 'text', or 'graph').
|
| 259 |
+
num_layers (int): Number of layers to unfreeze from the end.
|
| 260 |
+
"""
|
| 261 |
+
if model_part == "vision":
|
| 262 |
+
layers = list(self.vision_model.encoder.layers)
|
| 263 |
+
elif model_part == "text":
|
| 264 |
+
layers = list(self.text_model.text_model.encoder.layers)
|
| 265 |
+
elif model_part == "graph":
|
| 266 |
+
layers = list(self.graph_model.graph_encoder.layers)
|
| 267 |
+
else:
|
| 268 |
+
raise ValueError("Invalid model_part. Must be 'vision', 'text', or 'graph'.")
|
| 269 |
+
|
| 270 |
+
# Freeze all layers first
|
| 271 |
+
for layer in layers:
|
| 272 |
+
for param in layer.parameters():
|
| 273 |
+
param.requires_grad = False
|
| 274 |
+
|
| 275 |
+
# Unfreeze the last `num_layers`
|
| 276 |
+
for layer in layers[-num_layers:]:
|
| 277 |
+
for param in layer.parameters():
|
| 278 |
+
param.requires_grad = True
|