import torch import torch.nn as nn import torchxrayvision as xrv class MedicalImageEncoder(nn.Module): """ SOTA Image Encoder sử dụng DenseNet-121 (TorchXRayVision) Pretrained trên 200K+ ảnh X-ray (CheXpert, NIH, v.v.) """ def __init__(self, pretrained=True): super(MedicalImageEncoder, self).__init__() if pretrained: self.model = xrv.models.DenseNet(weights="densenet121-res224-chex") else: self.model = xrv.models.DenseNet(weights=None) self.model.classifier = nn.Identity() # Bỏ lớp phân loại self.projector = nn.Linear(1024, 768) # Map về dimension của PhoBERT def forward(self, x): feat_map = self.model.features(x) # [B, 1024, 7, 7] feat_map = feat_map.flatten(2).transpose(1, 2) # [B, 49, 1024] return self.projector(feat_map) # [B, 49, 768]