michaela299 commited on
Commit
361cbfe
·
1 Parent(s): 643d9c2

Restore app files

Browse files
Files changed (5) hide show
  1. best_model.pth +3 -0
  2. data_pipeline.py +156 -0
  3. model.py +30 -0
  4. requirements.txt +8 -0
  5. ui.py +71 -0
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1633d1300f4e2ae689ac619603499c5dacc876496079d72492e21254c7e3f9c9
3
+ size 20831138
data_pipeline.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader, default_collate
3
+ from torchvision import transforms
4
+ from datasets import load_dataset
5
+ import torch.utils.data
6
+
7
+ # ImageNet stats for normalization
8
+ IMAGE_MEAN = [0.485, 0.456, 0.406]
9
+ IMAGE_STD = [0.229, 0.224, 0.225]
10
+ IMAGE_SIZE = 256
11
+
12
+ # Transforms for training data (with advanced augmentation)
13
+ train_transform = transforms.Compose([
14
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
15
+
16
+ # geometric augmentations
17
+ transforms.RandomHorizontalFlip(p=0.5),
18
+ transforms.RandomVerticalFlip(p=0.5), # Added vertical flip
19
+ transforms.RandomRotation(30), # Increased rotation range
20
+
21
+ # color/appearance augmentations
22
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # Increased intensity
23
+ transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.0)), # Added blur
24
+
25
+ # final conversion
26
+ transforms.ToTensor(),
27
+ transforms.Normalize(IMAGE_MEAN, IMAGE_STD)
28
+ ])
29
+
30
+ # Transforms for validation/test data (no augmentation)
31
+ val_test_transform = transforms.Compose([
32
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize(IMAGE_MEAN, IMAGE_STD)
35
+ ])
36
+
37
+ def apply_transforms(batch, transform_pipeline):
38
+ """Applies a transform pipeline to a batch of images and converts labels."""
39
+ batch['image'] = [transform_pipeline(img.convert("RGB")) for img in batch['image']]
40
+ # This line is crucial for converting labels to tensors for batching
41
+ batch['label'] = torch.tensor(batch['label'])
42
+ return batch
43
+
44
+ def get_dataloaders(batch_size=32, use_prototype=True):
45
+ """
46
+ Loads, splits, and prepares the PlantVillage dataset, returning DataLoaders.
47
+
48
+ NOTE TO TEAM: The dataloaders yield a dictionary.
49
+ Access batches using:
50
+ batch = next(iter(loader))
51
+ images = batch['image']
52
+ labels = batch['label']
53
+ """
54
+ print("Loading and preparing dataset...")
55
+
56
+ # Load the full dataset from Hugging Face
57
+ full_dataset = load_dataset("DScomp380/plant_village", split='train')
58
+
59
+ if use_prototype:
60
+ # Use 20% of data for prototyping
61
+ print(f"Using 20% prototype dataset (approx {len(full_dataset) * 0.2:.0f} images)...")
62
+ data_subset = full_dataset.train_test_split(test_size=0.8, seed=42)['train']
63
+ else:
64
+ print(f"Using 100% full dataset ({len(full_dataset)} images)...")
65
+ data_subset = full_dataset
66
+
67
+ # 70/15/15 split for train/val/test
68
+ train_val_test_split = data_subset.train_test_split(test_size=0.3, seed=42)
69
+ train_dataset = train_val_test_split['train']
70
+
71
+ val_test_split = train_val_test_split['test'].train_test_split(test_size=0.5, seed=42)
72
+ val_dataset = val_test_split['train']
73
+ test_dataset = val_test_split['test']
74
+
75
+ print(f"Total images in prototype: {len(data_subset)}")
76
+ print(f"Training images: {len(train_dataset)}")
77
+ print(f"Validation images: {len(val_dataset)}")
78
+ print(f"Test images: {len(test_dataset)}")
79
+ print("--------------------")
80
+
81
+ # Apply the correct transforms to each dataset split
82
+ train_dataset.set_transform(lambda batch: apply_transforms(batch, train_transform))
83
+ val_dataset.set_transform(lambda batch: apply_transforms(batch, val_test_transform))
84
+ test_dataset.set_transform(lambda batch: apply_transforms(batch, val_test_transform))
85
+
86
+ # Define the collate_fn for batching tensors
87
+ collate_fn = default_collate
88
+
89
+ train_loader = DataLoader(
90
+ train_dataset,
91
+ batch_size=batch_size,
92
+ shuffle=True,
93
+ collate_fn=collate_fn
94
+ )
95
+ val_loader = DataLoader(
96
+ val_dataset,
97
+ batch_size=batch_size,
98
+ shuffle=False,
99
+ collate_fn=collate_fn
100
+ )
101
+ test_loader = DataLoader(
102
+ test_dataset,
103
+ batch_size=batch_size,
104
+ shuffle=False,
105
+ collate_fn=collate_fn
106
+ )
107
+
108
+ return train_loader, val_loader, test_loader
109
+
110
+ if __name__ == "__main__":
111
+ print("Running data_pipeline.py as a standalone script...")
112
+
113
+ # Test the pipeline with a small batch size
114
+ train_loader, val_loader, test_loader = get_dataloaders(batch_size=4, use_prototype=True)
115
+
116
+ print("\n--- Testing Train Loader ---")
117
+ # Test train loader
118
+ try:
119
+ # FIX: Get the batch as a dictionary first
120
+ batch = next(iter(train_loader))
121
+ # FIX: Access the data using keys
122
+ images = batch['image']
123
+ labels = batch['label']
124
+
125
+ print(f"Image batch shape: {images.shape}")
126
+ print(f"Label batch shape: {labels.shape}")
127
+
128
+ # Assert correct shapes
129
+ assert images.shape == (4, 3, IMAGE_SIZE, IMAGE_SIZE)
130
+ assert labels.shape == (4,)
131
+ print("Train loader test PASSED.")
132
+
133
+ except Exception as e:
134
+ print(f"Train loader test FAILED: {e}")
135
+
136
+ print("\n--- Testing Validation Loader ---")
137
+ # Test validation loader
138
+ try:
139
+ # FIX: Get the batch as a dictionary first
140
+ batch = next(iter(val_loader))
141
+ # FIX: Access the data using keys
142
+ images = batch['image']
143
+ labels = batch['label']
144
+
145
+ print(f"Image batch shape: {images.shape}")
146
+ print(f"Label batch shape: {labels.shape}")
147
+
148
+ # Assert correct shapes
149
+ assert images.shape == (4, 3, IMAGE_SIZE, IMAGE_SIZE)
150
+ assert labels.shape == (4,)
151
+ print("Validation loader test PASSED.")
152
+
153
+ except Exception as e:
154
+ print(f"Validation loader test FAILED: {e}")
155
+
156
+ print("\nData pipeline script finished.")
model.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class BaselineCNN(nn.Module):
6
+ def __init__(self, num_classes=39):
7
+ super(BaselineCNN, self).__init__()
8
+
9
+
10
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
11
+ self.bn1 = nn.BatchNorm2d(32)
12
+
13
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
14
+ self.bn2 = nn.BatchNorm2d(64)
15
+
16
+ self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
17
+ self.bn3 = nn.BatchNorm2d(128)
18
+
19
+ self.pool = nn.MaxPool2d(2, 2)
20
+
21
+ self.fc = nn.Linear(128 * 32 * 32, num_classes)
22
+
23
+ def forward(self, x):
24
+
25
+ x = self.pool(F.relu(self.bn1(self.conv1(x))))
26
+ x = self.pool(F.relu(self.bn2(self.conv2(x))))
27
+ x = self.pool(F.relu(self.bn3(self.conv3(x))))
28
+ x = torch.flatten(x, 1)
29
+ x = self.fc(x)
30
+ return x
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ datasets
5
+ clearml
6
+ pytest
7
+ scikit-learn
8
+ matplotlib
ui.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from model import BaselineCNN
4
+ from data_pipeline import val_test_transform, IMAGE_SIZE
5
+ import torch
6
+
7
+ from datasets import load_dataset
8
+ dataset = load_dataset("DScomp380/plant_village", split="train")
9
+ CLASS_NAMES = dataset.features["label"].names
10
+
11
+
12
+ #load the model
13
+ CLASSES = 39
14
+ model = BaselineCNN(num_classes=CLASSES)
15
+ model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
16
+ model.eval()
17
+
18
+ def predict(input_image):
19
+ #resize to models image size, convert to tensor, normalize values
20
+ image_tensor = val_test_transform(input_image)
21
+
22
+ #add new dimension at index 0 so each image has a batch size of atleast 1
23
+ image_tensor = image_tensor.unsqueeze(0)
24
+
25
+ #run inference
26
+ with torch.no_grad():
27
+ #pass the batch through the model
28
+ output = model(image_tensor)
29
+
30
+ #convert to probabilitiees
31
+ probabilities = torch.nn.functional.softmax(output,dim=1)[0]
32
+
33
+ numPredictionsToShow = 10
34
+
35
+ #get the top 5 predictions
36
+ topProbs, TopClassIndicies = torch.topk(probabilities, numPredictionsToShow)
37
+ #returns 5 largest probabilities
38
+
39
+ #create the output dictionary
40
+ result = {}
41
+ for rank in range(numPredictionsToShow):#loop through top 5
42
+ classIndex = TopClassIndicies[rank].item()#get the int value from the tensor at index rank
43
+ className = CLASS_NAMES[classIndex]#get human readable class name
44
+ probabilityValue = topProbs[rank].item()#convert prob from tensor to python float
45
+
46
+ result[className] = probabilityValue
47
+
48
+ return result
49
+
50
+
51
+ with gr.Blocks(title="Plant Disease Classifier") as app:
52
+ gr.Markdown("# Plant Disease Classification")
53
+ gr.Markdown("Upload an image of a plant leaf to classify its disease.")
54
+
55
+ with gr.Row():
56
+ image_input = gr.Image(type="pil", label="Upload Leaf Image")
57
+ label_output = gr.Label(label="Predicted Disease")
58
+
59
+ gr.Examples(
60
+ examples =[], inputs=image_input)
61
+
62
+ submit_btn = gr.Button("Submit")
63
+ submit_btn.click(fn=predict, inputs=image_input, outputs=label_output)
64
+
65
+ #fn=predict,
66
+ # inputs=gr.Image(type="pil"),
67
+ # outputs=gr.Label(num_top_classes=3))
68
+
69
+ if __name__ == "__main__":
70
+ app.launch(ssr_mode=False)
71
+