| """ |
| Pollen Grain Counter - Hugging Face Spaces Version |
| Enhanced drag-and-drop pollen-grain counter (multi-image, CSV download) |
| """ |
|
|
| import os |
| import cv2 |
| import csv |
| import tempfile |
| import numpy as np |
| from PIL import Image |
| from ultralytics import YOLO |
| import gradio as gr |
| import logging |
| from pathlib import Path |
| import torch |
| import torch.serialization |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| MODEL_NAME = "best.pt" |
| CONF_THRES = 0.37 |
| DEVICE = "cpu" |
| MAX_IMAGE_SIZE = 50 * 1024 * 1024 |
| SUPPORTED_FORMATS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'} |
| |
|
|
| def load_model(): |
| """Load YOLO model from local file with PyTorch compatibility.""" |
| try: |
| |
| if os.path.exists(MODEL_NAME): |
| |
| os.environ['YOLO_VERBOSE'] = 'False' |
| |
| model = YOLO(MODEL_NAME) |
| logger.info(f"Model loaded successfully on {DEVICE}") |
| return model |
| else: |
| raise FileNotFoundError(f"Model file not found: {MODEL_NAME}") |
| except Exception as e: |
| logger.error(f"Failed to load model: {e}") |
| raise |
|
|
| |
| model = load_model() |
|
|
| def validate_image_file(file_path): |
| """Validate image file size and format.""" |
| if not os.path.exists(file_path): |
| return False, "File does not exist" |
| |
| |
| file_size = os.path.getsize(file_path) |
| if file_size > MAX_IMAGE_SIZE: |
| return False, f"File too large: {file_size / (1024*1024):.1f}MB (max: {MAX_IMAGE_SIZE / (1024*1024)}MB)" |
| |
| |
| ext = Path(file_path).suffix.lower() |
| if ext not in SUPPORTED_FORMATS: |
| return False, f"Unsupported format: {ext}" |
| |
| return True, "Valid" |
|
|
| def process_single_image(file_path, progress_callback=None): |
| """Process a single image and return annotated result + count.""" |
| try: |
| |
| is_valid, msg = validate_image_file(file_path) |
| if not is_valid: |
| return None, 0, f"Validation failed: {msg}" |
| |
| |
| pil_img = Image.open(file_path).convert("RGB") |
| base_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) |
| overlay = base_bgr.copy() |
|
|
| if progress_callback: |
| progress_callback("Running YOLO detection...") |
|
|
| |
| results = model(base_bgr, conf=CONF_THRES, verbose=False, device=DEVICE) |
| |
| total_detections = 0 |
|
|
| |
| for res in results: |
| if hasattr(res, 'boxes') and res.boxes is not None: |
| for box in res.boxes.xyxy.cpu().numpy().astype(int): |
| x1, y1, x2, y2 = box |
| total_detections += 1 |
| cv2.rectangle( |
| overlay, |
| (x1, y1), |
| (x2, y2), |
| (0, 255, 0), |
| 1 |
| ) |
|
|
| |
| annotated_rgb = overlay[:, :, ::-1] |
| return annotated_rgb, total_detections, "Success" |
| |
| except Exception as e: |
| error_msg = f"Error processing {os.path.basename(file_path)}: {str(e)}" |
| logger.error(error_msg) |
| return None, 0, error_msg |
|
|
| def predict(files, progress=gr.Progress()): |
| """Enhanced Gradio callback with progress tracking.""" |
| if not files: |
| return [], None, "No files uploaded" |
| |
| annotated_images = [] |
| counts = [] |
| errors = [] |
| |
| progress(0, desc="Starting analysis...") |
| |
| |
| for i, file in enumerate(files): |
| progress((i + 1) / len(files), desc=f"Processing image {i+1}/{len(files)}") |
| |
| def progress_callback(msg): |
| progress((i + 0.5) / len(files), desc=msg) |
| |
| annotated_img, count, status = process_single_image(file, progress_callback) |
| |
| if annotated_img is not None: |
| annotated_images.append(annotated_img) |
| fname = os.path.basename(file) |
| counts.append((fname, count)) |
| else: |
| errors.append(status) |
|
|
| |
| if counts: |
| tmp_csv = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") |
| tmp_csv_path = tmp_csv.name |
| tmp_csv.close() |
|
|
| with open(tmp_csv_path, mode="w", newline="", encoding='utf-8') as f: |
| writer = csv.writer(f) |
| writer.writerow(["filename", "count"]) |
| |
| for fname, count in counts: |
| writer.writerow([fname, count]) |
|
|
| total_count = sum(count for _, count in counts) |
| progress(1.0, desc=f"Complete! Processed {len(counts)} images, found {total_count} pollen grains") |
| |
| |
| status_msg = f"Successfully processed {len(counts)} images" |
| if errors: |
| status_msg += f"\n{len(errors)} errors occurred:\n" + "\n".join(errors[:3]) |
| if len(errors) > 3: |
| status_msg += f"\n... and {len(errors) - 3} more errors" |
| |
| return annotated_images, tmp_csv_path, status_msg |
| else: |
| error_summary = "No images could be processed:\n" + "\n".join(errors) |
| return [], None, error_summary |
|
|
| |
| with gr.Blocks(css=""" |
| .main-title { |
| font-size: 2.5rem; |
| font-weight: bold; |
| text-align: center; |
| margin-bottom: 1rem; |
| color: #374151; |
| } |
| .subtitle { |
| font-size: 1.1rem; |
| text-align: center; |
| margin-bottom: 2rem; |
| color: #6b7280; |
| } |
| .control-panel { |
| border: 1px solid #e5e7eb; |
| border-radius: 8px; |
| padding: 1.5rem; |
| } |
| .results-panel { |
| border: 1px solid #e5e7eb; |
| border-radius: 8px; |
| padding: 1.5rem; |
| } |
| """) as demo: |
|
|
| gr.Markdown("<div class='main-title'>Pollen Grain Counter</div>") |
| gr.Markdown("<div class='subtitle'>Upload Images for automated pollen detection and counting</div>") |
| |
| with gr.Row(): |
| |
| with gr.Column(scale=1, elem_classes="control-panel"): |
| file_input = gr.File( |
| label="Upload Images", |
| file_count="multiple", |
| type="filepath" |
| ) |
| |
| with gr.Row(): |
| run_button = gr.Button("Start Processing", variant="primary", size="lg") |
| clear_button = gr.Button("Clear", variant="secondary") |
| |
| |
| with gr.Accordion("Settings", open=False): |
| conf_slider = gr.Slider( |
| minimum=0.1, maximum=0.9, value=CONF_THRES, step=0.05, |
| label="Confidence Threshold", |
| info="Lower = more detections, higher = more precise" |
| ) |
| |
| |
| download_csv = gr.File( |
| label="Download Results (CSV)", |
| visible=True |
| ) |
| |
| status_output = gr.Textbox( |
| label="Status", |
| interactive=False, |
| lines=4 |
| ) |
| |
| |
| with gr.Column(scale=2, elem_classes="results-panel"): |
| gallery = gr.Gallery( |
| label="Detected Pollen Grains", |
| show_label=True, |
| columns=3, |
| height="auto" |
| ) |
|
|
| |
| def update_confidence(new_conf): |
| global CONF_THRES |
| CONF_THRES = new_conf |
| return f"Confidence threshold updated to {new_conf}" |
|
|
| def clear_all(): |
| return None, [], None, "Ready for new images" |
|
|
| |
| run_button.click( |
| fn=predict, |
| inputs=file_input, |
| outputs=[gallery, download_csv, status_output] |
| ) |
| |
| conf_slider.change( |
| fn=update_confidence, |
| inputs=conf_slider, |
| outputs=status_output |
| ) |
| |
| clear_button.click( |
| fn=clear_all, |
| outputs=[file_input, gallery, download_csv, status_output] |
| ) |
|
|
| |
| if __name__ == "__main__": |
| print("Starting Pollen Counter on Hugging Face Spaces") |
| demo.launch() |