File size: 4,089 Bytes
41c9fff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
import io
import base64
import gc
import torch
import numpy as np
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from PIL import Image
from ultralytics import YOLO
import supervision as sv
# Environment setup for CPU efficiency
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["OMP_NUM_THREADS"] = "4"
torch.set_num_threads(4)
# Import local RF-DETR wrapper (assuming the library is installed)
from rfdetr import RFDETRSegPreview
app = FastAPI()
# Model paths and Globals
SEG_MODEL_PATH = "/tmp/checkpoint_best_total.pth"
CLS_MODEL_PATH = "weights/yolo12_cls.pt"
models = {"seg": None, "cls": None}
def load_models():
if models["seg"] is None:
# RF-DETR Initialization
models["seg"] = RFDETRSegPreview(pretrain_weights=SEG_MODEL_PATH)
models["seg"].optimize_for_inference()
if models["cls"] is None:
# YOLO12-cls Initialization
models["cls"] = YOLO(CLS_MODEL_PATH)
class PredictionConfig(BaseModel):
image: str
seg_enabled: bool
seg_conf: float
seg_show_conf: bool
cls_enabled: bool
cls_show_conf: bool
cls_show_label: bool
@app.get("/", response_class=HTMLResponse)
async def serve_ui():
with open("index.html", "r") as f:
return f.read()
@app.post("/predict")
async def predict(config: PredictionConfig):
load_models()
try:
# Decode image
header, encoded = config.image.split(",", 1) if "," in config.image else (None, config.image)
img_bytes = base64.b64decode(encoded)
original_img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# 1. Segmentation Phase
detections = models["seg"].predict(original_img, threshold=config.seg_conf)
if len(detections) == 0:
return {"annotated": config.image, "count": 0}
# 2. Classification Phase (if enabled)
labels = []
if config.cls_enabled:
for i in range(len(detections.xyxy)):
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
crop = original_img.crop((x1, y1, x2, y2))
cls_res = models["cls"](crop)[0]
top1_idx = cls_res.probs.top1
name = cls_res.names[top1_idx]
conf = float(cls_res.probs.top1conf)
label_str = ""
if config.cls_show_label: label_str += f"{name} "
if config.cls_show_conf: label_str += f"{conf:.2f}"
labels.append(label_str.strip())
else:
# Fallback to generic labels or segmentation conf
for conf in detections.confidence:
labels.append(f"Leaf {conf:.2f}" if config.seg_show_conf else "Leaf")
# 3. Annotation Phase
palette = sv.ColorPalette.from_hex(["#EA782D", "#FF7A5A", "#FFA382"])
mask_annotator = sv.MaskAnnotator(color=palette)
label_annotator = sv.LabelAnnotator(
color=palette,
text_position=sv.Position.CENTER_OF_MASS,
text_scale=0.5
)
annotated_img = original_img.copy()
if config.seg_enabled:
annotated_img = mask_annotator.annotate(scene=annotated_img, detections=detections)
annotated_img = label_annotator.annotate(scene=annotated_img, detections=detections, labels=labels)
# Encode result
buffered = io.BytesIO()
annotated_img.save(buffered, format="PNG")
encoded_res = base64.b64encode(buffered.getvalue()).decode("ascii")
return {
"annotated": f"data:image/png;base64,{encoded_res}",
"count": len(detections)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
gc.collect()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |