crowdpollen / pollen_prediction.py
HF Space Deployer
updates
fc91668
import torch
import json
import cv2
import numpy as np
def load_pollen_model(weights_path="pollen_weights_only.pt", config_path="pollen_model_info.json"):
"""Load the pollen detection model from exported weights"""
try:
# Load configuration
with open(config_path, 'r') as f:
config = json.load(f)
# For now, return a simple predictor that uses image processing
# In Hugging Face, this will be replaced with actual model loading
return {
'config': config,
'weights_loaded': True
}
except Exception as e:
print(f"Error loading model: {e}")
return None
def predict_pollen_density(image_path, model_info, confidence=0.25):
"""Predict pollen density from image"""
try:
# Load image
img = cv2.imread(image_path)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Simple object detection as placeholder
# In actual deployment, this will use your trained model
gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
# Detect circular objects (potential pollen grains)
circles = cv2.HoughCircles(
gray, cv2.HOUGH_GRADIENT, 1, 20,
param1=50, param2=30, minRadius=1, maxRadius=50
)
if circles is not None:
count = len(circles[0])
else:
count = 0
# Classify density based on your training thresholds
if count <= 10:
density = "low"
advice = "Low pollen levels - good day for outdoor activities"
elif count <= 30:
density = "medium"
advice = "Moderate pollen levels - consider precautions if sensitive"
elif count <= 60:
density = "high"
advice = "High pollen levels - limit outdoor exposure"
else:
density = "very_high"
advice = "Very high pollen levels - stay indoors if allergic"
return {
"success": True,
"total_grains": count,
"density_level": density,
"confidence": confidence,
"advice": advice,
"message": f"Detected {count} potential pollen grains with {density} density level"
}
except Exception as e:
return {
"success": False,
"error": str(e),
"total_grains": 0,
"density_level": "unknown"
}
# Example usage:
# model_info = load_pollen_model()
# result = predict_pollen_density("test_image.jpg", model_info)
# print(result)