Jacksonnavigator7 commited on
Commit
6b6edf3
·
verified ·
1 Parent(s): c579501

Update app

Browse files
Files changed (1) hide show
  1. app +33 -25
app CHANGED
@@ -1,32 +1,40 @@
1
  import gradio as gr
2
- from transformers import AutoImageProcessor, AutoModelForImageClassification
3
  from PIL import Image
4
  import torch
 
 
5
 
6
- # Load your model and processor (replace with your model’s Hugging Face ID or local path)
7
- model_name = "your-username/your-bird-classifier" # e.g., "dennisjooo/Birds-Classifier-EfficientNetB2"
8
- processor = AutoImageProcessor.from_pretrained(model_name)
9
- model = AutoModelForImageClassification.from_pretrained(model_name)
 
 
 
 
10
 
11
- # Prediction function
12
- def classify_bird(image):
13
- # Process the image
14
- inputs = processor(image, return_tensors="pt")
15
- with torch.no_grad():
16
- outputs = model(**inputs).logits
17
- # Get the predicted label
18
- predicted_idx = outputs.argmax(-1).item()
19
- label = model.config.id2label[predicted_idx]
20
- return f"Predicted bird species: {label}"
21
 
22
- # Create the Gradio interface
23
- interface = gr.Interface(
24
- fn=classify_bird,
25
- inputs=gr.Image(type="pil", label="Upload a bird image"),
26
- outputs=gr.Textbox(label="Prediction"),
27
- title="Bird Species Classifier",
28
- description="Upload an image of a bird to identify its species!"
29
- )
30
 
31
- # Launch the app
32
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  from PIL import Image
3
  import torch
4
+ import pickle
5
+ import torchvision.transforms as transforms
6
 
7
+ # Load the pickled model
8
+ model_path = "birds_classifier.pkl"
9
+ try:
10
+ with open(model_path, "rb") as f:
11
+ model = pickle.load(f)
12
+ model.eval() # Set model to evaluation mode
13
+ except Exception as e:
14
+ raise Exception(f"Failed to load model: {str(e)}")
15
 
16
+ # Define image preprocessing (adjust these transforms based on your model's training)
17
+ preprocess = transforms.Compose([
18
+ transforms.Resize((224, 224)), # Match your model's expected input size
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # ImageNet defaults
21
+ ])
 
 
 
 
22
 
23
+ # Replace with your actual list of bird species (in the order the model was trained)
24
+ class_labels = ["Sparrow", "Eagle", "Blue Jay", "Cardinal"] # Update this!
 
 
 
 
 
 
25
 
26
+ # Prediction function
27
+ def classify_bird(image):
28
+ try:
29
+ if image is None:
30
+ return "Please upload an image of a bird."
31
+
32
+ # Preprocess the uploaded image
33
+ img = preprocess(image).unsqueeze(0) # Add batch dimension
34
+
35
+ # Make prediction automatically
36
+ with torch.no_grad():
37
+ outputs = model(img) # Model outputs logits or probabilities
38
+
39
+ # Get the predicted species
40
+ predicted_idx = outputs.argmax(-1).item()