|
|
import os
|
|
|
import io
|
|
|
import base64
|
|
|
import gc
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
from fastapi.responses import HTMLResponse
|
|
|
from pydantic import BaseModel
|
|
|
from PIL import Image
|
|
|
from ultralytics import YOLO
|
|
|
import supervision as sv
|
|
|
|
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
|
os.environ["OMP_NUM_THREADS"] = "4"
|
|
|
torch.set_num_threads(4)
|
|
|
|
|
|
|
|
|
from rfdetr import RFDETRSegPreview
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
SEG_MODEL_PATH = "/tmp/checkpoint_best_total.pth"
|
|
|
CLS_MODEL_PATH = "weights/yolo12_cls.pt"
|
|
|
|
|
|
models = {"seg": None, "cls": None}
|
|
|
|
|
|
def load_models():
|
|
|
if models["seg"] is None:
|
|
|
|
|
|
models["seg"] = RFDETRSegPreview(pretrain_weights=SEG_MODEL_PATH)
|
|
|
models["seg"].optimize_for_inference()
|
|
|
if models["cls"] is None:
|
|
|
|
|
|
models["cls"] = YOLO(CLS_MODEL_PATH)
|
|
|
|
|
|
class PredictionConfig(BaseModel):
|
|
|
image: str
|
|
|
seg_enabled: bool
|
|
|
seg_conf: float
|
|
|
seg_show_conf: bool
|
|
|
cls_enabled: bool
|
|
|
cls_show_conf: bool
|
|
|
cls_show_label: bool
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
|
async def serve_ui():
|
|
|
with open("index.html", "r") as f:
|
|
|
return f.read()
|
|
|
|
|
|
@app.post("/predict")
|
|
|
async def predict(config: PredictionConfig):
|
|
|
load_models()
|
|
|
|
|
|
try:
|
|
|
|
|
|
header, encoded = config.image.split(",", 1) if "," in config.image else (None, config.image)
|
|
|
img_bytes = base64.b64decode(encoded)
|
|
|
original_img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
|
|
|
|
|
|
|
|
detections = models["seg"].predict(original_img, threshold=config.seg_conf)
|
|
|
|
|
|
if len(detections) == 0:
|
|
|
return {"annotated": config.image, "count": 0}
|
|
|
|
|
|
|
|
|
labels = []
|
|
|
if config.cls_enabled:
|
|
|
for i in range(len(detections.xyxy)):
|
|
|
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
|
|
|
crop = original_img.crop((x1, y1, x2, y2))
|
|
|
cls_res = models["cls"](crop)[0]
|
|
|
|
|
|
top1_idx = cls_res.probs.top1
|
|
|
name = cls_res.names[top1_idx]
|
|
|
conf = float(cls_res.probs.top1conf)
|
|
|
|
|
|
label_str = ""
|
|
|
if config.cls_show_label: label_str += f"{name} "
|
|
|
if config.cls_show_conf: label_str += f"{conf:.2f}"
|
|
|
labels.append(label_str.strip())
|
|
|
else:
|
|
|
|
|
|
for conf in detections.confidence:
|
|
|
labels.append(f"Leaf {conf:.2f}" if config.seg_show_conf else "Leaf")
|
|
|
|
|
|
|
|
|
palette = sv.ColorPalette.from_hex(["#EA782D", "#FF7A5A", "#FFA382"])
|
|
|
mask_annotator = sv.MaskAnnotator(color=palette)
|
|
|
label_annotator = sv.LabelAnnotator(
|
|
|
color=palette,
|
|
|
text_position=sv.Position.CENTER_OF_MASS,
|
|
|
text_scale=0.5
|
|
|
)
|
|
|
|
|
|
annotated_img = original_img.copy()
|
|
|
if config.seg_enabled:
|
|
|
annotated_img = mask_annotator.annotate(scene=annotated_img, detections=detections)
|
|
|
|
|
|
annotated_img = label_annotator.annotate(scene=annotated_img, detections=detections, labels=labels)
|
|
|
|
|
|
|
|
|
buffered = io.BytesIO()
|
|
|
annotated_img.save(buffered, format="PNG")
|
|
|
encoded_res = base64.b64encode(buffered.getvalue()).decode("ascii")
|
|
|
|
|
|
return {
|
|
|
"annotated": f"data:image/png;base64,{encoded_res}",
|
|
|
"count": len(detections)
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
finally:
|
|
|
gc.collect()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import uvicorn
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |