ronithsharmila commited on
Commit
fdaab4a
·
verified ·
1 Parent(s): e6b3256

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -17
app.py CHANGED
@@ -1,36 +1,44 @@
1
  import gradio as gr
2
- from PIL import Image
3
- from model import load_model
4
- from torchvision import transforms
5
  import torch
 
 
 
6
 
7
- # Load your model
8
- model = load_model('best_model_augmented.pth')
9
-
10
- # Define the transformations for input images
11
  transform = transforms.Compose([
12
  transforms.Resize((224, 224)),
13
  transforms.ToTensor(),
14
  ])
15
 
16
- # Define the class labels
 
 
 
17
  class_names = ['Normal', 'Monkeypox', 'Chickenpox', 'Measles']
18
 
19
  def predict(image):
20
  # Preprocess the image
21
- image = transform(image).unsqueeze(0) # Add batch dimension
 
 
 
 
22
  with torch.no_grad():
 
23
  outputs = model(image)
24
- _, predicted = torch.max(outputs, 1)
25
- return class_names[predicted.item()]
 
 
26
 
27
- # Create Gradio interface
28
  iface = gr.Interface(
29
- fn=predict,
30
- inputs=gr.inputs.Image(type="pil", label="Upload an Image"),
31
- outputs=gr.outputs.Label(num_top_classes=4, label="Prediction"),
32
- live=True
 
33
  )
34
 
35
- # Launch the interface
36
  iface.launch()
 
1
  import gradio as gr
 
 
 
2
  import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from model import CustomModel, load_model
6
 
7
+ # Define the transformation
 
 
 
8
  transform = transforms.Compose([
9
  transforms.Resize((224, 224)),
10
  transforms.ToTensor(),
11
  ])
12
 
13
+ # Load the model
14
+ model = load_model('best_model_augmented.pth')
15
+
16
+ # Define the class names
17
  class_names = ['Normal', 'Monkeypox', 'Chickenpox', 'Measles']
18
 
19
  def predict(image):
20
  # Preprocess the image
21
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
22
+ image = transform(image)
23
+ image = image.unsqueeze(0) # Add batch dimension
24
+
25
+ # Run inference
26
  with torch.no_grad():
27
+ model.eval()
28
  outputs = model(image)
29
+ _, preds = torch.max(outputs, 1)
30
+ predicted_class = class_names[preds.item()]
31
+
32
+ return predicted_class
33
 
34
+ # Define the Gradio interface
35
  iface = gr.Interface(
36
+ fn=predict,
37
+ inputs=gr.inputs.Image(type="numpy", label="Upload Image"),
38
+ outputs=gr.outputs.Label(num_top_classes=1, label="Predicted Class"),
39
+ title="Monkeypox Classifier",
40
+ description="Upload an image of skin lesions to classify the disease."
41
  )
42
 
43
+ # Launch the app
44
  iface.launch()