import os import io import base64 import gc import torch import numpy as np from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse from pydantic import BaseModel from PIL import Image from ultralytics import YOLO import supervision as sv # Environment setup for CPU efficiency os.environ["CUDA_VISIBLE_DEVICES"] = "" os.environ["OMP_NUM_THREADS"] = "4" torch.set_num_threads(4) # Import local RF-DETR wrapper (assuming the library is installed) from rfdetr import RFDETRSegPreview app = FastAPI() # Model paths and Globals SEG_MODEL_PATH = "/tmp/checkpoint_best_total.pth" CLS_MODEL_PATH = "weights/yolo12_cls.pt" models = {"seg": None, "cls": None} def load_models(): if models["seg"] is None: # RF-DETR Initialization models["seg"] = RFDETRSegPreview(pretrain_weights=SEG_MODEL_PATH) models["seg"].optimize_for_inference() if models["cls"] is None: # YOLO12-cls Initialization models["cls"] = YOLO(CLS_MODEL_PATH) class PredictionConfig(BaseModel): image: str seg_enabled: bool seg_conf: float seg_show_conf: bool cls_enabled: bool cls_show_conf: bool cls_show_label: bool @app.get("/", response_class=HTMLResponse) async def serve_ui(): with open("index.html", "r") as f: return f.read() @app.post("/predict") async def predict(config: PredictionConfig): load_models() try: # Decode image header, encoded = config.image.split(",", 1) if "," in config.image else (None, config.image) img_bytes = base64.b64decode(encoded) original_img = Image.open(io.BytesIO(img_bytes)).convert("RGB") # 1. Segmentation Phase detections = models["seg"].predict(original_img, threshold=config.seg_conf) if len(detections) == 0: return {"annotated": config.image, "count": 0} # 2. Classification Phase (if enabled) labels = [] if config.cls_enabled: for i in range(len(detections.xyxy)): x1, y1, x2, y2 = detections.xyxy[i].astype(int) crop = original_img.crop((x1, y1, x2, y2)) cls_res = models["cls"](crop)[0] top1_idx = cls_res.probs.top1 name = cls_res.names[top1_idx] conf = float(cls_res.probs.top1conf) label_str = "" if config.cls_show_label: label_str += f"{name} " if config.cls_show_conf: label_str += f"{conf:.2f}" labels.append(label_str.strip()) else: # Fallback to generic labels or segmentation conf for conf in detections.confidence: labels.append(f"Leaf {conf:.2f}" if config.seg_show_conf else "Leaf") # 3. Annotation Phase palette = sv.ColorPalette.from_hex(["#EA782D", "#FF7A5A", "#FFA382"]) mask_annotator = sv.MaskAnnotator(color=palette) label_annotator = sv.LabelAnnotator( color=palette, text_position=sv.Position.CENTER_OF_MASS, text_scale=0.5 ) annotated_img = original_img.copy() if config.seg_enabled: annotated_img = mask_annotator.annotate(scene=annotated_img, detections=detections) annotated_img = label_annotator.annotate(scene=annotated_img, detections=detections, labels=labels) # Encode result buffered = io.BytesIO() annotated_img.save(buffered, format="PNG") encoded_res = base64.b64encode(buffered.getvalue()).decode("ascii") return { "annotated": f"data:image/png;base64,{encoded_res}", "count": len(detections) } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) finally: gc.collect() if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)