ronithsharmila commited on
Commit
4e6d175
·
verified ·
1 Parent(s): 723f223

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -41
app.py CHANGED
@@ -1,58 +1,63 @@
1
  import gradio as gr
2
  import torch
3
- import torchvision.transforms as transforms
 
4
  from PIL import Image
5
- from model import load_model
6
 
7
- # Load model
8
- model_path = 'best_model_augmented.pth'
9
- model = load_model(model_path)
 
 
 
10
 
11
- # Define preprocessing transformations
12
- transform = transforms.Compose([
13
- transforms.Resize((224, 224)),
14
- transforms.ToTensor(),
15
- ])
16
 
17
- # Define the class names
18
- class_names = ['Normal', 'Monkeypox', 'Chickenpox', 'Measles']
 
 
 
 
19
 
20
- def predict_image(image):
21
- # Preprocess the image
 
 
 
 
 
22
  image = Image.fromarray(image)
23
  image = transform(image).unsqueeze(0)
24
-
25
- # Perform inference
26
- model.eval()
 
 
27
  with torch.no_grad():
28
- outputs = model(image)
29
- _, predicted = torch.max(outputs, 1)
30
- predicted_class = class_names[predicted.item()]
31
-
32
- return predicted_class
33
-
34
- # HTML for custom styling
35
- title = "<h1 style='text-align: center; font-family: Arial, sans-serif;'>PoxNet</h1>"
36
- description = (
37
- "<p style='text-align: center; font-family: Arial, sans-serif;'>"
38
- "Upload an image of a skin lesion to classify it into one of the following categories:</p>"
39
- "<p style='text-align: center; font-family: Arial, sans-serif; font-weight: bold;'>"
40
- "Normal, Monkeypox, Chickenpox, Measles</p>"
41
- "<p style='text-align: center; font-family: Arial, sans-serif;'>"
42
- "Please note that this model is not a substitute for professional medical advice, diagnosis, or treatment. "
43
- "Always seek the advice of your physician or other qualified health provider with any questions you may have regarding a medical condition."
44
- "</p>"
45
- )
46
 
47
- # Gradio app interface
48
  app = gr.Interface(
49
- fn=predict_image,
50
  inputs=gr.Image(type="numpy", label="Upload Image"),
51
- outputs=gr.Textbox(label="Prediction"),
52
  title=title,
53
- description=description,
54
- theme="default" # Use the default theme or specify a different theme if available
55
  )
56
 
57
- # Launch the app
58
  app.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
  from PIL import Image
 
6
 
7
+ # Define your model class
8
+ class CustomModel(nn.Module):
9
+ def __init__(self):
10
+ super(CustomModel, self).__init__()
11
+ # Define your layers here (example)
12
+ self.fc = nn.Linear(512, 2)
13
 
14
+ def forward(self, x):
15
+ x = self.fc(x)
16
+ return x
 
 
17
 
18
+ # Function to load your model
19
+ def load_model(model_path):
20
+ model = CustomModel()
21
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
22
+ model.eval()
23
+ return model
24
 
25
+ # Function to preprocess the image
26
+ def preprocess_image(image):
27
+ transform = transforms.Compose([
28
+ transforms.Resize((224, 224)),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
31
+ ])
32
  image = Image.fromarray(image)
33
  image = transform(image).unsqueeze(0)
34
+ return image
35
+
36
+ # Function to predict using the model
37
+ def predict(image):
38
+ image = preprocess_image(image)
39
  with torch.no_grad():
40
+ output = model(image)
41
+ prediction = torch.argmax(output, dim=1).item()
42
+ class_names = ['Monkeypox', 'Healthy']
43
+ return class_names[prediction]
44
+
45
+ # Load the model
46
+ model = load_model('best_model_augmented.pth')
47
+
48
+ # Create the Gradio interface
49
+ title = "<h1 style='text-align: center; margin-bottom: 10px;'>Pox Classifier</h1>"
50
+ class_names_text = "<h3 style='text-align: center; margin-top: -10px;'>[Monkeypox, Healthy]</h3>"
51
+ twitter_link = "<p style='text-align: center;'><a href='https://x.com/ronith_sharmila' target='_blank'>@your_twitter_handle</a></p>"
52
+ support_link = "<p style='text-align: center;'><a href='https://paypal.me/ronithsharmila?country.x=US&locale.x=en_US' target='_blank'>Support the work <strong>here</strong></a></p>"
53
+ disclaimer = "<p style='text-align: center; color: gray;'>Disclaimer: This model is for educational purposes only and should not be used for medical diagnosis.</p>"
 
 
 
 
54
 
 
55
  app = gr.Interface(
56
+ fn=predict,
57
  inputs=gr.Image(type="numpy", label="Upload Image"),
58
+ outputs=gr.Label(label="Prediction"),
59
  title=title,
60
+ description=class_names_text + twitter_link + support_link + disclaimer
 
61
  )
62
 
 
63
  app.launch()