Hameed0342j commited on
Commit
7bd18af
·
verified ·
1 Parent(s): be236a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -13
app.py CHANGED
@@ -1,25 +1,44 @@
1
  import gradio as gr
2
  import torch
3
- from FlowerClassificationModel import FlowerClassificationModel
 
 
4
 
5
  # Load the model
6
- model = FlowerClassificationModel()
7
- model.load_state_dict(torch.load("flower_classification_model.pth"))
8
- model.eval()
9
 
 
 
 
 
 
 
 
 
10
  def classify_flower(image):
11
- # Add your image preprocessing logic here
12
- # Perform prediction with the model
13
- prediction = model(image) # Assuming this is how your model works
14
- return prediction
 
 
 
 
 
 
 
 
15
 
16
- # Gradio Interface
17
- interface = gr.Interface(
18
  fn=classify_flower,
19
  inputs="image",
20
- outputs="label",
21
  title="Flower Classification",
22
- description="Upload an image to classify the flower."
23
  )
24
 
25
- interface.launch()
 
 
1
  import gradio as gr
2
  import torch
3
+ from FlowerClassificationModel import FlowerClassificationModel # Replace with your model's class name
4
+ from torchvision import transforms
5
+ from PIL import Image
6
 
7
  # Load the model
8
+ model = FlowerClassificationModel() # Instantiate your model
9
+ model.load_state_dict(torch.load("flower_classification_model.pth", map_location=torch.device('cpu')))
10
+ model.eval() # Set the model to evaluation mode
11
 
12
+ # Define image preprocessing
13
+ preprocess = transforms.Compose([
14
+ transforms.Resize((224, 224)), # Adjust to your model's input size
15
+ transforms.ToTensor(),
16
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
17
+ ])
18
+
19
+ # Define the prediction function
20
  def classify_flower(image):
21
+ # Preprocess the input image
22
+ image = Image.fromarray(image) # Convert NumPy array to PIL Image
23
+ input_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
24
+
25
+ # Perform prediction
26
+ with torch.no_grad():
27
+ output = model(input_tensor)
28
+ _, predicted = torch.max(output, 1)
29
+
30
+ # Map prediction to class label
31
+ labels = ["Class1", "Class2", "Class3", "Class4", "Class5"] # Replace with your actual class names
32
+ return labels[predicted.item()]
33
 
34
+ # Create the Gradio interface
35
+ demo = gr.Interface(
36
  fn=classify_flower,
37
  inputs="image",
38
+ outputs="text",
39
  title="Flower Classification",
40
+ description="Upload an image to classify the flower type."
41
  )
42
 
43
+ # Launch the app
44
+ demo.launch()