Yasser Babaouamer
Added the new 2 classes model
5d9efa5
# app.py
import io
import os
import uvicorn
import traceback
from typing import Optional
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse, StreamingResponse, Response
from pydantic import BaseModel
from PIL import Image
import numpy as np
import pandas as pd
import cv2
import base64
# Ultralytics YOLO
from ultralytics import YOLO
# DeepForest
from deepforest import main
# ------------------------
# Configuration
# ------------------------
YOLO_MODEL_PATH = "olive_cls_2c.pt"
DEVICE = os.environ.get("DEVICE", "cpu") # 'cpu' or 'cuda'
def read_imagefile(file_bytes) -> Image.Image:
return Image.open(io.BytesIO(file_bytes)).convert("RGB")
def get_text_size(draw, text: str, font):
"""Return (width, height) for possibly-multiline text.
Uses draw.multiline_textbbox / draw.textbbox when available, falls back to font.getsize.
"""
try:
# Pillow >= 8 has multiline_textbbox
bbox = draw.multiline_textbbox((0, 0), text, font=font)
return (bbox[2] - bbox[0], bbox[3] - bbox[1])
except Exception:
try:
bbox = draw.textbbox((0, 0), text, font=font)
return (bbox[2] - bbox[0], bbox[3] - bbox[1])
except Exception:
try:
return font.getsize(text)
except Exception:
# naive fallback: estimate per character
lines = text.splitlines() or [text]
widths = [len(line) * 7 for line in lines]
heights = [12 for _ in lines]
return (max(widths), sum(heights))
# ------------------------
# Initialize App & Model
# ------------------------
app = FastAPI(title="Olive Tree Analyzer")
# Load YOLO model (classification/detection)
YOLO_MODEL = None
CLASS_NAMES = None
DEEPFOREST_MODEL = None
try:
print(f"Loading YOLO model from {YOLO_MODEL_PATH} on device {DEVICE} ...")
YOLO_MODEL = YOLO(YOLO_MODEL_PATH)
# If model includes names/labels, prefer them
# ultralytics classification models often have .model.names or .names
if hasattr(YOLO_MODEL.model, "names"):
model_names = list(YOLO_MODEL.model.names.values()) if isinstance(YOLO_MODEL.model.names, dict) else list(YOLO_MODEL.model.names)
if model_names:
CLASS_NAMES = model_names
# attempt to load DeepForest model
print("Loading DeepForest model (Weecology/deepforest-tree) ...")
DEEPFOREST_MODEL = main.deepforest()
DEEPFOREST_MODEL.load_model("Weecology/deepforest-tree")
DEEPFOREST_MODEL.config["score_thresh"] = 0.15
print("DeepForest model ready")
except Exception as e:
print("Failed to load model:", str(e))
traceback.print_exc()
# ------------------------
# Routes
# ------------------------
# Health check endpoint: Returns server status and model loading info.
@app.get("/health")
def health():
return {
"status": "ok",
"deepforest_model_loaded": DEEPFOREST_MODEL is not None,
"yolo_model_loaded": YOLO_MODEL is not None,
"classes_known": CLASS_NAMES is not None
}
# Analyze endpoint: Accepts an image upload, runs YOLO object detection, draws bounding boxes with labels and confidences, and returns the annotated image.
@app.post("/analyze")
async def analyze(image: UploadFile = File(...), conf: float = Form(0.25), iou: float = Form(0.45)):
"""
Accepts a multipart/form-data file upload (key: image).
Runs YOLO detection on the image, draws rectangles with label and confidence for each detected tree,
and returns the annotated image (JPEG).
"""
if YOLO_MODEL is None:
return JSONResponse(status_code=500, content={"error": "Model not loaded on server."})
try:
contents = await image.read()
pil_img = read_imagefile(contents)
# Convert PIL -> RGB numpy for DeepForest and OpenCV
# FIXED: Removed .astype(np.float32) to keep as uint8 (standard image format)
img_np = np.array(pil_img)
if img_np.shape[-1] == 4:
# drop alpha
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2RGB)
# Ensure we have DeepForest model
if DEEPFOREST_MODEL is None:
return JSONResponse(status_code=500, content={"error": "DeepForest model not loaded on server."})
# DeepForest expects RGB numpy image
df_pred = DEEPFOREST_MODEL.predict_image(img_np)
# df_pred is expected to be a pandas DataFrame with xmin, ymin, xmax, ymax, score
health_states = []
health_confidences = []
for _, row in df_pred.iterrows():
xmin, ymin, xmax, ymax = int(row.get("xmin", 0)), int(row.get("ymin", 0)), int(row.get("xmax", 0)), int(row.get("ymax", 0))
# Clip to image bounds
xmin = max(0, xmin)
ymin = max(0, ymin)
xmax = min(img_np.shape[1], xmax)
ymax = min(img_np.shape[0], ymax)
if xmax <= xmin or ymax <= ymin:
print(f"Invalid bounding box: ({xmin}, {ymin}, {xmax}, {ymax})")
health_states.append("unknown")
health_confidences.append(0.0)
continue
crop = img_np[ymin:ymax, xmin:xmax]
if crop.size == 0:
print("Empty crop detected")
health_states.append("unknown")
health_confidences.append(0.0)
continue
# Classify crop with YOLO model (use classification mode)
try:
# FIXED: Pass uint8 array directly to YOLO (matching Kaggle code)
results = YOLO_MODEL.predict(source=crop, device=DEVICE, imgsz=224, batch=1, verbose=False)
if not results or len(results) == 0:
print("No results from YOLO")
health_states.append("unknown")
health_confidences.append(0.0)
continue
r = results[0]
# FIXED: Use the same approach as Kaggle code
try:
if hasattr(r, "probs") and r.probs is not None:
# Get the top predicted class
predicted_class = r.names[r.probs.top1]
confidence = float(r.probs.top1conf)
health_states.append(predicted_class)
health_confidences.append(confidence)
print(f"Classified as {predicted_class} with confidence {confidence:.3f}")
else:
print("No probs attribute found in results")
health_states.append("unknown")
health_confidences.append(0.0)
except Exception as e:
print(f"Error extracting classification results: {e}")
traceback.print_exc()
health_states.append("unknown")
health_confidences.append(0.0)
except Exception as e:
print(f"YOLO prediction error: {e}")
traceback.print_exc()
health_states.append("unknown")
health_confidences.append(0.0)
# attach columns
try:
df_pred["health_state"] = health_states
df_pred["health_confidence"] = health_confidences
except Exception as e:
print(f"Error attaching columns: {e}")
# Draw annotations on PIL image
from PIL import ImageDraw, ImageFont
draw = ImageDraw.Draw(pil_img)
try:
font = ImageFont.truetype("arial.ttf", size=14)
except Exception:
font = ImageFont.load_default()
color_map = {
"healthy": (4, 189, 44),
"dry": (192, 217, 4),
# "dead": (255, 54, 54),
"unknown": (128, 128, 128),
}
for _, row in df_pred.iterrows():
xmin, ymin, xmax, ymax = int(row.get("xmin", 0)), int(row.get("ymin", 0)), int(row.get("xmax", 0)), int(row.get("ymax", 0))
health = str(row.get("health_state", "unknown"))
conf = float(row.get("health_confidence", 0.0))
det_score = float(row.get("score", 0.0)) if row.get("score") is not None else 0.0
color = color_map.get(health.lower(), (255, 255, 255))
# draw rectangle
draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=3)
label = f"{health}\nYOLO: {conf:.2f} DET: {det_score:.2f}"
# text background
text_w, text_h = get_text_size(draw, label, font)
text_bg = (xmin, max(0, ymin - text_h - 4), xmin + text_w + 4, ymin)
draw.rectangle(text_bg, fill=color)
draw.multiline_text((xmin + 2, max(0, ymin - text_h - 2)), label, fill=(255, 255, 255), font=font)
# Print summary
print("\n=== Health State Summary ===")
print(df_pred["health_state"].value_counts())
print(f"\nProcessed {len(df_pred)} trees")
# return annotated image and counts
buf = io.BytesIO()
pil_img.save(buf, format="JPEG", quality=90)
buf.seek(0)
# compute counts
try:
states = df_pred.get("health_state")
if states is None:
total_trees = 0
healthy_trees = 0
stressed_trees = 0
dead_trees = 0
else:
states_filled = states.fillna("unknown").astype(str).str.lower()
total_trees = int(len(states_filled))
healthy_trees = int((states_filled == "healthy").sum())
dry_trees = int((states_filled == "dry").sum())
except Exception:
total_trees = len(df_pred) if df_pred is not None else 0
healthy_trees = dry_trees = 0
img_b64 = base64.b64encode(buf.getvalue()).decode("ascii")
return JSONResponse(content={
"image": img_b64,
"total_trees_count": total_trees,
"healthy_trees_count": healthy_trees,
"dry_trees_count": dry_trees
})
except Exception as e:
traceback.print_exc()
return JSONResponse(status_code=500, content={"error": str(e)})
# ------------------------
# Local debug run (when not in HF Space this can be used)
# ------------------------
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 8000)))