Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI,Query,HTTPException | |
| import torchxrayvision as xrv | |
| import skimage, torch, torchvision | |
| import cv2 | |
| import numpy as np | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from transformers import pipeline | |
| from PIL import Image | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import requests | |
| app = FastAPI() | |
| # Add the frontend origin here | |
| origins = [ | |
| "http://localhost:8080", # Your frontend running on port 8080 | |
| "http://127.0.0.1:8080" | |
| ] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, # OR ["*"] only during dev | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| model = xrv.models.DenseNet(weights="densenet121-res224-all") | |
| tb_classifier =pipeline("image-classification",model="vimal-humantics/dinov2-base-xray-224-finetuned-tb") | |
| def show_anomaly_bounding_box(img_tensor, model, class_index=None): | |
| target_layer = model.features[-1] | |
| cam = GradCAM(model=model, target_layers=[target_layer]) | |
| with torch.no_grad(): | |
| outputs = model(img_tensor[None, ...]) | |
| pred_index = class_index if class_index is not None else torch.argmax(outputs[0]).item() | |
| grayscale_cam = cam(input_tensor=img_tensor[None, ...], | |
| targets=[ClassifierOutputTarget(pred_index)]) | |
| grayscale_cam = grayscale_cam[0, :] | |
| input_img = img_tensor.numpy()[0] | |
| input_img_norm = (input_img - input_img.min()) / (input_img.max() - input_img.min()) | |
| input_img_rgb = cv2.cvtColor((input_img_norm * 255).astype(np.uint8), cv2.COLOR_GRAY2RGB) | |
| cam_resized = cv2.resize(grayscale_cam, (224, 224)) | |
| cam_uint8 = (cam_resized * 255).astype(np.uint8) | |
| _, thresh = cv2.threshold(cam_uint8, 100, 255, cv2.THRESH_BINARY) | |
| contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| bounding_box = () | |
| for cnt in contours: | |
| x, y, w, h = cv2.boundingRect(cnt) | |
| bounding_box = ((x,y),(x+w,y+h)) | |
| # cv2.rectangle(input_img_rgb, (x, y), (x + w, y + h), (0, 255, 0), 2) | |
| return bounding_box | |
| def greet_json(): | |
| return {"Hello": "World!"} | |
| def predict(image_url:str = Query(..., description="URL to a chest X-ray image")): | |
| try: | |
| img = skimage.io.imread(image_url) | |
| img = xrv.datasets.normalize(img,255) | |
| img = img.mean(2)[None, ...] | |
| transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(224)]) | |
| img = transform(img) | |
| img = torch.from_numpy(img) | |
| outputs = model(img[None,...]) | |
| prediction = dict(zip(model.pathologies,outputs[0].detach().numpy().tolist())) | |
| pred_class=outputs[0].argmax().item() | |
| pred_label = model.pathologies[pred_class] | |
| pred_output = {} | |
| for k,v in prediction.items(): | |
| pred_output.update({k:round(v,2)}) | |
| get_bounding_box = show_anomaly_bounding_box(img,model=model) | |
| # TB detection | |
| image = Image.open(requests.get(image_url, stream=True).raw) | |
| tb_finding = tb_classifier(images=image) | |
| tb_label = tb_finding[0]['label'] | |
| print(tb_label) | |
| tb_score = round(tb_finding[0]['score'],2) | |
| tb_output = 0 | |
| if tb_label == "normal": | |
| tb_output = 1-tb_score | |
| else: | |
| tb_output = tb_score | |
| return {"prediction_result":pred_output,"bounding_box":{pred_label:get_bounding_box},"tb_finding":tb_output} | |
| except Exception as e: | |
| print(e) | |
| raise HTTPException(status_code=400, detail=f"Failed to fetch/process image: {str(e)}") | |