import warnings import torch from transformers import AutoImageProcessor, AutoModel, AutoTokenizer from utils import model_inference # Suppress specific warnings for cleaner logs warnings.filterwarnings("ignore", category=UserWarning) def load_model(device, dtype): tokenizer = AutoTokenizer.from_pretrained("Deepnoid/RadZero") image_processor = AutoImageProcessor.from_pretrained("Deepnoid/RadZero") model = AutoModel.from_pretrained( "Deepnoid/RadZero", trust_remote_code=True, torch_dtype=dtype, device_map=device, ) model.eval() models = { "tokenizer": tokenizer, "image_processor": image_processor, "model": model, } return models if __name__ == "__main__": # Setup constant device = torch.device("cuda") dtype = torch.float32 # load models models = load_model(device, dtype) # load image image_path = "cxr_image.jpg" # inference similarity_prob, similarity_map = model_inference( image_path, "There is fibrosis", **models ) print(similarity_prob) print(similarity_map.min()) print(similarity_map.max()) print(similarity_map.shape)