YOLO-Detect / handler.py
VNraizo's picture
Upload folder using huggingface_hub
f344f14 verified
# handler.py — thêm các import mới
import os
import io, base64, requests
from typing import Any, Dict, Optional
from PIL import Image, ImageDraw
from ultralytics import YOLO
# ==== Gemini ====
from google import genai
from google.genai import types
def _to_ints(xyxy):
x1, y1, x2, y2 = xyxy
return int(x1), int(y1), int(x2), int(y2)
def _annotate_to_b64(img: Image.Image, boxes_out):
if not boxes_out:
return None
draw = ImageDraw.Draw(img)
for b in boxes_out:
draw.rectangle([(b["x1"], b["y1"]), (b["x2"], b["y2"])], outline=(255, 0, 0), width=3)
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
def _ask_gemini_with_image(image_bytes: bytes, meta: dict) -> str:
api_key = os.getenv("API_KEY", "")
if not api_key:
return ""
# Dùng model PRO qua biến môi trường, default dùng model pro
model_name = os.getenv("MODEL_NAME", "gemini-2.0-pro")
client = genai.Client(api_key=api_key)
context_lines = [
f"YOLO is_diseased: {meta.get('is_diseased')}",
f"YOLO max_confidence: {meta.get('max_confidence')}",
f"YOLO num_detections: {meta.get('num_detections')}",
f"Image size: {meta.get('image_width')}x{meta.get('image_height')}",
f"threshold_conf: {meta.get('threshold_conf')}",
]
for i, b in enumerate(meta.get("boxes", []), 1):
context_lines.append(
f"Box#{i}: ({b['x1']},{b['y1']},{b['x2']},{b['y2']}), conf={b['conf']}"
)
system_instruction = (
"You are a plant pathology assistant for Brassica (bok choy). "
"Analyze the annotated image (rectangles show suspicious regions). "
"Be concise and avoid over-diagnosis beyond visible evidence. "
"Also, re-evaluate based on your own understanding and not 100% on the data provided."
)
user_prompt = (
"Given the annotated image of bok choy leaves, determine whether disease signs are present.\n"
"If sick, describe the visible symptoms and suggest possible illnesses with a brief explanation. Also answer in one sentence\n"
"Context from detector:\n" + "\n".join(context_lines) + "\n"
"Finally transform result to vietnamese. Only reply result vietnamese (not include english)"
"Please answer clearly with just one sentence."
)
try:
resp = client.models.generate_content(
model=model_name,
contents=[
system_instruction,
types.Part.from_bytes(data=image_bytes, mime_type="image/png"),
user_prompt,
],
)
return (resp.text or "").strip()
except Exception:
return ""
class EndpointHandler:
def __init__(self, path: str = ""):
self.model = YOLO(f"{path}/best.pt")
def _load_image(self, data: Dict) -> Image.Image:
if "inputs" in data and isinstance(data["inputs"], dict):
data = data["inputs"]
if isinstance(data, dict) and data.get("image_url"):
img_bytes = requests.get(data["image_url"], timeout=15).content
return Image.open(io.BytesIO(img_bytes)).convert("RGB")
if isinstance(data, dict) and data.get("image_base64"):
img_bytes = base64.b64decode(data["image_base64"])
return Image.open(io.BytesIO(img_bytes)).convert("RGB")
if isinstance(data, (bytes, bytearray)):
return Image.open(io.BytesIO(data)).convert("RGB")
raise ValueError("No image provided. Use 'image_url' or 'image_base64' or raw bytes.")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# conf mặc định 0.5 nếu không truyền
conf = 0.5
if isinstance(data, dict):
# chấp nhận ở root hoặc trong "inputs"
if "conf" in data:
conf = float(data["conf"])
elif "inputs" in data and isinstance(data["inputs"], dict) and "conf" in data["inputs"]:
conf = float(data["inputs"]["conf"])
# 1) Load ảnh
img = self._load_image(data)
W, H = img.width, img.height
# 2) Predict
results = self.model.predict(img, conf=conf, verbose=False)
r = results[0]
# 3) Build boxes theo schemas.py
boxes_out = []
max_conf = 0.0
if r.boxes is not None:
for b in r.boxes:
xyxy = b.xyxy[0].tolist()
confv = float(b.conf[0])
x1, y1, x2, y2 = _to_ints(xyxy)
boxes_out.append({
"cls": "diseased",
"conf": confv,
"x1": x1, "y1": y1, "x2": x2, "y2": y2,
})
if confv > max_conf:
max_conf = confv
is_diseased = len(boxes_out) > 0
# 4) Annotate ảnh -> base64 (nếu có box)
annotated_b64: Optional[str] = _annotate_to_b64(img.copy(), boxes_out) if is_diseased else None
# 5) Gọi Gemini (model PRO) nếu có bệnh và có ảnh annotated
gemini_text = ""
if is_diseased and annotated_b64:
meta = {
"is_diseased": is_diseased,
"max_confidence": max_conf,
"num_detections": len(boxes_out),
"image_width": W,
"image_height": H,
"threshold_conf": conf,
"boxes": boxes_out,
}
try:
gemini_text = _ask_gemini_with_image(base64.b64decode(annotated_b64), meta)
except Exception:
gemini_text = ""
return {
"is_diseased": is_diseased,
"max_confidence": max_conf,
"num_detections": len(boxes_out),
"image_width": W,
"image_height": H,
"threshold_conf": conf,
"boxes": boxes_out,
"prediction_text": gemini_text,
}