File size: 4,089 Bytes
41c9fff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import io
import base64
import gc
import torch
import numpy as np
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from PIL import Image
from ultralytics import YOLO
import supervision as sv

# Environment setup for CPU efficiency
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["OMP_NUM_THREADS"] = "4"
torch.set_num_threads(4)

# Import local RF-DETR wrapper (assuming the library is installed)
from rfdetr import RFDETRSegPreview

app = FastAPI()

# Model paths and Globals
SEG_MODEL_PATH = "/tmp/checkpoint_best_total.pth"
CLS_MODEL_PATH = "weights/yolo12_cls.pt"

models = {"seg": None, "cls": None}

def load_models():
    if models["seg"] is None:
        # RF-DETR Initialization
        models["seg"] = RFDETRSegPreview(pretrain_weights=SEG_MODEL_PATH)
        models["seg"].optimize_for_inference()
    if models["cls"] is None:
        # YOLO12-cls Initialization
        models["cls"] = YOLO(CLS_MODEL_PATH)

class PredictionConfig(BaseModel):
    image: str
    seg_enabled: bool
    seg_conf: float
    seg_show_conf: bool
    cls_enabled: bool
    cls_show_conf: bool
    cls_show_label: bool

@app.get("/", response_class=HTMLResponse)
async def serve_ui():
    with open("index.html", "r") as f:
        return f.read()

@app.post("/predict")
async def predict(config: PredictionConfig):
    load_models()
    
    try:
        # Decode image
        header, encoded = config.image.split(",", 1) if "," in config.image else (None, config.image)
        img_bytes = base64.b64decode(encoded)
        original_img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
        
        # 1. Segmentation Phase
        detections = models["seg"].predict(original_img, threshold=config.seg_conf)
        
        if len(detections) == 0:
            return {"annotated": config.image, "count": 0}

        # 2. Classification Phase (if enabled)
        labels = []
        if config.cls_enabled:
            for i in range(len(detections.xyxy)):
                x1, y1, x2, y2 = detections.xyxy[i].astype(int)
                crop = original_img.crop((x1, y1, x2, y2))
                cls_res = models["cls"](crop)[0]
                
                top1_idx = cls_res.probs.top1
                name = cls_res.names[top1_idx]
                conf = float(cls_res.probs.top1conf)
                
                label_str = ""
                if config.cls_show_label: label_str += f"{name} "
                if config.cls_show_conf: label_str += f"{conf:.2f}"
                labels.append(label_str.strip())
        else:
            # Fallback to generic labels or segmentation conf
            for conf in detections.confidence:
                labels.append(f"Leaf {conf:.2f}" if config.seg_show_conf else "Leaf")

        # 3. Annotation Phase
        palette = sv.ColorPalette.from_hex(["#EA782D", "#FF7A5A", "#FFA382"])
        mask_annotator = sv.MaskAnnotator(color=palette)
        label_annotator = sv.LabelAnnotator(
            color=palette, 
            text_position=sv.Position.CENTER_OF_MASS,
            text_scale=0.5
        )

        annotated_img = original_img.copy()
        if config.seg_enabled:
            annotated_img = mask_annotator.annotate(scene=annotated_img, detections=detections)
        
        annotated_img = label_annotator.annotate(scene=annotated_img, detections=detections, labels=labels)

        # Encode result
        buffered = io.BytesIO()
        annotated_img.save(buffered, format="PNG")
        encoded_res = base64.b64encode(buffered.getvalue()).decode("ascii")
        
        return {
            "annotated": f"data:image/png;base64,{encoded_res}",
            "count": len(detections)
        }

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
    finally:
        gc.collect()

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)