File size: 3,676 Bytes
d3b130f
 
 
d6684f1
 
 
 
dcb533e
 
d6684f1
dcb533e
 
d3b130f
d6684f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3b130f
dcb533e
d6684f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e961059
 
 
d3b130f
 
 
 
 
 
 
 
 
 
 
9ed1c3d
d3b130f
9ed1c3d
 
d3b130f
 
 
 
d6684f1
dcb533e
 
 
 
 
 
 
 
 
 
 
 
 
 
d3b130f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from fastapi import FastAPI,Query,HTTPException
import torchxrayvision as xrv
import skimage, torch, torchvision
import cv2
import numpy as np
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from transformers import pipeline
from PIL import Image
from fastapi.middleware.cors import CORSMiddleware
import requests

app = FastAPI()
# Add the frontend origin here
origins = [
    "http://localhost:8080",  # Your frontend running on port 8080
    "http://127.0.0.1:8080"
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,  # OR ["*"] only during dev
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


model = xrv.models.DenseNet(weights="densenet121-res224-all")
tb_classifier =pipeline("image-classification",model="vimal-humantics/dinov2-base-xray-224-finetuned-tb") 




def show_anomaly_bounding_box(img_tensor, model, class_index=None):
    target_layer = model.features[-1]

    cam = GradCAM(model=model, target_layers=[target_layer])

    with torch.no_grad():
        outputs = model(img_tensor[None, ...])
        pred_index = class_index if class_index is not None else torch.argmax(outputs[0]).item()

    grayscale_cam = cam(input_tensor=img_tensor[None, ...],
                        targets=[ClassifierOutputTarget(pred_index)])
    grayscale_cam = grayscale_cam[0, :]

    input_img = img_tensor.numpy()[0]
    input_img_norm = (input_img - input_img.min()) / (input_img.max() - input_img.min())
    input_img_rgb = cv2.cvtColor((input_img_norm * 255).astype(np.uint8), cv2.COLOR_GRAY2RGB)

    cam_resized = cv2.resize(grayscale_cam, (224, 224))
    cam_uint8 = (cam_resized * 255).astype(np.uint8)
    _, thresh = cv2.threshold(cam_uint8, 100, 255, cv2.THRESH_BINARY)
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    bounding_box = ()
    for cnt in contours:
        x, y, w, h = cv2.boundingRect(cnt)
        bounding_box = ((x,y),(x+w,y+h))
        # cv2.rectangle(input_img_rgb, (x, y), (x + w, y + h), (0, 255, 0), 2)

    return bounding_box

@app.get("/")
def greet_json():
    return {"Hello": "World!"}

@app.get('/predict')
def predict(image_url:str = Query(..., description="URL to a chest X-ray image")):
    try:
        img = skimage.io.imread(image_url)
        img = xrv.datasets.normalize(img,255)
        img = img.mean(2)[None, ...]
        transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(224)])
        img = transform(img)
        img = torch.from_numpy(img)
        outputs = model(img[None,...])
        
        prediction = dict(zip(model.pathologies,outputs[0].detach().numpy().tolist()))
        pred_class=outputs[0].argmax().item()
        pred_label = model.pathologies[pred_class]
        pred_output = {}
        for k,v in prediction.items():
            pred_output.update({k:round(v,2)})

        get_bounding_box = show_anomaly_bounding_box(img,model=model)
        # TB detection
        
        image = Image.open(requests.get(image_url, stream=True).raw)
        tb_finding = tb_classifier(images=image)
        tb_label = tb_finding[0]['label']
        print(tb_label)
        tb_score = round(tb_finding[0]['score'],2)
        tb_output = 0
        if tb_label == "normal":
            tb_output = 1-tb_score
        else:
            tb_output = tb_score

        return {"prediction_result":pred_output,"bounding_box":{pred_label:get_bounding_box},"tb_finding":tb_output}
    except Exception as e:
        print(e)
        raise HTTPException(status_code=400, detail=f"Failed to fetch/process image: {str(e)}")