dwililiya commited on
Commit
a9f69a2
·
verified ·
1 Parent(s): d62cf84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -13
app.py CHANGED
@@ -1,31 +1,42 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
  from torchvision import transforms
4
  from PIL import Image
 
5
 
6
- # Load the model using Hugging Face pipeline
7
  MODEL_NAME = "dwililiya/sugarcane-plant-diseases-classification"
8
- classifier = pipeline("image-classification", model=MODEL_NAME)
 
9
 
10
- # Define class names based on your dataset
 
 
 
 
 
 
11
  class_names = ['Bacterial Blight', 'Healthy', 'Mosaic', 'Red Rot', 'Rust', 'Yellow']
12
 
13
  def predict(image):
14
- # Use the classifier to predict
15
- predictions = classifier(image)
16
-
17
- # Get the predicted class and confidence score
18
- predicted_class = predictions[0]['label']
19
- confidence = predictions[0]['score']
 
 
 
20
 
21
  return predicted_class, confidence
22
 
23
  # Gradio interface
24
  iface = gr.Interface(
25
  fn=predict,
26
- inputs=gr.inputs.Image(type="file", label="Upload Sugarcane Leaf Image"),
27
- outputs=[gr.outputs.Label(num_top_classes=1, label="Predicted Class"),
28
- gr.outputs.Textbox(label="Confidence Score")],
29
  title="Sugarcane Plant Diseases Classification",
30
  description="Upload an image of a sugarcane leaf to classify its disease.",
31
  )
 
1
  import gradio as gr
2
+ from transformers import AutoModelForImageClassification, AutoConfig
3
  from torchvision import transforms
4
  from PIL import Image
5
+ import torch
6
 
7
+ # Load the model
8
  MODEL_NAME = "dwililiya/sugarcane-plant-diseases-classification"
9
+ config = AutoConfig.from_pretrained(MODEL_NAME)
10
+ model = AutoModelForImageClassification.from_pretrained(MODEL_NAME, config=config)
11
 
12
+ # Define a transform to prepare the image
13
+ transform = transforms.Compose([
14
+ transforms.Resize((256, 256)),
15
+ transforms.ToTensor(),
16
+ ])
17
+
18
+ # Define class names
19
  class_names = ['Bacterial Blight', 'Healthy', 'Mosaic', 'Red Rot', 'Rust', 'Yellow']
20
 
21
  def predict(image):
22
+ # Transform the image
23
+ image = transform(image).unsqueeze(0) # Add batch dimension
24
+
25
+ # Perform inference
26
+ with torch.no_grad():
27
+ outputs = model(image)
28
+ _, predicted = torch.max(outputs.logits, 1)
29
+ predicted_class = class_names[predicted.item()]
30
+ confidence = torch.softmax(outputs.logits, dim=1)[0][predicted].item()
31
 
32
  return predicted_class, confidence
33
 
34
  # Gradio interface
35
  iface = gr.Interface(
36
  fn=predict,
37
+ inputs=gr.Image(type="pil", label="Upload Sugarcane Leaf Image"), # Change to 'pil'
38
+ outputs=[gr.Label(num_top_classes=1, label="Predicted Class"),
39
+ gr.Textbox(label="Confidence Score")],
40
  title="Sugarcane Plant Diseases Classification",
41
  description="Upload an image of a sugarcane leaf to classify its disease.",
42
  )