ronithsharmila commited on
Commit
0326be8
·
verified ·
1 Parent(s): 4e6d175

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -47
app.py CHANGED
@@ -1,63 +1,67 @@
1
  import gradio as gr
2
  import torch
3
- import torch.nn as nn
4
- from torchvision import transforms
5
  from PIL import Image
 
6
 
7
- # Define your model class
8
- class CustomModel(nn.Module):
9
- def __init__(self):
10
- super(CustomModel, self).__init__()
11
- # Define your layers here (example)
12
- self.fc = nn.Linear(512, 2)
13
-
14
- def forward(self, x):
15
- x = self.fc(x)
16
- return x
17
-
18
- # Function to load your model
19
- def load_model(model_path):
20
- model = CustomModel()
21
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
22
- model.eval()
23
- return model
24
-
25
- # Function to preprocess the image
26
- def preprocess_image(image):
27
- transform = transforms.Compose([
28
- transforms.Resize((224, 224)),
29
- transforms.ToTensor(),
30
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
31
- ])
32
  image = Image.fromarray(image)
33
  image = transform(image).unsqueeze(0)
34
- return image
35
-
36
- # Function to predict using the model
37
- def predict(image):
38
- image = preprocess_image(image)
39
  with torch.no_grad():
40
- output = model(image)
41
- prediction = torch.argmax(output, dim=1).item()
42
- class_names = ['Monkeypox', 'Healthy']
43
- return class_names[prediction]
 
44
 
45
- # Load the model
46
- model = load_model('best_model_augmented.pth')
 
 
 
 
 
47
 
48
- # Create the Gradio interface
49
- title = "<h1 style='text-align: center; margin-bottom: 10px;'>Pox Classifier</h1>"
50
- class_names_text = "<h3 style='text-align: center; margin-top: -10px;'>[Monkeypox, Healthy]</h3>"
51
- twitter_link = "<p style='text-align: center;'><a href='https://x.com/ronith_sharmila' target='_blank'>@your_twitter_handle</a></p>"
52
- support_link = "<p style='text-align: center;'><a href='https://paypal.me/ronithsharmila?country.x=US&locale.x=en_US' target='_blank'>Support the work <strong>here</strong></a></p>"
53
- disclaimer = "<p style='text-align: center; color: gray;'>Disclaimer: This model is for educational purposes only and should not be used for medical diagnosis.</p>"
 
 
 
 
 
 
 
54
 
 
55
  app = gr.Interface(
56
- fn=predict,
57
  inputs=gr.Image(type="numpy", label="Upload Image"),
58
- outputs=gr.Label(label="Prediction"),
59
  title=title,
60
- description=class_names_text + twitter_link + support_link + disclaimer
 
61
  )
62
 
 
63
  app.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import torchvision.transforms as transforms
 
4
  from PIL import Image
5
+ from model import load_model
6
 
7
+ # Load model
8
+ model_path = 'best_model_augmented.pth'
9
+ model = load_model(model_path)
10
+
11
+ # Define preprocessing transformations
12
+ transform = transforms.Compose([
13
+ transforms.Resize((224, 224)),
14
+ transforms.ToTensor(),
15
+ ])
16
+
17
+ # Define the class names
18
+ class_names = ['Normal', 'Monkeypox', 'Chickenpox', 'Measles']
19
+
20
+ def predict_image(image):
21
+ # Preprocess the image
 
 
 
 
 
 
 
 
 
 
22
  image = Image.fromarray(image)
23
  image = transform(image).unsqueeze(0)
24
+
25
+ # Perform inference
26
+ model.eval()
 
 
27
  with torch.no_grad():
28
+ outputs = model(image)
29
+ _, predicted = torch.max(outputs, 1)
30
+ predicted_class = class_names[predicted.item()]
31
+
32
+ return predicted_class
33
 
34
+ # HTML for custom styling
35
+ title = (
36
+ "<h1 style='text-align: center; font-family: Arial, sans-serif;'>Pox Classifier</h1>"
37
+ "<p style='text-align: center; font-family: Arial, sans-serif;'>"
38
+ "<a href='https://x.com/ronith_sharmila' target='_blank'>@your_twitter_handle</a>"
39
+ "</p>"
40
+ )
41
 
42
+ description = (
43
+ "<p style='text-align: center; font-family: Arial, sans-serif;'>"
44
+ "Upload an image of a skin lesion to classify it into one of the following categories:</p>"
45
+ "<p style='text-align: center; font-family: Arial, sans-serif; font-weight: bold;'>"
46
+ "Normal, Monkeypox, Chickenpox, Measles</p>"
47
+ "<p style='text-align: center; font-family: Arial, sans-serif;'>"
48
+ "Please note that this model is not a substitute for professional medical advice, diagnosis, or treatment. "
49
+ "Always seek the advice of your physician or other qualified health provider with any questions you may have regarding a medical condition."
50
+ "</p>"
51
+ "<p style='text-align: center; font-family: Arial, sans-serif;'>"
52
+ "<a href='https://paypal.me/ronithsharmila?country.x=US&locale.x=en_US' target='_blank'>Support my work here</a>"
53
+ "</p>"
54
+ )
55
 
56
+ # Gradio app interface
57
  app = gr.Interface(
58
+ fn=predict_image,
59
  inputs=gr.Image(type="numpy", label="Upload Image"),
60
+ outputs=gr.Textbox(label="Prediction"),
61
  title=title,
62
+ description=description,
63
+ theme="default" # Use the default theme or specify a different theme if available
64
  )
65
 
66
+ # Launch the app
67
  app.launch()