Upload modeling_chexficient.py
Browse files- modeling_chexficient.py +2 -1
modeling_chexficient.py
CHANGED
|
@@ -72,6 +72,7 @@ class CheXficientModel(PreTrainedModel):
|
|
| 72 |
# ===== Encoders =====
|
| 73 |
self.image_encoder = ImageEncoder(model_name=config.vision_model_name, image_size=config.image_size)
|
| 74 |
self.text_encoder = TextEncoder(model_name=config.text_model_name)
|
|
|
|
| 75 |
|
| 76 |
# ===== Projection heads =====
|
| 77 |
self.image_projection = load_projection_head(
|
|
@@ -108,7 +109,7 @@ class CheXficientModel(PreTrainedModel):
|
|
| 108 |
else:
|
| 109 |
raise NotImplementedError("Not supported pooling method : %s", self.text_pooling)
|
| 110 |
|
| 111 |
-
text_embeddings = self.text_projection(text_features)
|
| 112 |
|
| 113 |
text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
|
| 114 |
|
|
|
|
| 72 |
# ===== Encoders =====
|
| 73 |
self.image_encoder = ImageEncoder(model_name=config.vision_model_name, image_size=config.image_size)
|
| 74 |
self.text_encoder = TextEncoder(model_name=config.text_model_name)
|
| 75 |
+
self.text_pooling = 'eos'
|
| 76 |
|
| 77 |
# ===== Projection heads =====
|
| 78 |
self.image_projection = load_projection_head(
|
|
|
|
| 109 |
else:
|
| 110 |
raise NotImplementedError("Not supported pooling method : %s", self.text_pooling)
|
| 111 |
|
| 112 |
+
text_embeddings = self.text_projection(text_features)
|
| 113 |
|
| 114 |
text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
|
| 115 |
|