Spaces:
Build error
Build error
| import sys | |
| sys.path.append('DenseMammogram') | |
| import torch | |
| from models import get_FRCNN_model, Bilateral_model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| frcnn_model = get_FRCNN_model().to(device) | |
| bilat_model = Bilateral_model(frcnn_model).to(device) | |
| FRCNN_PATH = 'pretrained_models/frcnn/frcnn_models/frcnn_model.pth' | |
| BILAR_PATH = 'pretrained_models/BILATERAL/bilateral_models/bilateral_model.pth' | |
| frcnn_model.load_state_dict(torch.load(FRCNN_PATH, map_location=device)) | |
| bilat_model.load_state_dict(torch.load(BILAR_PATH, map_location=device)) | |
| import os | |
| import torchvision.transforms as T | |
| import cv2 | |
| from tqdm import tqdm | |
| import detection.transforms as transforms | |
| from dataloaders import get_direction | |
| def predict(left_file, right_file, threshold = 0.80, baseIsLeft = True): | |
| model = bilat_model | |
| with torch.no_grad(): | |
| transform = T.Compose([T.ToPILImage(),T.ToTensor()]) | |
| model.eval() | |
| # First is left, then right | |
| img1 = cv2.imread(left_file) | |
| img1 = transform(img1) | |
| img2 = cv2.imread(right_file) | |
| img2 = transform(img2) | |
| if baseIsLeft: | |
| img1,_ = transforms.RandomHorizontalFlip(1.0)(img1) | |
| else: | |
| img2,_ = transforms.RandomHorizontalFlip(1.0)(img2) | |
| images = [img1.to(device),img2.to(device)] | |
| output = model([images])[0] | |
| if baseIsLeft: | |
| img1,output = transforms.RandomHorizontalFlip(1.0)(img1,output) | |
| image = cv2.imread(left_file) | |
| for b,s,l in zip(output['boxes'], output['scores'], output['labels']): | |
| # Convert img1 tensor to numpy array | |
| if l == 1 and s > threshold: | |
| # Draw the bounding boxes | |
| b = b.detach().cpu().numpy().astype(int) | |
| # return image, b | |
| cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), (0, 255, 0), 2) | |
| # Print the % probability just above the box | |
| cv2.putText(image, 'Cancer: '+str(round(round(s.item(), 2) * 100, 1)) + '%', (b[0], b[1] - 40), cv2.FONT_HERSHEY_SIMPLEX, 3.6, (36,255,12), 6) | |
| return image |