Spaces:
Sleeping
Sleeping
| import os | |
| import zipfile | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torchvision import datasets, transforms | |
| from torch.utils.data import DataLoader, random_split | |
| from PIL import Image | |
| import numpy as np | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.metrics import classification_report, confusion_matrix | |
| import torchvision | |
| import os | |
| import zipfile | |
| from PIL import Image | |
| zip_path = "Fruits_dataset.zip" | |
| extract_root = "./unzipped" | |
| dataset_folder = None | |
| # Step 1: Extract ZIP | |
| if not os.path.exists(extract_root): | |
| print("Extracting dataset from uploaded zip...") | |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
| zip_ref.extractall(extract_root) | |
| else: | |
| print("Dataset already extracted.") | |
| # Step 2: Auto-detect the dataset folder (looking for class folders like Apples, Pears, etc.) | |
| for root, dirs, files in os.walk(extract_root): | |
| # Check if this folder contains class folders (e.g., Apples/) | |
| if all(os.path.isdir(os.path.join(root, d)) for d in dirs) and len(dirs) > 0: | |
| # Ensure it has image files inside class folders | |
| subdir = os.path.join(root, dirs[0]) | |
| if any(f.lower().endswith(('.png', '.jpg', '.jpeg')) for f in os.listdir(subdir)): | |
| dataset_folder = root | |
| break | |
| # Step 3: Handle case when nothing found | |
| if dataset_folder is None: | |
| raise RuntimeError("❌ Could not find dataset folder with class image directories.") | |
| print(f"✅ Detected dataset folder: {dataset_folder}") | |
| print("✅ Classes:", os.listdir(dataset_folder)) | |
| # Step 4: Show image dimensions per class | |
| for cls in os.listdir(dataset_folder): | |
| cls_path = os.path.join(dataset_folder, cls) | |
| if os.path.isdir(cls_path): | |
| for img_file in os.listdir(cls_path): | |
| img_path = os.path.join(cls_path, img_file) | |
| try: | |
| img = Image.open(img_path) | |
| print(f"{img_path}: {img.size}") | |
| except Exception as e: | |
| print(f"⚠️ Failed to open {img_path}: {e}") | |
| # ------------------------------- | |
| # 2. Load Data | |
| # ------------------------------- | |
| transform = transforms.Compose([ | |
| transforms.Resize((64, 64)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) | |
| ]) | |
| dataset = datasets.ImageFolder(root=dataset_folder, transform=transform) | |
| class_names = dataset.classes | |
| train_size = int(0.8 * len(dataset)) | |
| val_size = len(dataset) - train_size | |
| train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) | |
| train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) | |
| val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) | |
| # ------------------------------- | |
| # 3. Define Model | |
| # ------------------------------- | |
| class FruitCNN(nn.Module): | |
| def __init__(self, num_classes): | |
| super(FruitCNN, self).__init__() | |
| self.conv_layers = nn.Sequential( | |
| nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2), | |
| nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2), | |
| nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2) | |
| ) | |
| self.fc_layers = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(128 * 8 * 8, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.4), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, x): | |
| x = self.conv_layers(x) | |
| return self.fc_layers(x) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = FruitCNN(len(class_names)).to(device) | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr=0.001) | |
| # ------------------------------- | |
| # 4. Train Model | |
| # ------------------------------- | |
| epoch_logs = [] | |
| num_epochs = 15 | |
| for epoch in range(num_epochs): | |
| model.train() | |
| running_loss = 0 | |
| correct, total = 0, 0 | |
| for images, labels in train_loader: | |
| images, labels = images.to(device), labels.to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() | |
| _, preds = torch.max(outputs, 1) | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| acc = 100 * correct / total | |
| epoch_logs.append(f"Epoch {epoch+1}/{num_epochs} - Loss: {running_loss:.4f}, Accuracy: {acc:.2f}%") | |
| # ------------------------------- | |
| # 5. Evaluate | |
| # ------------------------------- | |
| model.eval() | |
| all_preds, all_labels = [], [] | |
| with torch.no_grad(): | |
| for images, labels in val_loader: | |
| images, labels = images.to(device), labels.to(device) | |
| outputs = model(images) | |
| _, preds = torch.max(outputs, 1) | |
| all_preds.extend(preds.cpu().numpy()) | |
| all_labels.extend(labels.cpu().numpy()) | |
| conf_mat = confusion_matrix(all_labels, all_preds) | |
| clf_report = classification_report(all_labels, all_preds, target_names=class_names) | |
| # ------------------------------- | |
| # 6. Gradio Functions | |
| # ------------------------------- | |
| def describe_model(): | |
| fig, axes = plt.subplots(2, 1, figsize=(8, 10)) | |
| # Confusion Matrix | |
| sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues', | |
| xticklabels=class_names, yticklabels=class_names, ax=axes[0]) | |
| axes[0].set_title("Confusion Matrix") | |
| # Show sample predictions | |
| images, labels = next(iter(val_loader)) | |
| outputs = model(images.to(device)) | |
| _, preds = torch.max(outputs, 1) | |
| grid_img = torchvision.utils.make_grid(images[:8] / 2 + 0.5, nrow=4) | |
| npimg = grid_img.numpy().transpose((1, 2, 0)) | |
| axes[1].imshow(npimg) | |
| axes[1].set_title("Sample Predictions: " + " | ".join([class_names[p.item()] for p in preds[:8]])) | |
| axes[1].axis('off') | |
| # Summary (without image_shapes) | |
| summary = "\n".join(epoch_logs) | |
| return fig, summary + "\n\n" + clf_report | |
| def predict_image(img): | |
| img = img.convert("RGB") | |
| transform_pred = transforms.Compose([ | |
| transforms.Resize((64, 64)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) | |
| ]) | |
| img_tensor = transform_pred(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(img_tensor) | |
| probs = torch.softmax(output, dim=1)[0] | |
| conf, pred_idx = torch.max(probs, dim=0) | |
| label = class_names[pred_idx.item()] | |
| return f"Predicted: {label} ({conf.item()*100:.2f}%)" | |
| # ------------------------------- | |
| # 7. Launch Gradio | |
| # ------------------------------- | |
| desc_block = gr.Blocks() | |
| with desc_block: | |
| with gr.Tab("Model Summary"): | |
| desc_output = gr.Plot(label="Evaluation Results") | |
| report_output = gr.Textbox(label="Training Summary and Classification Report") | |
| desc_button = gr.Button("Generate Summary") | |
| desc_button.click(fn=describe_model, outputs=[desc_output, report_output]) | |
| with gr.Tab("Upload & Predict"): | |
| image_input = gr.Image(type="pil", label="Upload Fruit Image", image_mode="RGB", height=256, width=256) | |
| pred_output = gr.Textbox(label="Prediction Result") | |
| predict_button = gr.Button("Classify Image") | |
| predict_button.click(fn=predict_image, inputs=image_input, outputs=pred_output) | |
| desc_block.launch() | |