Ashish Reddy
Update app.py
25b517f verified
"""
Pollen Grain Counter - Hugging Face Spaces Version
Enhanced drag-and-drop pollen-grain counter (multi-image, CSV download)
"""
import os
import cv2
import csv
import tempfile
import numpy as np
from PIL import Image
from ultralytics import YOLO
import gradio as gr
import logging
from pathlib import Path
import torch
import torch.serialization
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ─────────── configuration ───────────
MODEL_NAME = "best.pt" # Your model file name
CONF_THRES = 0.37 # YOLO confidence threshold
DEVICE = "cpu" # HF Spaces typically use CPU
MAX_IMAGE_SIZE = 50 * 1024 * 1024 # 50MB max per image
SUPPORTED_FORMATS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
# ──────────────────────────────────────
def load_model():
"""Load YOLO model from local file with PyTorch compatibility."""
try:
# Check if model exists locally
if os.path.exists(MODEL_NAME):
# Set environment variable to use legacy loading for ultralytics
os.environ['YOLO_VERBOSE'] = 'False'
model = YOLO(MODEL_NAME)
logger.info(f"Model loaded successfully on {DEVICE}")
return model
else:
raise FileNotFoundError(f"Model file not found: {MODEL_NAME}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
# Load model once at start-up
model = load_model()
def validate_image_file(file_path):
"""Validate image file size and format."""
if not os.path.exists(file_path):
return False, "File does not exist"
# Check file size
file_size = os.path.getsize(file_path)
if file_size > MAX_IMAGE_SIZE:
return False, f"File too large: {file_size / (1024*1024):.1f}MB (max: {MAX_IMAGE_SIZE / (1024*1024)}MB)"
# Check file extension
ext = Path(file_path).suffix.lower()
if ext not in SUPPORTED_FORMATS:
return False, f"Unsupported format: {ext}"
return True, "Valid"
def process_single_image(file_path, progress_callback=None):
"""Process a single image and return annotated result + count."""
try:
# Validate file
is_valid, msg = validate_image_file(file_path)
if not is_valid:
return None, 0, f"Validation failed: {msg}"
# Load and convert image
pil_img = Image.open(file_path).convert("RGB")
base_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
overlay = base_bgr.copy()
if progress_callback:
progress_callback("Running YOLO detection...")
# Direct YOLO inference on full image
results = model(base_bgr, conf=CONF_THRES, verbose=False, device=DEVICE)
total_detections = 0
# Draw boxes on overlay
for res in results:
if hasattr(res, 'boxes') and res.boxes is not None:
for box in res.boxes.xyxy.cpu().numpy().astype(int):
x1, y1, x2, y2 = box
total_detections += 1
cv2.rectangle(
overlay,
(x1, y1),
(x2, y2),
(0, 255, 0),
1 # Line width = 1 for small objects
)
# Convert BGR overlay back to RGB for Gradio
annotated_rgb = overlay[:, :, ::-1]
return annotated_rgb, total_detections, "Success"
except Exception as e:
error_msg = f"Error processing {os.path.basename(file_path)}: {str(e)}"
logger.error(error_msg)
return None, 0, error_msg
def predict(files, progress=gr.Progress()):
"""Enhanced Gradio callback with progress tracking."""
if not files:
return [], None, "No files uploaded"
annotated_images = []
counts = []
errors = []
progress(0, desc="Starting analysis...")
# Process each uploaded file
for i, file in enumerate(files):
progress((i + 1) / len(files), desc=f"Processing image {i+1}/{len(files)}")
def progress_callback(msg):
progress((i + 0.5) / len(files), desc=msg)
annotated_img, count, status = process_single_image(file, progress_callback)
if annotated_img is not None:
annotated_images.append(annotated_img)
fname = os.path.basename(file)
counts.append((fname, count))
else:
errors.append(status)
# Create CSV with results
if counts:
tmp_csv = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
tmp_csv_path = tmp_csv.name
tmp_csv.close()
with open(tmp_csv_path, mode="w", newline="", encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow(["filename", "count"])
for fname, count in counts:
writer.writerow([fname, count])
total_count = sum(count for _, count in counts)
progress(1.0, desc=f"Complete! Processed {len(counts)} images, found {total_count} pollen grains")
# Prepare status message
status_msg = f"Successfully processed {len(counts)} images"
if errors:
status_msg += f"\n{len(errors)} errors occurred:\n" + "\n".join(errors[:3])
if len(errors) > 3:
status_msg += f"\n... and {len(errors) - 3} more errors"
return annotated_images, tmp_csv_path, status_msg
else:
error_summary = "No images could be processed:\n" + "\n".join(errors)
return [], None, error_summary
# ─────────── Gradio UI ───────────
with gr.Blocks(css="""
.main-title {
font-size: 2.5rem;
font-weight: bold;
text-align: center;
margin-bottom: 1rem;
color: #374151;
}
.subtitle {
font-size: 1.1rem;
text-align: center;
margin-bottom: 2rem;
color: #6b7280;
}
.control-panel {
border: 1px solid #e5e7eb;
border-radius: 8px;
padding: 1.5rem;
}
.results-panel {
border: 1px solid #e5e7eb;
border-radius: 8px;
padding: 1.5rem;
}
""") as demo:
gr.Markdown("<div class='main-title'>Pollen Grain Counter</div>")
gr.Markdown("<div class='subtitle'>Upload Images for automated pollen detection and counting</div>")
with gr.Row():
# Left column - Controls and Downloads
with gr.Column(scale=1, elem_classes="control-panel"):
file_input = gr.File(
label="Upload Images",
file_count="multiple",
type="filepath"
)
with gr.Row():
run_button = gr.Button("Start Processing", variant="primary", size="lg")
clear_button = gr.Button("Clear", variant="secondary")
# Configuration section
with gr.Accordion("Settings", open=False):
conf_slider = gr.Slider(
minimum=0.1, maximum=0.9, value=CONF_THRES, step=0.05,
label="Confidence Threshold",
info="Lower = more detections, higher = more precise"
)
# Download section
download_csv = gr.File(
label="Download Results (CSV)",
visible=True
)
status_output = gr.Textbox(
label="Status",
interactive=False,
lines=4
)
# Right column - Results Gallery
with gr.Column(scale=2, elem_classes="results-panel"):
gallery = gr.Gallery(
label="Detected Pollen Grains",
show_label=True,
columns=3,
height="auto"
)
# Event handlers
def update_confidence(new_conf):
global CONF_THRES
CONF_THRES = new_conf
return f"Confidence threshold updated to {new_conf}"
def clear_all():
return None, [], None, "Ready for new images"
# Link interactions
run_button.click(
fn=predict,
inputs=file_input,
outputs=[gallery, download_csv, status_output]
)
conf_slider.change(
fn=update_confidence,
inputs=conf_slider,
outputs=status_output
)
clear_button.click(
fn=clear_all,
outputs=[file_input, gallery, download_csv, status_output]
)
# ─────────── Main ───────────
if __name__ == "__main__":
print("Starting Pollen Counter on Hugging Face Spaces")
demo.launch()