|
|
import gradio as gr |
|
|
from gradio_client import Client, handle_file |
|
|
from PIL import Image |
|
|
import tempfile |
|
|
import json |
|
|
import base64 |
|
|
import io |
|
|
import traceback |
|
|
|
|
|
|
|
|
resnet_client = Client("raqiat123/crop_disease_detection") |
|
|
yolo_client = Client("SoraRyuu/cv_first") |
|
|
|
|
|
def safe_load_json(maybe_json_str): |
|
|
"""Try to parse a JSON string, otherwise return original.""" |
|
|
if not isinstance(maybe_json_str, str): |
|
|
return maybe_json_str |
|
|
try: |
|
|
return json.loads(maybe_json_str) |
|
|
except Exception: |
|
|
return maybe_json_str |
|
|
|
|
|
def parse_model_response(resp): |
|
|
""" |
|
|
Normalize a response from gradio_client.predict into: |
|
|
- primary: dict of {label: confidence} (or None) |
|
|
- optional_image: a PIL.Image (or None) |
|
|
- raw: original response (kept for debug) |
|
|
Handles: |
|
|
- dict |
|
|
- JSON strings |
|
|
- [dict, image], (dict, image) |
|
|
- list where first element is dict |
|
|
- base64 image strings (attempt decode) |
|
|
""" |
|
|
primary = None |
|
|
optional_image = None |
|
|
raw = resp |
|
|
|
|
|
|
|
|
if isinstance(resp, (list, tuple)) and len(resp) > 0: |
|
|
|
|
|
first = resp[0] |
|
|
first = safe_load_json(first) |
|
|
if isinstance(first, dict): |
|
|
primary = first |
|
|
|
|
|
if len(resp) > 1: |
|
|
second = resp[1] |
|
|
|
|
|
if isinstance(second, Image.Image): |
|
|
optional_image = second |
|
|
|
|
|
elif isinstance(second, (bytes, bytearray)): |
|
|
try: |
|
|
optional_image = Image.open(io.BytesIO(second)).convert("RGB") |
|
|
except Exception: |
|
|
optional_image = None |
|
|
|
|
|
elif isinstance(second, str): |
|
|
try: |
|
|
|
|
|
if second.startswith("data:"): |
|
|
header, b64 = second.split(",", 1) |
|
|
decoded = base64.b64decode(b64) |
|
|
optional_image = Image.open(io.BytesIO(decoded)).convert("RGB") |
|
|
else: |
|
|
decoded = base64.b64decode(second) |
|
|
optional_image = Image.open(io.BytesIO(decoded)).convert("RGB") |
|
|
except Exception: |
|
|
optional_image = None |
|
|
|
|
|
if primary is None and len(resp) > 1: |
|
|
candidate = safe_load_json(resp[1]) |
|
|
if isinstance(candidate, dict): |
|
|
primary = candidate |
|
|
|
|
|
|
|
|
if primary is None: |
|
|
r = safe_load_json(resp) |
|
|
if isinstance(r, dict): |
|
|
primary = r |
|
|
|
|
|
|
|
|
if primary is None: |
|
|
try: |
|
|
|
|
|
if isinstance(resp, str): |
|
|
|
|
|
idx = resp.find("{") |
|
|
if idx != -1: |
|
|
candidate = safe_load_json(resp[idx:]) |
|
|
if isinstance(candidate, dict): |
|
|
primary = candidate |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return primary, optional_image, raw |
|
|
|
|
|
def extract_best_prediction(result_dict): |
|
|
"""Return (label, confidence) or (None, 0.0)""" |
|
|
if not result_dict or not isinstance(result_dict, dict): |
|
|
return None, 0.0 |
|
|
try: |
|
|
best_label = max(result_dict, key=result_dict.get) |
|
|
best_conf = float(result_dict[best_label]) |
|
|
return best_label, best_conf |
|
|
except Exception: |
|
|
|
|
|
try: |
|
|
converted = {k: float(v) for k, v in result_dict.items()} |
|
|
best_label = max(converted, key=converted.get) |
|
|
return best_label, float(converted[best_label]) |
|
|
except Exception: |
|
|
return None, 0.0 |
|
|
|
|
|
def combined_predict(image_pil): |
|
|
""" |
|
|
image_pil: PIL.Image from Gradio |
|
|
Returns: (text, json) where json contains debug info if error happened |
|
|
""" |
|
|
try: |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: |
|
|
image_pil.save(tmp.name) |
|
|
img_path = tmp.name |
|
|
|
|
|
|
|
|
try: |
|
|
resnet_raw = resnet_client.predict(image=handle_file(img_path), api_name="/predict") |
|
|
except Exception as e: |
|
|
resnet_raw = {"error": f"resnet predict call failed: {repr(e)}"} |
|
|
|
|
|
try: |
|
|
yolo_raw = yolo_client.predict(image=handle_file(img_path), api_name="/predict") |
|
|
except Exception as e: |
|
|
yolo_raw = {"error": f"yolo predict call failed: {repr(e)}"} |
|
|
|
|
|
|
|
|
resnet_dict, resnet_img, resnet_rawstore = parse_model_response(resnet_raw) |
|
|
yolo_dict, yolo_img, yolo_rawstore = parse_model_response(yolo_raw) |
|
|
|
|
|
|
|
|
r_label, r_conf = extract_best_prediction(resnet_dict) |
|
|
y_label, y_conf = extract_best_prediction(yolo_dict) |
|
|
|
|
|
debug = { |
|
|
"resnet_raw": resnet_rawstore, |
|
|
"resnet_parsed_dict": resnet_dict, |
|
|
"resnet_best": {"label": r_label, "confidence": r_conf}, |
|
|
"yolo_raw": yolo_rawstore, |
|
|
"yolo_parsed_dict": yolo_dict, |
|
|
"yolo_best": {"label": y_label, "confidence": y_conf}, |
|
|
} |
|
|
|
|
|
|
|
|
if r_conf >= y_conf: |
|
|
chosen = { |
|
|
"chosen_model": "ResNet (crop_disease_detection)", |
|
|
"label": r_label, |
|
|
"confidence": r_conf, |
|
|
"full_output": resnet_dict |
|
|
} |
|
|
text = f"Model Selected: ResNet\nPrediction: {r_label}\nConfidence: {r_conf:.4f}" |
|
|
else: |
|
|
chosen = { |
|
|
"chosen_model": "YOLO (cv_first)", |
|
|
"label": y_label, |
|
|
"confidence": y_conf, |
|
|
"full_output": yolo_dict |
|
|
} |
|
|
text = f"Model Selected: YOLO\nPrediction: {y_label}\nConfidence: {y_conf:.4f}" |
|
|
|
|
|
|
|
|
out_json = {"chosen": chosen, "debug": debug} |
|
|
return text, out_json |
|
|
|
|
|
except Exception as e: |
|
|
tb = traceback.format_exc() |
|
|
|
|
|
return ("❌ Internal error: " + str(e), |
|
|
{"error": str(e), "traceback": tb}) |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# 🌿 Crop Disease Classifier (PIL)") |
|
|
gr.Markdown("Uploads an image (PIL). Robust parsing & debug info included.") |
|
|
|
|
|
img = gr.Image(type="pil") |
|
|
text_out = gr.Textbox(label="Final Prediction", lines=2) |
|
|
json_out = gr.JSON(label="Raw Output (debug)") |
|
|
|
|
|
btn = gr.Button("Run Prediction") |
|
|
btn.click(fn=combined_predict, inputs=img, outputs=[text_out, json_out]) |
|
|
|
|
|
demo.launch() |