yichuan-huang commited on
Commit
d958a06
·
1 Parent(s): 5e7cf9c
Files changed (8) hide show
  1. .gitignore +3 -0
  2. app.py +97 -0
  3. categories.json +0 -0
  4. classifier.py +271 -0
  5. config.py +31 -0
  6. knowledge_base.py +55 -0
  7. requirement.txt +6 -2
  8. test_images/cardboard1.jpg +0 -0
.gitignore CHANGED
@@ -190,3 +190,6 @@ $RECYCLE.BIN/
190
 
191
  # Log files created by default by the nohup command
192
  nohup.out
 
 
 
 
190
 
191
  # Log files created by default by the nohup command
192
  nohup.out
193
+
194
+ # Python cache
195
+ __pycache__
app.py CHANGED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import os
4
+ from classifier import GarbageClassifier
5
+ from config import Config
6
+
7
+ # Initialize classifier
8
+ config = Config()
9
+ classifier = GarbageClassifier(config)
10
+
11
+ # Load model at startup
12
+ print("Loading model...")
13
+ classifier.load_model()
14
+ print("Model loaded successfully!")
15
+
16
+
17
+ def classify_garbage(image):
18
+ """
19
+ Classify garbage in uploaded image
20
+ """
21
+ if image is None:
22
+ return "Please upload an image", "No image provided"
23
+
24
+ try:
25
+ classification, full_response = classifier.classify_image(image)
26
+ return classification, full_response
27
+ except Exception as e:
28
+ return "Error", f"Classification failed: {str(e)}"
29
+
30
+
31
+ def get_example_images():
32
+ """Get example images if they exist"""
33
+ example_dir = "test_images"
34
+ examples = []
35
+ if os.path.exists(example_dir):
36
+ for file in os.listdir(example_dir):
37
+ if file.lower().endswith((".png", ".jpg", ".jpeg")):
38
+ examples.append(os.path.join(example_dir, file))
39
+ return examples[:3] # Limit to 3 examples
40
+
41
+
42
+ # Create Gradio interface
43
+ with gr.Blocks(title="Garbage Classification System") as demo:
44
+ gr.Markdown("# 🗂️ Garbage Classification System")
45
+ gr.Markdown(
46
+ "Upload an image to classify garbage into: Recyclable Waste, Food/Kitchen Waste, Hazardous Waste, or Other Waste"
47
+ )
48
+
49
+ with gr.Row():
50
+ with gr.Column():
51
+ image_input = gr.Image(type="pil", label="Upload Garbage Image")
52
+
53
+ classify_btn = gr.Button("Classify Garbage", variant="primary", size="lg")
54
+
55
+ with gr.Column():
56
+ classification_output = gr.Textbox(
57
+ label="Classification Result",
58
+ placeholder="Upload an image and click classify",
59
+ )
60
+
61
+ full_response_output = gr.Textbox(
62
+ label="Detailed Analysis",
63
+ placeholder="Detailed reasoning will appear here",
64
+ lines=10,
65
+ )
66
+
67
+ # Category information
68
+ with gr.Accordion("📋 Garbage Categories Information", open=False):
69
+ category_info = classifier.get_categories_info()
70
+ for category, description in category_info.items():
71
+ gr.Markdown(f"**{category}**: {description}")
72
+
73
+ # Examples
74
+ examples = get_example_images()
75
+ if examples:
76
+ gr.Examples(examples=examples, inputs=image_input, label="Example Images")
77
+
78
+ # Event handlers
79
+ classify_btn.click(
80
+ fn=classify_garbage,
81
+ inputs=image_input,
82
+ outputs=[classification_output, full_response_output],
83
+ )
84
+
85
+ # Auto-classify on image upload
86
+ image_input.change(
87
+ fn=classify_garbage,
88
+ inputs=image_input,
89
+ outputs=[classification_output, full_response_output],
90
+ )
91
+
92
+ if __name__ == "__main__":
93
+ demo.launch(
94
+ share=config.GRADIO_SHARE,
95
+ server_name=config.GRADIO_SERVER_NAME,
96
+ server_port=config.GRADIO_PORT,
97
+ )
categories.json DELETED
File without changes
classifier.py CHANGED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, Gemma3nForConditionalGeneration
2
+ from PIL import Image
3
+ import torch
4
+ import logging
5
+ from typing import Union, Tuple
6
+ from config import Config
7
+ from knowledge_base import GarbageClassificationKnowledge
8
+
9
+
10
+ class GarbageClassifier:
11
+ def __init__(self, config: Config = None):
12
+ self.config = config or Config()
13
+ self.knowledge = GarbageClassificationKnowledge()
14
+ self.processor = None
15
+ self.model = None
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ # Setup logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ self.logger = logging.getLogger(__name__)
21
+
22
+ def load_model(self):
23
+ """Load the model and processor"""
24
+ try:
25
+ self.logger.info(f"Loading model: {self.config.MODEL_NAME}")
26
+
27
+ # Load processor
28
+ kwargs = {}
29
+ if self.config.HF_TOKEN:
30
+ kwargs["token"] = self.config.HF_TOKEN
31
+
32
+ self.processor = AutoProcessor.from_pretrained(
33
+ self.config.MODEL_NAME, **kwargs
34
+ )
35
+
36
+ # Load model
37
+ self.model = Gemma3nForConditionalGeneration.from_pretrained(
38
+ self.config.MODEL_NAME,
39
+ torch_dtype=self.config.TORCH_DTYPE,
40
+ device_map=self.config.DEVICE_MAP,
41
+ ).eval()
42
+
43
+ self.logger.info("Model loaded successfully")
44
+
45
+ except Exception as e:
46
+ self.logger.error(f"Error loading model: {str(e)}")
47
+ raise
48
+
49
+ def preprocess_image(self, image: Image.Image) -> Image.Image:
50
+ """
51
+ Preprocess image to meet Gemma3n requirements (512x512)
52
+ """
53
+ # Convert to RGB if necessary
54
+ if image.mode != "RGB":
55
+ image = image.convert("RGB")
56
+
57
+ # Resize to 512x512 as required by Gemma3n
58
+ target_size = (512, 512)
59
+
60
+ # Calculate aspect ratio preserving resize
61
+ original_width, original_height = image.size
62
+ aspect_ratio = original_width / original_height
63
+
64
+ if aspect_ratio > 1:
65
+ # Width is larger
66
+ new_width = target_size[0]
67
+ new_height = int(target_size[0] / aspect_ratio)
68
+ else:
69
+ # Height is larger or equal
70
+ new_height = target_size[1]
71
+ new_width = int(target_size[1] * aspect_ratio)
72
+
73
+ # Resize image maintaining aspect ratio
74
+ image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
75
+
76
+ # Create a new image with target size and paste the resized image
77
+ processed_image = Image.new(
78
+ "RGB", target_size, (255, 255, 255)
79
+ ) # White background
80
+
81
+ # Calculate position to center the image
82
+ x_offset = (target_size[0] - new_width) // 2
83
+ y_offset = (target_size[1] - new_height) // 2
84
+
85
+ processed_image.paste(image, (x_offset, y_offset))
86
+
87
+ return processed_image
88
+
89
+ def classify_image(self, image: Union[str, Image.Image]) -> Tuple[str, str]:
90
+ """
91
+ Classify garbage in the image
92
+
93
+ Args:
94
+ image: PIL Image or path to image file
95
+
96
+ Returns:
97
+ Tuple of (classification_result, full_response)
98
+ """
99
+ if self.model is None or self.processor is None:
100
+ raise RuntimeError("Model not loaded. Call load_model() first.")
101
+
102
+ try:
103
+ # Load and process image
104
+ if isinstance(image, str):
105
+ image = Image.open(image)
106
+ elif not isinstance(image, Image.Image):
107
+ raise ValueError("Image must be a PIL Image or file path")
108
+
109
+ # Preprocess image to meet Gemma3n requirements
110
+ processed_image = self.preprocess_image(image)
111
+
112
+ # Prepare messages with system prompt and user query
113
+ messages = [
114
+ {
115
+ "role": "system",
116
+ "content": [{
117
+ "type" : "text",
118
+ "text": self.knowledge.get_system_prompt(),
119
+ }
120
+ ],
121
+ },
122
+ {
123
+ "role": "user",
124
+ "content": [
125
+ {"type": "image", "image": processed_image},
126
+ {
127
+ "type": "text",
128
+ "text": "Please classify the garbage in this image and explain your reasoning.",
129
+ },
130
+ ],
131
+ },
132
+ ]
133
+
134
+ # Apply chat template and tokenize
135
+ inputs = self.processor.apply_chat_template(
136
+ messages,
137
+ add_generation_prompt=True,
138
+ tokenize=True,
139
+ return_dict=True,
140
+ return_tensors="pt",
141
+ )
142
+
143
+ # Move inputs to model device and set dtype
144
+ inputs = inputs.to(self.model.device, dtype=self.model.dtype)
145
+ input_len = inputs["input_ids"].shape[-1]
146
+
147
+ # Generate response
148
+ with torch.no_grad():
149
+ generation_kwargs = {
150
+ "max_new_tokens": self.config.MAX_NEW_TOKENS,
151
+ "pad_token_id": self.processor.tokenizer.eos_token_id,
152
+ "disable_compile": True, # Important for stability
153
+ }
154
+
155
+ if self.config.DO_SAMPLE:
156
+ generation_kwargs.update(
157
+ {
158
+ "do_sample": True,
159
+ "temperature": self.config.TEMPERATURE,
160
+ "top_p": self.config.TOP_P,
161
+ "top_k": self.config.TOP_K,
162
+ }
163
+ )
164
+ else:
165
+ generation_kwargs["do_sample"] = False
166
+
167
+ outputs = self.model.generate(**inputs, **generation_kwargs)
168
+
169
+ # Decode response
170
+ response = self.processor.batch_decode(
171
+ outputs[:, input_len:],
172
+ skip_special_tokens=True,
173
+ clean_up_tokenization_spaces=True,
174
+ )[0]
175
+
176
+ # Extract classification from response
177
+ classification = self._extract_classification(response)
178
+
179
+ # Create formatted response
180
+ formatted_response = self._format_response(classification, response)
181
+
182
+ return classification, formatted_response
183
+
184
+ except Exception as e:
185
+ self.logger.error(f"Error during classification: {str(e)}")
186
+ import traceback
187
+
188
+ traceback.print_exc()
189
+ return "Error", f"Classification failed: {str(e)}"
190
+
191
+ def _extract_classification(self, response: str) -> str:
192
+ """Extract the main classification from the response"""
193
+ categories = self.knowledge.get_categories()
194
+
195
+ # Convert response to lowercase for matching
196
+ response_lower = response.lower()
197
+
198
+ # Look for exact category matches first
199
+ for category in categories:
200
+ if category.lower() in response_lower:
201
+ return category
202
+
203
+ # Look for key terms if no exact match
204
+ category_keywords = {
205
+ "Recyclable Waste": [
206
+ "recyclable",
207
+ "recycle",
208
+ "plastic",
209
+ "paper",
210
+ "metal",
211
+ "glass",
212
+ "bottle",
213
+ "can",
214
+ "aluminum",
215
+ "cardboard",
216
+ ],
217
+ "Food/Kitchen Waste": [
218
+ "food",
219
+ "kitchen",
220
+ "organic",
221
+ "fruit",
222
+ "vegetable",
223
+ "leftovers",
224
+ "scraps",
225
+ "peel",
226
+ "core",
227
+ "bone",
228
+ ],
229
+ "Hazardous Waste": [
230
+ "hazardous",
231
+ "dangerous",
232
+ "toxic",
233
+ "battery",
234
+ "chemical",
235
+ "medicine",
236
+ "paint",
237
+ "pharmaceutical",
238
+ ],
239
+ "Other Waste": [
240
+ "other",
241
+ "general",
242
+ "trash",
243
+ "garbage",
244
+ "waste",
245
+ "cigarette",
246
+ "ceramic",
247
+ "dust",
248
+ ],
249
+ }
250
+
251
+ for category, keywords in category_keywords.items():
252
+ if any(keyword in response_lower for keyword in keywords):
253
+ return category
254
+
255
+ return "Unable to classify"
256
+
257
+ def _format_response(self, classification: str, full_response: str) -> str:
258
+ """Format the response with classification and reasoning"""
259
+ if not full_response.strip():
260
+ return f"**Classification**: {classification}\n**Reasoning**: No detailed analysis available."
261
+
262
+ # If response already contains structured format, return as is
263
+ if "**Classification**" in full_response and "**Reasoning**" in full_response:
264
+ return full_response
265
+
266
+ # Otherwise, format it
267
+ return f"**Classification**: {classification}\n\n**Reasoning**: {full_response}"
268
+
269
+ def get_categories_info(self):
270
+ """Get information about all categories"""
271
+ return self.knowledge.get_category_descriptions()
config.py CHANGED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class Config:
8
+ # Gemma3n model configuration
9
+ MODEL_NAME: str = "google/gemma-3n-E2B-it"
10
+
11
+ # Generation parameters
12
+ MAX_NEW_TOKENS: int = 256
13
+ TEMPERATURE: float = 0.3
14
+ DO_SAMPLE: bool = True
15
+ TOP_P: float = 0.8
16
+ TOP_K: int = 40
17
+
18
+ # Device configuration
19
+ TORCH_DTYPE: str = torch.bfloat16
20
+ DEVICE_MAP: str = "auto"
21
+
22
+ # Image preprocessing
23
+ IMAGE_SIZE: int = 512
24
+
25
+ # Hugging Face token
26
+ HF_TOKEN: str = os.getenv("HF_TOKEN", "")
27
+
28
+ # Gradio configuration
29
+ GRADIO_SHARE: bool = False
30
+ GRADIO_PORT: int = 7860
31
+ GRADIO_SERVER_NAME: str = "0.0.0.0"
knowledge_base.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class GarbageClassificationKnowledge:
2
+ @staticmethod
3
+ def get_system_prompt():
4
+ return """You are a professional garbage classification expert. You need to carefully observe the items in the picture, analyze their materials, properties and uses, and then make accurate judgments according to garbage classification standards.
5
+
6
+ Garbage classification standards:
7
+
8
+ **Recyclable Waste**:
9
+ - Paper: newspapers, magazines, books, various packaging papers, office paper, advertising flyers, cardboard boxes, copy paper, etc.
10
+ - Plastics: various plastic bags, plastic packaging, disposable plastic food containers and utensils, toothbrushes, cups, water bottles, plastic toys, etc.
11
+ - Metals: aluminum cans, tin cans, toothpaste tubes, metal toys, metal stationery, nails, metal sheets, aluminum foil, etc.
12
+ - Glass: glass bottles, broken glass pieces, mirrors, light bulbs, vacuum flasks, etc.
13
+ - Textiles: old clothing, textile products, shoes, curtains, towels, bags, etc.
14
+
15
+ **Food/Kitchen Waste**:
16
+ - Food scraps: rice, noodles, bread, meat, fish, shrimp shells, crab shells, bones, etc.
17
+ - Fruit peels and cores: watermelon rinds, apple cores, orange peels, banana peels, nut shells, etc.
18
+ - Plants: withered branches and leaves, flowers, traditional Chinese medicine residue, etc.
19
+ - Expired food: expired canned food, cookies, candy, etc.
20
+
21
+ **Hazardous Waste**:
22
+ - Batteries: dry batteries, rechargeable batteries, button batteries, and all types of batteries
23
+ - Light tubes: energy-saving lamps, fluorescent tubes, incandescent bulbs, LED lights, etc.
24
+ - Pharmaceuticals: expired medicines, medicine packaging, thermometers, blood pressure monitors, etc.
25
+ - Paints: paint, coatings, glue, nail polish, cosmetics, etc.
26
+ - Others: pesticides, cleaning agents, agricultural chemicals, X-ray films, etc.
27
+
28
+ **Other Waste**:
29
+ - Contaminated non-recyclable paper: toilet paper, diapers, wet wipes, napkins, etc.
30
+ - Cigarette butts, ceramics, dust, disposable tableware (non-plastic)
31
+ - Large bones, hard shells, hard fruit pits (coconut shells, durian shells, walnut shells, corn cobs, etc.)
32
+ - Hair, pet waste, cat litter, etc.
33
+
34
+ Please observe the items in the image carefully according to the above classification standards, provide accurate garbage classification results, and briefly explain the classification reasoning. Format your response as:
35
+
36
+ **Classification**: [Category Name]
37
+ **Reasoning**: [Brief explanation of why this item belongs to this category]"""
38
+
39
+ @staticmethod
40
+ def get_categories():
41
+ return [
42
+ "Recyclable Waste",
43
+ "Food/Kitchen Waste",
44
+ "Hazardous Waste",
45
+ "Other Waste",
46
+ ]
47
+
48
+ @staticmethod
49
+ def get_category_descriptions():
50
+ return {
51
+ "Recyclable Waste": "Items that can be processed and reused, including paper, plastic, metal, glass, and textiles",
52
+ "Food/Kitchen Waste": "Organic waste from food preparation and consumption",
53
+ "Hazardous Waste": "Items containing harmful substances that require special disposal",
54
+ "Other Waste": "Items that don't fit into other categories and go to general waste",
55
+ }
requirement.txt CHANGED
@@ -1,4 +1,8 @@
1
  numpy
2
- torch>=2.1.0
3
- transformers
 
 
 
 
4
  gradio
 
1
  numpy
2
+ pillow
3
+ torch
4
+ torchvision
5
+ transformers >= 4.53
6
+ accelerate
7
+ timm
8
  gradio
test_images/cardboard1.jpg ADDED