import gradio as gr import torch import numpy as np from PIL import Image from torchvision import transforms, models from torch import nn import matplotlib.pyplot as plt from datasets import load_dataset from sklearn.model_selection import train_test_split import time # Set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Load dataset (using streaming to save memory) print("Loading dataset...") dataset = load_dataset("deep-plants/AGM", split="train", streaming=True) # Take a small sample for demonstration (1000 images) # In real training, you'd use more data sample_size = 1000 dataset_list = list(dataset.take(sample_size)) # Extract images and labels images = [item['image'] for item in dataset_list] labels = [item['label'] for item in dataset_list] # Split into train and test train_images, test_images, train_labels, test_labels = train_test_split( images, labels, test_size=0.2, random_state=42 ) print(f"Training samples: {len(train_images)}") print(f"Testing samples: {len(test_images)}") # Define EfficientNet-B0 model class PlantClassifier(nn.Module): def __init__(self, num_classes=18): # AGM dataset has 18 classes super(PlantClassifier, self).__init__() # Load pre-trained EfficientNet-B0 self.effnet = models.efficientnet_b0(pretrained=True) # Replace the classifier head num_features = self.effnet.classifier[1].in_features self.effnet.classifier = nn.Sequential( nn.Dropout(0.2), nn.Linear(num_features, num_classes) ) def forward(self, x): return self.effnet(x) # Initialize model model = PlantClassifier(num_classes=18).to(device) # Define transforms train_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(10), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) test_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Training function (simplified for Space demo) def train_model(epochs=1): print("Starting training...") model.train() # Simple training loop (for demo purposes) for epoch in range(epochs): correct = 0 total = 0 for i, (img, label) in enumerate(zip(train_images[:100], train_labels[:100])): # Small batch for demo try: # Preprocess image img_tensor = train_transform(img).unsqueeze(0).to(device) label_tensor = torch.tensor([label]).to(device) # Forward pass outputs = model(img_tensor) _, predicted = torch.max(outputs.data, 1) correct += (predicted == label_tensor).sum().item() total += 1 if i % 20 == 0: print(f"Epoch {epoch+1}, Batch {i}/100") except Exception as e: print(f"Error processing image {i}: {e}") continue accuracy = 100 * correct / total if total > 0 else 0 print(f"Epoch {epoch+1} completed. Accuracy: {accuracy:.2f}%") print("Training completed!") return model # Prediction function def predict_plant(image): try: # Preprocess the uploaded image img_tensor = test_transform(image).unsqueeze(0).to(device) # Make prediction model.eval() with torch.no_grad(): outputs = model(img_tensor) probabilities = torch.nn.functional.softmax(outputs[0], dim=0) # Get top 3 predictions top3_prob, top3_catid = torch.topk(probabilities, 3) # Class names for AGM dataset (you should replace with actual class names) class_names = [ "Wheat", "Rice", "Maize", "Barley", "Oats", "Soybean", "Cotton", "Sunflower", "Potato", "Tomato", "Pepper", "Cucumber", "Carrot", "Onion", "Apple", "Orange", "Grape", "Strawberry" ] results = [] for i in range(top3_prob.size(0)): class_name = class_names[top3_catid[i]] if top3_catid[i] < len(class_names) else f"Class {top3_catid[i]}" probability = top3_prob[i].item() * 100 results.append(f"{class_name}: {probability:.2f}%") # Create visualization fig, ax = plt.subplots(figsize=(10, 5)) y_pos = np.arange(len(results)) accuracies = [float(r.split(": ")[1].replace("%", "")) for r in results] class_names_plot = [r.split(": ")[0] for r in results] ax.barh(y_pos, accuracies, align='center') ax.set_yticks(y_pos) ax.set_yticklabels(class_names_plot) ax.set_xlabel('Probability (%)') ax.set_title('Top 3 Predictions') ax.set_xlim(0, 100) for i, v in enumerate(accuracies): ax.text(v + 1, i, f'{v:.1f}%', va='center') plt.tight_layout() return "\n".join(results), fig except Exception as e: return f"Error: {str(e)}", None # Train the model (this will run when the Space starts) try: print("Training model...") trained_model = train_model(epochs=1) # Just 1 epoch for demo print("Model trained successfully!") except Exception as e: print(f"Training failed: {e}") # Create Gradio interface with gr.Blocks(title="Plant Classifier") as demo: gr.Markdown("# 🌱 Plant Classifier using EfficientNet-B0") gr.Markdown("Upload a plant image to classify it using EfficientNet-B0") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Plant Image") submit_btn = gr.Button("Classify Plant", variant="primary") with gr.Column(): text_output = gr.Textbox(label="Predictions") plot_output = gr.Plot(label="Probability Distribution") submit_btn.click( fn=predict_plant, inputs=image_input, outputs=[text_output, plot_output] ) gr.Markdown("### Dataset Information") gr.Markdown("- **Dataset**: deep-plants/AGM") gr.Markdown("- **Classes**: 18 plant crops") gr.Markdown("- **Model**: EfficientNet-B0 (pre-trained on ImageNet)") gr.Markdown("- **Training**: 1 epoch on 100 samples (demo)") # Launch the app if __name__ == "__main__": demo.launch()