cwangrun commited on
Commit
ba779f8
·
verified ·
1 Parent(s): 57e8231

Upload modeling_chexficient.py

Browse files
Files changed (1) hide show
  1. modeling_chexficient.py +2 -1
modeling_chexficient.py CHANGED
@@ -72,6 +72,7 @@ class CheXficientModel(PreTrainedModel):
72
  # ===== Encoders =====
73
  self.image_encoder = ImageEncoder(model_name=config.vision_model_name, image_size=config.image_size)
74
  self.text_encoder = TextEncoder(model_name=config.text_model_name)
 
75
 
76
  # ===== Projection heads =====
77
  self.image_projection = load_projection_head(
@@ -108,7 +109,7 @@ class CheXficientModel(PreTrainedModel):
108
  else:
109
  raise NotImplementedError("Not supported pooling method : %s", self.text_pooling)
110
 
111
- text_embeddings = self.text_projection(text_features) if self.projection else text_features
112
 
113
  text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
114
 
 
72
  # ===== Encoders =====
73
  self.image_encoder = ImageEncoder(model_name=config.vision_model_name, image_size=config.image_size)
74
  self.text_encoder = TextEncoder(model_name=config.text_model_name)
75
+ self.text_pooling = 'eos'
76
 
77
  # ===== Projection heads =====
78
  self.image_projection = load_projection_head(
 
109
  else:
110
  raise NotImplementedError("Not supported pooling method : %s", self.text_pooling)
111
 
112
+ text_embeddings = self.text_projection(text_features)
113
 
114
  text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
115