kyanmahajan commited on
Commit
4e022d2
·
verified ·
1 Parent(s): 67627fb

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +12 -0
  2. lung.py +132 -0
  3. lung_xray_classify.pth +3 -0
  4. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /app
4
+
5
+ COPY . /app
6
+
7
+ RUN pip install --upgrade pip
8
+ RUN pip install -r requirements.txt
9
+
10
+ EXPOSE 7860
11
+
12
+ CMD ["python", "app.py"]
lung.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import models, transforms
5
+ from flask import Flask, jsonify, request, render_template
6
+ from PIL import Image
7
+ import os
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from PIL import Image
11
+ import numpy as np
12
+ import cv2
13
+ import cv2
14
+ import torch
15
+
16
+ from pytorch_grad_cam import GradCAM
17
+ from pytorch_grad_cam.utils.image import show_cam_on_image
18
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
19
+
20
+ from flask_cors import CORS
21
+
22
+
23
+
24
+ def get_edge_for_visualization(pil_img):
25
+
26
+ w, h = pil_img.size
27
+ scale = 256 / min(w, h)
28
+ new_w, new_h = int(w * scale), int(h * scale)
29
+ resized = pil_img.resize((new_w, new_h), Image.BILINEAR)
30
+
31
+
32
+ left = (resized.width - 224) // 2
33
+ top = (resized.height - 224) // 2
34
+ cropped = resized.crop((left, top, left + 224, top + 224))
35
+
36
+
37
+ gray = cv2.cvtColor(np.array(cropped), cv2.COLOR_RGB2GRAY)
38
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
39
+ edges = 255 - cv2.Canny(blurred, 30, 80)
40
+
41
+
42
+ edge_rgb = np.stack([edges]*3, axis=-1).astype(np.float32) / 255.0
43
+
44
+ return edge_rgb
45
+
46
+ app = Flask(__name__)
47
+ CORS(app)
48
+ os.makedirs("static", exist_ok=True)
49
+
50
+ # Device setup
51
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+
53
+
54
+ # Transform setup (same as training)
55
+ data_transforms = transforms.Compose([
56
+ transforms.Resize(256),
57
+ transforms.CenterCrop(224),
58
+ transforms.ToTensor(),
59
+ transforms.Normalize(
60
+ mean=[0.485, 0.456, 0.406],
61
+ std=[0.229, 0.224, 0.225]
62
+ )
63
+ ])
64
+
65
+
66
+ model = models.resnet18(pretrained=False)
67
+ model.fc = nn.Linear(model.fc.in_features, 3)
68
+
69
+ model.load_state_dict(torch.load("lung_xray_classify.pth", map_location=device))
70
+ model.to(device)
71
+ model.eval()
72
+
73
+
74
+ class_names = ['COVID', 'HEALTHY', 'PNEUMONIA']
75
+
76
+ # @app.route("/")
77
+ # def home():
78
+ # return render_template("index.html")
79
+
80
+ @app.route("/predict_lung", methods=["POST"])
81
+ def predict():
82
+ if "file" not in request.files:
83
+ return jsonify({"error": "No file provided"}), 400
84
+
85
+ file = request.files["file"]
86
+ filepath = os.path.join("static", file.filename)
87
+ file.save(filepath)
88
+
89
+ try:
90
+ image = Image.open(filepath).convert("RGB")
91
+ input_tensor = data_transforms(image).unsqueeze(0).to(device)
92
+
93
+ with torch.no_grad():
94
+ output = model(input_tensor)
95
+ pred_idx = torch.argmax(output, dim=1).item()
96
+ pred_label = class_names[pred_idx]
97
+ classes = [ClassifierOutputTarget(pred_idx)];
98
+ target_layer = [model.layer4[-1]]
99
+
100
+ cam = GradCAM(model = model, target_layers = target_layer)
101
+
102
+ heatmap = cam(input_tensor = input_tensor, targets = classes);
103
+ edge_img = get_edge_for_visualization(image)
104
+ cam_image = show_cam_on_image(edge_img, heatmap[0], use_rgb=True)
105
+ input_img = np.array(image);
106
+ input_img = input_img.astype(np.float32)/255
107
+ input_img = cv2.resize(input_img, (224,224))
108
+ cam_image_real = show_cam_on_image(input_img, heatmap[0], use_rgb=True)
109
+ cam_image_path = os.path.join("static", f"cam_{file.filename}")
110
+ cv2.imwrite(cam_image_path, cv2.cvtColor(cam_image_real, cv2.COLOR_RGB2BGR))
111
+ file={
112
+ "prediction": pred_label,
113
+ "cam_image_url": f"/static/lung_cam_{file.filename}"
114
+ }
115
+ print(file)
116
+ return jsonify(file)
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+ except Exception as e:
126
+ return jsonify({"error": str(e)}), 500
127
+
128
+
129
+ if __name__ == '__main__':
130
+ app.run(debug=True)
131
+
132
+
lung_xray_classify.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3704b0e4825e0d049f13227d73e472fed43a8cd06d89dbb109a8d63084cc104c
3
+ size 44792424
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ flask-cors
3
+ torch==2.2.0
4
+ torchvision==0.17.0
5
+ pillow
6
+ numpy
7
+ matplotlib
8
+ opencv-python-headless
9
+ git+https://github.com/jacobgil/pytorch-grad-cam.git