Boyun7 commited on
Commit
03d5bce
·
1 Parent(s): a5953ac

upload all files

Browse files
Files changed (12) hide show
  1. .gitattributes +0 -34
  2. README copy.md +28 -0
  3. app.py +94 -0
  4. checkpoints/best_model_993.pth +3 -0
  5. dataset.py +253 -0
  6. demo.py +163 -0
  7. evaluate.py +327 -0
  8. label_mapping.json +27 -0
  9. model.py +159 -0
  10. prepare_dataset.py +247 -0
  11. requirements.txt +10 -0
  12. train.py +286 -0
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
README copy.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Pest and Disease Classification 🌿
3
+ emoji: 🌱
4
+ colorFrom: green
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: "4.44.1"
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # 🌿 Pest and Disease Classification Demo
13
+
14
+ This demo provides a simple web interface for classifying **pests and diseases in citrus leaves**.
15
+
16
+ ## 🧠 Model
17
+ The model is based on a CNN backbone (ResNet50 by default) trained on a labeled dataset of citrus plant leaves.
18
+
19
+ - **Framework:** PyTorch
20
+ - **Interface:** Gradio
21
+ - **Backbone:** ResNet50
22
+ - **Task:** Image classification
23
+
24
+ ## 🚀 How to Use
25
+ 1. Click **“Upload Image”** and select a photo of a citrus leaf.
26
+ 2. The app will output the **predicted pest or disease category** with confidence scores.
27
+
28
+ ## 📂 Repository Structure
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple Demo for Pest and Disease Classification
3
+ For Hugging Face Space Deployment
4
+ """
5
+
6
+ import torch
7
+ from PIL import Image
8
+ import json
9
+ import gradio as gr
10
+ from torchvision import transforms
11
+
12
+ from model import create_model
13
+
14
+
15
+ class PestDiseasePredictor:
16
+ """Simple predictor class"""
17
+
18
+ def __init__(self, checkpoint_path, label_mapping_path, backbone='resnet50', device='cuda'):
19
+ self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
20
+
21
+ # Load label mapping
22
+ with open(label_mapping_path, 'r', encoding='utf-8') as f:
23
+ mapping = json.load(f)
24
+ self.id_to_label = {int(k): v for k, v in mapping['id_to_label'].items()}
25
+ self.num_classes = mapping['num_classes']
26
+
27
+ # Load model
28
+ self.model = create_model(
29
+ num_classes=self.num_classes,
30
+ backbone=backbone,
31
+ pretrained=False
32
+ )
33
+
34
+ # Load checkpoint
35
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
36
+ self.model.load_state_dict(checkpoint['model_state_dict'])
37
+ self.model = self.model.to(self.device)
38
+ self.model.eval()
39
+
40
+ # Image transforms
41
+ self.transform = transforms.Compose([
42
+ transforms.Resize((224, 224)),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
45
+ std=[0.229, 0.224, 0.225])
46
+ ])
47
+
48
+ print(f"✅ Model loaded from {checkpoint_path}")
49
+ print(f"💻 Device: {self.device}")
50
+ print(f"📚 Classes: {self.num_classes}")
51
+
52
+ def predict(self, image):
53
+ if image.mode != 'RGB':
54
+ image = image.convert('RGB')
55
+
56
+ img_tensor = self.transform(image).unsqueeze(0).to(self.device)
57
+ with torch.no_grad():
58
+ outputs = self.model(img_tensor)
59
+ probs = torch.nn.functional.softmax(outputs, dim=1)[0].cpu().numpy()
60
+
61
+ results = {self.id_to_label[i]: float(p) for i, p in enumerate(probs)}
62
+ return dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
63
+
64
+
65
+ # ========== For Hugging Face Space ==========
66
+ checkpoint_path = "checkpoints/best_efficientnet_b3.pth"
67
+ label_mapping_path = "label_mapping.json"
68
+ backbone = 'efficientnet_b3'
69
+ device = "cuda"
70
+
71
+ predictor = PestDiseasePredictor(
72
+ checkpoint_path=checkpoint_path,
73
+ label_mapping_path=label_mapping_path,
74
+ backbone=backbone,
75
+ device=device
76
+ )
77
+
78
+ def predict_image(image):
79
+ if image is None:
80
+ return None
81
+ return predictor.predict(image)
82
+
83
+ demo = gr.Interface(
84
+ fn=predict_image,
85
+ inputs=gr.Image(type="pil", label="Upload Image"),
86
+ outputs=gr.Label(num_top_classes=10, label="Predictions"),
87
+ title="🌿 Pest and Disease Classification",
88
+ description="Upload an image of a citrus leaf to classify its pest or disease type.",
89
+ theme=gr.themes.Soft(),
90
+ allow_flagging="never"
91
+ )
92
+
93
+ if __name__ == "__main__":
94
+ demo.launch()
checkpoints/best_model_993.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e18b22b823871125c07933e128e7afed92110da462fee5781462fed1066b4e33
3
+ size 138717293
dataset.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch Dataset and DataLoader for Pest and Disease Classification
3
+ """
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ import pandas as pd
10
+ import json
11
+ from pathlib import Path
12
+
13
+
14
+ class PestDiseaseDataset(Dataset):
15
+ """Custom Dataset for loading pest and disease images"""
16
+
17
+ def __init__(self, csv_file, label_mapping_file, split='train', transform=None):
18
+ """
19
+ Args:
20
+ csv_file (str): Path to CSV file with image paths and labels
21
+ label_mapping_file (str): Path to JSON file with label mappings
22
+ split (str): One of 'train', 'val', or 'test'
23
+ transform (callable, optional): Optional transform to be applied on images
24
+ """
25
+ self.df = pd.read_csv(csv_file)
26
+ self.df = self.df[self.df['split'] == split].reset_index(drop=True)
27
+
28
+ # Load label mapping
29
+ with open(label_mapping_file, 'r', encoding='utf-8') as f:
30
+ mapping = json.load(f)
31
+ self.label_to_id = mapping['label_to_id']
32
+ self.id_to_label = {int(k): v for k, v in mapping['id_to_label'].items()}
33
+ self.num_classes = mapping['num_classes']
34
+
35
+ self.transform = transform
36
+ self.split = split
37
+
38
+ print(f"Loaded {split} set: {len(self.df)} images")
39
+
40
+ def __len__(self):
41
+ return len(self.df)
42
+
43
+ def __getitem__(self, idx):
44
+ """
45
+ Returns:
46
+ image: Transformed image tensor
47
+ label: Label ID (integer)
48
+ """
49
+ row = self.df.iloc[idx]
50
+
51
+ # Load image
52
+ img_path = row['image_path']
53
+ image = Image.open(img_path).convert('RGB')
54
+
55
+ # Get label
56
+ label_name = row['label']
57
+ label = self.label_to_id[label_name]
58
+
59
+ # Apply transforms
60
+ if self.transform:
61
+ image = self.transform(image)
62
+
63
+ return image, label
64
+
65
+ def get_label_name(self, label_id):
66
+ """Convert label ID back to label name"""
67
+ return self.id_to_label[label_id]
68
+
69
+
70
+ def get_transforms(split='train', img_size=224):
71
+ """
72
+ Get data augmentation transforms for different splits
73
+
74
+ Args:
75
+ split (str): 'train', 'val', or 'test'
76
+ img_size (int): Target image size (default: 224 for most pretrained models)
77
+
78
+ Returns:
79
+ transforms.Compose: Composed transforms
80
+ """
81
+ if split == 'train':
82
+ # Training: Apply data augmentation
83
+ return transforms.Compose([
84
+ transforms.Resize((img_size, img_size)),
85
+ transforms.RandomHorizontalFlip(p=0.5),
86
+ transforms.RandomVerticalFlip(p=0.3),
87
+ transforms.RandomRotation(degrees=30),
88
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
89
+ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
90
+ transforms.ToTensor(),
91
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
92
+ std=[0.229, 0.224, 0.225])
93
+ ])
94
+ else:
95
+ # Validation/Test: No augmentation, only resize and normalize
96
+ return transforms.Compose([
97
+ transforms.Resize((img_size, img_size)),
98
+ transforms.ToTensor(),
99
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
100
+ std=[0.229, 0.224, 0.225])
101
+ ])
102
+
103
+
104
+ def get_dataloaders(csv_file='dataset.csv',
105
+ label_mapping_file='label_mapping.json',
106
+ batch_size=32,
107
+ img_size=224,
108
+ num_workers=4):
109
+ """
110
+ Create train, validation, and test dataloaders
111
+
112
+ Args:
113
+ csv_file (str): Path to dataset CSV
114
+ label_mapping_file (str): Path to label mapping JSON
115
+ batch_size (int): Batch size for training
116
+ img_size (int): Image size for models
117
+ num_workers (int): Number of workers for data loading
118
+
119
+ Returns:
120
+ dict: Dictionary containing 'train', 'val', 'test' dataloaders and 'num_classes'
121
+ """
122
+
123
+ # Create datasets
124
+ train_dataset = PestDiseaseDataset(
125
+ csv_file=csv_file,
126
+ label_mapping_file=label_mapping_file,
127
+ split='train',
128
+ transform=get_transforms('train', img_size)
129
+ )
130
+
131
+ val_dataset = PestDiseaseDataset(
132
+ csv_file=csv_file,
133
+ label_mapping_file=label_mapping_file,
134
+ split='val',
135
+ transform=get_transforms('val', img_size)
136
+ )
137
+
138
+ test_dataset = PestDiseaseDataset(
139
+ csv_file=csv_file,
140
+ label_mapping_file=label_mapping_file,
141
+ split='test',
142
+ transform=get_transforms('test', img_size)
143
+ )
144
+
145
+ # Create dataloaders
146
+ train_loader = DataLoader(
147
+ train_dataset,
148
+ batch_size=batch_size,
149
+ shuffle=True,
150
+ num_workers=num_workers,
151
+ pin_memory=True
152
+ )
153
+
154
+ val_loader = DataLoader(
155
+ val_dataset,
156
+ batch_size=batch_size,
157
+ shuffle=False,
158
+ num_workers=num_workers,
159
+ pin_memory=True
160
+ )
161
+
162
+ test_loader = DataLoader(
163
+ test_dataset,
164
+ batch_size=batch_size,
165
+ shuffle=False,
166
+ num_workers=num_workers,
167
+ pin_memory=True
168
+ )
169
+
170
+ return {
171
+ 'train': train_loader,
172
+ 'val': val_loader,
173
+ 'test': test_loader,
174
+ 'num_classes': train_dataset.num_classes,
175
+ 'datasets': {
176
+ 'train': train_dataset,
177
+ 'val': val_dataset,
178
+ 'test': test_dataset
179
+ }
180
+ }
181
+
182
+
183
+ def calculate_class_weights(csv_file='dataset.csv', label_mapping_file='label_mapping.json'):
184
+ """
185
+ Calculate class weights for handling imbalanced dataset
186
+
187
+ Returns:
188
+ torch.Tensor: Class weights for loss function
189
+ """
190
+ df = pd.read_csv(csv_file)
191
+ train_df = df[df['split'] == 'train']
192
+
193
+ with open(label_mapping_file, 'r', encoding='utf-8') as f:
194
+ mapping = json.load(f)
195
+ label_to_id = mapping['label_to_id']
196
+ num_classes = mapping['num_classes']
197
+
198
+ # Count samples per class
199
+ class_counts = {}
200
+ for label in train_df['label']:
201
+ label_id = label_to_id[label]
202
+ class_counts[label_id] = class_counts.get(label_id, 0) + 1
203
+
204
+ # Calculate weights (inverse frequency)
205
+ total_samples = len(train_df)
206
+ weights = []
207
+ for i in range(num_classes):
208
+ count = class_counts.get(i, 1)
209
+ weight = total_samples / (num_classes * count)
210
+ weights.append(weight)
211
+
212
+ weights = torch.FloatTensor(weights)
213
+
214
+ print("\nClass weights:")
215
+ for i, w in enumerate(weights):
216
+ print(f" Class {i}: {w:.4f}")
217
+
218
+ return weights
219
+
220
+
221
+ if __name__ == "__main__":
222
+ """Test the dataloader"""
223
+ print("Testing Pest and Disease Dataloader")
224
+ print("=" * 60)
225
+
226
+ # Get dataloaders
227
+ loaders = get_dataloaders(batch_size=8, img_size=224, num_workers=0)
228
+
229
+ # Calculate class weights
230
+ class_weights = calculate_class_weights()
231
+
232
+ print("\n" + "=" * 60)
233
+ print("Testing batch loading...")
234
+ print("=" * 60)
235
+
236
+ # Test loading a batch from train set
237
+ train_loader = loaders['train']
238
+ train_dataset = loaders['datasets']['train']
239
+
240
+ for images, labels in train_loader:
241
+ print(f"\nBatch shape: {images.shape}")
242
+ print(f"Labels shape: {labels.shape}")
243
+ print(f"Image dtype: {images.dtype}")
244
+ print(f"Labels: {labels.tolist()}")
245
+ print(f"Label names: {[train_dataset.get_label_name(l.item()) for l in labels]}")
246
+
247
+ # Check value ranges
248
+ print(f"\nImage value range: [{images.min():.3f}, {images.max():.3f}]")
249
+ break
250
+
251
+ print("\n" + "=" * 60)
252
+ print("Dataloader test completed successfully!")
253
+ print("=" * 60)
demo.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple Demo for Pest and Disease Classification
3
+ Upload an image and get prediction
4
+ """
5
+
6
+ import torch
7
+ from PIL import Image
8
+ import json
9
+ import argparse
10
+ import gradio as gr
11
+ from torchvision import transforms
12
+
13
+ from model import create_model
14
+
15
+
16
+ class PestDiseasePredictor:
17
+ """Simple predictor class"""
18
+
19
+ def __init__(self, checkpoint_path, label_mapping_path, backbone='resnet50', device='cuda'):
20
+ self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
21
+
22
+ # Load label mapping
23
+ with open(label_mapping_path, 'r', encoding='utf-8') as f:
24
+ mapping = json.load(f)
25
+ self.id_to_label = {int(k): v for k, v in mapping['id_to_label'].items()}
26
+ self.num_classes = mapping['num_classes']
27
+
28
+ # Load model
29
+ self.model = create_model(
30
+ num_classes=self.num_classes,
31
+ backbone=backbone,
32
+ pretrained=False
33
+ )
34
+
35
+ # Load checkpoint
36
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
37
+ self.model.load_state_dict(checkpoint['model_state_dict'])
38
+ self.model = self.model.to(self.device)
39
+ self.model.eval()
40
+
41
+ # Image transforms
42
+ self.transform = transforms.Compose([
43
+ transforms.Resize((224, 224)),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
46
+ std=[0.229, 0.224, 0.225])
47
+ ])
48
+
49
+ print(f"Model loaded from {checkpoint_path}")
50
+ print(f"Device: {self.device}")
51
+ print(f"Classes: {self.num_classes}")
52
+
53
+ def predict(self, image):
54
+ """
55
+ Predict class for input image
56
+
57
+ Args:
58
+ image: PIL Image
59
+
60
+ Returns:
61
+ dict: {class_name: probability}
62
+ """
63
+ # Preprocess
64
+ if image.mode != 'RGB':
65
+ image = image.convert('RGB')
66
+
67
+ img_tensor = self.transform(image).unsqueeze(0)
68
+ img_tensor = img_tensor.to(self.device)
69
+
70
+ # Predict
71
+ with torch.no_grad():
72
+ outputs = self.model(img_tensor)
73
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
74
+ probs = probabilities[0].cpu().numpy()
75
+
76
+ # Create results dictionary
77
+ results = {}
78
+ for idx, prob in enumerate(probs):
79
+ class_name = self.id_to_label[idx]
80
+ results[class_name] = float(prob)
81
+
82
+ # Sort by probability
83
+ results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
84
+
85
+ return results
86
+
87
+
88
+ def create_demo(predictor):
89
+ """Create Gradio interface"""
90
+
91
+ def predict_image(image):
92
+ """Prediction function for Gradio"""
93
+ if image is None:
94
+ return None
95
+
96
+ results = predictor.predict(image)
97
+ return results
98
+
99
+ # Create interface
100
+ demo = gr.Interface(
101
+ fn=predict_image,
102
+ inputs=gr.Image(type="pil", label="Upload Image"),
103
+ outputs=gr.Label(num_top_classes=10, label="Predictions"),
104
+ title="🌿 Pest and Disease Classification",
105
+ description="Upload an image of a citrus plant leaf to classify if it's healthy or has pests/diseases.",
106
+ examples=None,
107
+ theme=gr.themes.Soft(),
108
+ allow_flagging="never"
109
+ )
110
+
111
+ return demo
112
+
113
+
114
+ def main(args):
115
+ """Main function"""
116
+ print("Starting Pest and Disease Classification Demo...")
117
+ print("=" * 60)
118
+
119
+ # Create predictor
120
+ predictor = PestDiseasePredictor(
121
+ checkpoint_path=args.checkpoint,
122
+ label_mapping_path=args.label_mapping,
123
+ backbone=args.backbone,
124
+ device=args.device
125
+ )
126
+
127
+ # Create and launch demo
128
+ demo = create_demo(predictor)
129
+
130
+ print("\n" + "=" * 60)
131
+ print("Launching demo...")
132
+ print("=" * 60)
133
+
134
+ demo.launch(
135
+ server_name=args.host,
136
+ server_port=args.port,
137
+ share=args.share
138
+ )
139
+
140
+
141
+ if __name__ == "__main__":
142
+ parser = argparse.ArgumentParser(description='Demo for Pest and Disease Classification')
143
+
144
+ parser.add_argument('--checkpoint', type=str, default='checkpoints/best_model.pth',
145
+ help='Path to model checkpoint')
146
+ parser.add_argument('--label_mapping', type=str, default='label_mapping.json',
147
+ help='Path to label mapping JSON')
148
+ parser.add_argument('--backbone', type=str, default='resnet50',
149
+ choices=['resnet50', 'resnet101', 'efficientnet_b0',
150
+ 'efficientnet_b3', 'mobilenet_v2'],
151
+ help='Model backbone')
152
+ parser.add_argument('--device', type=str, default='cuda',
153
+ choices=['cuda', 'cpu'],
154
+ help='Device to use')
155
+ parser.add_argument('--host', type=str, default='127.0.0.1',
156
+ help='Server host')
157
+ parser.add_argument('--port', type=int, default=7860,
158
+ help='Server port')
159
+ parser.add_argument('--share', action='store_true',
160
+ help='Create public link')
161
+
162
+ args = parser.parse_args()
163
+ main(args)
evaluate.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation script for Pest and Disease Classification
3
+ Generate confusion matrix, classification report, and per-class metrics
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+ from sklearn.metrics import confusion_matrix, classification_report, f1_score
11
+ import argparse
12
+ import json
13
+ from pathlib import Path
14
+
15
+ from dataset import get_dataloaders
16
+ from model import create_model
17
+
18
+
19
+ def evaluate_model(model, dataloader, device, dataset):
20
+ """
21
+ Evaluate model on a dataset
22
+
23
+ Returns:
24
+ predictions: List of predicted labels
25
+ true_labels: List of true labels
26
+ accuracy: Overall accuracy
27
+ """
28
+ model.eval()
29
+ all_preds = []
30
+ all_labels = []
31
+
32
+ with torch.no_grad():
33
+ for inputs, labels in dataloader:
34
+ inputs = inputs.to(device)
35
+ labels = labels.to(device)
36
+
37
+ outputs = model(inputs)
38
+ _, preds = torch.max(outputs, 1)
39
+
40
+ all_preds.extend(preds.cpu().numpy())
41
+ all_labels.extend(labels.cpu().numpy())
42
+
43
+ all_preds = np.array(all_preds)
44
+ all_labels = np.array(all_labels)
45
+ accuracy = np.mean(all_preds == all_labels)
46
+
47
+ return all_preds, all_labels, accuracy
48
+
49
+
50
+ def plot_confusion_matrix(y_true, y_pred, class_names, save_path='confusion_matrix.png'):
51
+ """
52
+ Plot and save confusion matrix
53
+
54
+ Args:
55
+ y_true: True labels
56
+ y_pred: Predicted labels
57
+ class_names: List of class names
58
+ save_path: Path to save figure
59
+ """
60
+ cm = confusion_matrix(y_true, y_pred)
61
+
62
+ # Calculate percentages
63
+ cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
64
+
65
+ # Create figure
66
+ plt.figure(figsize=(12, 10))
67
+
68
+ # Plot with annotations
69
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
70
+ xticklabels=class_names,
71
+ yticklabels=class_names,
72
+ cbar_kws={'label': 'Count'})
73
+
74
+ plt.title('Confusion Matrix', fontsize=16, pad=20)
75
+ plt.ylabel('True Label', fontsize=12)
76
+ plt.xlabel('Predicted Label', fontsize=12)
77
+ plt.xticks(rotation=45, ha='right')
78
+ plt.yticks(rotation=0)
79
+ plt.tight_layout()
80
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
81
+ print(f"Confusion matrix saved to {save_path}")
82
+
83
+ # Also save percentage version
84
+ plt.figure(figsize=(12, 10))
85
+ sns.heatmap(cm_percent, annot=True, fmt='.1f', cmap='Blues',
86
+ xticklabels=class_names,
87
+ yticklabels=class_names,
88
+ cbar_kws={'label': 'Percentage (%)'})
89
+
90
+ plt.title('Confusion Matrix (Percentage)', fontsize=16, pad=20)
91
+ plt.ylabel('True Label', fontsize=12)
92
+ plt.xlabel('Predicted Label', fontsize=12)
93
+ plt.xticks(rotation=45, ha='right')
94
+ plt.yticks(rotation=0)
95
+ plt.tight_layout()
96
+
97
+ save_path_percent = str(save_path).replace('.png', '_percent.png')
98
+ plt.savefig(save_path_percent, dpi=300, bbox_inches='tight')
99
+ print(f"Confusion matrix (percentage) saved to {save_path_percent}")
100
+
101
+ plt.close('all')
102
+
103
+ return cm
104
+
105
+
106
+ def generate_classification_report(y_true, y_pred, class_names, save_path='classification_report.txt'):
107
+ """
108
+ Generate and save detailed classification report
109
+
110
+ Args:
111
+ y_true: True labels
112
+ y_pred: Predicted labels
113
+ class_names: List of class names
114
+ save_path: Path to save report
115
+ """
116
+ # Generate report
117
+ report = classification_report(
118
+ y_true, y_pred,
119
+ target_names=class_names,
120
+ digits=4
121
+ )
122
+
123
+ # Print to console
124
+ print("\n" + "=" * 80)
125
+ print("Classification Report")
126
+ print("=" * 80)
127
+ print(report)
128
+
129
+ # Save to file
130
+ with open(save_path, 'w', encoding='utf-8') as f:
131
+ f.write("Classification Report\n")
132
+ f.write("=" * 80 + "\n")
133
+ f.write(report)
134
+
135
+ print(f"\nClassification report saved to {save_path}")
136
+
137
+ # Calculate per-class metrics
138
+ from sklearn.metrics import precision_recall_fscore_support
139
+ precision, recall, f1, support = precision_recall_fscore_support(
140
+ y_true, y_pred, average=None
141
+ )
142
+
143
+ # Create detailed metrics dictionary
144
+ metrics = {}
145
+ for i, class_name in enumerate(class_names):
146
+ metrics[class_name] = {
147
+ 'precision': float(precision[i]),
148
+ 'recall': float(recall[i]),
149
+ 'f1-score': float(f1[i]),
150
+ 'support': int(support[i])
151
+ }
152
+
153
+ # Add overall metrics
154
+ metrics['overall'] = {
155
+ 'accuracy': float(np.mean(y_true == y_pred)),
156
+ 'macro_avg_f1': float(np.mean(f1)),
157
+ 'weighted_avg_f1': float(f1_score(y_true, y_pred, average='weighted'))
158
+ }
159
+
160
+ # Save metrics as JSON
161
+ metrics_path = str(save_path).replace('.txt', '.json')
162
+ with open(metrics_path, 'w', encoding='utf-8') as f:
163
+ json.dump(metrics, f, indent=2, ensure_ascii=False)
164
+
165
+ print(f"Metrics JSON saved to {metrics_path}")
166
+
167
+ return metrics
168
+
169
+
170
+ def plot_per_class_metrics(metrics, class_names, save_path='per_class_metrics.png'):
171
+ """
172
+ Plot per-class precision, recall, and F1-score
173
+
174
+ Args:
175
+ metrics: Dictionary of metrics
176
+ class_names: List of class names
177
+ save_path: Path to save figure
178
+ """
179
+ precision = [metrics[name]['precision'] for name in class_names]
180
+ recall = [metrics[name]['recall'] for name in class_names]
181
+ f1 = [metrics[name]['f1-score'] for name in class_names]
182
+
183
+ x = np.arange(len(class_names))
184
+ width = 0.25
185
+
186
+ fig, ax = plt.subplots(figsize=(14, 6))
187
+ ax.bar(x - width, precision, width, label='Precision', alpha=0.8)
188
+ ax.bar(x, recall, width, label='Recall', alpha=0.8)
189
+ ax.bar(x + width, f1, width, label='F1-Score', alpha=0.8)
190
+
191
+ ax.set_xlabel('Class', fontsize=12)
192
+ ax.set_ylabel('Score', fontsize=12)
193
+ ax.set_title('Per-Class Metrics', fontsize=14, pad=20)
194
+ ax.set_xticks(x)
195
+ ax.set_xticklabels(class_names, rotation=45, ha='right')
196
+ ax.legend()
197
+ ax.grid(axis='y', alpha=0.3)
198
+ ax.set_ylim([0, 1.1])
199
+
200
+ plt.tight_layout()
201
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
202
+ print(f"Per-class metrics plot saved to {save_path}")
203
+ plt.close()
204
+
205
+
206
+ def main(args):
207
+ """Main evaluation function"""
208
+ print("Pest and Disease Classification Evaluation")
209
+ print("=" * 80)
210
+ print(f"Configuration:")
211
+ print(f" Checkpoint: {args.checkpoint}")
212
+ print(f" Split: {args.split}")
213
+ print(f" Batch size: {args.batch_size}")
214
+ print(f" Device: {args.device}")
215
+ print("=" * 80)
216
+
217
+ # Set device
218
+ device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
219
+ print(f"\nUsing device: {device}")
220
+
221
+ # Load data
222
+ print("\nLoading datasets...")
223
+ loaders = get_dataloaders(
224
+ csv_file=args.csv_file,
225
+ label_mapping_file=args.label_mapping,
226
+ batch_size=args.batch_size,
227
+ img_size=args.img_size,
228
+ num_workers=args.num_workers
229
+ )
230
+
231
+ # Get class names
232
+ dataset = loaders['datasets'][args.split]
233
+ class_names = [dataset.get_label_name(i) for i in range(dataset.num_classes)]
234
+ print(f"Classes: {class_names}")
235
+
236
+ # Create model
237
+ print(f"\nCreating model: {args.backbone}")
238
+ model = create_model(
239
+ num_classes=loaders['num_classes'],
240
+ backbone=args.backbone,
241
+ pretrained=False
242
+ )
243
+
244
+ # Load checkpoint
245
+ print(f"\nLoading checkpoint: {args.checkpoint}")
246
+ checkpoint = torch.load(args.checkpoint, map_location=device)
247
+ model.load_state_dict(checkpoint['model_state_dict'])
248
+ model = model.to(device)
249
+
250
+ if 'val_acc' in checkpoint:
251
+ print(f"Checkpoint validation accuracy: {checkpoint['val_acc']:.4f}")
252
+
253
+ # Evaluate
254
+ print(f"\nEvaluating on {args.split} set...")
255
+ dataloader = loaders[args.split]
256
+ predictions, true_labels, accuracy = evaluate_model(model, dataloader, device, dataset)
257
+
258
+ print(f"\n{args.split.capitalize()} Set Accuracy: {accuracy:.4f}")
259
+
260
+ # Create output directory
261
+ output_dir = Path(args.output_dir)
262
+ output_dir.mkdir(exist_ok=True)
263
+
264
+ # Generate confusion matrix
265
+ print("\nGenerating confusion matrix...")
266
+ cm = plot_confusion_matrix(
267
+ true_labels, predictions, class_names,
268
+ save_path=output_dir / f'confusion_matrix_{args.split}.png'
269
+ )
270
+
271
+ # Generate classification report
272
+ print("\nGenerating classification report...")
273
+ metrics = generate_classification_report(
274
+ true_labels, predictions, class_names,
275
+ save_path=output_dir / f'classification_report_{args.split}.txt'
276
+ )
277
+
278
+ # Plot per-class metrics
279
+ print("\nGenerating per-class metrics plot...")
280
+ plot_per_class_metrics(
281
+ metrics, class_names,
282
+ save_path=output_dir / f'per_class_metrics_{args.split}.png'
283
+ )
284
+
285
+ print("\n" + "=" * 80)
286
+ print("Evaluation complete!")
287
+ print(f"Results saved to {output_dir}/")
288
+ print("=" * 80)
289
+
290
+
291
+ if __name__ == "__main__":
292
+ parser = argparse.ArgumentParser(description='Evaluate Pest and Disease Classifier')
293
+
294
+ # Data parameters
295
+ parser.add_argument('--csv_file', type=str, default='dataset.csv',
296
+ help='Path to dataset CSV')
297
+ parser.add_argument('--label_mapping', type=str, default='label_mapping.json',
298
+ help='Path to label mapping JSON')
299
+
300
+ # Model parameters
301
+ parser.add_argument('--checkpoint', type=str, default='checkpoints/best_model.pth',
302
+ help='Path to model checkpoint')
303
+ parser.add_argument('--backbone', type=str, default='resnet50',
304
+ choices=['resnet50', 'resnet101', 'efficientnet_b0',
305
+ 'efficientnet_b3', 'mobilenet_v2'],
306
+ help='Model backbone')
307
+
308
+ # Evaluation parameters
309
+ parser.add_argument('--split', type=str, default='test',
310
+ choices=['train', 'val', 'test'],
311
+ help='Dataset split to evaluate')
312
+ parser.add_argument('--batch_size', type=int, default=16,
313
+ help='Batch size')
314
+ parser.add_argument('--img_size', type=int, default=224,
315
+ help='Image size')
316
+
317
+ # System parameters
318
+ parser.add_argument('--device', type=str, default='cuda',
319
+ choices=['cuda', 'cpu'],
320
+ help='Device to use')
321
+ parser.add_argument('--num_workers', type=int, default=4,
322
+ help='Number of data loading workers')
323
+ parser.add_argument('--output_dir', type=str, default='evaluation_results',
324
+ help='Directory to save results')
325
+
326
+ args = parser.parse_args()
327
+ main(args)
label_mapping.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "label_to_id": {
3
+ "介殼蟲": 0,
4
+ "健康植株-椪柑": 1,
5
+ "健康植株-茂谷柑": 2,
6
+ "油斑病": 3,
7
+ "潛葉蛾": 4,
8
+ "潰瘍病": 5,
9
+ "煤煙病": 6,
10
+ "薊馬": 7,
11
+ "蚜蟲": 8,
12
+ "黑點病": 9
13
+ },
14
+ "id_to_label": {
15
+ "0": "介殼蟲",
16
+ "1": "健康植株-椪柑",
17
+ "2": "健康植株-茂谷柑",
18
+ "3": "油斑病",
19
+ "4": "潛葉蛾",
20
+ "5": "潰瘍病",
21
+ "6": "煤煙病",
22
+ "7": "薊馬",
23
+ "8": "蚜蟲",
24
+ "9": "黑點病"
25
+ },
26
+ "num_classes": 10
27
+ }
model.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Classification Models for Pest and Disease Detection
3
+ Supports multiple pretrained backbones: ResNet, EfficientNet, MobileNet
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torchvision.models as models
9
+
10
+
11
+ class PestDiseaseClassifier(nn.Module):
12
+ """
13
+ General classifier with pretrained backbone for transfer learning
14
+ """
15
+
16
+ def __init__(self, num_classes=10, backbone='resnet50', pretrained=True, dropout=0.3):
17
+ """
18
+ Args:
19
+ num_classes (int): Number of output classes
20
+ backbone (str): Backbone architecture ('resnet50', 'resnet101', 'efficientnet_b0',
21
+ 'efficientnet_b3', 'mobilenet_v2')
22
+ pretrained (bool): Use pretrained weights
23
+ dropout (float): Dropout rate for regularization
24
+ """
25
+ super(PestDiseaseClassifier, self).__init__()
26
+
27
+ self.backbone_name = backbone
28
+ self.num_classes = num_classes
29
+
30
+ # Select backbone
31
+ if backbone == 'resnet50':
32
+ self.backbone = models.resnet50(pretrained=pretrained)
33
+ num_features = self.backbone.fc.in_features
34
+ self.backbone.fc = nn.Identity()
35
+
36
+ elif backbone == 'resnet101':
37
+ self.backbone = models.resnet101(pretrained=pretrained)
38
+ num_features = self.backbone.fc.in_features
39
+ self.backbone.fc = nn.Identity()
40
+
41
+ elif backbone == 'efficientnet_b0':
42
+ self.backbone = models.efficientnet_b0(pretrained=pretrained)
43
+ num_features = self.backbone.classifier[1].in_features
44
+ self.backbone.classifier = nn.Identity()
45
+
46
+ elif backbone == 'efficientnet_b3':
47
+ self.backbone = models.efficientnet_b3(pretrained=pretrained)
48
+ num_features = self.backbone.classifier[1].in_features
49
+ self.backbone.classifier = nn.Identity()
50
+
51
+ elif backbone == 'mobilenet_v2':
52
+ self.backbone = models.mobilenet_v2(pretrained=pretrained)
53
+ num_features = self.backbone.classifier[1].in_features
54
+ self.backbone.classifier = nn.Identity()
55
+
56
+ else:
57
+ raise ValueError(f"Unknown backbone: {backbone}")
58
+
59
+ # Custom classifier head
60
+ self.classifier = nn.Sequential(
61
+ nn.Dropout(dropout),
62
+ nn.Linear(num_features, 512),
63
+ nn.ReLU(inplace=True),
64
+ nn.Dropout(dropout),
65
+ nn.Linear(512, num_classes)
66
+ )
67
+
68
+ print(f"Model created: {backbone}")
69
+ print(f" Features: {num_features}")
70
+ print(f" Classes: {num_classes}")
71
+ print(f" Pretrained: {pretrained}")
72
+
73
+ def forward(self, x):
74
+ """
75
+ Forward pass
76
+ Args:
77
+ x: Input tensor [batch_size, 3, H, W]
78
+ Returns:
79
+ logits: Output tensor [batch_size, num_classes]
80
+ """
81
+ features = self.backbone(x)
82
+ logits = self.classifier(features)
83
+ return logits
84
+
85
+ def freeze_backbone(self):
86
+ """Freeze backbone parameters for fine-tuning"""
87
+ for param in self.backbone.parameters():
88
+ param.requires_grad = False
89
+ print("Backbone frozen")
90
+
91
+ def unfreeze_backbone(self):
92
+ """Unfreeze backbone parameters"""
93
+ for param in self.backbone.parameters():
94
+ param.requires_grad = True
95
+ print("Backbone unfrozen")
96
+
97
+
98
+ def create_model(num_classes=10, backbone='resnet50', pretrained=True, dropout=0.3):
99
+ """
100
+ Factory function to create model
101
+
102
+ Args:
103
+ num_classes (int): Number of classes
104
+ backbone (str): Model architecture
105
+ pretrained (bool): Use pretrained weights
106
+ dropout (float): Dropout rate
107
+
108
+ Returns:
109
+ model: PestDiseaseClassifier instance
110
+ """
111
+ model = PestDiseaseClassifier(
112
+ num_classes=num_classes,
113
+ backbone=backbone,
114
+ pretrained=pretrained,
115
+ dropout=dropout
116
+ )
117
+ return model
118
+
119
+
120
+ def count_parameters(model):
121
+ """Count total and trainable parameters"""
122
+ total_params = sum(p.numel() for p in model.parameters())
123
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
124
+
125
+ print(f"\nModel Parameters:")
126
+ print(f" Total: {total_params:,}")
127
+ print(f" Trainable: {trainable_params:,}")
128
+ print(f" Non-trainable: {total_params - trainable_params:,}")
129
+
130
+ return total_params, trainable_params
131
+
132
+
133
+ if __name__ == "__main__":
134
+ """Test model creation"""
135
+ print("Testing Pest and Disease Classification Models")
136
+ print("=" * 60)
137
+
138
+ # Test different backbones
139
+ backbones = ['resnet50', 'efficientnet_b0', 'mobilenet_v2']
140
+
141
+ for backbone in backbones:
142
+ print(f"\nTesting {backbone}...")
143
+ print("-" * 60)
144
+
145
+ model = create_model(num_classes=10, backbone=backbone, pretrained=True)
146
+ count_parameters(model)
147
+
148
+ # Test forward pass
149
+ dummy_input = torch.randn(2, 3, 224, 224)
150
+ with torch.no_grad():
151
+ output = model(dummy_input)
152
+
153
+ print(f" Input shape: {dummy_input.shape}")
154
+ print(f" Output shape: {output.shape}")
155
+ print(f" Output range: [{output.min():.3f}, {output.max():.3f}]")
156
+
157
+ print("\n" + "=" * 60)
158
+ print("Model test completed successfully!")
159
+ print("=" * 60)
prepare_dataset.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pest and Disease Classification Dataset Preparation Script
3
+ - Scan data folders
4
+ - Analyze image distribution
5
+ - Generate train/val/test CSV files
6
+ """
7
+
8
+ import pandas as pd
9
+ import numpy as np
10
+ from pathlib import Path
11
+ from PIL import Image
12
+ from sklearn.model_selection import train_test_split
13
+ import json
14
+
15
+ # Configuration parameters
16
+ DATA_DIR = "Data"
17
+ OUTPUT_CSV = "dataset.csv"
18
+ TRAIN_RATIO = 0.7
19
+ VAL_RATIO = 0.15
20
+ TEST_RATIO = 0.15
21
+ RANDOM_SEED = 42
22
+
23
+ # Set random seed
24
+ np.random.seed(RANDOM_SEED)
25
+
26
+ def scan_dataset():
27
+ """Scan dataset and collect all image information"""
28
+ data_list = []
29
+ image_sizes = []
30
+
31
+ category_mapping = {
32
+ "A.健康植株": {
33
+ "椪柑": "健康植株-椪柑",
34
+ "茂谷柑": "健康植株-茂谷柑"
35
+ },
36
+ "B.病害": {
37
+ "1.病害-潰瘍病": "潰瘍病",
38
+ "2.病害-煤煙病": "煤煙病",
39
+ "3.病害-油斑病": "油斑病",
40
+ "4.病害-黑點病2": "黑點病"
41
+ },
42
+ "C.蟲害": {
43
+ "1.蟲害-薊馬": "薊馬",
44
+ "2.蟲害-潛葉蛾": "潛葉蛾",
45
+ "3.蟲害-蚜蟲": "蚜蟲",
46
+ "4.蟲害-介殼蟲": "介殼蟲"
47
+ }
48
+ }
49
+
50
+ print("Scanning dataset...")
51
+
52
+ for main_dir in ["A.健康植株", "B.病害", "C.蟲害"]:
53
+ main_path = Path(DATA_DIR) / main_dir
54
+
55
+ if not main_path.exists():
56
+ print(f"Warning: {main_path} does not exist")
57
+ continue
58
+
59
+ # Iterate through subdirectories
60
+ for sub_dir in main_path.iterdir():
61
+ if not sub_dir.is_dir():
62
+ continue
63
+
64
+ # Determine class label
65
+ try:
66
+ label = category_mapping[main_dir][sub_dir.name]
67
+ print(f" Processing: {main_dir}/{sub_dir.name} -> {label}")
68
+ except KeyError:
69
+ print(f" Warning: Unknown subdirectory {main_dir}/{sub_dir.name}, skipping...")
70
+ continue
71
+
72
+ # Store plant type info
73
+ if main_dir == "A.健康植株":
74
+ plant_type = sub_dir.name # Ponkan or Murcott
75
+ else:
76
+ plant_type = "柑橘"
77
+
78
+ # Scan images (case-insensitive)
79
+ image_files = (list(sub_dir.glob("*.jpg")) + list(sub_dir.glob("*.JPG")) +
80
+ list(sub_dir.glob("*.jpeg")) + list(sub_dir.glob("*.JPEG")) +
81
+ list(sub_dir.glob("*.png")) + list(sub_dir.glob("*.PNG")))
82
+
83
+ for img_path in image_files:
84
+ try:
85
+ # Get image dimensions
86
+ with Image.open(img_path) as img:
87
+ width, height = img.size
88
+ image_sizes.append((width, height))
89
+
90
+ data_list.append({
91
+ 'image_path': str(img_path),
92
+ 'label': label,
93
+ 'main_category': main_dir.split('.')[1],
94
+ 'plant_type': plant_type,
95
+ 'width': width,
96
+ 'height': height
97
+ })
98
+ except Exception as e:
99
+ print(f"Warning: Cannot read {img_path}: {e}")
100
+
101
+ return data_list, image_sizes
102
+
103
+ def analyze_dataset(data_list, image_sizes):
104
+ """Analyze dataset statistics"""
105
+ df = pd.DataFrame(data_list)
106
+
107
+ print("\n" + "="*60)
108
+ print("Dataset Statistics")
109
+ print("="*60)
110
+
111
+ # Overall statistics
112
+ print(f"\nTotal images: {len(df)}")
113
+ print(f"\nClass distribution:")
114
+ label_counts = df['label'].value_counts()
115
+ for label, count in label_counts.items():
116
+ print(f" {label}: {count} images ({count/len(df)*100:.1f}%)")
117
+
118
+ # Image size analysis
119
+ if image_sizes:
120
+ widths, heights = zip(*image_sizes)
121
+ print(f"\nImage size analysis:")
122
+ print(f" Width: min={min(widths)}, max={max(widths)}, avg={np.mean(widths):.0f}")
123
+ print(f" Height: min={min(heights)}, max={max(heights)}, avg={np.mean(heights):.0f}")
124
+
125
+ # Check size consistency
126
+ unique_sizes = set(image_sizes)
127
+ print(f" Unique sizes: {len(unique_sizes)}")
128
+ if len(unique_sizes) <= 5:
129
+ print(f" Main sizes: {list(unique_sizes)[:5]}")
130
+
131
+ # Check class imbalance
132
+ max_count = label_counts.max()
133
+ min_count = label_counts.min()
134
+ imbalance_ratio = max_count / min_count
135
+ print(f"\nClass imbalance ratio: {imbalance_ratio:.2f}x")
136
+ if imbalance_ratio > 3:
137
+ print(" Warning: Severe class imbalance detected. Consider using weighted loss or data augmentation")
138
+
139
+ return df
140
+
141
+ def split_dataset(df, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
142
+ """Split dataset into train/val/test sets with stratified sampling"""
143
+ assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "Ratios must sum to 1"
144
+
145
+ print("\n" + "="*60)
146
+ print("Splitting Dataset (Stratified Sampling)")
147
+ print("="*60)
148
+
149
+ # First split out test set
150
+ train_val_df, test_df = train_test_split(
151
+ df,
152
+ test_size=test_ratio,
153
+ stratify=df['label'],
154
+ random_state=RANDOM_SEED
155
+ )
156
+
157
+ # Then split train and validation from remaining data
158
+ val_ratio_adjusted = val_ratio / (train_ratio + val_ratio)
159
+ train_df, val_df = train_test_split(
160
+ train_val_df,
161
+ test_size=val_ratio_adjusted,
162
+ stratify=train_val_df['label'],
163
+ random_state=RANDOM_SEED
164
+ )
165
+
166
+ # Add split column
167
+ train_df = train_df.copy()
168
+ val_df = val_df.copy()
169
+ test_df = test_df.copy()
170
+
171
+ train_df['split'] = 'train'
172
+ val_df['split'] = 'val'
173
+ test_df['split'] = 'test'
174
+
175
+ # Merge all splits
176
+ final_df = pd.concat([train_df, val_df, test_df], ignore_index=True)
177
+
178
+ # Display class distribution for each split
179
+ print(f"\nTrain set: {len(train_df)} images ({len(train_df)/len(df)*100:.1f}%)")
180
+ print(train_df['label'].value_counts().to_string())
181
+
182
+ print(f"\nValidation set: {len(val_df)} images ({len(val_df)/len(df)*100:.1f}%)")
183
+ print(val_df['label'].value_counts().to_string())
184
+
185
+ print(f"\nTest set: {len(test_df)} images ({len(test_df)/len(df)*100:.1f}%)")
186
+ print(test_df['label'].value_counts().to_string())
187
+
188
+ return final_df
189
+
190
+ def save_dataset(df, output_path):
191
+ """Save dataset CSV and label mapping"""
192
+ # Save complete CSV
193
+ df.to_csv(output_path, index=False, encoding='utf-8-sig')
194
+ print(f"\nDataset saved to: {output_path}")
195
+
196
+ # Create label to ID mapping
197
+ unique_labels = sorted(df['label'].unique())
198
+ label_to_id = {label: idx for idx, label in enumerate(unique_labels)}
199
+ id_to_label = {idx: label for label, idx in label_to_id.items()}
200
+
201
+ # Save label mapping
202
+ mapping_file = "label_mapping.json"
203
+ with open(mapping_file, 'w', encoding='utf-8') as f:
204
+ json.dump({
205
+ 'label_to_id': label_to_id,
206
+ 'id_to_label': id_to_label,
207
+ 'num_classes': len(unique_labels)
208
+ }, f, ensure_ascii=False, indent=2)
209
+
210
+ print(f"Label mapping saved to: {mapping_file}")
211
+ print(f"\nLabel mapping ({len(unique_labels)} classes):")
212
+ for label, idx in label_to_id.items():
213
+ print(f" {idx}: {label}")
214
+
215
+ return label_to_id
216
+
217
+ def main():
218
+ """Main function"""
219
+ print("Pest and Disease Dataset Preparation Tool")
220
+ print("="*60)
221
+
222
+ # 1. Scan dataset
223
+ data_list, image_sizes = scan_dataset()
224
+
225
+ if not data_list:
226
+ print("Error: No images found!")
227
+ return
228
+
229
+ # 2. Analyze dataset
230
+ df = analyze_dataset(data_list, image_sizes)
231
+
232
+ # 3. Split dataset
233
+ final_df = split_dataset(df, TRAIN_RATIO, VAL_RATIO, TEST_RATIO)
234
+
235
+ # 4. Save dataset
236
+ label_to_id = save_dataset(final_df, OUTPUT_CSV)
237
+
238
+ print("\n" + "="*60)
239
+ print("Dataset preparation completed!")
240
+ print("="*60)
241
+ print("\nNext steps:")
242
+ print(" 1. Check dataset.csv and label_mapping.json")
243
+ print(" 2. Run data loader test script")
244
+ print(" 3. Start model training")
245
+
246
+ if __name__ == "__main__":
247
+ main()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pillow
4
+ gradio==4.39.0
5
+ huggingface_hub==0.25.2
6
+ rich
7
+ seaborn
8
+ pathlib
9
+ pandas
10
+ pydantic==2.10.6
train.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple Training Script for Pest and Disease Classification
3
+ Using Rich for progress display
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ from pathlib import Path
10
+ import json
11
+ import argparse
12
+ from rich.console import Console
13
+ from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn
14
+ from rich.table import Table
15
+ from rich.panel import Panel
16
+
17
+ from dataset import get_dataloaders, calculate_class_weights
18
+ from model import create_model
19
+
20
+
21
+ console = Console()
22
+
23
+
24
+ def train_epoch(model, dataloader, criterion, optimizer, device, progress, task):
25
+ """Train for one epoch with progress bar"""
26
+ model.train()
27
+ running_loss = 0.0
28
+ running_corrects = 0
29
+ total_samples = 0
30
+
31
+ for inputs, labels in dataloader:
32
+ inputs = inputs.to(device)
33
+ labels = labels.to(device)
34
+
35
+ optimizer.zero_grad()
36
+ outputs = model(inputs)
37
+ loss = criterion(outputs, labels)
38
+ _, preds = torch.max(outputs, 1)
39
+
40
+ loss.backward()
41
+ optimizer.step()
42
+
43
+ running_loss += loss.item() * inputs.size(0)
44
+ running_corrects += torch.sum(preds == labels.data)
45
+ total_samples += inputs.size(0)
46
+
47
+ progress.update(task, advance=1)
48
+
49
+ epoch_loss = running_loss / total_samples
50
+ epoch_acc = running_corrects.double() / total_samples
51
+ return epoch_loss, epoch_acc.item()
52
+
53
+
54
+ def validate_epoch(model, dataloader, criterion, device, progress, task):
55
+ """Validate for one epoch with progress bar"""
56
+ model.eval()
57
+ running_loss = 0.0
58
+ running_corrects = 0
59
+ total_samples = 0
60
+
61
+ with torch.no_grad():
62
+ for inputs, labels in dataloader:
63
+ inputs = inputs.to(device)
64
+ labels = labels.to(device)
65
+
66
+ outputs = model(inputs)
67
+ loss = criterion(outputs, labels)
68
+ _, preds = torch.max(outputs, 1)
69
+
70
+ running_loss += loss.item() * inputs.size(0)
71
+ running_corrects += torch.sum(preds == labels.data)
72
+ total_samples += inputs.size(0)
73
+
74
+ progress.update(task, advance=1)
75
+
76
+ epoch_loss = running_loss / total_samples
77
+ epoch_acc = running_corrects.double() / total_samples
78
+ return epoch_loss, epoch_acc.item()
79
+
80
+
81
+ def train_model(model, train_loader, val_loader, criterion, optimizer,
82
+ num_epochs, device, save_dir):
83
+ """
84
+ Simple training loop with Rich progress display
85
+ """
86
+ save_dir = Path(save_dir)
87
+ save_dir.mkdir(exist_ok=True)
88
+
89
+ best_val_acc = 0.0
90
+ history = {
91
+ 'train_loss': [],
92
+ 'train_acc': [],
93
+ 'val_loss': [],
94
+ 'val_acc': []
95
+ }
96
+
97
+ console.print("\n[bold green]Starting Training[/bold green]")
98
+
99
+ for epoch in range(num_epochs):
100
+ console.print(f"\n[bold cyan]Epoch {epoch+1}/{num_epochs}[/bold cyan]")
101
+
102
+ with Progress(
103
+ SpinnerColumn(),
104
+ TextColumn("[progress.description]{task.description}"),
105
+ BarColumn(),
106
+ TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
107
+ TimeRemainingColumn(),
108
+ console=console
109
+ ) as progress:
110
+
111
+ # Training
112
+ train_task = progress.add_task(
113
+ "[red]Training...",
114
+ total=len(train_loader)
115
+ )
116
+ train_loss, train_acc = train_epoch(
117
+ model, train_loader, criterion, optimizer,
118
+ device, progress, train_task
119
+ )
120
+
121
+ # Validation
122
+ val_task = progress.add_task(
123
+ "[green]Validating...",
124
+ total=len(val_loader)
125
+ )
126
+ val_loss, val_acc = validate_epoch(
127
+ model, val_loader, criterion, device,
128
+ progress, val_task
129
+ )
130
+
131
+ # Create results table
132
+ table = Table(show_header=True, header_style="bold magenta")
133
+ table.add_column("Split", style="cyan")
134
+ table.add_column("Loss", justify="right", style="yellow")
135
+ table.add_column("Accuracy", justify="right", style="green")
136
+
137
+ table.add_row("Train", f"{train_loss:.4f}", f"{train_acc:.4f}")
138
+ table.add_row("Val", f"{val_loss:.4f}", f"{val_acc:.4f}")
139
+
140
+ console.print(table)
141
+
142
+ # Save history
143
+ history['train_loss'].append(train_loss)
144
+ history['train_acc'].append(train_acc)
145
+ history['val_loss'].append(val_loss)
146
+ history['val_acc'].append(val_acc)
147
+
148
+ # Save best model
149
+ if val_acc > best_val_acc:
150
+ best_val_acc = val_acc
151
+ torch.save({
152
+ 'epoch': epoch,
153
+ 'model_state_dict': model.state_dict(),
154
+ 'optimizer_state_dict': optimizer.state_dict(),
155
+ 'val_acc': val_acc,
156
+ 'val_loss': val_loss,
157
+ }, save_dir / 'best_model.pth')
158
+ console.print(f"[bold green]✓ Saved best model (Val Acc: {val_acc:.4f})[/bold green]")
159
+
160
+ # Save checkpoint every 10 epochs
161
+ if (epoch + 1) % 10 == 0:
162
+ torch.save({
163
+ 'epoch': epoch,
164
+ 'model_state_dict': model.state_dict(),
165
+ 'optimizer_state_dict': optimizer.state_dict(),
166
+ 'val_acc': val_acc,
167
+ 'val_loss': val_loss,
168
+ }, save_dir / f'checkpoint_epoch_{epoch+1}.pth')
169
+ console.print(f"[yellow]Checkpoint saved at epoch {epoch+1}[/yellow]")
170
+
171
+ # Save training history
172
+ with open(save_dir / 'training_history.json', 'w') as f:
173
+ json.dump(history, f, indent=2)
174
+
175
+ console.print(f"\n[bold green]Training Complete![/bold green]")
176
+ console.print(f"[bold]Best Val Acc: {best_val_acc:.4f}[/bold]")
177
+ console.print(f"[bold]Results saved to: {save_dir}/[/bold]")
178
+
179
+ return model, history
180
+
181
+
182
+ def main(args):
183
+ """Main training function"""
184
+ # Print configuration
185
+ config_panel = Panel.fit(
186
+ f"""[bold]Configuration[/bold]
187
+ Backbone: {args.backbone}
188
+ Batch Size: {args.batch_size}
189
+ Image Size: {args.img_size}
190
+ Epochs: {args.epochs}
191
+ Learning Rate: {args.lr}
192
+ Optimizer: {args.optimizer}
193
+ Device: {args.device}
194
+ Class Weights: {args.use_class_weights}""",
195
+ title="Training Settings",
196
+ border_style="blue"
197
+ )
198
+ console.print(config_panel)
199
+
200
+ # Set device
201
+ device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
202
+ console.print(f"\n[bold]Using device: {device}[/bold]")
203
+
204
+ # Load data
205
+ console.print("\n[bold]Loading datasets...[/bold]")
206
+ loaders = get_dataloaders(
207
+ csv_file=args.csv_file,
208
+ label_mapping_file=args.label_mapping,
209
+ batch_size=args.batch_size,
210
+ img_size=args.img_size,
211
+ num_workers=args.num_workers
212
+ )
213
+
214
+ # Create model
215
+ console.print(f"\n[bold]Creating model: {args.backbone}[/bold]")
216
+ model = create_model(
217
+ num_classes=loaders['num_classes'],
218
+ backbone=args.backbone,
219
+ pretrained=True,
220
+ dropout=args.dropout
221
+ )
222
+ model = model.to(device)
223
+
224
+ # Loss function
225
+ if args.use_class_weights:
226
+ class_weights = calculate_class_weights(args.csv_file, args.label_mapping)
227
+ class_weights = class_weights.to(device)
228
+ criterion = nn.CrossEntropyLoss(weight=class_weights)
229
+ console.print("[bold]Using weighted CrossEntropyLoss[/bold]")
230
+ else:
231
+ criterion = nn.CrossEntropyLoss()
232
+ console.print("[bold]Using CrossEntropyLoss[/bold]")
233
+
234
+ # Optimizer
235
+ if args.optimizer == 'adam':
236
+ optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
237
+ elif args.optimizer == 'adamw':
238
+ optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
239
+ elif args.optimizer == 'sgd':
240
+ optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9,
241
+ weight_decay=args.weight_decay)
242
+
243
+ # Train model
244
+ model, history = train_model(
245
+ model=model,
246
+ train_loader=loaders['train'],
247
+ val_loader=loaders['val'],
248
+ criterion=criterion,
249
+ optimizer=optimizer,
250
+ num_epochs=args.epochs,
251
+ device=device,
252
+ save_dir=args.save_dir
253
+ )
254
+
255
+
256
+ if __name__ == "__main__":
257
+ parser = argparse.ArgumentParser(description='Simple Training for Pest and Disease Classifier')
258
+
259
+ # Data parameters
260
+ parser.add_argument('--csv_file', type=str, default='dataset.csv')
261
+ parser.add_argument('--label_mapping', type=str, default='label_mapping.json')
262
+
263
+ # Model parameters
264
+ parser.add_argument('--backbone', type=str, default='resnet50',
265
+ choices=['resnet50', 'resnet101', 'efficientnet_b0',
266
+ 'efficientnet_b3', 'mobilenet_v2'])
267
+ parser.add_argument('--dropout', type=float, default=0.3)
268
+
269
+ # Training parameters
270
+ parser.add_argument('--batch_size', type=int, default=64)
271
+ parser.add_argument('--img_size', type=int, default=224)
272
+ parser.add_argument('--epochs', type=int, default=50)
273
+ parser.add_argument('--lr', type=float, default=0.001)
274
+ parser.add_argument('--optimizer', type=str, default='adamw',
275
+ choices=['adam', 'adamw', 'sgd'])
276
+ parser.add_argument('--weight_decay', type=float, default=0.01)
277
+ parser.add_argument('--use_class_weights', action='store_true')
278
+
279
+ # System parameters
280
+ parser.add_argument('--device', type=str, default='cuda',
281
+ choices=['cuda', 'cpu'])
282
+ parser.add_argument('--num_workers', type=int, default=8)
283
+ parser.add_argument('--save_dir', type=str, default='checkpoints')
284
+
285
+ args = parser.parse_args()
286
+ main(args)