ronithsharmila commited on
Commit
9cfef3d
·
verified ·
1 Parent(s): bab77b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -23
app.py CHANGED
@@ -1,44 +1,44 @@
1
  import gradio as gr
2
  import torch
3
- from torchvision import transforms
4
  from PIL import Image
5
- from model import CustomModel, load_model
6
 
7
- # Define the transformation
 
 
 
 
8
  transform = transforms.Compose([
9
  transforms.Resize((224, 224)),
10
  transforms.ToTensor(),
11
  ])
12
 
13
- # Load the model
14
- model = load_model('best_model_augmented.pth')
15
-
16
  # Define the class names
17
  class_names = ['Normal', 'Monkeypox', 'Chickenpox', 'Measles']
18
 
19
- def predict(image):
20
  # Preprocess the image
21
- image = Image.fromarray(image.astype('uint8'), 'RGB')
22
- image = transform(image)
23
- image = image.unsqueeze(0) # Add batch dimension
24
-
25
- # Run inference
26
  with torch.no_grad():
27
- model.eval()
28
  outputs = model(image)
29
- _, preds = torch.max(outputs, 1)
30
- predicted_class = class_names[preds.item()]
31
-
32
  return predicted_class
33
 
34
- # Define the Gradio interface using the updated API
35
- iface = gr.Interface(
36
- fn=predict,
37
  inputs=gr.Image(type="numpy", label="Upload Image"),
38
- outputs=gr.Label(num_top_classes=1, label="Predicted Class"),
39
- title="Monkeypox Classifier",
40
- description="Upload an image of skin lesions to classify the disease."
41
  )
42
 
43
  # Launch the app
44
- iface.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import torchvision.transforms as transforms
4
  from PIL import Image
5
+ from model import load_model
6
 
7
+ # Load model
8
+ model_path = 'best_model_augmented.pth'
9
+ model = load_model(model_path)
10
+
11
+ # Define preprocessing transformations
12
  transform = transforms.Compose([
13
  transforms.Resize((224, 224)),
14
  transforms.ToTensor(),
15
  ])
16
 
 
 
 
17
  # Define the class names
18
  class_names = ['Normal', 'Monkeypox', 'Chickenpox', 'Measles']
19
 
20
+ def predict_image(image):
21
  # Preprocess the image
22
+ image = Image.fromarray(image)
23
+ image = transform(image).unsqueeze(0)
24
+
25
+ # Perform inference
26
+ model.eval()
27
  with torch.no_grad():
 
28
  outputs = model(image)
29
+ _, predicted = torch.max(outputs, 1)
30
+ predicted_class = class_names[predicted.item()]
31
+
32
  return predicted_class
33
 
34
+ # Gradio app interface
35
+ app = gr.Interface(
36
+ fn=predict_image,
37
  inputs=gr.Image(type="numpy", label="Upload Image"),
38
+ outputs=gr.Textbox(label="Prediction"),
39
+ title="Pox Classifier (Normal, Monkeypox, Chickenpox, Measles)",
40
+ description="Upload an image of a skin lesion to classify it into one of the following categories: Normal, Monkeypox, Chickenpox, or Measles. Please note that this model is not a substitute for professional medical advice, diagnosis, or treatment. Always seek the advice of your physician or other qualified health provider with any questions you may have regarding a medical condition."
41
  )
42
 
43
  # Launch the app
44
+ app.launch()