import os import zipfile import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader, random_split from PIL import Image import numpy as np import gradio as gr import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import classification_report, confusion_matrix import torchvision import os import zipfile from PIL import Image zip_path = "Fruits_dataset.zip" extract_root = "./unzipped" dataset_folder = None # Step 1: Extract ZIP if not os.path.exists(extract_root): print("Extracting dataset from uploaded zip...") with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(extract_root) else: print("Dataset already extracted.") # Step 2: Auto-detect the dataset folder (looking for class folders like Apples, Pears, etc.) for root, dirs, files in os.walk(extract_root): # Check if this folder contains class folders (e.g., Apples/) if all(os.path.isdir(os.path.join(root, d)) for d in dirs) and len(dirs) > 0: # Ensure it has image files inside class folders subdir = os.path.join(root, dirs[0]) if any(f.lower().endswith(('.png', '.jpg', '.jpeg')) for f in os.listdir(subdir)): dataset_folder = root break # Step 3: Handle case when nothing found if dataset_folder is None: raise RuntimeError("❌ Could not find dataset folder with class image directories.") print(f"✅ Detected dataset folder: {dataset_folder}") print("✅ Classes:", os.listdir(dataset_folder)) # Step 4: Show image dimensions per class for cls in os.listdir(dataset_folder): cls_path = os.path.join(dataset_folder, cls) if os.path.isdir(cls_path): for img_file in os.listdir(cls_path): img_path = os.path.join(cls_path, img_file) try: img = Image.open(img_path) print(f"{img_path}: {img.size}") except Exception as e: print(f"⚠️ Failed to open {img_path}: {e}") # ------------------------------- # 2. Load Data # ------------------------------- transform = transforms.Compose([ transforms.Resize((64, 64)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) dataset = datasets.ImageFolder(root=dataset_folder, transform=transform) class_names = dataset.classes train_size = int(0.8 * len(dataset)) val_size = len(dataset) - train_size train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) # ------------------------------- # 3. Define Model # ------------------------------- class FruitCNN(nn.Module): def __init__(self, num_classes): super(FruitCNN, self).__init__() self.conv_layers = nn.Sequential( nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2) ) self.fc_layers = nn.Sequential( nn.Flatten(), nn.Linear(128 * 8 * 8, 256), nn.ReLU(), nn.Dropout(0.4), nn.Linear(256, num_classes) ) def forward(self, x): x = self.conv_layers(x) return self.fc_layers(x) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = FruitCNN(len(class_names)).to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # ------------------------------- # 4. Train Model # ------------------------------- epoch_logs = [] num_epochs = 15 for epoch in range(num_epochs): model.train() running_loss = 0 correct, total = 0, 0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() _, preds = torch.max(outputs, 1) correct += (preds == labels).sum().item() total += labels.size(0) acc = 100 * correct / total epoch_logs.append(f"Epoch {epoch+1}/{num_epochs} - Loss: {running_loss:.4f}, Accuracy: {acc:.2f}%") # ------------------------------- # 5. Evaluate # ------------------------------- model.eval() all_preds, all_labels = [], [] with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) _, preds = torch.max(outputs, 1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) conf_mat = confusion_matrix(all_labels, all_preds) clf_report = classification_report(all_labels, all_preds, target_names=class_names) # ------------------------------- # 6. Gradio Functions # ------------------------------- def describe_model(): fig, axes = plt.subplots(2, 1, figsize=(8, 10)) # Confusion Matrix sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names, ax=axes[0]) axes[0].set_title("Confusion Matrix") # Show sample predictions images, labels = next(iter(val_loader)) outputs = model(images.to(device)) _, preds = torch.max(outputs, 1) grid_img = torchvision.utils.make_grid(images[:8] / 2 + 0.5, nrow=4) npimg = grid_img.numpy().transpose((1, 2, 0)) axes[1].imshow(npimg) axes[1].set_title("Sample Predictions: " + " | ".join([class_names[p.item()] for p in preds[:8]])) axes[1].axis('off') # Summary (without image_shapes) summary = "\n".join(epoch_logs) return fig, summary + "\n\n" + clf_report def predict_image(img): img = img.convert("RGB") transform_pred = transforms.Compose([ transforms.Resize((64, 64)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) img_tensor = transform_pred(img).unsqueeze(0).to(device) with torch.no_grad(): output = model(img_tensor) probs = torch.softmax(output, dim=1)[0] conf, pred_idx = torch.max(probs, dim=0) label = class_names[pred_idx.item()] return f"Predicted: {label} ({conf.item()*100:.2f}%)" # ------------------------------- # 7. Launch Gradio # ------------------------------- desc_block = gr.Blocks() with desc_block: with gr.Tab("Model Summary"): desc_output = gr.Plot(label="Evaluation Results") report_output = gr.Textbox(label="Training Summary and Classification Report") desc_button = gr.Button("Generate Summary") desc_button.click(fn=describe_model, outputs=[desc_output, report_output]) with gr.Tab("Upload & Predict"): image_input = gr.Image(type="pil", label="Upload Fruit Image", image_mode="RGB", height=256, width=256) pred_output = gr.Textbox(label="Prediction Result") predict_button = gr.Button("Classify Image") predict_button.click(fn=predict_image, inputs=image_input, outputs=pred_output) desc_block.launch()