Arunisto commited on
Commit
4952a5c
·
verified ·
1 Parent(s): 4e9ff7d

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +69 -0
  2. requirements.txt +160 -0
  3. veggie_net.pth +3 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import os
7
+
8
+ # 3. Define the model used for training
9
+ class VeggieNet(nn.Module):
10
+ def __init__(self, num_classes):
11
+ super().__init__()
12
+ self.net = nn.Sequential(
13
+ nn.Flatten(),
14
+ nn.Linear(3 * 128 * 128, 512),
15
+ nn.BatchNorm1d(512),
16
+ nn.ReLU(),
17
+ nn.Dropout(0.3),
18
+ nn.Linear(512, 256),
19
+ nn.BatchNorm1d(256),
20
+ nn.ReLU(),
21
+ nn.Dropout(0.3),
22
+ nn.Linear(256, 128),
23
+ nn.BatchNorm1d(128),
24
+ nn.ReLU(),
25
+ nn.Dropout(0.3),
26
+ nn.Linear(128, num_classes)
27
+ )
28
+
29
+ def forward(self, x):
30
+ return self.net(x)
31
+
32
+ # Manually loading the class names to match the dataset
33
+ class_names = ['Bean', 'Bitter_Gourd', 'Bottle_Gourd', 'Brinjal', 'Broccoli', 'Cabbage', 'Capsicum', 'Carrot', 'Cauliflower', 'Cucumber', 'Papaya', 'Potato', 'Pumpkin', 'Radish', 'Tomato']
34
+
35
+ #loading the model
36
+ device = "gpu" if torch.cuda.is_available() else "cpu"
37
+ model = VeggieNet(num_classes=len(class_names))
38
+ model.load_state_dict(torch.load("veggie_net.pth", map_location=device))
39
+ model.eval()
40
+
41
+ #image preprocessing
42
+ transform = transforms.Compose([
43
+ transforms.Resize((128, 128)),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize((0.5,), (0.5,))
46
+ ])
47
+
48
+ # prediction function
49
+ def predict(image):
50
+ img = image.convert("RGB")
51
+ img = transform(img)
52
+ img = img.unsqueeze(0)
53
+ with torch.no_grad():
54
+ outputs = model(img)
55
+ _, predicted = torch.max(outputs, 1)
56
+ return class_names[predicted.item()]
57
+
58
+ # gradio ui
59
+ interface = gr.Interface(
60
+ fn=predict,
61
+ inputs=gr.Image(type="pil"),
62
+ outputs="label",
63
+ title="🥕 Vegetable Image Classifier",
64
+ description="Upload a vegetable image and the model will try to guess what it is!"
65
+ )
66
+
67
+ #launching the app
68
+ if __name__ == "__main__":
69
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ anyio==4.8.0
4
+ argon2-cffi==23.1.0
5
+ argon2-cffi-bindings==21.2.0
6
+ arrow==1.3.0
7
+ asttokens==3.0.0
8
+ async-lru==2.0.4
9
+ attrs==25.1.0
10
+ babel==2.17.0
11
+ beautifulsoup4==4.13.3
12
+ bleach==6.2.0
13
+ certifi==2025.1.31
14
+ cffi==1.17.1
15
+ charset-normalizer==3.4.1
16
+ click==8.1.8
17
+ comm==0.2.2
18
+ contourpy==1.3.1
19
+ cycler==0.12.1
20
+ debugpy==1.8.12
21
+ decorator==5.2.1
22
+ defusedxml==0.7.1
23
+ executing==2.2.0
24
+ fastapi==0.115.11
25
+ fastjsonschema==2.21.1
26
+ ffmpy==0.5.0
27
+ filelock==3.17.0
28
+ fonttools==4.56.0
29
+ fqdn==1.5.1
30
+ fsspec==2025.2.0
31
+ gradio==5.20.1
32
+ gradio_client==1.7.2
33
+ groovy==0.1.2
34
+ h11==0.14.0
35
+ httpcore==1.0.7
36
+ httpx==0.28.1
37
+ huggingface-hub==0.29.3
38
+ idna==3.10
39
+ ipykernel==6.29.5
40
+ ipython==9.0.1
41
+ ipython_pygments_lexers==1.1.1
42
+ ipywidgets==8.1.5
43
+ isoduration==20.11.0
44
+ jedi==0.19.2
45
+ Jinja2==3.1.5
46
+ joblib==1.4.2
47
+ json5==0.10.0
48
+ jsonpointer==3.0.0
49
+ jsonschema==4.23.0
50
+ jsonschema-specifications==2024.10.1
51
+ jupyter==1.1.1
52
+ jupyter-console==6.6.3
53
+ jupyter-events==0.12.0
54
+ jupyter-lsp==2.2.5
55
+ jupyter_client==8.6.3
56
+ jupyter_core==5.7.2
57
+ jupyter_server==2.15.0
58
+ jupyter_server_terminals==0.5.3
59
+ jupyterlab==4.3.5
60
+ jupyterlab_pygments==0.3.0
61
+ jupyterlab_server==2.27.3
62
+ jupyterlab_widgets==3.0.13
63
+ kiwisolver==1.4.8
64
+ markdown-it-py==3.0.0
65
+ MarkupSafe==2.1.5
66
+ matplotlib==3.10.1
67
+ matplotlib-inline==0.1.7
68
+ mdurl==0.1.2
69
+ mistune==3.1.2
70
+ mpmath==1.3.0
71
+ nbclient==0.10.2
72
+ nbconvert==7.16.6
73
+ nbformat==5.10.4
74
+ nest-asyncio==1.6.0
75
+ networkx==3.4.2
76
+ notebook==7.3.2
77
+ notebook_shim==0.2.4
78
+ numpy==2.2.3
79
+ nvidia-cublas-cu12==12.4.5.8
80
+ nvidia-cuda-cupti-cu12==12.4.127
81
+ nvidia-cuda-nvrtc-cu12==12.4.127
82
+ nvidia-cuda-runtime-cu12==12.4.127
83
+ nvidia-cudnn-cu12==9.1.0.70
84
+ nvidia-cufft-cu12==11.2.1.3
85
+ nvidia-curand-cu12==10.3.5.147
86
+ nvidia-cusolver-cu12==11.6.1.9
87
+ nvidia-cusparse-cu12==12.3.1.170
88
+ nvidia-cusparselt-cu12==0.6.2
89
+ nvidia-nccl-cu12==2.21.5
90
+ nvidia-nvjitlink-cu12==12.4.127
91
+ nvidia-nvtx-cu12==12.4.127
92
+ orjson==3.10.15
93
+ overrides==7.7.0
94
+ packaging==24.2
95
+ pandas==2.2.3
96
+ pandocfilters==1.5.1
97
+ parso==0.8.4
98
+ pexpect==4.9.0
99
+ pillow==11.1.0
100
+ platformdirs==4.3.6
101
+ prometheus_client==0.21.1
102
+ prompt_toolkit==3.0.50
103
+ psutil==7.0.0
104
+ ptyprocess==0.7.0
105
+ pure_eval==0.2.3
106
+ pycparser==2.22
107
+ pydantic==2.10.6
108
+ pydantic_core==2.27.2
109
+ pydub==0.25.1
110
+ Pygments==2.19.1
111
+ pyparsing==3.2.1
112
+ python-dateutil==2.9.0.post0
113
+ python-json-logger==3.2.1
114
+ python-multipart==0.0.20
115
+ pytz==2025.1
116
+ PyYAML==6.0.2
117
+ pyzmq==26.2.1
118
+ referencing==0.36.2
119
+ requests==2.32.3
120
+ rfc3339-validator==0.1.4
121
+ rfc3986-validator==0.1.1
122
+ rich==13.9.4
123
+ rpds-py==0.23.1
124
+ ruff==0.9.10
125
+ safehttpx==0.1.6
126
+ scikit-learn==1.6.1
127
+ scipy==1.15.2
128
+ semantic-version==2.10.0
129
+ Send2Trash==1.8.3
130
+ setuptools==75.8.2
131
+ shellingham==1.5.4
132
+ six==1.17.0
133
+ sniffio==1.3.1
134
+ soupsieve==2.6
135
+ stack-data==0.6.3
136
+ starlette==0.46.1
137
+ sympy==1.13.1
138
+ terminado==0.18.1
139
+ threadpoolctl==3.5.0
140
+ tinycss2==1.4.0
141
+ tomlkit==0.13.2
142
+ torch==2.6.0
143
+ torchvision==0.21.0
144
+ tornado==6.4.2
145
+ tqdm==4.67.1
146
+ traitlets==5.14.3
147
+ triton==3.2.0
148
+ typer==0.15.2
149
+ types-python-dateutil==2.9.0.20241206
150
+ typing_extensions==4.12.2
151
+ tzdata==2025.1
152
+ uri-template==1.3.0
153
+ urllib3==2.3.0
154
+ uvicorn==0.34.0
155
+ wcwidth==0.2.13
156
+ webcolors==24.11.1
157
+ webencodings==0.5.1
158
+ websocket-client==1.8.0
159
+ websockets==15.0.1
160
+ widgetsnbextension==4.0.13
veggie_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84d06bfeaefe102198eb22344af28cc01fbdce2453657ba2b08fdc80c0758d62
3
+ size 101351650