Spaces:
Sleeping
Sleeping
File size: 3,676 Bytes
d3b130f d6684f1 dcb533e d6684f1 dcb533e d3b130f d6684f1 d3b130f dcb533e d6684f1 e961059 d3b130f 9ed1c3d d3b130f 9ed1c3d d3b130f d6684f1 dcb533e d3b130f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
from fastapi import FastAPI,Query,HTTPException
import torchxrayvision as xrv
import skimage, torch, torchvision
import cv2
import numpy as np
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from transformers import pipeline
from PIL import Image
from fastapi.middleware.cors import CORSMiddleware
import requests
app = FastAPI()
# Add the frontend origin here
origins = [
"http://localhost:8080", # Your frontend running on port 8080
"http://127.0.0.1:8080"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins, # OR ["*"] only during dev
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model = xrv.models.DenseNet(weights="densenet121-res224-all")
tb_classifier =pipeline("image-classification",model="vimal-humantics/dinov2-base-xray-224-finetuned-tb")
def show_anomaly_bounding_box(img_tensor, model, class_index=None):
target_layer = model.features[-1]
cam = GradCAM(model=model, target_layers=[target_layer])
with torch.no_grad():
outputs = model(img_tensor[None, ...])
pred_index = class_index if class_index is not None else torch.argmax(outputs[0]).item()
grayscale_cam = cam(input_tensor=img_tensor[None, ...],
targets=[ClassifierOutputTarget(pred_index)])
grayscale_cam = grayscale_cam[0, :]
input_img = img_tensor.numpy()[0]
input_img_norm = (input_img - input_img.min()) / (input_img.max() - input_img.min())
input_img_rgb = cv2.cvtColor((input_img_norm * 255).astype(np.uint8), cv2.COLOR_GRAY2RGB)
cam_resized = cv2.resize(grayscale_cam, (224, 224))
cam_uint8 = (cam_resized * 255).astype(np.uint8)
_, thresh = cv2.threshold(cam_uint8, 100, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
bounding_box = ()
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
bounding_box = ((x,y),(x+w,y+h))
# cv2.rectangle(input_img_rgb, (x, y), (x + w, y + h), (0, 255, 0), 2)
return bounding_box
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.get('/predict')
def predict(image_url:str = Query(..., description="URL to a chest X-ray image")):
try:
img = skimage.io.imread(image_url)
img = xrv.datasets.normalize(img,255)
img = img.mean(2)[None, ...]
transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(224)])
img = transform(img)
img = torch.from_numpy(img)
outputs = model(img[None,...])
prediction = dict(zip(model.pathologies,outputs[0].detach().numpy().tolist()))
pred_class=outputs[0].argmax().item()
pred_label = model.pathologies[pred_class]
pred_output = {}
for k,v in prediction.items():
pred_output.update({k:round(v,2)})
get_bounding_box = show_anomaly_bounding_box(img,model=model)
# TB detection
image = Image.open(requests.get(image_url, stream=True).raw)
tb_finding = tb_classifier(images=image)
tb_label = tb_finding[0]['label']
print(tb_label)
tb_score = round(tb_finding[0]['score'],2)
tb_output = 0
if tb_label == "normal":
tb_output = 1-tb_score
else:
tb_output = tb_score
return {"prediction_result":pred_output,"bounding_box":{pred_label:get_bounding_box},"tb_finding":tb_output}
except Exception as e:
print(e)
raise HTTPException(status_code=400, detail=f"Failed to fetch/process image: {str(e)}")
|