MedResNet RadioCaptioner
MedResNet RadioCaptioner is a custom medical image captioning model designed to generate descriptive text for medical images (X-rays, MRIs, CT scans). It utilizes a memory-efficient architecture combining a ResNet-50 encoder, a Graph Propagation Layer for spatial reasoning, and a 5-layer Transformer Decoder.
This model was developed to address memory constraints on TPUs while maintaining the ability to capture spatial relationships in medical imagery, which is crucial for generating accurate clinical reports.
Model Architecture
The model follows an encoder-decoder structure with a custom graph-based intermediate layer:
Vision Encoder (ResNet-50):
- Base model:
microsoft/resnet-50. - Input Resolution: $224 \times 224$.
- Output: Feature maps of dimension $7 \times 7$ (49 spatial tokens).
- Unlike standard CLS-token approaches, this model flattens the spatial grid to preserve localization information.
- The encoder is unfrozen and trained end-to-end.
- Note: The
classifier.1.weightkey is intentionally ignored (UNEXPECTED) during weight loading, as the pre-trained classification head is removed and replaced with the custom captioning decoder.
- Base model:
Projection & Graph Propagation:
- Visual Projection: Linear layer projecting ResNet-50 channels (2048) to the embedding dimension.
- GraphPropagationLayer: A custom neural module that performs message passing over the spatial grid ($7 \times 7$). It aggregates information from local neighborhoods (window size 3) over multiple steps to refine feature representations based on spatial context.
Text Decoder (Transformer):
- A 5-layer Transformer Decoder.
- Generates text autoregressively with causal masking.
- Uses standard BERT Tokenizer (
bert-base-uncased) for vocabulary mapping.
Architecture Diagram
Intended Uses & Limitations
Intended Uses
- Automated Report Generation: Generating preliminary captions for medical imaging datasets.
- Research: Testing graph propagation techniques on vision-language tasks.
- Education: Demonstrating encoder-decoder architectures for medical AI.
Limitations
- Not a Clinical Tool: This model is a research prototype. Do not use it for medical diagnosis or clinical decision support.
- Domain Specific: Performance may vary significantly depending on the specific modality (X-ray vs. CT) or the specific dataset the model was fine-tuned on.
- Memory vs. Accuracy: This model prioritizes memory efficiency and speed over the potentially higher accuracy of larger ViT-based encoders (like SigLIP).
How to use
Because this model defines a custom architecture (ImageCaptioningModel and GraphPropagationLayer), you cannot use standard AutoModel loaders. You must define the classes in your script to load the weights.
Installation
pip install torch transformers pillow torchvision
### Inference Code
python
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from transformers import BertTokenizer
from huggingface_hub import hf_hub_download
import json
# 1. Define the Architecture (Must match the training code)
class GraphPropagationLayer(nn.Module):
def __init__(self, embed_dim, window_size=3, max_steps=3):
super().__init__()
self.window_size = window_size
self.max_steps = max_steps
self.msg_linear = nn.Linear(embed_dim, embed_dim)
self.activation = nn.Hardtanh()
def forward(self, x):
B, L, D = x.shape
pad = self.window_size // 2
for _ in range(self.max_steps):
neighbors = []
for i in range(-pad, pad + 1):
neighbors.append(torch.roll(x, shifts=i, dims=1))
stacked = torch.stack(neighbors, dim=2)
aggregated = stacked.mean(dim=2)
messages = self.msg_linear(aggregated)
x = self.activation(x + messages)
return x
class ImageCaptioningModel(nn.Module):
def __init__(self, vocab_size, embed_dim):
super().__init__()
# Initialize ResNet-50
from transformers import AutoModel
self.encoder = AutoModel.from_pretrained("microsoft/resnet-50")
self.encoder.config.pooling = None # Prevent global pooling
# Get hidden size from ResNet config
encoder_dim = self.encoder.config.hidden_sizes[-1]
self.vis_proj = nn.Linear(encoder_dim, embed_dim)
self.graph_layer = GraphPropagationLayer(embed_dim)
self.embedding = nn.Embedding(vocab_size, embed_dim)
# Positional Encoding for Decoder
self.pos_encoder = nn.Parameter(torch.randn(1, 2000, embed_dim))
decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=8, dim_feedforward=1024, batch_first=True)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=5)
self.fc_out = nn.Linear(embed_dim, vocab_size)
def encode_images(self, images):
outputs = self.encoder(pixel_values=images)
# Reshape [B, C, 7, 7] -> [B, 49, C]
B, C, H, W = outputs.last_hidden_state.shape
visual_features = outputs.last_hidden_state.permute(0, 2, 3, 1).reshape(B, H*W, C)
visual_tokens = self.vis_proj(visual_features)
return self.graph_layer(visual_tokens)
@torch.no_grad()
def generate(self, images, max_len=80, start_token_id=101, end_token_id=102, pad_token_id=0):
self.eval()
memory = self.encode_images(images)
batch_size = images.shape[0]
generated = torch.full((batch_size, 1), start_token_id, dtype=torch.long, device=images.device)
for _ in range(max_len):
tgt_emb = self.embedding(generated)
seq_len = tgt_emb.size(1)
tgt_emb = tgt_emb + self.pos_encoder[:, :seq_len, :]
tgt_mask = torch.triu(torch.ones((seq_len, seq_len), device=images.device) * float('-inf'), diagonal=1)
tgt_key_padding_mask = (generated == pad_token_id).to(images.device)
output = self.decoder(tgt=tgt_emb, memory=memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask)
next_token = self.fc_out(output[:, -1, :]).argmax(dim=-1).unsqueeze(1)
if (next_token == end_token_id).all():
break
generated = torch.cat([generated, next_token], dim=1)
return generated
# 2. Load Assets
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Download Model Files from Hub
config_path = hf_hub_download(repo_id="erfanasghariyan/resnet50_radiocaptioner", filename="config.json", subfolder="model_checkpoint")
model_path = hf_hub_download(repo_id="erfanasghariyan/resnet50_radiocaptioner", filename="pytorch_model.bin", subfolder="model_checkpoint")
with open(config_path, "r") as f:
config = json.load(f)
# 3. Instantiate Model
model = ImageCaptioningModel(vocab_size=config['vocab_size'], embed_dim=config['embed_dim']).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
# 4. Preprocess & Run
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Run Inference
image = Image.open("path_to_xray.jpg").convert('RGB')
inputs = transform(image).unsqueeze(0).to(device)
generated_ids = model.generate(inputs, max_len=80, start_token_id=tokenizer.cls_token_id)
print("Caption:", tokenizer.decode(generated_ids.squeeze(0), skip_special_tokens=True))
## Training Details
* **Hardware:** Google TPU v3-8.
* **Optimizer:** AdamW.
* **Learning Rate:** Started at $1e^{-4}$ with decay.
* **Preprocessing:** Images resized to $224 \times 224$ and normalized using ImageNet mean/std statistics.
* **Loss Function:** CrossEntropyLoss applied to decoder outputs.
### Training Performance
The model was trained for 1 epoch on the medical dataset.
| Metric | Value |
| :--- | :--- |
| **Training Loss** | 0.7590 |
| **Perplexity (PPL)** | 2.1361 |
## Evaluation Results & Inference
After the first epoch, the model was evaluated on sample data. Below are comparisons between the model-generated captions and the ground truth.
### Sample Inference Comparisons
| Sample | Region | Generated Caption | Ground Truth |
| :--- | :--- | :--- | :--- |
| **1** | Hip | "mri scan of the hip femoral revealing normal tissue architecture" | "mr scan of the hip pelvic revealing no abnormal findings" |
| **2** | Knee | "mri scan of the knee articular showing soft tissue fluid collection in the region" | "mr scan examination of the knee tibiofemoral identifies soft tissue fluid collection" |
| **3** | Ankle | "mri scan of the ankle foot ankle with evident chondral abnormality" | "mr imaging study of the ankle foot tarsal confirms chondral abnormality" |
**Observation:** The model demonstrates strong capability in identifying anatomical regions (hip, knee, ankle) and pathologies (soft tissue fluid, chondral abnormality), although minor variations in anatomical descriptors (e.g., "femoral" vs "pelvic") occur.
## Acknowledgements
* **ResNet:** He et al. for the ResNet-50 architecture.
* **Transformers:** Hugging Face ecosystem for tokenizer and base model integration.
* **SigLIP Inspiration:** Initial architecture designed with MedSiglip-448, adapted to ResNet-50 for memory efficiency.
- Downloads last month
- 31