Spaces:
No application file
No application file
| import numpy as np | |
| import torch | |
| import clip | |
| from clip.model import CLIP | |
| from torchvision.transforms.transforms import Compose | |
| from PIL import Image | |
| from typing import List | |
| class ModalityClip: | |
| def __init__(self, modality:List[str])->int: | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| self.device=device | |
| model,preprocess=clip.load('ViT-B/32',device=self.device) | |
| self.model = model | |
| self.modality = modality | |
| self.text=clip.tokenize(modality).to(self.device) | |
| self.preprocess=preprocess | |
| def identify(self, img)->int: | |
| image=self.preprocess(img).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| logits_per_image, logits_per_text=self.model(image,self.text) | |
| probs=logits_per_image.softmax(dim=-1).cpu().numpy() | |
| max_index = np.argmax(probs, axis=1)[0] | |
| # print(f"This image is a {self.modality[max_index]}") | |
| return max_index | |
| if __name__=="__main__": | |
| modality=["panoramic dental x-ray","chest x-ray", "knee mri","Mammography","knee x-ray"] | |
| identifier=ModalityClip(modality) | |
| # upload medical images and input the filename | |
| # index=identifier.identify("dental/periodontals/Subject No.186.jpg") | |
| index=identifier.identify("chest.jpg") | |