Spaces:
Sleeping
Sleeping
File size: 2,372 Bytes
8b61522 f28de38 8b61522 f8aac29 8b61522 db3e092 8b61522 7b0eb5f 8b61522 998c2b4 8b61522 60a0c88 8b61522 b0fe863 8b61522 812cd20 8b61522 812cd20 8b61522 812cd20 8b61522 812cd20 8b61522 812cd20 8b61522 812cd20 8b61522 812cd20 8b61522 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import io
from fastapi import FastAPI, UploadFile, File, HTTPException
from PIL import Image
from ultralytics import YOLO
import torch
app = FastAPI(title="Plant Disease Detection API")
# Load the model strictly once on startup
# Ensure you have uploaded your 'best.pt' renamed as 'model.pt' to the root directory
try:
model = YOLO("model.pt")
except Exception as e:
print(f"Error loading model: {e}")
# Fallback to a tiny model just so the app doesn't crash if the file is missing during build
print("Warning: 'model.pt' not found, loading 'yolov8n.pt' as placeholder.")
model = YOLO("yolov8n.pt")
@app.get("/")
def home():
return {"message": "Plant Disease Detection API is running. POST to /predict with an image."}
@app.post("/predict")
async def predict_image(file: UploadFile = File(...)):
"""
Receives an image file, runs inference, and returns bounding boxes and confidence.
"""
# 1. Validate Input
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File provided is not an image.")
try:
# 2. Read Image
image_data = await file.read()
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# 3. Run Inference
# conf=0.25 is standard, you can lower it if you need more sensitivity
results = model.predict(image, conf=0.25)
# 4. Process Results
detections = []
result = results[0] # We only processed one image
for box in result.boxes:
# Extract coordinates (x1, y1, x2, y2)
coords = box.xyxy[0].tolist()
# Extract confidence
confidence = float(box.conf[0])
# Extract class (should be 0 for 'diseased' based on your training)
cls_id = int(box.cls[0])
cls_name = result.names[cls_id]
detections.append({
"bbox": coords, # [x_min, y_min, x_max, y_max]
"confidence": confidence,
"class_id": cls_id,
"class_name": cls_name
})
return {
"filename": file.filename,
"count": len(detections),
"predictions": detections
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) |