dvtiendat commited on
Commit
a7f04f4
·
1 Parent(s): 1a3b25f
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ project1_datdvt
2
+ flagged
3
+ __pycache__
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Dat Dao
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
dataset/dataset.py ADDED
File without changes
design/design.css ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .container {
2
+ max-width: 1200px;
3
+ margin: 0 auto;
4
+ }
5
+
6
+ .heading {
7
+ background-image: linear-gradient(45deg, #00B894, #56a0f0);
8
+ background-clip: text;
9
+ -webkit-background-clip: text;
10
+ -webkit-text-fill-color: transparent;
11
+ color: transparent;
12
+ font-size: 3.5em !important;
13
+ font-weight: bold;
14
+ }
15
+
16
+ .primary-button {
17
+ background: linear-gradient(90deg, #00B894, #56a0f0) !important;
18
+ border: none !important;
19
+ box-shadow: 0 4px 15px rgba(0, 184, 148, 0.2) !important;
20
+ }
21
+
22
+ .primary-button:hover {
23
+ transform: translateY(-2px);
24
+ box-shadow: 0 6px 20px rgba(0, 184, 148, 0.3) !important;
25
+ }
26
+
27
+ .results-container {
28
+ text-align: center;
29
+ display: flex;
30
+ justify-content: center;
31
+ align-items: center;
32
+ background: rgba(0, 184, 148, 0.1);
33
+ border-radius: 10px;
34
+ padding: 20px;
35
+ }
36
+
37
+ .confidence-high {
38
+ text-align: center;
39
+ display: flex;
40
+ justify-content: center;
41
+ align-items: center;
42
+ color: #00B894 !important;
43
+ font-weight: bold;
44
+ }
45
+
46
+ .confidence-medium {
47
+ text-align: center;
48
+ display: flex;
49
+ justify-content: center;
50
+ align-items: center;
51
+ color: #FFA502 !important;
52
+ font-weight: bold;
53
+ }
54
+
55
+ .confidence-low {
56
+ text-align: center;
57
+ display: flex;
58
+ justify-content: center;
59
+ align-items: center;
60
+ color: #FF4757 !important;
61
+ font-weight: bold;
62
+ }
63
+
64
+ .diagnosis-text {
65
+ font-size: 16px;
66
+ text-align: center;
67
+ display: flex;
68
+ justify-content: center;
69
+ align-items: center;
70
+ padding: 10px;
71
+ box-sizing: border-box;
72
+ border: None;
73
+ background-color: #2f3640;
74
+ color: #f5f6fa;
75
+ }
76
+
77
+ .image-controls {
78
+ background: rgba(9, 132, 227, 0.1);
79
+ border-radius: 8px;
80
+ padding: 15px;
81
+ margin-top: 10px;
82
+ }
83
+
84
+ .accordion {
85
+ border: none !important;
86
+ box-shadow: none !important;
87
+ }
88
+
89
+ .accordion:hover {
90
+ background: rgba(255, 255, 255, 0.05) !important;
91
+ }
92
+
93
+ [data-testid="image"] {
94
+ border: 2px ridge #00B894;
95
+ border-radius: 10px;
96
+ transition: all 0.3s ease;
97
+ }
98
+
99
+ [data-testid="image"]:hover {
100
+ border-color: #56a0f0;
101
+ box-shadow: 0 0 10px rgba(9, 132, 227, 0.3);
102
+ }
images/COVID/covid_1579.png ADDED
images/Healthy/Normal (1).png ADDED
images/Non COVID/non_COVID (11905).png ADDED
interface.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch.nn.functional as F
3
+ import albumentations as A
4
+ from pipeline import *
5
+
6
+ def get_css(css_path):
7
+ with open(css_path, 'r') as f:
8
+ custom = f.read()
9
+
10
+ return custom
11
+
12
+ def create_interface():
13
+ custom = get_css('design/design.css')
14
+ processor = Pipeline()
15
+
16
+ with gr.Blocks(css=custom, theme=gr.themes.Soft(primary_hue='teal', secondary_hue='blue')) as interface:
17
+ with gr.Column(variant="compact"):
18
+ gr.Markdown("# Lungs Radiography Analysis", elem_classes='heading')
19
+ gr.Markdown("""
20
+ Upload/ Drop a chest X-ray image for COVID-19 diagnosis and analysis.
21
+ """)
22
+ with gr.Row(equal_height=True):
23
+ # [UPLOAD IMAGE SECTION]
24
+ with gr.Column():
25
+ input_image = gr.Image(
26
+ label="Upload Chest X-ray",
27
+ height=400,
28
+ elem_classes="upload-image"
29
+ )
30
+
31
+ # [BUTTON]
32
+ with gr.Row():
33
+ submit_btn = gr.Button("Analyze Image", variant="primary", elem_classes='primary-button', scale=2)
34
+ clear_btn = gr.Button('Clear', variant='secondary', scale=1)
35
+
36
+ with gr.Column():
37
+ with gr.Group(elem_classes='results-container'):
38
+ output_image = gr.Image(
39
+ label="COVID-19 Analysis",
40
+ visible=False,
41
+ height=400
42
+ )
43
+
44
+ with gr.Row(equal_height=True):
45
+ diagnosis_label = gr.Label(label="Diagnosis Conclusion", elem_classes='results-container')
46
+ confidence_label = gr.Label(label="Confidence Score", elem_classes='results-container')
47
+
48
+ with gr.Row():
49
+ diagnosis_text = gr.Textbox(
50
+ label="Diagnosis Details",
51
+ visible=False,
52
+ container=False
53
+ )
54
+
55
+ # [HELP SECTION]
56
+ with gr.Accordion("Information", open=False):
57
+ gr.Markdown("""
58
+ ### Tutorial
59
+ 1. Click the upload button/ Drag and drop a chest X-ray image.
60
+ 2. Choose 'Analyze Image'.
61
+ 3. Review the results:
62
+ - For COVID cases: View highlighted infection regions.
63
+ - For Non-COVID/Healthy cases: Review detailed diagnosis text.
64
+ """)
65
+
66
+ def clear_inputs():
67
+ return {
68
+ input_image: None,
69
+ output_image: gr.update(visible=False),
70
+ diagnosis_text: gr.update(visible=False),
71
+ diagnosis_label: None,
72
+ confidence_label: None
73
+ }
74
+
75
+ def handle_prediction(image, opacity=0.4):
76
+ prediction, confidence, output_img, analysis_text = processor.process_image(
77
+ image, overlay_opacity=opacity
78
+ )
79
+
80
+ confidence_class = (
81
+ "confidence-high" if confidence > 90
82
+ else "confidence-medium" if confidence > 70
83
+ else "confidence-low"
84
+ )
85
+ print(confidence_class)
86
+
87
+ is_covid = output_img is not None
88
+
89
+ return {
90
+ diagnosis_label: prediction,
91
+ confidence_label: gr.update(
92
+ value=f"Confidence: {confidence:.2f}%",
93
+ elem_classes=[confidence_class]
94
+ ),
95
+ output_image: gr.update(value=output_img, visible=is_covid),
96
+ diagnosis_text: gr.update(value=analysis_text, visible=True)
97
+ }
98
+
99
+ submit_btn.click(
100
+ fn=handle_prediction,
101
+ inputs=[input_image],
102
+ outputs=[
103
+ diagnosis_label,
104
+ confidence_label,
105
+ output_image,
106
+ diagnosis_text,
107
+ ]
108
+ )
109
+
110
+ clear_btn.click(
111
+ fn=clear_inputs,
112
+ inputs=[],
113
+ outputs=[
114
+ input_image,
115
+ output_image,
116
+ diagnosis_text,
117
+ diagnosis_label,
118
+ confidence_label
119
+ ]
120
+ )
121
+
122
+ return interface
123
+
124
+ if __name__ == "__main__":
125
+ interface = create_interface()
126
+ interface.launch(share=True)
models/classification_models/ResNet.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+
5
+ weights = models.ResNet50_Weights.DEFAULT
6
+ resnet_model = models.resnet50(weights=weights)
7
+ resnet_model.fc = nn.Linear(resnet_model.fc.in_features , 3)
8
+
9
+
models/segmentation_models/ResnetUnet.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+
5
+ def basic_block(in_channels, out_channels):
6
+ block = nn.Sequential(
7
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
8
+ nn.BatchNorm2d(out_channels),
9
+ nn.ReLU(inplace=True),
10
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
11
+ nn.BatchNorm2d(out_channels),
12
+ nn.ReLU(inplace=True)
13
+ )
14
+ return block
15
+
16
+ class DecoderBlock(nn.Module):
17
+ def __init__(self, in_channels, out_channels):
18
+ super().__init__()
19
+ self.basic_block = basic_block(in_channels, out_channels)
20
+ self.up_sample = nn.ConvTranspose2d(in_channels - out_channels, in_channels - out_channels, 2, 2)
21
+
22
+ def forward(self, down, skip):
23
+ x = self.up_sample(down)
24
+ x = torch.cat([x, skip], dim=1)
25
+ x = self.basic_block(x)
26
+ return x
27
+
28
+ class ResNetUnet(nn.Module):
29
+ def __init__(self, n_classes=1, freeze=True):
30
+ super().__init__()
31
+ backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
32
+
33
+ self.encoder1 = nn.Sequential(
34
+ backbone.conv1,
35
+ backbone.bn1,
36
+ backbone.relu
37
+ )
38
+ self.maxpool = backbone.maxpool
39
+ self.encoder2 = backbone.layer1
40
+ self.encoder3 = backbone.layer2
41
+ self.encoder4 = backbone.layer3
42
+ self.encoder5 = backbone.layer4
43
+
44
+ if freeze:
45
+ self._freeze_backbone()
46
+
47
+ self.decoder5 = DecoderBlock(2048 + 1024, 1024)
48
+ self.decoder4 = DecoderBlock(1024 + 512, 512)
49
+ self.decoder3 = DecoderBlock(512 + 256, 256)
50
+ self.decoder2 = DecoderBlock(256 + 64, 64)
51
+ self.decoder1 = nn.Sequential(
52
+ nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
53
+ nn.Conv2d(32, 32, kernel_size=3, padding=1),
54
+ nn.BatchNorm2d(32),
55
+ nn.ReLU(inplace=True)
56
+ )
57
+ self.out = nn.Conv2d(32, n_classes, kernel_size=1)
58
+
59
+ def _freeze_backbone(self):
60
+ layers = [self.encoder1, self.encoder2, self.encoder3,
61
+ self.encoder4, self.encoder5]
62
+
63
+ for layer in layers:
64
+ for param in layer.parameters():
65
+ param.requires_grad = False
66
+
67
+ def forward(self, x):
68
+ e1 = self.encoder1(x)
69
+ p1 = self.maxpool(e1)
70
+ e2 = self.encoder2(p1)
71
+ e3 = self.encoder3(e2)
72
+ e4 = self.encoder4(e3)
73
+ e5 = self.encoder5(e4)
74
+
75
+ d5 = self.decoder5(e5, e4)
76
+ d4 = self.decoder4(d5, e3)
77
+ d3 = self.decoder3(d4, e2)
78
+ d2 = self.decoder2(d3, e1)
79
+ d1 = self.decoder1(d2)
80
+ out = self.out(d1)
81
+
82
+ return out
83
+
84
+ ResNetUnetmodel_50 = ResNetUnet(n_classes=1, freeze=True)
pipeline.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import albumentations as A
4
+ from albumentations.pytorch import ToTensorV2
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from models.classification_models.ResNet import *
9
+ from models.segmentation_models.ResnetUnet import *
10
+
11
+ class Pipeline:
12
+ def __init__(self, img_size=256):
13
+ self.transform = self._get_transforms(img_size)
14
+ self.classification_model, self.segmentation_model = self._load_models()
15
+ self.class_names = ['COVID', 'Non-COVID', 'Healthy']
16
+
17
+ def _get_transforms(self, img_size):
18
+ return A.Compose([
19
+ A.LongestMaxSize(max_size=img_size),
20
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
21
+ ToTensorV2(),
22
+ ])
23
+
24
+ def _load_models(self):
25
+ classification_model = resnet_model
26
+ classification_model.load_state_dict(torch.load('weights/classification_models/resnet50.pt'))
27
+ classification_model.eval()
28
+
29
+ segmentation_model = ResNetUnet()
30
+ checkpoint = torch.load('weights/segmentation_models/ResNetUnet_best.pt')
31
+ segmentation_model.load_state_dict(checkpoint['model_state_dict'])
32
+ segmentation_model.eval()
33
+
34
+ return classification_model, segmentation_model
35
+
36
+ def process_image(self, image, overlay_opacity=0.4):
37
+ if image is None:
38
+ return None, None, None, None
39
+
40
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
41
+ transformed = self.transform(image=image)
42
+ input_tensor = transformed['image'].unsqueeze(0)
43
+
44
+ with torch.inference_mode():
45
+ outputs = self.classification_model(input_tensor)
46
+ probs = F.softmax(outputs, dim=1)
47
+ pred_class = torch.argmax(probs, dim=1).item()
48
+ confidence = probs[0][pred_class].item() * 100
49
+
50
+ prediction = self.class_names[pred_class]
51
+
52
+ if prediction == 'COVID':
53
+ with torch.inference_mode():
54
+ output = self.segmentation_model(input_tensor)
55
+ output = torch.sigmoid(output)
56
+ output = output.squeeze().cpu().numpy()
57
+ binary_mask = (output > 0.5).astype(np.uint8) * 255
58
+
59
+ mask_resized = cv2.resize(binary_mask, (image.shape[1], image.shape[0]))
60
+
61
+ overlay = np.zeros_like(image)
62
+ overlay[mask_resized > 0] = [255, 0, 0]
63
+ blended = cv2.addWeighted(image, 1, overlay, overlay_opacity, 0)
64
+
65
+ analysis_text = (
66
+ f"COVID-19 Detection Results:\n"
67
+ f"• Infection detected with {confidence:.1f}% confidence\n"
68
+ f"• Red overlay indicates areas of potential COVID-19 infection\n"
69
+ f"• Recommended: Seek immediate medical attention"
70
+ )
71
+ return prediction, confidence, blended, analysis_text
72
+
73
+ elif prediction == 'Non-COVID':
74
+ analysis_text = (
75
+ f"Non-COVID Lung Condition Detected:\n"
76
+ f"• Confidence: {confidence:.1f}%\n"
77
+ f"• Other lung abnormalities as pneumonia or lungs enlargement should be considered for further treatment\n"
78
+ f"• Recommended: Consult with healthcare provider for proper diagnosis"
79
+ )
80
+ return prediction, confidence, None, analysis_text
81
+
82
+ else:
83
+ analysis_text = (
84
+ f"Healthy Lung Scan Results:\n"
85
+ f"• Confidence: {confidence:.1f}%\n"
86
+ f"• No significant abnormalities detected :)\n"
87
+ f"• Regular check-ups and an apple a day is recommended"
88
+ )
89
+ return prediction, confidence, None, analysis_text
utils/helper.py ADDED
File without changes
weights/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore