|
|
|
|
|
import os
|
|
|
import io, base64, requests
|
|
|
from typing import Any, Dict, Optional
|
|
|
from PIL import Image, ImageDraw
|
|
|
from ultralytics import YOLO
|
|
|
|
|
|
|
|
|
from google import genai
|
|
|
from google.genai import types
|
|
|
|
|
|
def _to_ints(xyxy):
|
|
|
x1, y1, x2, y2 = xyxy
|
|
|
return int(x1), int(y1), int(x2), int(y2)
|
|
|
|
|
|
def _annotate_to_b64(img: Image.Image, boxes_out):
|
|
|
if not boxes_out:
|
|
|
return None
|
|
|
draw = ImageDraw.Draw(img)
|
|
|
for b in boxes_out:
|
|
|
draw.rectangle([(b["x1"], b["y1"]), (b["x2"], b["y2"])], outline=(255, 0, 0), width=3)
|
|
|
buf = io.BytesIO()
|
|
|
img.save(buf, format="PNG")
|
|
|
return base64.b64encode(buf.getvalue()).decode("utf-8")
|
|
|
|
|
|
def _ask_gemini_with_image(image_bytes: bytes, meta: dict) -> str:
|
|
|
api_key = os.getenv("API_KEY", "")
|
|
|
if not api_key:
|
|
|
return ""
|
|
|
|
|
|
model_name = os.getenv("MODEL_NAME", "gemini-2.0-pro")
|
|
|
|
|
|
client = genai.Client(api_key=api_key)
|
|
|
|
|
|
context_lines = [
|
|
|
f"YOLO is_diseased: {meta.get('is_diseased')}",
|
|
|
f"YOLO max_confidence: {meta.get('max_confidence')}",
|
|
|
f"YOLO num_detections: {meta.get('num_detections')}",
|
|
|
f"Image size: {meta.get('image_width')}x{meta.get('image_height')}",
|
|
|
f"threshold_conf: {meta.get('threshold_conf')}",
|
|
|
]
|
|
|
for i, b in enumerate(meta.get("boxes", []), 1):
|
|
|
context_lines.append(
|
|
|
f"Box#{i}: ({b['x1']},{b['y1']},{b['x2']},{b['y2']}), conf={b['conf']}"
|
|
|
)
|
|
|
|
|
|
system_instruction = (
|
|
|
"You are a plant pathology assistant for Brassica (bok choy). "
|
|
|
"Analyze the annotated image (rectangles show suspicious regions). "
|
|
|
"Be concise and avoid over-diagnosis beyond visible evidence. "
|
|
|
"Also, re-evaluate based on your own understanding and not 100% on the data provided."
|
|
|
)
|
|
|
user_prompt = (
|
|
|
"Given the annotated image of bok choy leaves, determine whether disease signs are present.\n"
|
|
|
"If sick, describe the visible symptoms and suggest possible illnesses with a brief explanation. Also answer in one sentence\n"
|
|
|
"Context from detector:\n" + "\n".join(context_lines) + "\n"
|
|
|
"Finally transform result to vietnamese. Only reply result vietnamese (not include english)"
|
|
|
"Please answer clearly with just one sentence."
|
|
|
)
|
|
|
|
|
|
try:
|
|
|
resp = client.models.generate_content(
|
|
|
model=model_name,
|
|
|
contents=[
|
|
|
system_instruction,
|
|
|
types.Part.from_bytes(data=image_bytes, mime_type="image/png"),
|
|
|
user_prompt,
|
|
|
],
|
|
|
)
|
|
|
return (resp.text or "").strip()
|
|
|
except Exception:
|
|
|
return ""
|
|
|
|
|
|
class EndpointHandler:
|
|
|
def __init__(self, path: str = ""):
|
|
|
self.model = YOLO(f"{path}/best.pt")
|
|
|
|
|
|
def _load_image(self, data: Dict) -> Image.Image:
|
|
|
if "inputs" in data and isinstance(data["inputs"], dict):
|
|
|
data = data["inputs"]
|
|
|
|
|
|
if isinstance(data, dict) and data.get("image_url"):
|
|
|
img_bytes = requests.get(data["image_url"], timeout=15).content
|
|
|
return Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
|
|
|
|
|
if isinstance(data, dict) and data.get("image_base64"):
|
|
|
img_bytes = base64.b64decode(data["image_base64"])
|
|
|
return Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
|
|
|
|
|
if isinstance(data, (bytes, bytearray)):
|
|
|
return Image.open(io.BytesIO(data)).convert("RGB")
|
|
|
|
|
|
raise ValueError("No image provided. Use 'image_url' or 'image_base64' or raw bytes.")
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
conf = 0.5
|
|
|
if isinstance(data, dict):
|
|
|
|
|
|
if "conf" in data:
|
|
|
conf = float(data["conf"])
|
|
|
elif "inputs" in data and isinstance(data["inputs"], dict) and "conf" in data["inputs"]:
|
|
|
conf = float(data["inputs"]["conf"])
|
|
|
|
|
|
|
|
|
img = self._load_image(data)
|
|
|
W, H = img.width, img.height
|
|
|
|
|
|
|
|
|
results = self.model.predict(img, conf=conf, verbose=False)
|
|
|
r = results[0]
|
|
|
|
|
|
|
|
|
boxes_out = []
|
|
|
max_conf = 0.0
|
|
|
|
|
|
if r.boxes is not None:
|
|
|
for b in r.boxes:
|
|
|
xyxy = b.xyxy[0].tolist()
|
|
|
confv = float(b.conf[0])
|
|
|
x1, y1, x2, y2 = _to_ints(xyxy)
|
|
|
boxes_out.append({
|
|
|
"cls": "diseased",
|
|
|
"conf": confv,
|
|
|
"x1": x1, "y1": y1, "x2": x2, "y2": y2,
|
|
|
})
|
|
|
if confv > max_conf:
|
|
|
max_conf = confv
|
|
|
|
|
|
is_diseased = len(boxes_out) > 0
|
|
|
|
|
|
|
|
|
annotated_b64: Optional[str] = _annotate_to_b64(img.copy(), boxes_out) if is_diseased else None
|
|
|
|
|
|
|
|
|
gemini_text = ""
|
|
|
if is_diseased and annotated_b64:
|
|
|
meta = {
|
|
|
"is_diseased": is_diseased,
|
|
|
"max_confidence": max_conf,
|
|
|
"num_detections": len(boxes_out),
|
|
|
"image_width": W,
|
|
|
"image_height": H,
|
|
|
"threshold_conf": conf,
|
|
|
"boxes": boxes_out,
|
|
|
}
|
|
|
try:
|
|
|
gemini_text = _ask_gemini_with_image(base64.b64decode(annotated_b64), meta)
|
|
|
except Exception:
|
|
|
gemini_text = ""
|
|
|
|
|
|
return {
|
|
|
"is_diseased": is_diseased,
|
|
|
"max_confidence": max_conf,
|
|
|
"num_detections": len(boxes_out),
|
|
|
"image_width": W,
|
|
|
"image_height": H,
|
|
|
"threshold_conf": conf,
|
|
|
"boxes": boxes_out,
|
|
|
"prediction_text": gemini_text,
|
|
|
}
|
|
|
|