Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision import transforms, models | |
| from torch import nn | |
| import matplotlib.pyplot as plt | |
| from datasets import load_dataset | |
| from sklearn.model_selection import train_test_split | |
| import time | |
| # Set device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| # Load dataset (using streaming to save memory) | |
| print("Loading dataset...") | |
| dataset = load_dataset("deep-plants/AGM", split="train", streaming=True) | |
| # Take a small sample for demonstration (1000 images) | |
| # In real training, you'd use more data | |
| sample_size = 1000 | |
| dataset_list = list(dataset.take(sample_size)) | |
| # Extract images and labels | |
| images = [item['image'] for item in dataset_list] | |
| labels = [item['label'] for item in dataset_list] | |
| # Split into train and test | |
| train_images, test_images, train_labels, test_labels = train_test_split( | |
| images, labels, test_size=0.2, random_state=42 | |
| ) | |
| print(f"Training samples: {len(train_images)}") | |
| print(f"Testing samples: {len(test_images)}") | |
| # Define EfficientNet-B0 model | |
| class PlantClassifier(nn.Module): | |
| def __init__(self, num_classes=18): # AGM dataset has 18 classes | |
| super(PlantClassifier, self).__init__() | |
| # Load pre-trained EfficientNet-B0 | |
| self.effnet = models.efficientnet_b0(pretrained=True) | |
| # Replace the classifier head | |
| num_features = self.effnet.classifier[1].in_features | |
| self.effnet.classifier = nn.Sequential( | |
| nn.Dropout(0.2), | |
| nn.Linear(num_features, num_classes) | |
| ) | |
| def forward(self, x): | |
| return self.effnet(x) | |
| # Initialize model | |
| model = PlantClassifier(num_classes=18).to(device) | |
| # Define transforms | |
| train_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomRotation(10), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| test_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Training function (simplified for Space demo) | |
| def train_model(epochs=1): | |
| print("Starting training...") | |
| model.train() | |
| # Simple training loop (for demo purposes) | |
| for epoch in range(epochs): | |
| correct = 0 | |
| total = 0 | |
| for i, (img, label) in enumerate(zip(train_images[:100], train_labels[:100])): # Small batch for demo | |
| try: | |
| # Preprocess image | |
| img_tensor = train_transform(img).unsqueeze(0).to(device) | |
| label_tensor = torch.tensor([label]).to(device) | |
| # Forward pass | |
| outputs = model(img_tensor) | |
| _, predicted = torch.max(outputs.data, 1) | |
| correct += (predicted == label_tensor).sum().item() | |
| total += 1 | |
| if i % 20 == 0: | |
| print(f"Epoch {epoch+1}, Batch {i}/100") | |
| except Exception as e: | |
| print(f"Error processing image {i}: {e}") | |
| continue | |
| accuracy = 100 * correct / total if total > 0 else 0 | |
| print(f"Epoch {epoch+1} completed. Accuracy: {accuracy:.2f}%") | |
| print("Training completed!") | |
| return model | |
| # Prediction function | |
| def predict_plant(image): | |
| try: | |
| # Preprocess the uploaded image | |
| img_tensor = test_transform(image).unsqueeze(0).to(device) | |
| # Make prediction | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(img_tensor) | |
| probabilities = torch.nn.functional.softmax(outputs[0], dim=0) | |
| # Get top 3 predictions | |
| top3_prob, top3_catid = torch.topk(probabilities, 3) | |
| # Class names for AGM dataset (you should replace with actual class names) | |
| class_names = [ | |
| "Wheat", "Rice", "Maize", "Barley", "Oats", "Soybean", "Cotton", | |
| "Sunflower", "Potato", "Tomato", "Pepper", "Cucumber", "Carrot", | |
| "Onion", "Apple", "Orange", "Grape", "Strawberry" | |
| ] | |
| results = [] | |
| for i in range(top3_prob.size(0)): | |
| class_name = class_names[top3_catid[i]] if top3_catid[i] < len(class_names) else f"Class {top3_catid[i]}" | |
| probability = top3_prob[i].item() * 100 | |
| results.append(f"{class_name}: {probability:.2f}%") | |
| # Create visualization | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| y_pos = np.arange(len(results)) | |
| accuracies = [float(r.split(": ")[1].replace("%", "")) for r in results] | |
| class_names_plot = [r.split(": ")[0] for r in results] | |
| ax.barh(y_pos, accuracies, align='center') | |
| ax.set_yticks(y_pos) | |
| ax.set_yticklabels(class_names_plot) | |
| ax.set_xlabel('Probability (%)') | |
| ax.set_title('Top 3 Predictions') | |
| ax.set_xlim(0, 100) | |
| for i, v in enumerate(accuracies): | |
| ax.text(v + 1, i, f'{v:.1f}%', va='center') | |
| plt.tight_layout() | |
| return "\n".join(results), fig | |
| except Exception as e: | |
| return f"Error: {str(e)}", None | |
| # Train the model (this will run when the Space starts) | |
| try: | |
| print("Training model...") | |
| trained_model = train_model(epochs=1) # Just 1 epoch for demo | |
| print("Model trained successfully!") | |
| except Exception as e: | |
| print(f"Training failed: {e}") | |
| # Create Gradio interface | |
| with gr.Blocks(title="Plant Classifier") as demo: | |
| gr.Markdown("# 🌱 Plant Classifier using EfficientNet-B0") | |
| gr.Markdown("Upload a plant image to classify it using EfficientNet-B0") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Plant Image") | |
| submit_btn = gr.Button("Classify Plant", variant="primary") | |
| with gr.Column(): | |
| text_output = gr.Textbox(label="Predictions") | |
| plot_output = gr.Plot(label="Probability Distribution") | |
| submit_btn.click( | |
| fn=predict_plant, | |
| inputs=image_input, | |
| outputs=[text_output, plot_output] | |
| ) | |
| gr.Markdown("### Dataset Information") | |
| gr.Markdown("- **Dataset**: deep-plants/AGM") | |
| gr.Markdown("- **Classes**: 18 plant crops") | |
| gr.Markdown("- **Model**: EfficientNet-B0 (pre-trained on ImageNet)") | |
| gr.Markdown("- **Training**: 1 epoch on 100 samples (demo)") | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |