Spaces:
Sleeping
Sleeping
upload all files
Browse files- .gitattributes +0 -34
- README copy.md +28 -0
- app.py +94 -0
- checkpoints/best_model_993.pth +3 -0
- dataset.py +253 -0
- demo.py +163 -0
- evaluate.py +327 -0
- label_mapping.json +27 -0
- model.py +159 -0
- prepare_dataset.py +247 -0
- requirements.txt +10 -0
- train.py +286 -0
.gitattributes
CHANGED
|
@@ -1,35 +1 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README copy.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Pest and Disease Classification 🌿
|
| 3 |
+
emoji: 🌱
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: "4.44.1"
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# 🌿 Pest and Disease Classification Demo
|
| 13 |
+
|
| 14 |
+
This demo provides a simple web interface for classifying **pests and diseases in citrus leaves**.
|
| 15 |
+
|
| 16 |
+
## 🧠 Model
|
| 17 |
+
The model is based on a CNN backbone (ResNet50 by default) trained on a labeled dataset of citrus plant leaves.
|
| 18 |
+
|
| 19 |
+
- **Framework:** PyTorch
|
| 20 |
+
- **Interface:** Gradio
|
| 21 |
+
- **Backbone:** ResNet50
|
| 22 |
+
- **Task:** Image classification
|
| 23 |
+
|
| 24 |
+
## 🚀 How to Use
|
| 25 |
+
1. Click **“Upload Image”** and select a photo of a citrus leaf.
|
| 26 |
+
2. The app will output the **predicted pest or disease category** with confidence scores.
|
| 27 |
+
|
| 28 |
+
## 📂 Repository Structure
|
app.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple Demo for Pest and Disease Classification
|
| 3 |
+
For Hugging Face Space Deployment
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import json
|
| 9 |
+
import gradio as gr
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
|
| 12 |
+
from model import create_model
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PestDiseasePredictor:
|
| 16 |
+
"""Simple predictor class"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, checkpoint_path, label_mapping_path, backbone='resnet50', device='cuda'):
|
| 19 |
+
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
| 20 |
+
|
| 21 |
+
# Load label mapping
|
| 22 |
+
with open(label_mapping_path, 'r', encoding='utf-8') as f:
|
| 23 |
+
mapping = json.load(f)
|
| 24 |
+
self.id_to_label = {int(k): v for k, v in mapping['id_to_label'].items()}
|
| 25 |
+
self.num_classes = mapping['num_classes']
|
| 26 |
+
|
| 27 |
+
# Load model
|
| 28 |
+
self.model = create_model(
|
| 29 |
+
num_classes=self.num_classes,
|
| 30 |
+
backbone=backbone,
|
| 31 |
+
pretrained=False
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Load checkpoint
|
| 35 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 36 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 37 |
+
self.model = self.model.to(self.device)
|
| 38 |
+
self.model.eval()
|
| 39 |
+
|
| 40 |
+
# Image transforms
|
| 41 |
+
self.transform = transforms.Compose([
|
| 42 |
+
transforms.Resize((224, 224)),
|
| 43 |
+
transforms.ToTensor(),
|
| 44 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 45 |
+
std=[0.229, 0.224, 0.225])
|
| 46 |
+
])
|
| 47 |
+
|
| 48 |
+
print(f"✅ Model loaded from {checkpoint_path}")
|
| 49 |
+
print(f"💻 Device: {self.device}")
|
| 50 |
+
print(f"📚 Classes: {self.num_classes}")
|
| 51 |
+
|
| 52 |
+
def predict(self, image):
|
| 53 |
+
if image.mode != 'RGB':
|
| 54 |
+
image = image.convert('RGB')
|
| 55 |
+
|
| 56 |
+
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
outputs = self.model(img_tensor)
|
| 59 |
+
probs = torch.nn.functional.softmax(outputs, dim=1)[0].cpu().numpy()
|
| 60 |
+
|
| 61 |
+
results = {self.id_to_label[i]: float(p) for i, p in enumerate(probs)}
|
| 62 |
+
return dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# ========== For Hugging Face Space ==========
|
| 66 |
+
checkpoint_path = "checkpoints/best_efficientnet_b3.pth"
|
| 67 |
+
label_mapping_path = "label_mapping.json"
|
| 68 |
+
backbone = 'efficientnet_b3'
|
| 69 |
+
device = "cuda"
|
| 70 |
+
|
| 71 |
+
predictor = PestDiseasePredictor(
|
| 72 |
+
checkpoint_path=checkpoint_path,
|
| 73 |
+
label_mapping_path=label_mapping_path,
|
| 74 |
+
backbone=backbone,
|
| 75 |
+
device=device
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def predict_image(image):
|
| 79 |
+
if image is None:
|
| 80 |
+
return None
|
| 81 |
+
return predictor.predict(image)
|
| 82 |
+
|
| 83 |
+
demo = gr.Interface(
|
| 84 |
+
fn=predict_image,
|
| 85 |
+
inputs=gr.Image(type="pil", label="Upload Image"),
|
| 86 |
+
outputs=gr.Label(num_top_classes=10, label="Predictions"),
|
| 87 |
+
title="🌿 Pest and Disease Classification",
|
| 88 |
+
description="Upload an image of a citrus leaf to classify its pest or disease type.",
|
| 89 |
+
theme=gr.themes.Soft(),
|
| 90 |
+
allow_flagging="never"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
demo.launch()
|
checkpoints/best_model_993.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e18b22b823871125c07933e128e7afed92110da462fee5781462fed1066b4e33
|
| 3 |
+
size 138717293
|
dataset.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch Dataset and DataLoader for Pest and Disease Classification
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import Dataset, DataLoader
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import json
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class PestDiseaseDataset(Dataset):
|
| 15 |
+
"""Custom Dataset for loading pest and disease images"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, csv_file, label_mapping_file, split='train', transform=None):
|
| 18 |
+
"""
|
| 19 |
+
Args:
|
| 20 |
+
csv_file (str): Path to CSV file with image paths and labels
|
| 21 |
+
label_mapping_file (str): Path to JSON file with label mappings
|
| 22 |
+
split (str): One of 'train', 'val', or 'test'
|
| 23 |
+
transform (callable, optional): Optional transform to be applied on images
|
| 24 |
+
"""
|
| 25 |
+
self.df = pd.read_csv(csv_file)
|
| 26 |
+
self.df = self.df[self.df['split'] == split].reset_index(drop=True)
|
| 27 |
+
|
| 28 |
+
# Load label mapping
|
| 29 |
+
with open(label_mapping_file, 'r', encoding='utf-8') as f:
|
| 30 |
+
mapping = json.load(f)
|
| 31 |
+
self.label_to_id = mapping['label_to_id']
|
| 32 |
+
self.id_to_label = {int(k): v for k, v in mapping['id_to_label'].items()}
|
| 33 |
+
self.num_classes = mapping['num_classes']
|
| 34 |
+
|
| 35 |
+
self.transform = transform
|
| 36 |
+
self.split = split
|
| 37 |
+
|
| 38 |
+
print(f"Loaded {split} set: {len(self.df)} images")
|
| 39 |
+
|
| 40 |
+
def __len__(self):
|
| 41 |
+
return len(self.df)
|
| 42 |
+
|
| 43 |
+
def __getitem__(self, idx):
|
| 44 |
+
"""
|
| 45 |
+
Returns:
|
| 46 |
+
image: Transformed image tensor
|
| 47 |
+
label: Label ID (integer)
|
| 48 |
+
"""
|
| 49 |
+
row = self.df.iloc[idx]
|
| 50 |
+
|
| 51 |
+
# Load image
|
| 52 |
+
img_path = row['image_path']
|
| 53 |
+
image = Image.open(img_path).convert('RGB')
|
| 54 |
+
|
| 55 |
+
# Get label
|
| 56 |
+
label_name = row['label']
|
| 57 |
+
label = self.label_to_id[label_name]
|
| 58 |
+
|
| 59 |
+
# Apply transforms
|
| 60 |
+
if self.transform:
|
| 61 |
+
image = self.transform(image)
|
| 62 |
+
|
| 63 |
+
return image, label
|
| 64 |
+
|
| 65 |
+
def get_label_name(self, label_id):
|
| 66 |
+
"""Convert label ID back to label name"""
|
| 67 |
+
return self.id_to_label[label_id]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_transforms(split='train', img_size=224):
|
| 71 |
+
"""
|
| 72 |
+
Get data augmentation transforms for different splits
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
split (str): 'train', 'val', or 'test'
|
| 76 |
+
img_size (int): Target image size (default: 224 for most pretrained models)
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
transforms.Compose: Composed transforms
|
| 80 |
+
"""
|
| 81 |
+
if split == 'train':
|
| 82 |
+
# Training: Apply data augmentation
|
| 83 |
+
return transforms.Compose([
|
| 84 |
+
transforms.Resize((img_size, img_size)),
|
| 85 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
| 86 |
+
transforms.RandomVerticalFlip(p=0.3),
|
| 87 |
+
transforms.RandomRotation(degrees=30),
|
| 88 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
| 89 |
+
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
|
| 90 |
+
transforms.ToTensor(),
|
| 91 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 92 |
+
std=[0.229, 0.224, 0.225])
|
| 93 |
+
])
|
| 94 |
+
else:
|
| 95 |
+
# Validation/Test: No augmentation, only resize and normalize
|
| 96 |
+
return transforms.Compose([
|
| 97 |
+
transforms.Resize((img_size, img_size)),
|
| 98 |
+
transforms.ToTensor(),
|
| 99 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 100 |
+
std=[0.229, 0.224, 0.225])
|
| 101 |
+
])
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def get_dataloaders(csv_file='dataset.csv',
|
| 105 |
+
label_mapping_file='label_mapping.json',
|
| 106 |
+
batch_size=32,
|
| 107 |
+
img_size=224,
|
| 108 |
+
num_workers=4):
|
| 109 |
+
"""
|
| 110 |
+
Create train, validation, and test dataloaders
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
csv_file (str): Path to dataset CSV
|
| 114 |
+
label_mapping_file (str): Path to label mapping JSON
|
| 115 |
+
batch_size (int): Batch size for training
|
| 116 |
+
img_size (int): Image size for models
|
| 117 |
+
num_workers (int): Number of workers for data loading
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
dict: Dictionary containing 'train', 'val', 'test' dataloaders and 'num_classes'
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
# Create datasets
|
| 124 |
+
train_dataset = PestDiseaseDataset(
|
| 125 |
+
csv_file=csv_file,
|
| 126 |
+
label_mapping_file=label_mapping_file,
|
| 127 |
+
split='train',
|
| 128 |
+
transform=get_transforms('train', img_size)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
val_dataset = PestDiseaseDataset(
|
| 132 |
+
csv_file=csv_file,
|
| 133 |
+
label_mapping_file=label_mapping_file,
|
| 134 |
+
split='val',
|
| 135 |
+
transform=get_transforms('val', img_size)
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
test_dataset = PestDiseaseDataset(
|
| 139 |
+
csv_file=csv_file,
|
| 140 |
+
label_mapping_file=label_mapping_file,
|
| 141 |
+
split='test',
|
| 142 |
+
transform=get_transforms('test', img_size)
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Create dataloaders
|
| 146 |
+
train_loader = DataLoader(
|
| 147 |
+
train_dataset,
|
| 148 |
+
batch_size=batch_size,
|
| 149 |
+
shuffle=True,
|
| 150 |
+
num_workers=num_workers,
|
| 151 |
+
pin_memory=True
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
val_loader = DataLoader(
|
| 155 |
+
val_dataset,
|
| 156 |
+
batch_size=batch_size,
|
| 157 |
+
shuffle=False,
|
| 158 |
+
num_workers=num_workers,
|
| 159 |
+
pin_memory=True
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
test_loader = DataLoader(
|
| 163 |
+
test_dataset,
|
| 164 |
+
batch_size=batch_size,
|
| 165 |
+
shuffle=False,
|
| 166 |
+
num_workers=num_workers,
|
| 167 |
+
pin_memory=True
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
return {
|
| 171 |
+
'train': train_loader,
|
| 172 |
+
'val': val_loader,
|
| 173 |
+
'test': test_loader,
|
| 174 |
+
'num_classes': train_dataset.num_classes,
|
| 175 |
+
'datasets': {
|
| 176 |
+
'train': train_dataset,
|
| 177 |
+
'val': val_dataset,
|
| 178 |
+
'test': test_dataset
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def calculate_class_weights(csv_file='dataset.csv', label_mapping_file='label_mapping.json'):
|
| 184 |
+
"""
|
| 185 |
+
Calculate class weights for handling imbalanced dataset
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
torch.Tensor: Class weights for loss function
|
| 189 |
+
"""
|
| 190 |
+
df = pd.read_csv(csv_file)
|
| 191 |
+
train_df = df[df['split'] == 'train']
|
| 192 |
+
|
| 193 |
+
with open(label_mapping_file, 'r', encoding='utf-8') as f:
|
| 194 |
+
mapping = json.load(f)
|
| 195 |
+
label_to_id = mapping['label_to_id']
|
| 196 |
+
num_classes = mapping['num_classes']
|
| 197 |
+
|
| 198 |
+
# Count samples per class
|
| 199 |
+
class_counts = {}
|
| 200 |
+
for label in train_df['label']:
|
| 201 |
+
label_id = label_to_id[label]
|
| 202 |
+
class_counts[label_id] = class_counts.get(label_id, 0) + 1
|
| 203 |
+
|
| 204 |
+
# Calculate weights (inverse frequency)
|
| 205 |
+
total_samples = len(train_df)
|
| 206 |
+
weights = []
|
| 207 |
+
for i in range(num_classes):
|
| 208 |
+
count = class_counts.get(i, 1)
|
| 209 |
+
weight = total_samples / (num_classes * count)
|
| 210 |
+
weights.append(weight)
|
| 211 |
+
|
| 212 |
+
weights = torch.FloatTensor(weights)
|
| 213 |
+
|
| 214 |
+
print("\nClass weights:")
|
| 215 |
+
for i, w in enumerate(weights):
|
| 216 |
+
print(f" Class {i}: {w:.4f}")
|
| 217 |
+
|
| 218 |
+
return weights
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
if __name__ == "__main__":
|
| 222 |
+
"""Test the dataloader"""
|
| 223 |
+
print("Testing Pest and Disease Dataloader")
|
| 224 |
+
print("=" * 60)
|
| 225 |
+
|
| 226 |
+
# Get dataloaders
|
| 227 |
+
loaders = get_dataloaders(batch_size=8, img_size=224, num_workers=0)
|
| 228 |
+
|
| 229 |
+
# Calculate class weights
|
| 230 |
+
class_weights = calculate_class_weights()
|
| 231 |
+
|
| 232 |
+
print("\n" + "=" * 60)
|
| 233 |
+
print("Testing batch loading...")
|
| 234 |
+
print("=" * 60)
|
| 235 |
+
|
| 236 |
+
# Test loading a batch from train set
|
| 237 |
+
train_loader = loaders['train']
|
| 238 |
+
train_dataset = loaders['datasets']['train']
|
| 239 |
+
|
| 240 |
+
for images, labels in train_loader:
|
| 241 |
+
print(f"\nBatch shape: {images.shape}")
|
| 242 |
+
print(f"Labels shape: {labels.shape}")
|
| 243 |
+
print(f"Image dtype: {images.dtype}")
|
| 244 |
+
print(f"Labels: {labels.tolist()}")
|
| 245 |
+
print(f"Label names: {[train_dataset.get_label_name(l.item()) for l in labels]}")
|
| 246 |
+
|
| 247 |
+
# Check value ranges
|
| 248 |
+
print(f"\nImage value range: [{images.min():.3f}, {images.max():.3f}]")
|
| 249 |
+
break
|
| 250 |
+
|
| 251 |
+
print("\n" + "=" * 60)
|
| 252 |
+
print("Dataloader test completed successfully!")
|
| 253 |
+
print("=" * 60)
|
demo.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple Demo for Pest and Disease Classification
|
| 3 |
+
Upload an image and get prediction
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import json
|
| 9 |
+
import argparse
|
| 10 |
+
import gradio as gr
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
|
| 13 |
+
from model import create_model
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PestDiseasePredictor:
|
| 17 |
+
"""Simple predictor class"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, checkpoint_path, label_mapping_path, backbone='resnet50', device='cuda'):
|
| 20 |
+
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
| 21 |
+
|
| 22 |
+
# Load label mapping
|
| 23 |
+
with open(label_mapping_path, 'r', encoding='utf-8') as f:
|
| 24 |
+
mapping = json.load(f)
|
| 25 |
+
self.id_to_label = {int(k): v for k, v in mapping['id_to_label'].items()}
|
| 26 |
+
self.num_classes = mapping['num_classes']
|
| 27 |
+
|
| 28 |
+
# Load model
|
| 29 |
+
self.model = create_model(
|
| 30 |
+
num_classes=self.num_classes,
|
| 31 |
+
backbone=backbone,
|
| 32 |
+
pretrained=False
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Load checkpoint
|
| 36 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 37 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 38 |
+
self.model = self.model.to(self.device)
|
| 39 |
+
self.model.eval()
|
| 40 |
+
|
| 41 |
+
# Image transforms
|
| 42 |
+
self.transform = transforms.Compose([
|
| 43 |
+
transforms.Resize((224, 224)),
|
| 44 |
+
transforms.ToTensor(),
|
| 45 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 46 |
+
std=[0.229, 0.224, 0.225])
|
| 47 |
+
])
|
| 48 |
+
|
| 49 |
+
print(f"Model loaded from {checkpoint_path}")
|
| 50 |
+
print(f"Device: {self.device}")
|
| 51 |
+
print(f"Classes: {self.num_classes}")
|
| 52 |
+
|
| 53 |
+
def predict(self, image):
|
| 54 |
+
"""
|
| 55 |
+
Predict class for input image
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
image: PIL Image
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
dict: {class_name: probability}
|
| 62 |
+
"""
|
| 63 |
+
# Preprocess
|
| 64 |
+
if image.mode != 'RGB':
|
| 65 |
+
image = image.convert('RGB')
|
| 66 |
+
|
| 67 |
+
img_tensor = self.transform(image).unsqueeze(0)
|
| 68 |
+
img_tensor = img_tensor.to(self.device)
|
| 69 |
+
|
| 70 |
+
# Predict
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
outputs = self.model(img_tensor)
|
| 73 |
+
probabilities = torch.nn.functional.softmax(outputs, dim=1)
|
| 74 |
+
probs = probabilities[0].cpu().numpy()
|
| 75 |
+
|
| 76 |
+
# Create results dictionary
|
| 77 |
+
results = {}
|
| 78 |
+
for idx, prob in enumerate(probs):
|
| 79 |
+
class_name = self.id_to_label[idx]
|
| 80 |
+
results[class_name] = float(prob)
|
| 81 |
+
|
| 82 |
+
# Sort by probability
|
| 83 |
+
results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
|
| 84 |
+
|
| 85 |
+
return results
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def create_demo(predictor):
|
| 89 |
+
"""Create Gradio interface"""
|
| 90 |
+
|
| 91 |
+
def predict_image(image):
|
| 92 |
+
"""Prediction function for Gradio"""
|
| 93 |
+
if image is None:
|
| 94 |
+
return None
|
| 95 |
+
|
| 96 |
+
results = predictor.predict(image)
|
| 97 |
+
return results
|
| 98 |
+
|
| 99 |
+
# Create interface
|
| 100 |
+
demo = gr.Interface(
|
| 101 |
+
fn=predict_image,
|
| 102 |
+
inputs=gr.Image(type="pil", label="Upload Image"),
|
| 103 |
+
outputs=gr.Label(num_top_classes=10, label="Predictions"),
|
| 104 |
+
title="🌿 Pest and Disease Classification",
|
| 105 |
+
description="Upload an image of a citrus plant leaf to classify if it's healthy or has pests/diseases.",
|
| 106 |
+
examples=None,
|
| 107 |
+
theme=gr.themes.Soft(),
|
| 108 |
+
allow_flagging="never"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
return demo
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def main(args):
|
| 115 |
+
"""Main function"""
|
| 116 |
+
print("Starting Pest and Disease Classification Demo...")
|
| 117 |
+
print("=" * 60)
|
| 118 |
+
|
| 119 |
+
# Create predictor
|
| 120 |
+
predictor = PestDiseasePredictor(
|
| 121 |
+
checkpoint_path=args.checkpoint,
|
| 122 |
+
label_mapping_path=args.label_mapping,
|
| 123 |
+
backbone=args.backbone,
|
| 124 |
+
device=args.device
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Create and launch demo
|
| 128 |
+
demo = create_demo(predictor)
|
| 129 |
+
|
| 130 |
+
print("\n" + "=" * 60)
|
| 131 |
+
print("Launching demo...")
|
| 132 |
+
print("=" * 60)
|
| 133 |
+
|
| 134 |
+
demo.launch(
|
| 135 |
+
server_name=args.host,
|
| 136 |
+
server_port=args.port,
|
| 137 |
+
share=args.share
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
if __name__ == "__main__":
|
| 142 |
+
parser = argparse.ArgumentParser(description='Demo for Pest and Disease Classification')
|
| 143 |
+
|
| 144 |
+
parser.add_argument('--checkpoint', type=str, default='checkpoints/best_model.pth',
|
| 145 |
+
help='Path to model checkpoint')
|
| 146 |
+
parser.add_argument('--label_mapping', type=str, default='label_mapping.json',
|
| 147 |
+
help='Path to label mapping JSON')
|
| 148 |
+
parser.add_argument('--backbone', type=str, default='resnet50',
|
| 149 |
+
choices=['resnet50', 'resnet101', 'efficientnet_b0',
|
| 150 |
+
'efficientnet_b3', 'mobilenet_v2'],
|
| 151 |
+
help='Model backbone')
|
| 152 |
+
parser.add_argument('--device', type=str, default='cuda',
|
| 153 |
+
choices=['cuda', 'cpu'],
|
| 154 |
+
help='Device to use')
|
| 155 |
+
parser.add_argument('--host', type=str, default='127.0.0.1',
|
| 156 |
+
help='Server host')
|
| 157 |
+
parser.add_argument('--port', type=int, default=7860,
|
| 158 |
+
help='Server port')
|
| 159 |
+
parser.add_argument('--share', action='store_true',
|
| 160 |
+
help='Create public link')
|
| 161 |
+
|
| 162 |
+
args = parser.parse_args()
|
| 163 |
+
main(args)
|
evaluate.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation script for Pest and Disease Classification
|
| 3 |
+
Generate confusion matrix, classification report, and per-class metrics
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import seaborn as sns
|
| 10 |
+
from sklearn.metrics import confusion_matrix, classification_report, f1_score
|
| 11 |
+
import argparse
|
| 12 |
+
import json
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
from dataset import get_dataloaders
|
| 16 |
+
from model import create_model
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def evaluate_model(model, dataloader, device, dataset):
|
| 20 |
+
"""
|
| 21 |
+
Evaluate model on a dataset
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
predictions: List of predicted labels
|
| 25 |
+
true_labels: List of true labels
|
| 26 |
+
accuracy: Overall accuracy
|
| 27 |
+
"""
|
| 28 |
+
model.eval()
|
| 29 |
+
all_preds = []
|
| 30 |
+
all_labels = []
|
| 31 |
+
|
| 32 |
+
with torch.no_grad():
|
| 33 |
+
for inputs, labels in dataloader:
|
| 34 |
+
inputs = inputs.to(device)
|
| 35 |
+
labels = labels.to(device)
|
| 36 |
+
|
| 37 |
+
outputs = model(inputs)
|
| 38 |
+
_, preds = torch.max(outputs, 1)
|
| 39 |
+
|
| 40 |
+
all_preds.extend(preds.cpu().numpy())
|
| 41 |
+
all_labels.extend(labels.cpu().numpy())
|
| 42 |
+
|
| 43 |
+
all_preds = np.array(all_preds)
|
| 44 |
+
all_labels = np.array(all_labels)
|
| 45 |
+
accuracy = np.mean(all_preds == all_labels)
|
| 46 |
+
|
| 47 |
+
return all_preds, all_labels, accuracy
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def plot_confusion_matrix(y_true, y_pred, class_names, save_path='confusion_matrix.png'):
|
| 51 |
+
"""
|
| 52 |
+
Plot and save confusion matrix
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
y_true: True labels
|
| 56 |
+
y_pred: Predicted labels
|
| 57 |
+
class_names: List of class names
|
| 58 |
+
save_path: Path to save figure
|
| 59 |
+
"""
|
| 60 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 61 |
+
|
| 62 |
+
# Calculate percentages
|
| 63 |
+
cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
|
| 64 |
+
|
| 65 |
+
# Create figure
|
| 66 |
+
plt.figure(figsize=(12, 10))
|
| 67 |
+
|
| 68 |
+
# Plot with annotations
|
| 69 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
| 70 |
+
xticklabels=class_names,
|
| 71 |
+
yticklabels=class_names,
|
| 72 |
+
cbar_kws={'label': 'Count'})
|
| 73 |
+
|
| 74 |
+
plt.title('Confusion Matrix', fontsize=16, pad=20)
|
| 75 |
+
plt.ylabel('True Label', fontsize=12)
|
| 76 |
+
plt.xlabel('Predicted Label', fontsize=12)
|
| 77 |
+
plt.xticks(rotation=45, ha='right')
|
| 78 |
+
plt.yticks(rotation=0)
|
| 79 |
+
plt.tight_layout()
|
| 80 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 81 |
+
print(f"Confusion matrix saved to {save_path}")
|
| 82 |
+
|
| 83 |
+
# Also save percentage version
|
| 84 |
+
plt.figure(figsize=(12, 10))
|
| 85 |
+
sns.heatmap(cm_percent, annot=True, fmt='.1f', cmap='Blues',
|
| 86 |
+
xticklabels=class_names,
|
| 87 |
+
yticklabels=class_names,
|
| 88 |
+
cbar_kws={'label': 'Percentage (%)'})
|
| 89 |
+
|
| 90 |
+
plt.title('Confusion Matrix (Percentage)', fontsize=16, pad=20)
|
| 91 |
+
plt.ylabel('True Label', fontsize=12)
|
| 92 |
+
plt.xlabel('Predicted Label', fontsize=12)
|
| 93 |
+
plt.xticks(rotation=45, ha='right')
|
| 94 |
+
plt.yticks(rotation=0)
|
| 95 |
+
plt.tight_layout()
|
| 96 |
+
|
| 97 |
+
save_path_percent = str(save_path).replace('.png', '_percent.png')
|
| 98 |
+
plt.savefig(save_path_percent, dpi=300, bbox_inches='tight')
|
| 99 |
+
print(f"Confusion matrix (percentage) saved to {save_path_percent}")
|
| 100 |
+
|
| 101 |
+
plt.close('all')
|
| 102 |
+
|
| 103 |
+
return cm
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def generate_classification_report(y_true, y_pred, class_names, save_path='classification_report.txt'):
|
| 107 |
+
"""
|
| 108 |
+
Generate and save detailed classification report
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
y_true: True labels
|
| 112 |
+
y_pred: Predicted labels
|
| 113 |
+
class_names: List of class names
|
| 114 |
+
save_path: Path to save report
|
| 115 |
+
"""
|
| 116 |
+
# Generate report
|
| 117 |
+
report = classification_report(
|
| 118 |
+
y_true, y_pred,
|
| 119 |
+
target_names=class_names,
|
| 120 |
+
digits=4
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Print to console
|
| 124 |
+
print("\n" + "=" * 80)
|
| 125 |
+
print("Classification Report")
|
| 126 |
+
print("=" * 80)
|
| 127 |
+
print(report)
|
| 128 |
+
|
| 129 |
+
# Save to file
|
| 130 |
+
with open(save_path, 'w', encoding='utf-8') as f:
|
| 131 |
+
f.write("Classification Report\n")
|
| 132 |
+
f.write("=" * 80 + "\n")
|
| 133 |
+
f.write(report)
|
| 134 |
+
|
| 135 |
+
print(f"\nClassification report saved to {save_path}")
|
| 136 |
+
|
| 137 |
+
# Calculate per-class metrics
|
| 138 |
+
from sklearn.metrics import precision_recall_fscore_support
|
| 139 |
+
precision, recall, f1, support = precision_recall_fscore_support(
|
| 140 |
+
y_true, y_pred, average=None
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Create detailed metrics dictionary
|
| 144 |
+
metrics = {}
|
| 145 |
+
for i, class_name in enumerate(class_names):
|
| 146 |
+
metrics[class_name] = {
|
| 147 |
+
'precision': float(precision[i]),
|
| 148 |
+
'recall': float(recall[i]),
|
| 149 |
+
'f1-score': float(f1[i]),
|
| 150 |
+
'support': int(support[i])
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
# Add overall metrics
|
| 154 |
+
metrics['overall'] = {
|
| 155 |
+
'accuracy': float(np.mean(y_true == y_pred)),
|
| 156 |
+
'macro_avg_f1': float(np.mean(f1)),
|
| 157 |
+
'weighted_avg_f1': float(f1_score(y_true, y_pred, average='weighted'))
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
# Save metrics as JSON
|
| 161 |
+
metrics_path = str(save_path).replace('.txt', '.json')
|
| 162 |
+
with open(metrics_path, 'w', encoding='utf-8') as f:
|
| 163 |
+
json.dump(metrics, f, indent=2, ensure_ascii=False)
|
| 164 |
+
|
| 165 |
+
print(f"Metrics JSON saved to {metrics_path}")
|
| 166 |
+
|
| 167 |
+
return metrics
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def plot_per_class_metrics(metrics, class_names, save_path='per_class_metrics.png'):
|
| 171 |
+
"""
|
| 172 |
+
Plot per-class precision, recall, and F1-score
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
metrics: Dictionary of metrics
|
| 176 |
+
class_names: List of class names
|
| 177 |
+
save_path: Path to save figure
|
| 178 |
+
"""
|
| 179 |
+
precision = [metrics[name]['precision'] for name in class_names]
|
| 180 |
+
recall = [metrics[name]['recall'] for name in class_names]
|
| 181 |
+
f1 = [metrics[name]['f1-score'] for name in class_names]
|
| 182 |
+
|
| 183 |
+
x = np.arange(len(class_names))
|
| 184 |
+
width = 0.25
|
| 185 |
+
|
| 186 |
+
fig, ax = plt.subplots(figsize=(14, 6))
|
| 187 |
+
ax.bar(x - width, precision, width, label='Precision', alpha=0.8)
|
| 188 |
+
ax.bar(x, recall, width, label='Recall', alpha=0.8)
|
| 189 |
+
ax.bar(x + width, f1, width, label='F1-Score', alpha=0.8)
|
| 190 |
+
|
| 191 |
+
ax.set_xlabel('Class', fontsize=12)
|
| 192 |
+
ax.set_ylabel('Score', fontsize=12)
|
| 193 |
+
ax.set_title('Per-Class Metrics', fontsize=14, pad=20)
|
| 194 |
+
ax.set_xticks(x)
|
| 195 |
+
ax.set_xticklabels(class_names, rotation=45, ha='right')
|
| 196 |
+
ax.legend()
|
| 197 |
+
ax.grid(axis='y', alpha=0.3)
|
| 198 |
+
ax.set_ylim([0, 1.1])
|
| 199 |
+
|
| 200 |
+
plt.tight_layout()
|
| 201 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 202 |
+
print(f"Per-class metrics plot saved to {save_path}")
|
| 203 |
+
plt.close()
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def main(args):
|
| 207 |
+
"""Main evaluation function"""
|
| 208 |
+
print("Pest and Disease Classification Evaluation")
|
| 209 |
+
print("=" * 80)
|
| 210 |
+
print(f"Configuration:")
|
| 211 |
+
print(f" Checkpoint: {args.checkpoint}")
|
| 212 |
+
print(f" Split: {args.split}")
|
| 213 |
+
print(f" Batch size: {args.batch_size}")
|
| 214 |
+
print(f" Device: {args.device}")
|
| 215 |
+
print("=" * 80)
|
| 216 |
+
|
| 217 |
+
# Set device
|
| 218 |
+
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
|
| 219 |
+
print(f"\nUsing device: {device}")
|
| 220 |
+
|
| 221 |
+
# Load data
|
| 222 |
+
print("\nLoading datasets...")
|
| 223 |
+
loaders = get_dataloaders(
|
| 224 |
+
csv_file=args.csv_file,
|
| 225 |
+
label_mapping_file=args.label_mapping,
|
| 226 |
+
batch_size=args.batch_size,
|
| 227 |
+
img_size=args.img_size,
|
| 228 |
+
num_workers=args.num_workers
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Get class names
|
| 232 |
+
dataset = loaders['datasets'][args.split]
|
| 233 |
+
class_names = [dataset.get_label_name(i) for i in range(dataset.num_classes)]
|
| 234 |
+
print(f"Classes: {class_names}")
|
| 235 |
+
|
| 236 |
+
# Create model
|
| 237 |
+
print(f"\nCreating model: {args.backbone}")
|
| 238 |
+
model = create_model(
|
| 239 |
+
num_classes=loaders['num_classes'],
|
| 240 |
+
backbone=args.backbone,
|
| 241 |
+
pretrained=False
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Load checkpoint
|
| 245 |
+
print(f"\nLoading checkpoint: {args.checkpoint}")
|
| 246 |
+
checkpoint = torch.load(args.checkpoint, map_location=device)
|
| 247 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 248 |
+
model = model.to(device)
|
| 249 |
+
|
| 250 |
+
if 'val_acc' in checkpoint:
|
| 251 |
+
print(f"Checkpoint validation accuracy: {checkpoint['val_acc']:.4f}")
|
| 252 |
+
|
| 253 |
+
# Evaluate
|
| 254 |
+
print(f"\nEvaluating on {args.split} set...")
|
| 255 |
+
dataloader = loaders[args.split]
|
| 256 |
+
predictions, true_labels, accuracy = evaluate_model(model, dataloader, device, dataset)
|
| 257 |
+
|
| 258 |
+
print(f"\n{args.split.capitalize()} Set Accuracy: {accuracy:.4f}")
|
| 259 |
+
|
| 260 |
+
# Create output directory
|
| 261 |
+
output_dir = Path(args.output_dir)
|
| 262 |
+
output_dir.mkdir(exist_ok=True)
|
| 263 |
+
|
| 264 |
+
# Generate confusion matrix
|
| 265 |
+
print("\nGenerating confusion matrix...")
|
| 266 |
+
cm = plot_confusion_matrix(
|
| 267 |
+
true_labels, predictions, class_names,
|
| 268 |
+
save_path=output_dir / f'confusion_matrix_{args.split}.png'
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# Generate classification report
|
| 272 |
+
print("\nGenerating classification report...")
|
| 273 |
+
metrics = generate_classification_report(
|
| 274 |
+
true_labels, predictions, class_names,
|
| 275 |
+
save_path=output_dir / f'classification_report_{args.split}.txt'
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# Plot per-class metrics
|
| 279 |
+
print("\nGenerating per-class metrics plot...")
|
| 280 |
+
plot_per_class_metrics(
|
| 281 |
+
metrics, class_names,
|
| 282 |
+
save_path=output_dir / f'per_class_metrics_{args.split}.png'
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
print("\n" + "=" * 80)
|
| 286 |
+
print("Evaluation complete!")
|
| 287 |
+
print(f"Results saved to {output_dir}/")
|
| 288 |
+
print("=" * 80)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
if __name__ == "__main__":
|
| 292 |
+
parser = argparse.ArgumentParser(description='Evaluate Pest and Disease Classifier')
|
| 293 |
+
|
| 294 |
+
# Data parameters
|
| 295 |
+
parser.add_argument('--csv_file', type=str, default='dataset.csv',
|
| 296 |
+
help='Path to dataset CSV')
|
| 297 |
+
parser.add_argument('--label_mapping', type=str, default='label_mapping.json',
|
| 298 |
+
help='Path to label mapping JSON')
|
| 299 |
+
|
| 300 |
+
# Model parameters
|
| 301 |
+
parser.add_argument('--checkpoint', type=str, default='checkpoints/best_model.pth',
|
| 302 |
+
help='Path to model checkpoint')
|
| 303 |
+
parser.add_argument('--backbone', type=str, default='resnet50',
|
| 304 |
+
choices=['resnet50', 'resnet101', 'efficientnet_b0',
|
| 305 |
+
'efficientnet_b3', 'mobilenet_v2'],
|
| 306 |
+
help='Model backbone')
|
| 307 |
+
|
| 308 |
+
# Evaluation parameters
|
| 309 |
+
parser.add_argument('--split', type=str, default='test',
|
| 310 |
+
choices=['train', 'val', 'test'],
|
| 311 |
+
help='Dataset split to evaluate')
|
| 312 |
+
parser.add_argument('--batch_size', type=int, default=16,
|
| 313 |
+
help='Batch size')
|
| 314 |
+
parser.add_argument('--img_size', type=int, default=224,
|
| 315 |
+
help='Image size')
|
| 316 |
+
|
| 317 |
+
# System parameters
|
| 318 |
+
parser.add_argument('--device', type=str, default='cuda',
|
| 319 |
+
choices=['cuda', 'cpu'],
|
| 320 |
+
help='Device to use')
|
| 321 |
+
parser.add_argument('--num_workers', type=int, default=4,
|
| 322 |
+
help='Number of data loading workers')
|
| 323 |
+
parser.add_argument('--output_dir', type=str, default='evaluation_results',
|
| 324 |
+
help='Directory to save results')
|
| 325 |
+
|
| 326 |
+
args = parser.parse_args()
|
| 327 |
+
main(args)
|
label_mapping.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"label_to_id": {
|
| 3 |
+
"介殼蟲": 0,
|
| 4 |
+
"健康植株-椪柑": 1,
|
| 5 |
+
"健康植株-茂谷柑": 2,
|
| 6 |
+
"油斑病": 3,
|
| 7 |
+
"潛葉蛾": 4,
|
| 8 |
+
"潰瘍病": 5,
|
| 9 |
+
"煤煙病": 6,
|
| 10 |
+
"薊馬": 7,
|
| 11 |
+
"蚜蟲": 8,
|
| 12 |
+
"黑點病": 9
|
| 13 |
+
},
|
| 14 |
+
"id_to_label": {
|
| 15 |
+
"0": "介殼蟲",
|
| 16 |
+
"1": "健康植株-椪柑",
|
| 17 |
+
"2": "健康植株-茂谷柑",
|
| 18 |
+
"3": "油斑病",
|
| 19 |
+
"4": "潛葉蛾",
|
| 20 |
+
"5": "潰瘍病",
|
| 21 |
+
"6": "煤煙病",
|
| 22 |
+
"7": "薊馬",
|
| 23 |
+
"8": "蚜蟲",
|
| 24 |
+
"9": "黑點病"
|
| 25 |
+
},
|
| 26 |
+
"num_classes": 10
|
| 27 |
+
}
|
model.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Classification Models for Pest and Disease Detection
|
| 3 |
+
Supports multiple pretrained backbones: ResNet, EfficientNet, MobileNet
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torchvision.models as models
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PestDiseaseClassifier(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
General classifier with pretrained backbone for transfer learning
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, num_classes=10, backbone='resnet50', pretrained=True, dropout=0.3):
|
| 17 |
+
"""
|
| 18 |
+
Args:
|
| 19 |
+
num_classes (int): Number of output classes
|
| 20 |
+
backbone (str): Backbone architecture ('resnet50', 'resnet101', 'efficientnet_b0',
|
| 21 |
+
'efficientnet_b3', 'mobilenet_v2')
|
| 22 |
+
pretrained (bool): Use pretrained weights
|
| 23 |
+
dropout (float): Dropout rate for regularization
|
| 24 |
+
"""
|
| 25 |
+
super(PestDiseaseClassifier, self).__init__()
|
| 26 |
+
|
| 27 |
+
self.backbone_name = backbone
|
| 28 |
+
self.num_classes = num_classes
|
| 29 |
+
|
| 30 |
+
# Select backbone
|
| 31 |
+
if backbone == 'resnet50':
|
| 32 |
+
self.backbone = models.resnet50(pretrained=pretrained)
|
| 33 |
+
num_features = self.backbone.fc.in_features
|
| 34 |
+
self.backbone.fc = nn.Identity()
|
| 35 |
+
|
| 36 |
+
elif backbone == 'resnet101':
|
| 37 |
+
self.backbone = models.resnet101(pretrained=pretrained)
|
| 38 |
+
num_features = self.backbone.fc.in_features
|
| 39 |
+
self.backbone.fc = nn.Identity()
|
| 40 |
+
|
| 41 |
+
elif backbone == 'efficientnet_b0':
|
| 42 |
+
self.backbone = models.efficientnet_b0(pretrained=pretrained)
|
| 43 |
+
num_features = self.backbone.classifier[1].in_features
|
| 44 |
+
self.backbone.classifier = nn.Identity()
|
| 45 |
+
|
| 46 |
+
elif backbone == 'efficientnet_b3':
|
| 47 |
+
self.backbone = models.efficientnet_b3(pretrained=pretrained)
|
| 48 |
+
num_features = self.backbone.classifier[1].in_features
|
| 49 |
+
self.backbone.classifier = nn.Identity()
|
| 50 |
+
|
| 51 |
+
elif backbone == 'mobilenet_v2':
|
| 52 |
+
self.backbone = models.mobilenet_v2(pretrained=pretrained)
|
| 53 |
+
num_features = self.backbone.classifier[1].in_features
|
| 54 |
+
self.backbone.classifier = nn.Identity()
|
| 55 |
+
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(f"Unknown backbone: {backbone}")
|
| 58 |
+
|
| 59 |
+
# Custom classifier head
|
| 60 |
+
self.classifier = nn.Sequential(
|
| 61 |
+
nn.Dropout(dropout),
|
| 62 |
+
nn.Linear(num_features, 512),
|
| 63 |
+
nn.ReLU(inplace=True),
|
| 64 |
+
nn.Dropout(dropout),
|
| 65 |
+
nn.Linear(512, num_classes)
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
print(f"Model created: {backbone}")
|
| 69 |
+
print(f" Features: {num_features}")
|
| 70 |
+
print(f" Classes: {num_classes}")
|
| 71 |
+
print(f" Pretrained: {pretrained}")
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
"""
|
| 75 |
+
Forward pass
|
| 76 |
+
Args:
|
| 77 |
+
x: Input tensor [batch_size, 3, H, W]
|
| 78 |
+
Returns:
|
| 79 |
+
logits: Output tensor [batch_size, num_classes]
|
| 80 |
+
"""
|
| 81 |
+
features = self.backbone(x)
|
| 82 |
+
logits = self.classifier(features)
|
| 83 |
+
return logits
|
| 84 |
+
|
| 85 |
+
def freeze_backbone(self):
|
| 86 |
+
"""Freeze backbone parameters for fine-tuning"""
|
| 87 |
+
for param in self.backbone.parameters():
|
| 88 |
+
param.requires_grad = False
|
| 89 |
+
print("Backbone frozen")
|
| 90 |
+
|
| 91 |
+
def unfreeze_backbone(self):
|
| 92 |
+
"""Unfreeze backbone parameters"""
|
| 93 |
+
for param in self.backbone.parameters():
|
| 94 |
+
param.requires_grad = True
|
| 95 |
+
print("Backbone unfrozen")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def create_model(num_classes=10, backbone='resnet50', pretrained=True, dropout=0.3):
|
| 99 |
+
"""
|
| 100 |
+
Factory function to create model
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
num_classes (int): Number of classes
|
| 104 |
+
backbone (str): Model architecture
|
| 105 |
+
pretrained (bool): Use pretrained weights
|
| 106 |
+
dropout (float): Dropout rate
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
model: PestDiseaseClassifier instance
|
| 110 |
+
"""
|
| 111 |
+
model = PestDiseaseClassifier(
|
| 112 |
+
num_classes=num_classes,
|
| 113 |
+
backbone=backbone,
|
| 114 |
+
pretrained=pretrained,
|
| 115 |
+
dropout=dropout
|
| 116 |
+
)
|
| 117 |
+
return model
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def count_parameters(model):
|
| 121 |
+
"""Count total and trainable parameters"""
|
| 122 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 123 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 124 |
+
|
| 125 |
+
print(f"\nModel Parameters:")
|
| 126 |
+
print(f" Total: {total_params:,}")
|
| 127 |
+
print(f" Trainable: {trainable_params:,}")
|
| 128 |
+
print(f" Non-trainable: {total_params - trainable_params:,}")
|
| 129 |
+
|
| 130 |
+
return total_params, trainable_params
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
"""Test model creation"""
|
| 135 |
+
print("Testing Pest and Disease Classification Models")
|
| 136 |
+
print("=" * 60)
|
| 137 |
+
|
| 138 |
+
# Test different backbones
|
| 139 |
+
backbones = ['resnet50', 'efficientnet_b0', 'mobilenet_v2']
|
| 140 |
+
|
| 141 |
+
for backbone in backbones:
|
| 142 |
+
print(f"\nTesting {backbone}...")
|
| 143 |
+
print("-" * 60)
|
| 144 |
+
|
| 145 |
+
model = create_model(num_classes=10, backbone=backbone, pretrained=True)
|
| 146 |
+
count_parameters(model)
|
| 147 |
+
|
| 148 |
+
# Test forward pass
|
| 149 |
+
dummy_input = torch.randn(2, 3, 224, 224)
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
output = model(dummy_input)
|
| 152 |
+
|
| 153 |
+
print(f" Input shape: {dummy_input.shape}")
|
| 154 |
+
print(f" Output shape: {output.shape}")
|
| 155 |
+
print(f" Output range: [{output.min():.3f}, {output.max():.3f}]")
|
| 156 |
+
|
| 157 |
+
print("\n" + "=" * 60)
|
| 158 |
+
print("Model test completed successfully!")
|
| 159 |
+
print("=" * 60)
|
prepare_dataset.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pest and Disease Classification Dataset Preparation Script
|
| 3 |
+
- Scan data folders
|
| 4 |
+
- Analyze image distribution
|
| 5 |
+
- Generate train/val/test CSV files
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from sklearn.model_selection import train_test_split
|
| 13 |
+
import json
|
| 14 |
+
|
| 15 |
+
# Configuration parameters
|
| 16 |
+
DATA_DIR = "Data"
|
| 17 |
+
OUTPUT_CSV = "dataset.csv"
|
| 18 |
+
TRAIN_RATIO = 0.7
|
| 19 |
+
VAL_RATIO = 0.15
|
| 20 |
+
TEST_RATIO = 0.15
|
| 21 |
+
RANDOM_SEED = 42
|
| 22 |
+
|
| 23 |
+
# Set random seed
|
| 24 |
+
np.random.seed(RANDOM_SEED)
|
| 25 |
+
|
| 26 |
+
def scan_dataset():
|
| 27 |
+
"""Scan dataset and collect all image information"""
|
| 28 |
+
data_list = []
|
| 29 |
+
image_sizes = []
|
| 30 |
+
|
| 31 |
+
category_mapping = {
|
| 32 |
+
"A.健康植株": {
|
| 33 |
+
"椪柑": "健康植株-椪柑",
|
| 34 |
+
"茂谷柑": "健康植株-茂谷柑"
|
| 35 |
+
},
|
| 36 |
+
"B.病害": {
|
| 37 |
+
"1.病害-潰瘍病": "潰瘍病",
|
| 38 |
+
"2.病害-煤煙病": "煤煙病",
|
| 39 |
+
"3.病害-油斑病": "油斑病",
|
| 40 |
+
"4.病害-黑點病2": "黑點病"
|
| 41 |
+
},
|
| 42 |
+
"C.蟲害": {
|
| 43 |
+
"1.蟲害-薊馬": "薊馬",
|
| 44 |
+
"2.蟲害-潛葉蛾": "潛葉蛾",
|
| 45 |
+
"3.蟲害-蚜蟲": "蚜蟲",
|
| 46 |
+
"4.蟲害-介殼蟲": "介殼蟲"
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
print("Scanning dataset...")
|
| 51 |
+
|
| 52 |
+
for main_dir in ["A.健康植株", "B.病害", "C.蟲害"]:
|
| 53 |
+
main_path = Path(DATA_DIR) / main_dir
|
| 54 |
+
|
| 55 |
+
if not main_path.exists():
|
| 56 |
+
print(f"Warning: {main_path} does not exist")
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
# Iterate through subdirectories
|
| 60 |
+
for sub_dir in main_path.iterdir():
|
| 61 |
+
if not sub_dir.is_dir():
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
# Determine class label
|
| 65 |
+
try:
|
| 66 |
+
label = category_mapping[main_dir][sub_dir.name]
|
| 67 |
+
print(f" Processing: {main_dir}/{sub_dir.name} -> {label}")
|
| 68 |
+
except KeyError:
|
| 69 |
+
print(f" Warning: Unknown subdirectory {main_dir}/{sub_dir.name}, skipping...")
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
# Store plant type info
|
| 73 |
+
if main_dir == "A.健康植株":
|
| 74 |
+
plant_type = sub_dir.name # Ponkan or Murcott
|
| 75 |
+
else:
|
| 76 |
+
plant_type = "柑橘"
|
| 77 |
+
|
| 78 |
+
# Scan images (case-insensitive)
|
| 79 |
+
image_files = (list(sub_dir.glob("*.jpg")) + list(sub_dir.glob("*.JPG")) +
|
| 80 |
+
list(sub_dir.glob("*.jpeg")) + list(sub_dir.glob("*.JPEG")) +
|
| 81 |
+
list(sub_dir.glob("*.png")) + list(sub_dir.glob("*.PNG")))
|
| 82 |
+
|
| 83 |
+
for img_path in image_files:
|
| 84 |
+
try:
|
| 85 |
+
# Get image dimensions
|
| 86 |
+
with Image.open(img_path) as img:
|
| 87 |
+
width, height = img.size
|
| 88 |
+
image_sizes.append((width, height))
|
| 89 |
+
|
| 90 |
+
data_list.append({
|
| 91 |
+
'image_path': str(img_path),
|
| 92 |
+
'label': label,
|
| 93 |
+
'main_category': main_dir.split('.')[1],
|
| 94 |
+
'plant_type': plant_type,
|
| 95 |
+
'width': width,
|
| 96 |
+
'height': height
|
| 97 |
+
})
|
| 98 |
+
except Exception as e:
|
| 99 |
+
print(f"Warning: Cannot read {img_path}: {e}")
|
| 100 |
+
|
| 101 |
+
return data_list, image_sizes
|
| 102 |
+
|
| 103 |
+
def analyze_dataset(data_list, image_sizes):
|
| 104 |
+
"""Analyze dataset statistics"""
|
| 105 |
+
df = pd.DataFrame(data_list)
|
| 106 |
+
|
| 107 |
+
print("\n" + "="*60)
|
| 108 |
+
print("Dataset Statistics")
|
| 109 |
+
print("="*60)
|
| 110 |
+
|
| 111 |
+
# Overall statistics
|
| 112 |
+
print(f"\nTotal images: {len(df)}")
|
| 113 |
+
print(f"\nClass distribution:")
|
| 114 |
+
label_counts = df['label'].value_counts()
|
| 115 |
+
for label, count in label_counts.items():
|
| 116 |
+
print(f" {label}: {count} images ({count/len(df)*100:.1f}%)")
|
| 117 |
+
|
| 118 |
+
# Image size analysis
|
| 119 |
+
if image_sizes:
|
| 120 |
+
widths, heights = zip(*image_sizes)
|
| 121 |
+
print(f"\nImage size analysis:")
|
| 122 |
+
print(f" Width: min={min(widths)}, max={max(widths)}, avg={np.mean(widths):.0f}")
|
| 123 |
+
print(f" Height: min={min(heights)}, max={max(heights)}, avg={np.mean(heights):.0f}")
|
| 124 |
+
|
| 125 |
+
# Check size consistency
|
| 126 |
+
unique_sizes = set(image_sizes)
|
| 127 |
+
print(f" Unique sizes: {len(unique_sizes)}")
|
| 128 |
+
if len(unique_sizes) <= 5:
|
| 129 |
+
print(f" Main sizes: {list(unique_sizes)[:5]}")
|
| 130 |
+
|
| 131 |
+
# Check class imbalance
|
| 132 |
+
max_count = label_counts.max()
|
| 133 |
+
min_count = label_counts.min()
|
| 134 |
+
imbalance_ratio = max_count / min_count
|
| 135 |
+
print(f"\nClass imbalance ratio: {imbalance_ratio:.2f}x")
|
| 136 |
+
if imbalance_ratio > 3:
|
| 137 |
+
print(" Warning: Severe class imbalance detected. Consider using weighted loss or data augmentation")
|
| 138 |
+
|
| 139 |
+
return df
|
| 140 |
+
|
| 141 |
+
def split_dataset(df, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
|
| 142 |
+
"""Split dataset into train/val/test sets with stratified sampling"""
|
| 143 |
+
assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "Ratios must sum to 1"
|
| 144 |
+
|
| 145 |
+
print("\n" + "="*60)
|
| 146 |
+
print("Splitting Dataset (Stratified Sampling)")
|
| 147 |
+
print("="*60)
|
| 148 |
+
|
| 149 |
+
# First split out test set
|
| 150 |
+
train_val_df, test_df = train_test_split(
|
| 151 |
+
df,
|
| 152 |
+
test_size=test_ratio,
|
| 153 |
+
stratify=df['label'],
|
| 154 |
+
random_state=RANDOM_SEED
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Then split train and validation from remaining data
|
| 158 |
+
val_ratio_adjusted = val_ratio / (train_ratio + val_ratio)
|
| 159 |
+
train_df, val_df = train_test_split(
|
| 160 |
+
train_val_df,
|
| 161 |
+
test_size=val_ratio_adjusted,
|
| 162 |
+
stratify=train_val_df['label'],
|
| 163 |
+
random_state=RANDOM_SEED
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# Add split column
|
| 167 |
+
train_df = train_df.copy()
|
| 168 |
+
val_df = val_df.copy()
|
| 169 |
+
test_df = test_df.copy()
|
| 170 |
+
|
| 171 |
+
train_df['split'] = 'train'
|
| 172 |
+
val_df['split'] = 'val'
|
| 173 |
+
test_df['split'] = 'test'
|
| 174 |
+
|
| 175 |
+
# Merge all splits
|
| 176 |
+
final_df = pd.concat([train_df, val_df, test_df], ignore_index=True)
|
| 177 |
+
|
| 178 |
+
# Display class distribution for each split
|
| 179 |
+
print(f"\nTrain set: {len(train_df)} images ({len(train_df)/len(df)*100:.1f}%)")
|
| 180 |
+
print(train_df['label'].value_counts().to_string())
|
| 181 |
+
|
| 182 |
+
print(f"\nValidation set: {len(val_df)} images ({len(val_df)/len(df)*100:.1f}%)")
|
| 183 |
+
print(val_df['label'].value_counts().to_string())
|
| 184 |
+
|
| 185 |
+
print(f"\nTest set: {len(test_df)} images ({len(test_df)/len(df)*100:.1f}%)")
|
| 186 |
+
print(test_df['label'].value_counts().to_string())
|
| 187 |
+
|
| 188 |
+
return final_df
|
| 189 |
+
|
| 190 |
+
def save_dataset(df, output_path):
|
| 191 |
+
"""Save dataset CSV and label mapping"""
|
| 192 |
+
# Save complete CSV
|
| 193 |
+
df.to_csv(output_path, index=False, encoding='utf-8-sig')
|
| 194 |
+
print(f"\nDataset saved to: {output_path}")
|
| 195 |
+
|
| 196 |
+
# Create label to ID mapping
|
| 197 |
+
unique_labels = sorted(df['label'].unique())
|
| 198 |
+
label_to_id = {label: idx for idx, label in enumerate(unique_labels)}
|
| 199 |
+
id_to_label = {idx: label for label, idx in label_to_id.items()}
|
| 200 |
+
|
| 201 |
+
# Save label mapping
|
| 202 |
+
mapping_file = "label_mapping.json"
|
| 203 |
+
with open(mapping_file, 'w', encoding='utf-8') as f:
|
| 204 |
+
json.dump({
|
| 205 |
+
'label_to_id': label_to_id,
|
| 206 |
+
'id_to_label': id_to_label,
|
| 207 |
+
'num_classes': len(unique_labels)
|
| 208 |
+
}, f, ensure_ascii=False, indent=2)
|
| 209 |
+
|
| 210 |
+
print(f"Label mapping saved to: {mapping_file}")
|
| 211 |
+
print(f"\nLabel mapping ({len(unique_labels)} classes):")
|
| 212 |
+
for label, idx in label_to_id.items():
|
| 213 |
+
print(f" {idx}: {label}")
|
| 214 |
+
|
| 215 |
+
return label_to_id
|
| 216 |
+
|
| 217 |
+
def main():
|
| 218 |
+
"""Main function"""
|
| 219 |
+
print("Pest and Disease Dataset Preparation Tool")
|
| 220 |
+
print("="*60)
|
| 221 |
+
|
| 222 |
+
# 1. Scan dataset
|
| 223 |
+
data_list, image_sizes = scan_dataset()
|
| 224 |
+
|
| 225 |
+
if not data_list:
|
| 226 |
+
print("Error: No images found!")
|
| 227 |
+
return
|
| 228 |
+
|
| 229 |
+
# 2. Analyze dataset
|
| 230 |
+
df = analyze_dataset(data_list, image_sizes)
|
| 231 |
+
|
| 232 |
+
# 3. Split dataset
|
| 233 |
+
final_df = split_dataset(df, TRAIN_RATIO, VAL_RATIO, TEST_RATIO)
|
| 234 |
+
|
| 235 |
+
# 4. Save dataset
|
| 236 |
+
label_to_id = save_dataset(final_df, OUTPUT_CSV)
|
| 237 |
+
|
| 238 |
+
print("\n" + "="*60)
|
| 239 |
+
print("Dataset preparation completed!")
|
| 240 |
+
print("="*60)
|
| 241 |
+
print("\nNext steps:")
|
| 242 |
+
print(" 1. Check dataset.csv and label_mapping.json")
|
| 243 |
+
print(" 2. Run data loader test script")
|
| 244 |
+
print(" 3. Start model training")
|
| 245 |
+
|
| 246 |
+
if __name__ == "__main__":
|
| 247 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
pillow
|
| 4 |
+
gradio==4.39.0
|
| 5 |
+
huggingface_hub==0.25.2
|
| 6 |
+
rich
|
| 7 |
+
seaborn
|
| 8 |
+
pathlib
|
| 9 |
+
pandas
|
| 10 |
+
pydantic==2.10.6
|
train.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple Training Script for Pest and Disease Classification
|
| 3 |
+
Using Rich for progress display
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import json
|
| 11 |
+
import argparse
|
| 12 |
+
from rich.console import Console
|
| 13 |
+
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn
|
| 14 |
+
from rich.table import Table
|
| 15 |
+
from rich.panel import Panel
|
| 16 |
+
|
| 17 |
+
from dataset import get_dataloaders, calculate_class_weights
|
| 18 |
+
from model import create_model
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
console = Console()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def train_epoch(model, dataloader, criterion, optimizer, device, progress, task):
|
| 25 |
+
"""Train for one epoch with progress bar"""
|
| 26 |
+
model.train()
|
| 27 |
+
running_loss = 0.0
|
| 28 |
+
running_corrects = 0
|
| 29 |
+
total_samples = 0
|
| 30 |
+
|
| 31 |
+
for inputs, labels in dataloader:
|
| 32 |
+
inputs = inputs.to(device)
|
| 33 |
+
labels = labels.to(device)
|
| 34 |
+
|
| 35 |
+
optimizer.zero_grad()
|
| 36 |
+
outputs = model(inputs)
|
| 37 |
+
loss = criterion(outputs, labels)
|
| 38 |
+
_, preds = torch.max(outputs, 1)
|
| 39 |
+
|
| 40 |
+
loss.backward()
|
| 41 |
+
optimizer.step()
|
| 42 |
+
|
| 43 |
+
running_loss += loss.item() * inputs.size(0)
|
| 44 |
+
running_corrects += torch.sum(preds == labels.data)
|
| 45 |
+
total_samples += inputs.size(0)
|
| 46 |
+
|
| 47 |
+
progress.update(task, advance=1)
|
| 48 |
+
|
| 49 |
+
epoch_loss = running_loss / total_samples
|
| 50 |
+
epoch_acc = running_corrects.double() / total_samples
|
| 51 |
+
return epoch_loss, epoch_acc.item()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def validate_epoch(model, dataloader, criterion, device, progress, task):
|
| 55 |
+
"""Validate for one epoch with progress bar"""
|
| 56 |
+
model.eval()
|
| 57 |
+
running_loss = 0.0
|
| 58 |
+
running_corrects = 0
|
| 59 |
+
total_samples = 0
|
| 60 |
+
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
for inputs, labels in dataloader:
|
| 63 |
+
inputs = inputs.to(device)
|
| 64 |
+
labels = labels.to(device)
|
| 65 |
+
|
| 66 |
+
outputs = model(inputs)
|
| 67 |
+
loss = criterion(outputs, labels)
|
| 68 |
+
_, preds = torch.max(outputs, 1)
|
| 69 |
+
|
| 70 |
+
running_loss += loss.item() * inputs.size(0)
|
| 71 |
+
running_corrects += torch.sum(preds == labels.data)
|
| 72 |
+
total_samples += inputs.size(0)
|
| 73 |
+
|
| 74 |
+
progress.update(task, advance=1)
|
| 75 |
+
|
| 76 |
+
epoch_loss = running_loss / total_samples
|
| 77 |
+
epoch_acc = running_corrects.double() / total_samples
|
| 78 |
+
return epoch_loss, epoch_acc.item()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def train_model(model, train_loader, val_loader, criterion, optimizer,
|
| 82 |
+
num_epochs, device, save_dir):
|
| 83 |
+
"""
|
| 84 |
+
Simple training loop with Rich progress display
|
| 85 |
+
"""
|
| 86 |
+
save_dir = Path(save_dir)
|
| 87 |
+
save_dir.mkdir(exist_ok=True)
|
| 88 |
+
|
| 89 |
+
best_val_acc = 0.0
|
| 90 |
+
history = {
|
| 91 |
+
'train_loss': [],
|
| 92 |
+
'train_acc': [],
|
| 93 |
+
'val_loss': [],
|
| 94 |
+
'val_acc': []
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
console.print("\n[bold green]Starting Training[/bold green]")
|
| 98 |
+
|
| 99 |
+
for epoch in range(num_epochs):
|
| 100 |
+
console.print(f"\n[bold cyan]Epoch {epoch+1}/{num_epochs}[/bold cyan]")
|
| 101 |
+
|
| 102 |
+
with Progress(
|
| 103 |
+
SpinnerColumn(),
|
| 104 |
+
TextColumn("[progress.description]{task.description}"),
|
| 105 |
+
BarColumn(),
|
| 106 |
+
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
| 107 |
+
TimeRemainingColumn(),
|
| 108 |
+
console=console
|
| 109 |
+
) as progress:
|
| 110 |
+
|
| 111 |
+
# Training
|
| 112 |
+
train_task = progress.add_task(
|
| 113 |
+
"[red]Training...",
|
| 114 |
+
total=len(train_loader)
|
| 115 |
+
)
|
| 116 |
+
train_loss, train_acc = train_epoch(
|
| 117 |
+
model, train_loader, criterion, optimizer,
|
| 118 |
+
device, progress, train_task
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Validation
|
| 122 |
+
val_task = progress.add_task(
|
| 123 |
+
"[green]Validating...",
|
| 124 |
+
total=len(val_loader)
|
| 125 |
+
)
|
| 126 |
+
val_loss, val_acc = validate_epoch(
|
| 127 |
+
model, val_loader, criterion, device,
|
| 128 |
+
progress, val_task
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# Create results table
|
| 132 |
+
table = Table(show_header=True, header_style="bold magenta")
|
| 133 |
+
table.add_column("Split", style="cyan")
|
| 134 |
+
table.add_column("Loss", justify="right", style="yellow")
|
| 135 |
+
table.add_column("Accuracy", justify="right", style="green")
|
| 136 |
+
|
| 137 |
+
table.add_row("Train", f"{train_loss:.4f}", f"{train_acc:.4f}")
|
| 138 |
+
table.add_row("Val", f"{val_loss:.4f}", f"{val_acc:.4f}")
|
| 139 |
+
|
| 140 |
+
console.print(table)
|
| 141 |
+
|
| 142 |
+
# Save history
|
| 143 |
+
history['train_loss'].append(train_loss)
|
| 144 |
+
history['train_acc'].append(train_acc)
|
| 145 |
+
history['val_loss'].append(val_loss)
|
| 146 |
+
history['val_acc'].append(val_acc)
|
| 147 |
+
|
| 148 |
+
# Save best model
|
| 149 |
+
if val_acc > best_val_acc:
|
| 150 |
+
best_val_acc = val_acc
|
| 151 |
+
torch.save({
|
| 152 |
+
'epoch': epoch,
|
| 153 |
+
'model_state_dict': model.state_dict(),
|
| 154 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 155 |
+
'val_acc': val_acc,
|
| 156 |
+
'val_loss': val_loss,
|
| 157 |
+
}, save_dir / 'best_model.pth')
|
| 158 |
+
console.print(f"[bold green]✓ Saved best model (Val Acc: {val_acc:.4f})[/bold green]")
|
| 159 |
+
|
| 160 |
+
# Save checkpoint every 10 epochs
|
| 161 |
+
if (epoch + 1) % 10 == 0:
|
| 162 |
+
torch.save({
|
| 163 |
+
'epoch': epoch,
|
| 164 |
+
'model_state_dict': model.state_dict(),
|
| 165 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 166 |
+
'val_acc': val_acc,
|
| 167 |
+
'val_loss': val_loss,
|
| 168 |
+
}, save_dir / f'checkpoint_epoch_{epoch+1}.pth')
|
| 169 |
+
console.print(f"[yellow]Checkpoint saved at epoch {epoch+1}[/yellow]")
|
| 170 |
+
|
| 171 |
+
# Save training history
|
| 172 |
+
with open(save_dir / 'training_history.json', 'w') as f:
|
| 173 |
+
json.dump(history, f, indent=2)
|
| 174 |
+
|
| 175 |
+
console.print(f"\n[bold green]Training Complete![/bold green]")
|
| 176 |
+
console.print(f"[bold]Best Val Acc: {best_val_acc:.4f}[/bold]")
|
| 177 |
+
console.print(f"[bold]Results saved to: {save_dir}/[/bold]")
|
| 178 |
+
|
| 179 |
+
return model, history
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def main(args):
|
| 183 |
+
"""Main training function"""
|
| 184 |
+
# Print configuration
|
| 185 |
+
config_panel = Panel.fit(
|
| 186 |
+
f"""[bold]Configuration[/bold]
|
| 187 |
+
Backbone: {args.backbone}
|
| 188 |
+
Batch Size: {args.batch_size}
|
| 189 |
+
Image Size: {args.img_size}
|
| 190 |
+
Epochs: {args.epochs}
|
| 191 |
+
Learning Rate: {args.lr}
|
| 192 |
+
Optimizer: {args.optimizer}
|
| 193 |
+
Device: {args.device}
|
| 194 |
+
Class Weights: {args.use_class_weights}""",
|
| 195 |
+
title="Training Settings",
|
| 196 |
+
border_style="blue"
|
| 197 |
+
)
|
| 198 |
+
console.print(config_panel)
|
| 199 |
+
|
| 200 |
+
# Set device
|
| 201 |
+
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
|
| 202 |
+
console.print(f"\n[bold]Using device: {device}[/bold]")
|
| 203 |
+
|
| 204 |
+
# Load data
|
| 205 |
+
console.print("\n[bold]Loading datasets...[/bold]")
|
| 206 |
+
loaders = get_dataloaders(
|
| 207 |
+
csv_file=args.csv_file,
|
| 208 |
+
label_mapping_file=args.label_mapping,
|
| 209 |
+
batch_size=args.batch_size,
|
| 210 |
+
img_size=args.img_size,
|
| 211 |
+
num_workers=args.num_workers
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Create model
|
| 215 |
+
console.print(f"\n[bold]Creating model: {args.backbone}[/bold]")
|
| 216 |
+
model = create_model(
|
| 217 |
+
num_classes=loaders['num_classes'],
|
| 218 |
+
backbone=args.backbone,
|
| 219 |
+
pretrained=True,
|
| 220 |
+
dropout=args.dropout
|
| 221 |
+
)
|
| 222 |
+
model = model.to(device)
|
| 223 |
+
|
| 224 |
+
# Loss function
|
| 225 |
+
if args.use_class_weights:
|
| 226 |
+
class_weights = calculate_class_weights(args.csv_file, args.label_mapping)
|
| 227 |
+
class_weights = class_weights.to(device)
|
| 228 |
+
criterion = nn.CrossEntropyLoss(weight=class_weights)
|
| 229 |
+
console.print("[bold]Using weighted CrossEntropyLoss[/bold]")
|
| 230 |
+
else:
|
| 231 |
+
criterion = nn.CrossEntropyLoss()
|
| 232 |
+
console.print("[bold]Using CrossEntropyLoss[/bold]")
|
| 233 |
+
|
| 234 |
+
# Optimizer
|
| 235 |
+
if args.optimizer == 'adam':
|
| 236 |
+
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
| 237 |
+
elif args.optimizer == 'adamw':
|
| 238 |
+
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
| 239 |
+
elif args.optimizer == 'sgd':
|
| 240 |
+
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9,
|
| 241 |
+
weight_decay=args.weight_decay)
|
| 242 |
+
|
| 243 |
+
# Train model
|
| 244 |
+
model, history = train_model(
|
| 245 |
+
model=model,
|
| 246 |
+
train_loader=loaders['train'],
|
| 247 |
+
val_loader=loaders['val'],
|
| 248 |
+
criterion=criterion,
|
| 249 |
+
optimizer=optimizer,
|
| 250 |
+
num_epochs=args.epochs,
|
| 251 |
+
device=device,
|
| 252 |
+
save_dir=args.save_dir
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
if __name__ == "__main__":
|
| 257 |
+
parser = argparse.ArgumentParser(description='Simple Training for Pest and Disease Classifier')
|
| 258 |
+
|
| 259 |
+
# Data parameters
|
| 260 |
+
parser.add_argument('--csv_file', type=str, default='dataset.csv')
|
| 261 |
+
parser.add_argument('--label_mapping', type=str, default='label_mapping.json')
|
| 262 |
+
|
| 263 |
+
# Model parameters
|
| 264 |
+
parser.add_argument('--backbone', type=str, default='resnet50',
|
| 265 |
+
choices=['resnet50', 'resnet101', 'efficientnet_b0',
|
| 266 |
+
'efficientnet_b3', 'mobilenet_v2'])
|
| 267 |
+
parser.add_argument('--dropout', type=float, default=0.3)
|
| 268 |
+
|
| 269 |
+
# Training parameters
|
| 270 |
+
parser.add_argument('--batch_size', type=int, default=64)
|
| 271 |
+
parser.add_argument('--img_size', type=int, default=224)
|
| 272 |
+
parser.add_argument('--epochs', type=int, default=50)
|
| 273 |
+
parser.add_argument('--lr', type=float, default=0.001)
|
| 274 |
+
parser.add_argument('--optimizer', type=str, default='adamw',
|
| 275 |
+
choices=['adam', 'adamw', 'sgd'])
|
| 276 |
+
parser.add_argument('--weight_decay', type=float, default=0.01)
|
| 277 |
+
parser.add_argument('--use_class_weights', action='store_true')
|
| 278 |
+
|
| 279 |
+
# System parameters
|
| 280 |
+
parser.add_argument('--device', type=str, default='cuda',
|
| 281 |
+
choices=['cuda', 'cpu'])
|
| 282 |
+
parser.add_argument('--num_workers', type=int, default=8)
|
| 283 |
+
parser.add_argument('--save_dir', type=str, default='checkpoints')
|
| 284 |
+
|
| 285 |
+
args = parser.parse_args()
|
| 286 |
+
main(args)
|