Jacksonnavigator7 commited on
Commit
ecfbb2d
·
verified ·
1 Parent(s): cbebaf5

Rename app to app.py

Browse files
Files changed (2) hide show
  1. app +0 -40
  2. app.py +46 -0
app DELETED
@@ -1,40 +0,0 @@
1
- import gradio as gr
2
- from PIL import Image
3
- import torch
4
- import pickle
5
- import torchvision.transforms as transforms
6
-
7
- # Load the pickled model
8
- model_path = "birds_classifier.pkl"
9
- try:
10
- with open(model_path, "rb") as f:
11
- model = pickle.load(f)
12
- model.eval() # Set model to evaluation mode
13
- except Exception as e:
14
- raise Exception(f"Failed to load model: {str(e)}")
15
-
16
- # Define image preprocessing (adjust these transforms based on your model's training)
17
- preprocess = transforms.Compose([
18
- transforms.Resize((224, 224)), # Match your model's expected input size
19
- transforms.ToTensor(),
20
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # ImageNet defaults
21
- ])
22
-
23
- # Replace with your actual list of bird species (in the order the model was trained)
24
- class_labels = ["Sparrow", "Eagle", "Blue Jay", "Cardinal"] # Update this!
25
-
26
- # Prediction function
27
- def classify_bird(image):
28
- try:
29
- if image is None:
30
- return "Please upload an image of a bird."
31
-
32
- # Preprocess the uploaded image
33
- img = preprocess(image).unsqueeze(0) # Add batch dimension
34
-
35
- # Make prediction automatically
36
- with torch.no_grad():
37
- outputs = model(img) # Model outputs logits or probabilities
38
-
39
- # Get the predicted species
40
- predicted_idx = outputs.argmax(-1).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pickle
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ # Load the trained model
7
+ with open("bird_classifier.pkl", "rb") as f:
8
+ model = pickle.load(f)
9
+
10
+ # Get class names automatically from the model
11
+ try:
12
+ class_names = model.classes_ # Works for scikit-learn models
13
+ except AttributeError:
14
+ # If the model doesn't have classes_, you'd need a fallback or custom logic
15
+ raise ValueError("Model does not have 'classes_' attribute. Please provide class names manually or adjust the code.")
16
+
17
+ # Define the prediction function
18
+ def classify_bird(image):
19
+ # Preprocess the image (adjust this based on how your model was trained)
20
+ img = Image.fromarray(image.astype("uint8"), "RGB") # Convert to PIL Image
21
+ img = img.resize((224, 224)) # Example resize, adjust to your model's input size
22
+ img_array = np.array(img) / 255.0 # Normalize if your model expects this
23
+ img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
24
+
25
+ # Make prediction
26
+ prediction = model.predict(img_array) # Adjust based on your model's method
27
+
28
+ # Handle prediction output
29
+ if len(prediction.shape) > 1: # If prediction is a probability array (e.g., softmax output)
30
+ predicted_class = class_names[np.argmax(prediction)]
31
+ else: # If prediction is a single class index (e.g., scikit-learn's default)
32
+ predicted_class = class_names[prediction[0]]
33
+
34
+ return predicted_class
35
+
36
+ # Create the Gradio interface
37
+ interface = gr.Interface(
38
+ fn=classify_bird, # Prediction function
39
+ inputs=gr.Image(type="numpy"), # Input is an image, returned as NumPy array
40
+ outputs=gr.Textbox(), # Output is text (bird species)
41
+ title="Bird Classifier",
42
+ description="Upload an image of a bird and get its species predicted!"
43
+ )
44
+
45
+ # Launch the app
46
+ interface.launch()