apple_pearSort / app.py
EzekielMW's picture
Update app.py
6c24de3 verified
import os
import zipfile
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from PIL import Image
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import torchvision
import os
import zipfile
from PIL import Image
zip_path = "Fruits_dataset.zip"
extract_root = "./unzipped"
dataset_folder = None
# Step 1: Extract ZIP
if not os.path.exists(extract_root):
print("Extracting dataset from uploaded zip...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_root)
else:
print("Dataset already extracted.")
# Step 2: Auto-detect the dataset folder (looking for class folders like Apples, Pears, etc.)
for root, dirs, files in os.walk(extract_root):
# Check if this folder contains class folders (e.g., Apples/)
if all(os.path.isdir(os.path.join(root, d)) for d in dirs) and len(dirs) > 0:
# Ensure it has image files inside class folders
subdir = os.path.join(root, dirs[0])
if any(f.lower().endswith(('.png', '.jpg', '.jpeg')) for f in os.listdir(subdir)):
dataset_folder = root
break
# Step 3: Handle case when nothing found
if dataset_folder is None:
raise RuntimeError("❌ Could not find dataset folder with class image directories.")
print(f"✅ Detected dataset folder: {dataset_folder}")
print("✅ Classes:", os.listdir(dataset_folder))
# Step 4: Show image dimensions per class
for cls in os.listdir(dataset_folder):
cls_path = os.path.join(dataset_folder, cls)
if os.path.isdir(cls_path):
for img_file in os.listdir(cls_path):
img_path = os.path.join(cls_path, img_file)
try:
img = Image.open(img_path)
print(f"{img_path}: {img.size}")
except Exception as e:
print(f"⚠️ Failed to open {img_path}: {e}")
# -------------------------------
# 2. Load Data
# -------------------------------
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.ImageFolder(root=dataset_folder, transform=transform)
class_names = dataset.classes
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# -------------------------------
# 3. Define Model
# -------------------------------
class FruitCNN(nn.Module):
def __init__(self, num_classes):
super(FruitCNN, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2)
)
self.fc_layers = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 8 * 8, 256),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.conv_layers(x)
return self.fc_layers(x)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FruitCNN(len(class_names)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# -------------------------------
# 4. Train Model
# -------------------------------
epoch_logs = []
num_epochs = 15
for epoch in range(num_epochs):
model.train()
running_loss = 0
correct, total = 0, 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
_, preds = torch.max(outputs, 1)
correct += (preds == labels).sum().item()
total += labels.size(0)
acc = 100 * correct / total
epoch_logs.append(f"Epoch {epoch+1}/{num_epochs} - Loss: {running_loss:.4f}, Accuracy: {acc:.2f}%")
# -------------------------------
# 5. Evaluate
# -------------------------------
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
conf_mat = confusion_matrix(all_labels, all_preds)
clf_report = classification_report(all_labels, all_preds, target_names=class_names)
# -------------------------------
# 6. Gradio Functions
# -------------------------------
def describe_model():
fig, axes = plt.subplots(2, 1, figsize=(8, 10))
# Confusion Matrix
sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names, ax=axes[0])
axes[0].set_title("Confusion Matrix")
# Show sample predictions
images, labels = next(iter(val_loader))
outputs = model(images.to(device))
_, preds = torch.max(outputs, 1)
grid_img = torchvision.utils.make_grid(images[:8] / 2 + 0.5, nrow=4)
npimg = grid_img.numpy().transpose((1, 2, 0))
axes[1].imshow(npimg)
axes[1].set_title("Sample Predictions: " + " | ".join([class_names[p.item()] for p in preds[:8]]))
axes[1].axis('off')
# Summary (without image_shapes)
summary = "\n".join(epoch_logs)
return fig, summary + "\n\n" + clf_report
def predict_image(img):
img = img.convert("RGB")
transform_pred = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
img_tensor = transform_pred(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(img_tensor)
probs = torch.softmax(output, dim=1)[0]
conf, pred_idx = torch.max(probs, dim=0)
label = class_names[pred_idx.item()]
return f"Predicted: {label} ({conf.item()*100:.2f}%)"
# -------------------------------
# 7. Launch Gradio
# -------------------------------
desc_block = gr.Blocks()
with desc_block:
with gr.Tab("Model Summary"):
desc_output = gr.Plot(label="Evaluation Results")
report_output = gr.Textbox(label="Training Summary and Classification Report")
desc_button = gr.Button("Generate Summary")
desc_button.click(fn=describe_model, outputs=[desc_output, report_output])
with gr.Tab("Upload & Predict"):
image_input = gr.Image(type="pil", label="Upload Fruit Image", image_mode="RGB", height=256, width=256)
pred_output = gr.Textbox(label="Prediction Result")
predict_button = gr.Button("Classify Image")
predict_button.click(fn=predict_image, inputs=image_input, outputs=pred_output)
desc_block.launch()