Abhiyoshi commited on
Commit
f5fafcc
·
verified ·
1 Parent(s): 1b8cc7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -206
app.py CHANGED
@@ -1,207 +1,207 @@
1
- import gradio as gr
2
- import torch
3
- import torch.nn as nn
4
- from torchvision import transforms
5
- from PIL import Image
6
- import cv2
7
- import numpy as np
8
- import requests
9
- import os
10
- from typing import Tuple, Dict
11
-
12
- # CustomViT model definition
13
- class PatchEmbedding(nn.Module):
14
- def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
15
- super().__init__()
16
- self.img_size = img_size
17
- self.patch_size = patch_size
18
- self.n_patches = (img_size // patch_size) ** 2
19
- self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
20
-
21
- def forward(self, x):
22
- x = self.proj(x)
23
- x = x.flatten(2)
24
- x = x.transpose(1, 2)
25
- return x
26
-
27
- class Attention(nn.Module):
28
- def __init__(self, dim, n_heads=12, qkv_bias=True, attn_drop=0., proj_drop=0.):
29
- super().__init__()
30
- self.n_heads = n_heads
31
- self.scale = (dim // n_heads) ** -0.5
32
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
33
- self.attn_drop = nn.Dropout(attn_drop)
34
- self.proj = nn.Linear(dim, dim)
35
- self.proj_drop = nn.Dropout(proj_drop)
36
-
37
- def forward(self, x):
38
- B, N, C = x.shape
39
- qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, C // self.n_heads).permute(2, 0, 3, 1, 4)
40
- q, k, v = qkv.unbind(0)
41
- attn = (q @ k.transpose(-2, -1)) * self.scale
42
- attn = attn.softmax(dim=-1)
43
- attn = self.attn_drop(attn)
44
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
45
- x = self.proj(x)
46
- x = self.proj_drop(x)
47
- return x
48
-
49
- class TransformerBlock(nn.Module):
50
- def __init__(self, dim, n_heads, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.):
51
- super().__init__()
52
- self.norm1 = nn.LayerNorm(dim)
53
- self.attn = Attention(dim, n_heads=n_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
54
- self.norm2 = nn.LayerNorm(dim)
55
- mlp_hidden_dim = int(dim * mlp_ratio)
56
- self.mlp = nn.Sequential(
57
- nn.Linear(dim, mlp_hidden_dim),
58
- nn.GELU(),
59
- nn.Dropout(drop),
60
- nn.Linear(mlp_hidden_dim, dim),
61
- nn.Dropout(drop)
62
- )
63
-
64
- def forward(self, x):
65
- x = x + self.attn(self.norm1(x))
66
- x = x + self.mlp(self.norm2(x))
67
- return x
68
-
69
- class CustomViT(nn.Module):
70
- def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=2, embed_dim=768, depth=12, n_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.):
71
- super().__init__()
72
- self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
73
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
74
- self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embed.n_patches, embed_dim))
75
- self.pos_drop = nn.Dropout(p=drop_rate)
76
- self.blocks = nn.ModuleList([
77
- TransformerBlock(embed_dim, n_heads, mlp_ratio, qkv_bias, drop_rate, drop_rate)
78
- for _ in range(depth)
79
- ])
80
- self.norm = nn.LayerNorm(embed_dim)
81
- self.head = nn.Linear(embed_dim, num_classes)
82
-
83
- def forward(self, x):
84
- B = x.shape[0]
85
- x = self.patch_embed(x)
86
- cls_tokens = self.cls_token.expand(B, -1, -1)
87
- x = torch.cat((cls_tokens, x), dim=1)
88
- x = x + self.pos_embed
89
- x = self.pos_drop(x)
90
- for block in self.blocks:
91
- x = block(x)
92
- x = self.norm(x)
93
- x = x[:, 0]
94
- x = self.head(x)
95
- return x
96
-
97
- # Helper functions
98
- def load_model(model_path: str) -> CustomViT:
99
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
100
- model = CustomViT(num_classes=2)
101
- state_dict = torch.load(model_path, map_location=device)
102
-
103
- # Remove 'module.' prefix if present
104
- if all(k.startswith('module.') for k in state_dict.keys()):
105
- state_dict = {k[7:]: v for k, v in state_dict.items()}
106
-
107
- model.load_state_dict(state_dict)
108
- model.to(device)
109
- model.eval()
110
- return model
111
-
112
- def preprocess_image(image: np.ndarray) -> torch.Tensor:
113
- # Convert numpy array to PIL Image
114
- if isinstance(image, np.ndarray):
115
- image = Image.fromarray(image)
116
-
117
- transform = transforms.Compose([
118
- transforms.Resize((224, 224)),
119
- transforms.ToTensor(),
120
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
121
- ])
122
- return transform(image).unsqueeze(0)
123
-
124
- def predict_image(image: np.ndarray, model: CustomViT) -> Tuple[np.ndarray, Dict[str, float]]:
125
- device = next(model.parameters()).device
126
-
127
- # Preprocess the image
128
- image_tensor = preprocess_image(image)
129
-
130
- # Make prediction
131
- with torch.no_grad():
132
- outputs = model(image_tensor.to(device))
133
- probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
134
-
135
- # Create visualization
136
- visualization = image.copy()
137
- height, width = visualization.shape[:2]
138
-
139
- # Add prediction overlay
140
- result = "Leprosy" if probabilities[0] > probabilities[1] else "No Leprosy"
141
- confidence = float(probabilities[0] if result == "Leprosy" else probabilities[1])
142
-
143
- # Add text to image
144
- color = (0, 0, 255) if result == "Leprosy" else (0, 255, 0)
145
- cv2.putText(visualization, f"{result}: {confidence:.2%}",
146
- (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
147
-
148
- # Convert BGR to RGB for Gradio
149
- visualization = cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB)
150
-
151
- # Prepare labels dictionary
152
- labels = {
153
- "Leprosy": float(probabilities[0]),
154
- "No Leprosy": float(probabilities[1])
155
- }
156
-
157
- return visualization, labels
158
-
159
- # Download example images
160
- file_urls = [
161
- 'https://www.dropbox.com/scl/fi/onrg1u9tqegh64nsfmxgr/lp2.jpg?rlkey=2vgw5n6abqmyismg16mdd1v3n&dl=1',
162
- 'https://www.dropbox.com/scl/fi/xq103ic7ovuuei3l9e8jf/lp1.jpg?rlkey=g7d9khyyc6wplv0ljd4mcha60&dl=1',
163
- 'https://www.dropbox.com/scl/fi/fagkh3gnio2pefdje7fb9/Non_Leprosy_210823_86_jpg.rf.5bb80a7704ecc6c8615574cad5d074c5.jpg?rlkey=ks8afue5gsx5jqvxj3u9mbjmg&dl=1',
164
- ]
165
-
166
- def download_example_images():
167
- examples = []
168
- for i, url in enumerate(file_urls):
169
- filename = f"example_{i}.jpg"
170
- if not os.path.exists(filename):
171
- response = requests.get(url)
172
- with open(filename, 'wb') as f:
173
- f.write(response.content)
174
- examples.append(filename)
175
- return examples
176
-
177
- # Main Gradio interface
178
- def create_gradio_interface():
179
- # Load the model
180
- model = load_model('best_custom_vit_mo50.pth')
181
-
182
- # Create inference function
183
- def inference(image):
184
- return predict_image(image, model)
185
-
186
- # Download example images
187
- examples = download_example_images()
188
-
189
- # Create Gradio interface
190
- interface = gr.Interface(
191
- fn=inference,
192
- inputs=gr.Image(),
193
- outputs=[
194
- gr.Image(label="Prediction Visualization"),
195
- gr.Label(label="Classification Probabilities")
196
- ],
197
- title="Leprosy Detection using Vision Transformer",
198
- description="Upload an image to detect signs of leprosy using a Vision Transformer model.",
199
- examples=examples,
200
- cache_examples=False
201
- )
202
-
203
- return interface
204
-
205
- if __name__ == "__main__":
206
- interface = create_gradio_interface()
207
  interface.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import cv2
7
+ import numpy as np
8
+ import requests
9
+ import os
10
+ from typing import Tuple, Dict
11
+
12
+ # CustomViT model definition
13
+ class PatchEmbedding(nn.Module):
14
+ def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
15
+ super().__init__()
16
+ self.img_size = img_size
17
+ self.patch_size = patch_size
18
+ self.n_patches = (img_size // patch_size) ** 2
19
+ self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
20
+
21
+ def forward(self, x):
22
+ x = self.proj(x)
23
+ x = x.flatten(2)
24
+ x = x.transpose(1, 2)
25
+ return x
26
+
27
+ class Attention(nn.Module):
28
+ def __init__(self, dim, n_heads=12, qkv_bias=True, attn_drop=0., proj_drop=0.):
29
+ super().__init__()
30
+ self.n_heads = n_heads
31
+ self.scale = (dim // n_heads) ** -0.5
32
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
33
+ self.attn_drop = nn.Dropout(attn_drop)
34
+ self.proj = nn.Linear(dim, dim)
35
+ self.proj_drop = nn.Dropout(proj_drop)
36
+
37
+ def forward(self, x):
38
+ B, N, C = x.shape
39
+ qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, C // self.n_heads).permute(2, 0, 3, 1, 4)
40
+ q, k, v = qkv.unbind(0)
41
+ attn = (q @ k.transpose(-2, -1)) * self.scale
42
+ attn = attn.softmax(dim=-1)
43
+ attn = self.attn_drop(attn)
44
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
45
+ x = self.proj(x)
46
+ x = self.proj_drop(x)
47
+ return x
48
+
49
+ class TransformerBlock(nn.Module):
50
+ def __init__(self, dim, n_heads, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.):
51
+ super().__init__()
52
+ self.norm1 = nn.LayerNorm(dim)
53
+ self.attn = Attention(dim, n_heads=n_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
54
+ self.norm2 = nn.LayerNorm(dim)
55
+ mlp_hidden_dim = int(dim * mlp_ratio)
56
+ self.mlp = nn.Sequential(
57
+ nn.Linear(dim, mlp_hidden_dim),
58
+ nn.GELU(),
59
+ nn.Dropout(drop),
60
+ nn.Linear(mlp_hidden_dim, dim),
61
+ nn.Dropout(drop)
62
+ )
63
+
64
+ def forward(self, x):
65
+ x = x + self.attn(self.norm1(x))
66
+ x = x + self.mlp(self.norm2(x))
67
+ return x
68
+
69
+ class CustomViT(nn.Module):
70
+ def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=2, embed_dim=768, depth=12, n_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.):
71
+ super().__init__()
72
+ self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
73
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
74
+ self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embed.n_patches, embed_dim))
75
+ self.pos_drop = nn.Dropout(p=drop_rate)
76
+ self.blocks = nn.ModuleList([
77
+ TransformerBlock(embed_dim, n_heads, mlp_ratio, qkv_bias, drop_rate, drop_rate)
78
+ for _ in range(depth)
79
+ ])
80
+ self.norm = nn.LayerNorm(embed_dim)
81
+ self.head = nn.Linear(embed_dim, num_classes)
82
+
83
+ def forward(self, x):
84
+ B = x.shape[0]
85
+ x = self.patch_embed(x)
86
+ cls_tokens = self.cls_token.expand(B, -1, -1)
87
+ x = torch.cat((cls_tokens, x), dim=1)
88
+ x = x + self.pos_embed
89
+ x = self.pos_drop(x)
90
+ for block in self.blocks:
91
+ x = block(x)
92
+ x = self.norm(x)
93
+ x = x[:, 0]
94
+ x = self.head(x)
95
+ return x
96
+
97
+ # Helper functions
98
+ def load_model(model_path: str) -> CustomViT:
99
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
100
+ model = CustomViT(num_classes=2)
101
+ state_dict = torch.load(model_path, map_location=device)
102
+
103
+ # Remove 'module.' prefix if present
104
+ if all(k.startswith('module.') for k in state_dict.keys()):
105
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
106
+
107
+ model.load_state_dict(state_dict)
108
+ model.to(device)
109
+ model.eval()
110
+ return model
111
+
112
+ def preprocess_image(image: np.ndarray) -> torch.Tensor:
113
+ # Convert numpy array to PIL Image
114
+ if isinstance(image, np.ndarray):
115
+ image = Image.fromarray(image)
116
+
117
+ transform = transforms.Compose([
118
+ transforms.Resize((224, 224)),
119
+ transforms.ToTensor(),
120
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
121
+ ])
122
+ return transform(image).unsqueeze(0)
123
+
124
+ def predict_image(image: np.ndarray, model: CustomViT) -> Tuple[np.ndarray, Dict[str, float]]:
125
+ device = next(model.parameters()).device
126
+
127
+ # Preprocess the image
128
+ image_tensor = preprocess_image(image)
129
+
130
+ # Make prediction
131
+ with torch.no_grad():
132
+ outputs = model(image_tensor.to(device))
133
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
134
+
135
+ # Create visualization
136
+ visualization = image.copy()
137
+ height, width = visualization.shape[:2]
138
+
139
+ # Add prediction overlay
140
+ result = "Leprosy" if probabilities[0] > probabilities[1] else "No Leprosy"
141
+ confidence = float(probabilities[0] if result == "Leprosy" else probabilities[1])
142
+
143
+ # Add text to image
144
+ color = (0, 0, 255) if result == "Leprosy" else (0, 255, 0)
145
+ cv2.putText(visualization, f"{result}: {confidence:.2%}",
146
+ (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
147
+
148
+ # Convert BGR to RGB for Gradio
149
+ visualization = cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB)
150
+
151
+ # Prepare labels dictionary
152
+ labels = {
153
+ "Leprosy": float(probabilities[0]),
154
+ "No Leprosy": float(probabilities[1])
155
+ }
156
+
157
+ return visualization, labels
158
+
159
+ # Download example images
160
+ file_urls = [
161
+ 'https://www.dropbox.com/scl/fi/onrg1u9tqegh64nsfmxgr/lp2.jpg?rlkey=2vgw5n6abqmyismg16mdd1v3n&dl=1',
162
+ 'https://www.dropbox.com/scl/fi/xq103ic7ovuuei3l9e8jf/lp1.jpg?rlkey=g7d9khyyc6wplv0ljd4mcha60&dl=1',
163
+ 'https://www.dropbox.com/scl/fi/fagkh3gnio2pefdje7fb9/Non_Leprosy_210823_86_jpg.rf.5bb80a7704ecc6c8615574cad5d074c5.jpg?rlkey=ks8afue5gsx5jqvxj3u9mbjmg&dl=1',
164
+ ]
165
+
166
+ def download_example_images():
167
+ examples = []
168
+ for i, url in enumerate(file_urls):
169
+ filename = f"example_{i}.jpg"
170
+ if not os.path.exists(filename):
171
+ response = requests.get(url)
172
+ with open(filename, 'wb') as f:
173
+ f.write(response.content)
174
+ examples.append(filename)
175
+ return examples
176
+
177
+ # Main Gradio interface
178
+ def create_gradio_interface():
179
+ # Load the model
180
+ model = load_model('best_custom_vit_mo50.pth')
181
+
182
+ # Create inference function
183
+ def inference(image):
184
+ return predict_image(image, model)
185
+
186
+ # Download example images
187
+ examples = download_example_images()
188
+
189
+ # Create Gradio interface
190
+ interface = gr.Interface(
191
+ fn=inference,
192
+ inputs=gr.Image(),
193
+ outputs=[
194
+ gr.Image(label="Prediction Visualization"),
195
+ gr.Label(label="Classification Probabilities")
196
+ ],
197
+ title="Leprosy Detection",
198
+ description="Upload an image to detect signs of leprosy.",
199
+ examples=examples,
200
+ cache_examples=False
201
+ )
202
+
203
+ return interface
204
+
205
+ if __name__ == "__main__":
206
+ interface = create_gradio_interface()
207
  interface.launch()