Ashish Reddy commited on
Commit
2d61b5a
Β·
verified Β·
1 Parent(s): 4de78a5

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +266 -0
  2. best.pt +3 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pollen Grain Counter - Hugging Face Spaces Version
3
+ Enhanced drag-and-drop pollen-grain counter (multi-image, CSV download)
4
+ """
5
+
6
+ import os
7
+ import cv2
8
+ import csv
9
+ import tempfile
10
+ import numpy as np
11
+ from PIL import Image
12
+ from ultralytics import YOLO
13
+ import gradio as gr
14
+ import logging
15
+ from pathlib import Path
16
+ import requests
17
+ from huggingface_hub import hf_hub_download
18
+
19
+ # Set up logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # ─────────── configuration ───────────
24
+ MODEL_NAME = "best.pt" # Your model file name
25
+ CONF_THRES = 0.37 # YOLO confidence threshold
26
+ DEVICE = "cpu" # HF Spaces typically use CPU
27
+ MAX_IMAGE_SIZE = 50 * 1024 * 1024 # 50MB max per image
28
+ SUPPORTED_FORMATS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
29
+ # ──────────────────────────────────────
30
+
31
+ def load_model():
32
+ """Load YOLO model from local file."""
33
+ try:
34
+ # Check if model exists locally
35
+ if os.path.exists(MODEL_NAME):
36
+ model = YOLO(MODEL_NAME)
37
+ logger.info(f"Model loaded successfully on {DEVICE}")
38
+ return model
39
+ else:
40
+ raise FileNotFoundError(f"Model file not found: {MODEL_NAME}")
41
+ except Exception as e:
42
+ logger.error(f"Failed to load model: {e}")
43
+ raise
44
+
45
+ # Load model once at start-up
46
+ model = load_model()
47
+
48
+ def validate_image_file(file_path):
49
+ """Validate image file size and format."""
50
+ if not os.path.exists(file_path):
51
+ return False, "File does not exist"
52
+
53
+ # Check file size
54
+ file_size = os.path.getsize(file_path)
55
+ if file_size > MAX_IMAGE_SIZE:
56
+ return False, f"File too large: {file_size / (1024*1024):.1f}MB (max: {MAX_IMAGE_SIZE / (1024*1024)}MB)"
57
+
58
+ # Check file extension
59
+ ext = Path(file_path).suffix.lower()
60
+ if ext not in SUPPORTED_FORMATS:
61
+ return False, f"Unsupported format: {ext}"
62
+
63
+ return True, "Valid"
64
+
65
+ def process_single_image(file_path, progress_callback=None):
66
+ """Process a single image and return annotated result + count."""
67
+ try:
68
+ # Validate file
69
+ is_valid, msg = validate_image_file(file_path)
70
+ if not is_valid:
71
+ return None, 0, f"Validation failed: {msg}"
72
+
73
+ # Load and convert image
74
+ pil_img = Image.open(file_path).convert("RGB")
75
+ base_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
76
+ overlay = base_bgr.copy()
77
+
78
+ if progress_callback:
79
+ progress_callback("Running YOLO detection...")
80
+
81
+ # Direct YOLO inference on full image
82
+ results = model(base_bgr, conf=CONF_THRES, verbose=False, device=DEVICE)
83
+
84
+ total_detections = 0
85
+
86
+ # Draw boxes on overlay
87
+ for res in results:
88
+ if hasattr(res, 'boxes') and res.boxes is not None:
89
+ for box in res.boxes.xyxy.cpu().numpy().astype(int):
90
+ x1, y1, x2, y2 = box
91
+ total_detections += 1
92
+ cv2.rectangle(
93
+ overlay,
94
+ (x1, y1),
95
+ (x2, y2),
96
+ (0, 255, 0),
97
+ 1 # Line width = 1 for small objects
98
+ )
99
+
100
+ # Convert BGR overlay back to RGB for Gradio
101
+ annotated_rgb = overlay[:, :, ::-1]
102
+ return annotated_rgb, total_detections, "Success"
103
+
104
+ except Exception as e:
105
+ error_msg = f"Error processing {os.path.basename(file_path)}: {str(e)}"
106
+ logger.error(error_msg)
107
+ return None, 0, error_msg
108
+
109
+ def predict(files, progress=gr.Progress()):
110
+ """Enhanced Gradio callback with progress tracking."""
111
+ if not files:
112
+ return [], None, "No files uploaded"
113
+
114
+ annotated_images = []
115
+ counts = []
116
+ errors = []
117
+
118
+ progress(0, desc="Starting analysis...")
119
+
120
+ # Process each uploaded file
121
+ for i, file in enumerate(files):
122
+ progress((i + 1) / len(files), desc=f"Processing image {i+1}/{len(files)}")
123
+
124
+ def progress_callback(msg):
125
+ progress((i + 0.5) / len(files), desc=msg)
126
+
127
+ annotated_img, count, status = process_single_image(file, progress_callback)
128
+
129
+ if annotated_img is not None:
130
+ annotated_images.append(annotated_img)
131
+ fname = os.path.basename(file)
132
+ counts.append((fname, count))
133
+ else:
134
+ errors.append(status)
135
+
136
+ # Create CSV with results
137
+ if counts:
138
+ tmp_csv = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
139
+ tmp_csv_path = tmp_csv.name
140
+ tmp_csv.close()
141
+
142
+ with open(tmp_csv_path, mode="w", newline="", encoding='utf-8') as f:
143
+ writer = csv.writer(f)
144
+ writer.writerow(["filename", "count"])
145
+
146
+ for fname, count in counts:
147
+ writer.writerow([fname, count])
148
+
149
+ total_count = sum(count for _, count in counts)
150
+ progress(1.0, desc=f"Complete! Processed {len(counts)} images, found {total_count} pollen grains")
151
+
152
+ # Prepare status message
153
+ status_msg = f"Successfully processed {len(counts)} images"
154
+ if errors:
155
+ status_msg += f"\n{len(errors)} errors occurred:\n" + "\n".join(errors[:3])
156
+ if len(errors) > 3:
157
+ status_msg += f"\n... and {len(errors) - 3} more errors"
158
+
159
+ return annotated_images, tmp_csv_path, status_msg
160
+ else:
161
+ error_summary = "No images could be processed:\n" + "\n".join(errors)
162
+ return [], None, error_summary
163
+
164
+ # ─────────── Gradio UI ───────────
165
+ with gr.Blocks(css="""
166
+ .main-title {
167
+ font-size: 2.5rem;
168
+ font-weight: bold;
169
+ text-align: center;
170
+ margin-bottom: 1rem;
171
+ color: #374151;
172
+ }
173
+ .subtitle {
174
+ font-size: 1.1rem;
175
+ text-align: center;
176
+ margin-bottom: 2rem;
177
+ color: #6b7280;
178
+ }
179
+ .control-panel {
180
+ border: 1px solid #e5e7eb;
181
+ border-radius: 8px;
182
+ padding: 1.5rem;
183
+ }
184
+ .results-panel {
185
+ border: 1px solid #e5e7eb;
186
+ border-radius: 8px;
187
+ padding: 1.5rem;
188
+ }
189
+ """) as demo:
190
+
191
+ gr.Markdown("<div class='main-title'>Pollen Grain Counter</div>")
192
+ gr.Markdown("<div class='subtitle'>Upload images for automated pollen detection and counting</div>")
193
+
194
+ with gr.Row():
195
+ # Left column - Controls and Downloads
196
+ with gr.Column(scale=1, elem_classes="control-panel"):
197
+ file_input = gr.File(
198
+ label="Upload Images",
199
+ file_count="multiple",
200
+ type="filepath"
201
+ )
202
+
203
+ with gr.Row():
204
+ run_button = gr.Button("Analyze Images", variant="primary", size="lg")
205
+ clear_button = gr.Button("Clear", variant="secondary")
206
+
207
+ # Configuration section
208
+ with gr.Accordion("Settings", open=False):
209
+ conf_slider = gr.Slider(
210
+ minimum=0.1, maximum=0.9, value=CONF_THRES, step=0.05,
211
+ label="Confidence Threshold",
212
+ info="Lower = more detections, higher = more precise"
213
+ )
214
+
215
+ # Download section
216
+ download_csv = gr.File(
217
+ label="Download Results (CSV)",
218
+ visible=True
219
+ )
220
+
221
+ status_output = gr.Textbox(
222
+ label="Status",
223
+ interactive=False,
224
+ lines=4
225
+ )
226
+
227
+ # Right column - Results Gallery
228
+ with gr.Column(scale=2, elem_classes="results-panel"):
229
+ gallery = gr.Gallery(
230
+ label="Detected Pollen Grains",
231
+ show_label=True,
232
+ columns=3,
233
+ height="auto"
234
+ )
235
+
236
+ # Event handlers
237
+ def update_confidence(new_conf):
238
+ global CONF_THRES
239
+ CONF_THRES = new_conf
240
+ return f"Confidence threshold updated to {new_conf}"
241
+
242
+ def clear_all():
243
+ return None, [], None, "Ready for new images"
244
+
245
+ # Link interactions
246
+ run_button.click(
247
+ fn=predict,
248
+ inputs=file_input,
249
+ outputs=[gallery, download_csv, status_output]
250
+ )
251
+
252
+ conf_slider.change(
253
+ fn=update_confidence,
254
+ inputs=conf_slider,
255
+ outputs=status_output
256
+ )
257
+
258
+ clear_button.click(
259
+ fn=clear_all,
260
+ outputs=[file_input, gallery, download_csv, status_output]
261
+ )
262
+
263
+ # ─────────── Main ───────────
264
+ if __name__ == "__main__":
265
+ print("Starting Pollen Counter on Hugging Face Spaces")
266
+ demo.launch()
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:109be7882d87d4d2bd1e7f85ae40c2fdeef57c55796718072aee1a1127537f22
3
+ size 40541285
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ultralytics==8.0.196
2
+ gradio==4.44.0
3
+ pillow==10.0.1
4
+ opencv-python-headless==4.8.1.78
5
+ numpy==1.24.3
6
+ huggingface_hub==0.17.3