# app.py import io import os import uvicorn import traceback from typing import Optional from fastapi import FastAPI, File, UploadFile, Form from fastapi.responses import JSONResponse, StreamingResponse, Response from pydantic import BaseModel from PIL import Image import numpy as np import pandas as pd import cv2 import base64 # Ultralytics YOLO from ultralytics import YOLO # DeepForest from deepforest import main # ------------------------ # Configuration # ------------------------ YOLO_MODEL_PATH = "olive_cls_2c.pt" DEVICE = os.environ.get("DEVICE", "cpu") # 'cpu' or 'cuda' def read_imagefile(file_bytes) -> Image.Image: return Image.open(io.BytesIO(file_bytes)).convert("RGB") def get_text_size(draw, text: str, font): """Return (width, height) for possibly-multiline text. Uses draw.multiline_textbbox / draw.textbbox when available, falls back to font.getsize. """ try: # Pillow >= 8 has multiline_textbbox bbox = draw.multiline_textbbox((0, 0), text, font=font) return (bbox[2] - bbox[0], bbox[3] - bbox[1]) except Exception: try: bbox = draw.textbbox((0, 0), text, font=font) return (bbox[2] - bbox[0], bbox[3] - bbox[1]) except Exception: try: return font.getsize(text) except Exception: # naive fallback: estimate per character lines = text.splitlines() or [text] widths = [len(line) * 7 for line in lines] heights = [12 for _ in lines] return (max(widths), sum(heights)) # ------------------------ # Initialize App & Model # ------------------------ app = FastAPI(title="Olive Tree Analyzer") # Load YOLO model (classification/detection) YOLO_MODEL = None CLASS_NAMES = None DEEPFOREST_MODEL = None try: print(f"Loading YOLO model from {YOLO_MODEL_PATH} on device {DEVICE} ...") YOLO_MODEL = YOLO(YOLO_MODEL_PATH) # If model includes names/labels, prefer them # ultralytics classification models often have .model.names or .names if hasattr(YOLO_MODEL.model, "names"): model_names = list(YOLO_MODEL.model.names.values()) if isinstance(YOLO_MODEL.model.names, dict) else list(YOLO_MODEL.model.names) if model_names: CLASS_NAMES = model_names # attempt to load DeepForest model print("Loading DeepForest model (Weecology/deepforest-tree) ...") DEEPFOREST_MODEL = main.deepforest() DEEPFOREST_MODEL.load_model("Weecology/deepforest-tree") DEEPFOREST_MODEL.config["score_thresh"] = 0.15 print("DeepForest model ready") except Exception as e: print("Failed to load model:", str(e)) traceback.print_exc() # ------------------------ # Routes # ------------------------ # Health check endpoint: Returns server status and model loading info. @app.get("/health") def health(): return { "status": "ok", "deepforest_model_loaded": DEEPFOREST_MODEL is not None, "yolo_model_loaded": YOLO_MODEL is not None, "classes_known": CLASS_NAMES is not None } # Analyze endpoint: Accepts an image upload, runs YOLO object detection, draws bounding boxes with labels and confidences, and returns the annotated image. @app.post("/analyze") async def analyze(image: UploadFile = File(...), conf: float = Form(0.25), iou: float = Form(0.45)): """ Accepts a multipart/form-data file upload (key: image). Runs YOLO detection on the image, draws rectangles with label and confidence for each detected tree, and returns the annotated image (JPEG). """ if YOLO_MODEL is None: return JSONResponse(status_code=500, content={"error": "Model not loaded on server."}) try: contents = await image.read() pil_img = read_imagefile(contents) # Convert PIL -> RGB numpy for DeepForest and OpenCV # FIXED: Removed .astype(np.float32) to keep as uint8 (standard image format) img_np = np.array(pil_img) if img_np.shape[-1] == 4: # drop alpha img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2RGB) # Ensure we have DeepForest model if DEEPFOREST_MODEL is None: return JSONResponse(status_code=500, content={"error": "DeepForest model not loaded on server."}) # DeepForest expects RGB numpy image df_pred = DEEPFOREST_MODEL.predict_image(img_np) # df_pred is expected to be a pandas DataFrame with xmin, ymin, xmax, ymax, score health_states = [] health_confidences = [] for _, row in df_pred.iterrows(): xmin, ymin, xmax, ymax = int(row.get("xmin", 0)), int(row.get("ymin", 0)), int(row.get("xmax", 0)), int(row.get("ymax", 0)) # Clip to image bounds xmin = max(0, xmin) ymin = max(0, ymin) xmax = min(img_np.shape[1], xmax) ymax = min(img_np.shape[0], ymax) if xmax <= xmin or ymax <= ymin: print(f"Invalid bounding box: ({xmin}, {ymin}, {xmax}, {ymax})") health_states.append("unknown") health_confidences.append(0.0) continue crop = img_np[ymin:ymax, xmin:xmax] if crop.size == 0: print("Empty crop detected") health_states.append("unknown") health_confidences.append(0.0) continue # Classify crop with YOLO model (use classification mode) try: # FIXED: Pass uint8 array directly to YOLO (matching Kaggle code) results = YOLO_MODEL.predict(source=crop, device=DEVICE, imgsz=224, batch=1, verbose=False) if not results or len(results) == 0: print("No results from YOLO") health_states.append("unknown") health_confidences.append(0.0) continue r = results[0] # FIXED: Use the same approach as Kaggle code try: if hasattr(r, "probs") and r.probs is not None: # Get the top predicted class predicted_class = r.names[r.probs.top1] confidence = float(r.probs.top1conf) health_states.append(predicted_class) health_confidences.append(confidence) print(f"Classified as {predicted_class} with confidence {confidence:.3f}") else: print("No probs attribute found in results") health_states.append("unknown") health_confidences.append(0.0) except Exception as e: print(f"Error extracting classification results: {e}") traceback.print_exc() health_states.append("unknown") health_confidences.append(0.0) except Exception as e: print(f"YOLO prediction error: {e}") traceback.print_exc() health_states.append("unknown") health_confidences.append(0.0) # attach columns try: df_pred["health_state"] = health_states df_pred["health_confidence"] = health_confidences except Exception as e: print(f"Error attaching columns: {e}") # Draw annotations on PIL image from PIL import ImageDraw, ImageFont draw = ImageDraw.Draw(pil_img) try: font = ImageFont.truetype("arial.ttf", size=14) except Exception: font = ImageFont.load_default() color_map = { "healthy": (4, 189, 44), "dry": (192, 217, 4), # "dead": (255, 54, 54), "unknown": (128, 128, 128), } for _, row in df_pred.iterrows(): xmin, ymin, xmax, ymax = int(row.get("xmin", 0)), int(row.get("ymin", 0)), int(row.get("xmax", 0)), int(row.get("ymax", 0)) health = str(row.get("health_state", "unknown")) conf = float(row.get("health_confidence", 0.0)) det_score = float(row.get("score", 0.0)) if row.get("score") is not None else 0.0 color = color_map.get(health.lower(), (255, 255, 255)) # draw rectangle draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=3) label = f"{health}\nYOLO: {conf:.2f} DET: {det_score:.2f}" # text background text_w, text_h = get_text_size(draw, label, font) text_bg = (xmin, max(0, ymin - text_h - 4), xmin + text_w + 4, ymin) draw.rectangle(text_bg, fill=color) draw.multiline_text((xmin + 2, max(0, ymin - text_h - 2)), label, fill=(255, 255, 255), font=font) # Print summary print("\n=== Health State Summary ===") print(df_pred["health_state"].value_counts()) print(f"\nProcessed {len(df_pred)} trees") # return annotated image and counts buf = io.BytesIO() pil_img.save(buf, format="JPEG", quality=90) buf.seek(0) # compute counts try: states = df_pred.get("health_state") if states is None: total_trees = 0 healthy_trees = 0 stressed_trees = 0 dead_trees = 0 else: states_filled = states.fillna("unknown").astype(str).str.lower() total_trees = int(len(states_filled)) healthy_trees = int((states_filled == "healthy").sum()) dry_trees = int((states_filled == "dry").sum()) except Exception: total_trees = len(df_pred) if df_pred is not None else 0 healthy_trees = dry_trees = 0 img_b64 = base64.b64encode(buf.getvalue()).decode("ascii") return JSONResponse(content={ "image": img_b64, "total_trees_count": total_trees, "healthy_trees_count": healthy_trees, "dry_trees_count": dry_trees }) except Exception as e: traceback.print_exc() return JSONResponse(status_code=500, content={"error": str(e)}) # ------------------------ # Local debug run (when not in HF Space this can be used) # ------------------------ if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 8000)))