lukeingawesome commited on
Commit
82069f1
·
verified ·
1 Parent(s): 21c8bfc

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +0 -34
  2. model.py +1 -5
README.md CHANGED
@@ -120,40 +120,6 @@ python inference.py \
120
  --previous_image /path/to/previous.png
121
  ```
122
 
123
- ## Model Architecture
124
-
125
- ```
126
- IMAGE ENCODER:
127
- Input: current CXR [B, 3, 448, 448] + previous CXR [B, 3, 448, 448]
128
- |
129
- +-- ResNet-50 backbone (shared weights, processes both images)
130
- | -> patch features [B, 2048, 14, 14]
131
- |
132
- +-- 1x1 Conv projection (2048 -> 256)
133
- |
134
- +-- Vision Transformer Pooler (3 blocks, 8 heads)
135
- | -> temporal difference features [B, 256, 14, 14]
136
- |
137
- +-- Concatenate [static, temporal] -> [B, 512, 14, 14]
138
- |
139
- +-- MLP Projector (512 -> 128)
140
- -> image embedding [B, 128] <-- get_embeddings()
141
-
142
- TEXT ENCODER:
143
- Input: tokenized text
144
- |
145
- +-- CXR-BERT (12 layers, 768-dim)
146
- | -> CLS token [B, 768]
147
- |
148
- +-- LayerNorm + Linear (768 -> 128)
149
- -> text embedding [B, 128] <-- encode_text()
150
-
151
- CLASSIFIER:
152
- image embedding [B, 128]
153
- |
154
- +-- Linear (128 -> 64) -> ReLU -> Linear (64 -> 1)
155
- -> change probability [B] <-- get_interval_change_prediction()
156
- ```
157
 
158
  ## Preprocessing Raw Images
159
 
 
120
  --previous_image /path/to/previous.png
121
  ```
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  ## Preprocessing Raw Images
125
 
model.py CHANGED
@@ -481,11 +481,7 @@ class TILAModel(_BASE_CLASS):
481
  tokens = self._tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=256)
482
  tokens = {k: v.to(device) for k, v in tokens.items()}
483
  self.eval()
484
- # Run text encoder in float32 for numerical stability
485
- with torch.autocast(device_type=device.type if isinstance(device, torch.device) else "cuda", enabled=False):
486
- self.text_encoder.float()
487
- emb = self.text_encoder(tokens)
488
- self.text_encoder.to(next(self.image_encoder.parameters()).dtype)
489
  return F.normalize(emb.float(), p=2, dim=1)
490
 
491
  @torch.no_grad()
 
481
  tokens = self._tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=256)
482
  tokens = {k: v.to(device) for k, v in tokens.items()}
483
  self.eval()
484
+ emb = self.text_encoder(tokens)
 
 
 
 
485
  return F.normalize(emb.float(), p=2, dim=1)
486
 
487
  @torch.no_grad()