AAAkater commited on
Commit
d10b7cf
·
verified ·
1 Parent(s): bf17302

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -textimages/BLAST1_008.jpg filter=lfs diff=lfs merge=lfs -text
2
+ images/BLAST1_016.jpg filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+ # checkpoints
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
README.md CHANGED
@@ -1,3 +1,66 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - pytorch
5
+ - computer-vision
6
+ - image-classification
7
+ - rice-disease
8
+ license: mit
9
+ ---
10
+
11
+ # VGG16-CNN Rice Disease Classification Model
12
+
13
+ This model is designed for classifying rice plant diseases using a modified VGG16 architecture with additional CNN layers.
14
+
15
+ ## Model Description
16
+
17
+ ### Architecture
18
+ - Base model: VGG16 (pretrained on ImageNet)
19
+ - Additional custom CNN layer with:
20
+ - Conv2d(512, 64, kernel_size=3)
21
+ - ReLU activation
22
+ - BatchNorm2d
23
+ - MaxPool2d
24
+ - Custom classifier with:
25
+ - Linear layers (32*3*6 → 1024 → 5)
26
+ - Dropout (0.4)
27
+
28
+ ### Task
29
+ Image classification for rice plant diseases
30
+
31
+ ### Classes
32
+ 1. Bacterialblight
33
+ 2. Blast
34
+ 3. Brownspot
35
+ 4. Healthy
36
+ 5. Tungro
37
+
38
+ ## Training
39
+
40
+ The model uses transfer learning with a frozen VGG16 backbone.
41
+
42
+ ## Intended Use
43
+ - Primary intended use: Rice disease diagnosis through leaf image analysis
44
+ - Out-of-scope use: Should not be used for critical agricultural decisions without expert verification
45
+
46
+ ## Input
47
+ - RGB images
48
+ - Required size: 224x224 pixels
49
+ - Preprocessing:
50
+ - Normalization (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
51
+
52
+ ## Limitations
53
+ Please note that this model should be used as a supportive tool and not as a sole decision-maker for disease diagnosis.
54
+
55
+ ## Model Author
56
+ [Your Name/Organization]
57
+
58
+ ## Citation
59
+ If you use this model, please cite:
60
+ ```
61
+ @software{vgg_cnn_rice_disease,
62
+ title={VGG16-CNN Rice Disease Classification Model},
63
+ version={0.1.0},
64
+ year={2024}
65
+ }
66
+ ```
checkpoints/vgg_net_model_50.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9da3747b37fa8286a07b7dd7eb5cfdfd734e46fcb94b099c346749625e841e05
3
+ size 557011773
images/BACTERAILBLIGHT3_001.jpg ADDED
images/BACTERAILBLIGHT3_002.jpg ADDED
images/BACTERAILBLIGHT3_003.jpg ADDED
images/BACTERAILBLIGHT3_004.jpg ADDED
images/BLAST1_008.jpg ADDED

Git LFS Details

  • SHA256: 8b1286cff8e4d2738addb6a9d83d0203e7a4002d73179ebc5b37558d659696bc
  • Pointer size: 131 Bytes
  • Size of remote file: 132 kB
images/BLAST1_011.jpg ADDED
images/BLAST1_016.jpg ADDED

Git LFS Details

  • SHA256: c92e8632850b6924bda2ac6f59e1a9cd43955bdcfef8b8b486af9245b019e146
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
images/BLAST1_020.jpg ADDED
images/TUNGRO1_009.jpg ADDED
images/TUNGRO1_014.jpg ADDED
images/TUNGRO1_019.jpg ADDED
images/TUNGRO1_022.jpg ADDED
images/brownspot_orig_010.jpg ADDED
images/brownspot_orig_014.jpg ADDED
images/brownspot_orig_018.jpg ADDED
images/brownspot_orig_021.jpg ADDED
images/shape 13 .jpg ADDED
images/shape 19 .jpg ADDED
images/shape 7 .jpg ADDED
main.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from PIL import Image
4
+ from torch import Tensor
5
+ from torchvision import transforms
6
+
7
+ from model import VGG16WithCNN
8
+
9
+
10
+ def getModel(device: torch.device, model_path: str):
11
+ model = VGG16WithCNN(5)
12
+ # 加载训练好的权重
13
+ model.load_state_dict(
14
+ torch.load(
15
+ model_path,
16
+ weights_only=True,
17
+ )
18
+ )
19
+
20
+ model.to(device)
21
+ return model
22
+
23
+
24
+ def preprocess_image(image_path: str, image_size=(224, 224)):
25
+ """
26
+ 预处理图片,使其符合模型输入要求
27
+ """
28
+ transform = transforms.Compose(
29
+ [
30
+ transforms.Resize(image_size),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
33
+ ]
34
+ )
35
+
36
+ # 打开图片并转换
37
+ image: Image.Image = Image.open(image_path).convert("RGB")
38
+ image_tensor: Tensor = transform(image)
39
+ # 添加batch维度
40
+ image_tensor = image_tensor.unsqueeze(0)
41
+
42
+ return image_tensor
43
+
44
+
45
+ def predict_single_image(
46
+ image_path: str, model: nn.Module, device: torch.device, class_names: list[str]
47
+ ) -> str:
48
+ """
49
+ 预测单个图片的标签
50
+ Args:
51
+ image_path: 图片路径
52
+ model: 模型
53
+ device: 设备
54
+
55
+ Returns:
56
+ 预测的标签名
57
+ """
58
+
59
+ image_tensor = preprocess_image(image_path)
60
+
61
+ image_tensor = image_tensor.to(device)
62
+ # 预测
63
+ model.eval()
64
+ with torch.no_grad():
65
+ output = model(image_tensor)
66
+ _, pred = torch.max(output, 1)
67
+
68
+ predicted_label = class_names[int(pred.item())]
69
+
70
+ return predicted_label
71
+
72
+
73
+ if __name__ == "__main__":
74
+ # 测试单张图片预测
75
+ # 注意:需要替换为实际的测试图片路径
76
+
77
+ p = "./checkpoints/vgg_net_model_50.pth"
78
+
79
+ class_names = [
80
+ "Bacterialblight",
81
+ "Blast",
82
+ "Brownspot",
83
+ "Healthy",
84
+ "Tungro",
85
+ ]
86
+
87
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
+ model = getModel(device=device, model_path=p)
89
+
90
+ test_image_path = "./images/BLAST1_011.jpg"
91
+ try:
92
+ predicted_label = predict_single_image(
93
+ test_image_path, model, device, class_names=class_names
94
+ )
95
+ print("\nSingle image prediction result:")
96
+ print(f"Image: {test_image_path}")
97
+ print(f"Predicted label: {predicted_label}")
98
+ except FileNotFoundError:
99
+ print("Please provide a valid image path to test single image prediction")
model.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+ from PIL import Image
5
+ from torch import Tensor, nn
6
+ from torchvision.models import VGG, VGG16_Weights
7
+
8
+
9
+ class VGG16WithCNN(nn.Module):
10
+ def __init__(self, num_classes: int = 10):
11
+ super(VGG16WithCNN, self).__init__()
12
+ self.num_classes = num_classes
13
+ self.vgg16: VGG = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
14
+
15
+ for param in self.vgg16.parameters():
16
+ param.requires_grad = False
17
+
18
+ self.custom_cnn = nn.Sequential(
19
+ nn.Conv2d(512, 64, kernel_size=3, stride=1, padding=1),
20
+ nn.ReLU(inplace=True),
21
+ nn.BatchNorm2d(64),
22
+ nn.MaxPool2d(kernel_size=2, stride=2),
23
+ )
24
+
25
+ # 分类器
26
+ self.classifier = nn.Sequential(
27
+ nn.Flatten(),
28
+ nn.Linear(32 * 3 * 6, 1024),
29
+ nn.ReLU(inplace=True),
30
+ nn.Dropout(0.4),
31
+ nn.Linear(1024, num_classes),
32
+ )
33
+
34
+ def forward(self, x: Tensor):
35
+ x = self.vgg16.features(x)
36
+ x = self.custom_cnn(x)
37
+ x = x.view(x.size(0), -1)
38
+ x = self.classifier(x)
39
+ return x
pyproject.toml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "vgg-cnn"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "huggingface-hub[cli]>=0.35.3",
9
+ "torch>=2.8.0",
10
+ "torchvision>=0.23.0",
11
+ ]
uv.lock ADDED
The diff for this file is too large to render. See raw diff