vimal-yuvabe commited on
Commit
d6684f1
·
1 Parent(s): 9ed1c3d

added bounding box features

Browse files
Files changed (2) hide show
  1. Dockerfile +2 -0
  2. app.py +56 -2
Dockerfile CHANGED
@@ -13,6 +13,8 @@ COPY --chown=user ./requirements.txt requirements.txt
13
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
  RUN pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
15
  RUN pip install torchxrayvision
 
 
16
 
17
  COPY --chown=user . /app
18
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
13
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
  RUN pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
15
  RUN pip install torchxrayvision
16
+ RUN pip install grad-cam
17
+ RUN pip install opencv-python
18
 
19
  COPY --chown=user . /app
20
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -1,11 +1,64 @@
1
  from fastapi import FastAPI,Query,HTTPException
2
  import torchxrayvision as xrv
3
  import skimage, torch, torchvision
4
- import requests
 
 
 
5
 
6
 
 
7
  app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  model = xrv.models.DenseNet(weights="densenet121-res224-all")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  @app.get("/")
10
  def greet_json():
11
  return {"Hello": "World!"}
@@ -28,8 +81,9 @@ def predict(image_url:str = Query(..., description="URL to a chest X-ray image")
28
  for k,v in prediction.items():
29
  pred_output.update({k:round(v,2)})
30
 
 
31
 
32
- return {"prediction_result":pred_output,"bounding_box":{pred_label:((139,61),(224,132))}}
33
  except Exception as e:
34
  print(e)
35
  raise HTTPException(status_code=400, detail=f"Failed to fetch/process image: {str(e)}")
 
1
  from fastapi import FastAPI,Query,HTTPException
2
  import torchxrayvision as xrv
3
  import skimage, torch, torchvision
4
+ import cv2
5
+ import numpy as np
6
+ from pytorch_grad_cam import GradCAM
7
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
8
 
9
 
10
+ from fastapi.middleware.cors import CORSMiddleware
11
  app = FastAPI()
12
+ # Add the frontend origin here
13
+ origins = [
14
+ "http://localhost:8080", # Your frontend running on port 8080
15
+ "http://127.0.0.1:8080"
16
+ ]
17
+
18
+ app.add_middleware(
19
+ CORSMiddleware,
20
+ allow_origins=origins, # OR ["*"] only during dev
21
+ allow_credentials=True,
22
+ allow_methods=["*"],
23
+ allow_headers=["*"],
24
+ )
25
+
26
+
27
  model = xrv.models.DenseNet(weights="densenet121-res224-all")
28
+
29
+
30
+
31
+
32
+
33
+ def show_anomaly_bounding_box(img_tensor, model, class_index=None):
34
+ target_layer = model.features[-1]
35
+
36
+ cam = GradCAM(model=model, target_layers=[target_layer])
37
+
38
+ with torch.no_grad():
39
+ outputs = model(img_tensor[None, ...])
40
+ pred_index = class_index if class_index is not None else torch.argmax(outputs[0]).item()
41
+
42
+ grayscale_cam = cam(input_tensor=img_tensor[None, ...],
43
+ targets=[ClassifierOutputTarget(pred_index)])
44
+ grayscale_cam = grayscale_cam[0, :]
45
+
46
+ input_img = img_tensor.numpy()[0]
47
+ input_img_norm = (input_img - input_img.min()) / (input_img.max() - input_img.min())
48
+ input_img_rgb = cv2.cvtColor((input_img_norm * 255).astype(np.uint8), cv2.COLOR_GRAY2RGB)
49
+
50
+ cam_resized = cv2.resize(grayscale_cam, (224, 224))
51
+ cam_uint8 = (cam_resized * 255).astype(np.uint8)
52
+ _, thresh = cv2.threshold(cam_uint8, 100, 255, cv2.THRESH_BINARY)
53
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
54
+ bounding_box = ()
55
+ for cnt in contours:
56
+ x, y, w, h = cv2.boundingRect(cnt)
57
+ bounding_box = ((x,y),(x+w,y+h))
58
+ # cv2.rectangle(input_img_rgb, (x, y), (x + w, y + h), (0, 255, 0), 2)
59
+
60
+ return bounding_box
61
+
62
  @app.get("/")
63
  def greet_json():
64
  return {"Hello": "World!"}
 
81
  for k,v in prediction.items():
82
  pred_output.update({k:round(v,2)})
83
 
84
+ get_bounding_box = show_anomaly_bounding_box(img,model=model)
85
 
86
+ return {"prediction_result":pred_output,"bounding_box":{pred_label:get_bounding_box}}
87
  except Exception as e:
88
  print(e)
89
  raise HTTPException(status_code=400, detail=f"Failed to fetch/process image: {str(e)}")