crowdpollen / predict.py
alie354's picture
Upload 7 files
12610d0 verified
raw
history blame contribute delete
931 Bytes
import cog
import torch
import tempfile
import os
class Predictor(cog.Predictor):
def setup(self):
self.model = torch.load("complete_model.pt", map_location="cpu")
self.model.eval()
@cog.input("image", type=cog.File)
@cog.input("confidence", type=float, default=0.25)
def predict(self, image, confidence):
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as f:
f.write(image.read())
image_path = f.name
results = self.model(image_path, conf=confidence)
detections = results.pandas().xyxy[0]
count = len(detections)
os.unlink(image_path)
if count <= 10:
density = "low"
elif count <= 30:
density = "medium"
else:
density = "high"
return {
"total_grains": count,
"density_level": density
}