Samgityyyy's picture
Create app.py
2a8edf3 verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision import transforms, models
from torch import nn
import matplotlib.pyplot as plt
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import time
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load dataset (using streaming to save memory)
print("Loading dataset...")
dataset = load_dataset("deep-plants/AGM", split="train", streaming=True)
# Take a small sample for demonstration (1000 images)
# In real training, you'd use more data
sample_size = 1000
dataset_list = list(dataset.take(sample_size))
# Extract images and labels
images = [item['image'] for item in dataset_list]
labels = [item['label'] for item in dataset_list]
# Split into train and test
train_images, test_images, train_labels, test_labels = train_test_split(
images, labels, test_size=0.2, random_state=42
)
print(f"Training samples: {len(train_images)}")
print(f"Testing samples: {len(test_images)}")
# Define EfficientNet-B0 model
class PlantClassifier(nn.Module):
def __init__(self, num_classes=18): # AGM dataset has 18 classes
super(PlantClassifier, self).__init__()
# Load pre-trained EfficientNet-B0
self.effnet = models.efficientnet_b0(pretrained=True)
# Replace the classifier head
num_features = self.effnet.classifier[1].in_features
self.effnet.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(num_features, num_classes)
)
def forward(self, x):
return self.effnet(x)
# Initialize model
model = PlantClassifier(num_classes=18).to(device)
# Define transforms
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Training function (simplified for Space demo)
def train_model(epochs=1):
print("Starting training...")
model.train()
# Simple training loop (for demo purposes)
for epoch in range(epochs):
correct = 0
total = 0
for i, (img, label) in enumerate(zip(train_images[:100], train_labels[:100])): # Small batch for demo
try:
# Preprocess image
img_tensor = train_transform(img).unsqueeze(0).to(device)
label_tensor = torch.tensor([label]).to(device)
# Forward pass
outputs = model(img_tensor)
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == label_tensor).sum().item()
total += 1
if i % 20 == 0:
print(f"Epoch {epoch+1}, Batch {i}/100")
except Exception as e:
print(f"Error processing image {i}: {e}")
continue
accuracy = 100 * correct / total if total > 0 else 0
print(f"Epoch {epoch+1} completed. Accuracy: {accuracy:.2f}%")
print("Training completed!")
return model
# Prediction function
def predict_plant(image):
try:
# Preprocess the uploaded image
img_tensor = test_transform(image).unsqueeze(0).to(device)
# Make prediction
model.eval()
with torch.no_grad():
outputs = model(img_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# Get top 3 predictions
top3_prob, top3_catid = torch.topk(probabilities, 3)
# Class names for AGM dataset (you should replace with actual class names)
class_names = [
"Wheat", "Rice", "Maize", "Barley", "Oats", "Soybean", "Cotton",
"Sunflower", "Potato", "Tomato", "Pepper", "Cucumber", "Carrot",
"Onion", "Apple", "Orange", "Grape", "Strawberry"
]
results = []
for i in range(top3_prob.size(0)):
class_name = class_names[top3_catid[i]] if top3_catid[i] < len(class_names) else f"Class {top3_catid[i]}"
probability = top3_prob[i].item() * 100
results.append(f"{class_name}: {probability:.2f}%")
# Create visualization
fig, ax = plt.subplots(figsize=(10, 5))
y_pos = np.arange(len(results))
accuracies = [float(r.split(": ")[1].replace("%", "")) for r in results]
class_names_plot = [r.split(": ")[0] for r in results]
ax.barh(y_pos, accuracies, align='center')
ax.set_yticks(y_pos)
ax.set_yticklabels(class_names_plot)
ax.set_xlabel('Probability (%)')
ax.set_title('Top 3 Predictions')
ax.set_xlim(0, 100)
for i, v in enumerate(accuracies):
ax.text(v + 1, i, f'{v:.1f}%', va='center')
plt.tight_layout()
return "\n".join(results), fig
except Exception as e:
return f"Error: {str(e)}", None
# Train the model (this will run when the Space starts)
try:
print("Training model...")
trained_model = train_model(epochs=1) # Just 1 epoch for demo
print("Model trained successfully!")
except Exception as e:
print(f"Training failed: {e}")
# Create Gradio interface
with gr.Blocks(title="Plant Classifier") as demo:
gr.Markdown("# 🌱 Plant Classifier using EfficientNet-B0")
gr.Markdown("Upload a plant image to classify it using EfficientNet-B0")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Plant Image")
submit_btn = gr.Button("Classify Plant", variant="primary")
with gr.Column():
text_output = gr.Textbox(label="Predictions")
plot_output = gr.Plot(label="Probability Distribution")
submit_btn.click(
fn=predict_plant,
inputs=image_input,
outputs=[text_output, plot_output]
)
gr.Markdown("### Dataset Information")
gr.Markdown("- **Dataset**: deep-plants/AGM")
gr.Markdown("- **Classes**: 18 plant crops")
gr.Markdown("- **Model**: EfficientNet-B0 (pre-trained on ImageNet)")
gr.Markdown("- **Training**: 1 epoch on 100 samples (demo)")
# Launch the app
if __name__ == "__main__":
demo.launch()