|
|
|
|
|
import io |
|
|
import os |
|
|
import uvicorn |
|
|
import traceback |
|
|
from typing import Optional |
|
|
from fastapi import FastAPI, File, UploadFile, Form |
|
|
from fastapi.responses import JSONResponse, StreamingResponse, Response |
|
|
from pydantic import BaseModel |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import cv2 |
|
|
import base64 |
|
|
|
|
|
from ultralytics import YOLO |
|
|
|
|
|
|
|
|
from deepforest import main |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
YOLO_MODEL_PATH = "olive_cls_2c.pt" |
|
|
DEVICE = os.environ.get("DEVICE", "cpu") |
|
|
|
|
|
def read_imagefile(file_bytes) -> Image.Image: |
|
|
return Image.open(io.BytesIO(file_bytes)).convert("RGB") |
|
|
|
|
|
|
|
|
def get_text_size(draw, text: str, font): |
|
|
"""Return (width, height) for possibly-multiline text. |
|
|
Uses draw.multiline_textbbox / draw.textbbox when available, falls back to font.getsize. |
|
|
""" |
|
|
try: |
|
|
|
|
|
bbox = draw.multiline_textbbox((0, 0), text, font=font) |
|
|
return (bbox[2] - bbox[0], bbox[3] - bbox[1]) |
|
|
except Exception: |
|
|
try: |
|
|
bbox = draw.textbbox((0, 0), text, font=font) |
|
|
return (bbox[2] - bbox[0], bbox[3] - bbox[1]) |
|
|
except Exception: |
|
|
try: |
|
|
return font.getsize(text) |
|
|
except Exception: |
|
|
|
|
|
lines = text.splitlines() or [text] |
|
|
widths = [len(line) * 7 for line in lines] |
|
|
heights = [12 for _ in lines] |
|
|
return (max(widths), sum(heights)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="Olive Tree Analyzer") |
|
|
|
|
|
|
|
|
YOLO_MODEL = None |
|
|
CLASS_NAMES = None |
|
|
DEEPFOREST_MODEL = None |
|
|
try: |
|
|
print(f"Loading YOLO model from {YOLO_MODEL_PATH} on device {DEVICE} ...") |
|
|
YOLO_MODEL = YOLO(YOLO_MODEL_PATH) |
|
|
|
|
|
|
|
|
if hasattr(YOLO_MODEL.model, "names"): |
|
|
model_names = list(YOLO_MODEL.model.names.values()) if isinstance(YOLO_MODEL.model.names, dict) else list(YOLO_MODEL.model.names) |
|
|
if model_names: |
|
|
CLASS_NAMES = model_names |
|
|
|
|
|
|
|
|
print("Loading DeepForest model (Weecology/deepforest-tree) ...") |
|
|
DEEPFOREST_MODEL = main.deepforest() |
|
|
DEEPFOREST_MODEL.load_model("Weecology/deepforest-tree") |
|
|
DEEPFOREST_MODEL.config["score_thresh"] = 0.15 |
|
|
print("DeepForest model ready") |
|
|
|
|
|
except Exception as e: |
|
|
print("Failed to load model:", str(e)) |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health(): |
|
|
return { |
|
|
"status": "ok", |
|
|
"deepforest_model_loaded": DEEPFOREST_MODEL is not None, |
|
|
"yolo_model_loaded": YOLO_MODEL is not None, |
|
|
"classes_known": CLASS_NAMES is not None |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/analyze") |
|
|
async def analyze(image: UploadFile = File(...), conf: float = Form(0.25), iou: float = Form(0.45)): |
|
|
""" |
|
|
Accepts a multipart/form-data file upload (key: image). |
|
|
Runs YOLO detection on the image, draws rectangles with label and confidence for each detected tree, |
|
|
and returns the annotated image (JPEG). |
|
|
""" |
|
|
if YOLO_MODEL is None: |
|
|
return JSONResponse(status_code=500, content={"error": "Model not loaded on server."}) |
|
|
|
|
|
try: |
|
|
contents = await image.read() |
|
|
pil_img = read_imagefile(contents) |
|
|
|
|
|
|
|
|
|
|
|
img_np = np.array(pil_img) |
|
|
if img_np.shape[-1] == 4: |
|
|
|
|
|
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2RGB) |
|
|
|
|
|
|
|
|
if DEEPFOREST_MODEL is None: |
|
|
return JSONResponse(status_code=500, content={"error": "DeepForest model not loaded on server."}) |
|
|
|
|
|
|
|
|
df_pred = DEEPFOREST_MODEL.predict_image(img_np) |
|
|
|
|
|
|
|
|
health_states = [] |
|
|
health_confidences = [] |
|
|
|
|
|
for _, row in df_pred.iterrows(): |
|
|
xmin, ymin, xmax, ymax = int(row.get("xmin", 0)), int(row.get("ymin", 0)), int(row.get("xmax", 0)), int(row.get("ymax", 0)) |
|
|
|
|
|
xmin = max(0, xmin) |
|
|
ymin = max(0, ymin) |
|
|
xmax = min(img_np.shape[1], xmax) |
|
|
ymax = min(img_np.shape[0], ymax) |
|
|
|
|
|
if xmax <= xmin or ymax <= ymin: |
|
|
print(f"Invalid bounding box: ({xmin}, {ymin}, {xmax}, {ymax})") |
|
|
health_states.append("unknown") |
|
|
health_confidences.append(0.0) |
|
|
continue |
|
|
|
|
|
crop = img_np[ymin:ymax, xmin:xmax] |
|
|
if crop.size == 0: |
|
|
print("Empty crop detected") |
|
|
health_states.append("unknown") |
|
|
health_confidences.append(0.0) |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
results = YOLO_MODEL.predict(source=crop, device=DEVICE, imgsz=224, batch=1, verbose=False) |
|
|
if not results or len(results) == 0: |
|
|
print("No results from YOLO") |
|
|
health_states.append("unknown") |
|
|
health_confidences.append(0.0) |
|
|
continue |
|
|
|
|
|
r = results[0] |
|
|
|
|
|
|
|
|
try: |
|
|
if hasattr(r, "probs") and r.probs is not None: |
|
|
|
|
|
predicted_class = r.names[r.probs.top1] |
|
|
confidence = float(r.probs.top1conf) |
|
|
|
|
|
health_states.append(predicted_class) |
|
|
health_confidences.append(confidence) |
|
|
print(f"Classified as {predicted_class} with confidence {confidence:.3f}") |
|
|
else: |
|
|
print("No probs attribute found in results") |
|
|
health_states.append("unknown") |
|
|
health_confidences.append(0.0) |
|
|
except Exception as e: |
|
|
print(f"Error extracting classification results: {e}") |
|
|
traceback.print_exc() |
|
|
health_states.append("unknown") |
|
|
health_confidences.append(0.0) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"YOLO prediction error: {e}") |
|
|
traceback.print_exc() |
|
|
health_states.append("unknown") |
|
|
health_confidences.append(0.0) |
|
|
|
|
|
|
|
|
try: |
|
|
df_pred["health_state"] = health_states |
|
|
df_pred["health_confidence"] = health_confidences |
|
|
except Exception as e: |
|
|
print(f"Error attaching columns: {e}") |
|
|
|
|
|
|
|
|
from PIL import ImageDraw, ImageFont |
|
|
draw = ImageDraw.Draw(pil_img) |
|
|
try: |
|
|
font = ImageFont.truetype("arial.ttf", size=14) |
|
|
except Exception: |
|
|
font = ImageFont.load_default() |
|
|
|
|
|
color_map = { |
|
|
"healthy": (4, 189, 44), |
|
|
"dry": (192, 217, 4), |
|
|
|
|
|
"unknown": (128, 128, 128), |
|
|
} |
|
|
|
|
|
for _, row in df_pred.iterrows(): |
|
|
xmin, ymin, xmax, ymax = int(row.get("xmin", 0)), int(row.get("ymin", 0)), int(row.get("xmax", 0)), int(row.get("ymax", 0)) |
|
|
health = str(row.get("health_state", "unknown")) |
|
|
conf = float(row.get("health_confidence", 0.0)) |
|
|
det_score = float(row.get("score", 0.0)) if row.get("score") is not None else 0.0 |
|
|
|
|
|
color = color_map.get(health.lower(), (255, 255, 255)) |
|
|
|
|
|
draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=3) |
|
|
|
|
|
label = f"{health}\nYOLO: {conf:.2f} DET: {det_score:.2f}" |
|
|
|
|
|
text_w, text_h = get_text_size(draw, label, font) |
|
|
text_bg = (xmin, max(0, ymin - text_h - 4), xmin + text_w + 4, ymin) |
|
|
draw.rectangle(text_bg, fill=color) |
|
|
draw.multiline_text((xmin + 2, max(0, ymin - text_h - 2)), label, fill=(255, 255, 255), font=font) |
|
|
|
|
|
|
|
|
print("\n=== Health State Summary ===") |
|
|
print(df_pred["health_state"].value_counts()) |
|
|
print(f"\nProcessed {len(df_pred)} trees") |
|
|
|
|
|
|
|
|
buf = io.BytesIO() |
|
|
pil_img.save(buf, format="JPEG", quality=90) |
|
|
buf.seek(0) |
|
|
|
|
|
|
|
|
try: |
|
|
states = df_pred.get("health_state") |
|
|
if states is None: |
|
|
total_trees = 0 |
|
|
healthy_trees = 0 |
|
|
stressed_trees = 0 |
|
|
dead_trees = 0 |
|
|
else: |
|
|
states_filled = states.fillna("unknown").astype(str).str.lower() |
|
|
total_trees = int(len(states_filled)) |
|
|
healthy_trees = int((states_filled == "healthy").sum()) |
|
|
dry_trees = int((states_filled == "dry").sum()) |
|
|
except Exception: |
|
|
total_trees = len(df_pred) if df_pred is not None else 0 |
|
|
healthy_trees = dry_trees = 0 |
|
|
|
|
|
img_b64 = base64.b64encode(buf.getvalue()).decode("ascii") |
|
|
return JSONResponse(content={ |
|
|
"image": img_b64, |
|
|
"total_trees_count": total_trees, |
|
|
"healthy_trees_count": healthy_trees, |
|
|
"dry_trees_count": dry_trees |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
traceback.print_exc() |
|
|
return JSONResponse(status_code=500, content={"error": str(e)}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 8000))) |