cv / app.py
SoraRyuu's picture
Update app.py
fff128f verified
import gradio as gr
from gradio_client import Client, handle_file
from PIL import Image
import tempfile
import json
import base64
import io
import traceback
# Clients for the two external Spaces
resnet_client = Client("raqiat123/crop_disease_detection")
yolo_client = Client("SoraRyuu/cv_first")
def safe_load_json(maybe_json_str):
"""Try to parse a JSON string, otherwise return original."""
if not isinstance(maybe_json_str, str):
return maybe_json_str
try:
return json.loads(maybe_json_str)
except Exception:
return maybe_json_str
def parse_model_response(resp):
"""
Normalize a response from gradio_client.predict into:
- primary: dict of {label: confidence} (or None)
- optional_image: a PIL.Image (or None)
- raw: original response (kept for debug)
Handles:
- dict
- JSON strings
- [dict, image], (dict, image)
- list where first element is dict
- base64 image strings (attempt decode)
"""
primary = None
optional_image = None
raw = resp
# If response is tuple/list, prioritize first element for dict, second for image
if isinstance(resp, (list, tuple)) and len(resp) > 0:
# try first element as dict-like
first = resp[0]
first = safe_load_json(first)
if isinstance(first, dict):
primary = first
# attempt to parse second element as image (base64 / bytes / PIL)
if len(resp) > 1:
second = resp[1]
# If second is already a PIL Image
if isinstance(second, Image.Image):
optional_image = second
# if second is bytes-like, try to open
elif isinstance(second, (bytes, bytearray)):
try:
optional_image = Image.open(io.BytesIO(second)).convert("RGB")
except Exception:
optional_image = None
# if second is base64 string
elif isinstance(second, str):
try:
# some Gradio endpoints return data URLs e.g. "data:image/png;base64,...."
if second.startswith("data:"):
header, b64 = second.split(",", 1)
decoded = base64.b64decode(b64)
optional_image = Image.open(io.BytesIO(decoded)).convert("RGB")
else:
decoded = base64.b64decode(second)
optional_image = Image.open(io.BytesIO(decoded)).convert("RGB")
except Exception:
optional_image = None
# If still no primary, maybe the first element was image and second is dict
if primary is None and len(resp) > 1:
candidate = safe_load_json(resp[1])
if isinstance(candidate, dict):
primary = candidate
# If resp itself is a dict
if primary is None:
r = safe_load_json(resp)
if isinstance(r, dict):
primary = r
# If still nothing, attempt to find a dict nested inside resp
if primary is None:
try:
# if it's a string that contains a JSON object somewhere
if isinstance(resp, str):
# try to find first "{" and parse
idx = resp.find("{")
if idx != -1:
candidate = safe_load_json(resp[idx:])
if isinstance(candidate, dict):
primary = candidate
except Exception:
pass
return primary, optional_image, raw
def extract_best_prediction(result_dict):
"""Return (label, confidence) or (None, 0.0)"""
if not result_dict or not isinstance(result_dict, dict):
return None, 0.0
try:
best_label = max(result_dict, key=result_dict.get)
best_conf = float(result_dict[best_label])
return best_label, best_conf
except Exception:
# maybe values are strings that look like floats
try:
converted = {k: float(v) for k, v in result_dict.items()}
best_label = max(converted, key=converted.get)
return best_label, float(converted[best_label])
except Exception:
return None, 0.0
def combined_predict(image_pil):
"""
image_pil: PIL.Image from Gradio
Returns: (text, json) where json contains debug info if error happened
"""
try:
# save to temp file for gradio_client
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
image_pil.save(tmp.name)
img_path = tmp.name
# 1) call resnet space
try:
resnet_raw = resnet_client.predict(image=handle_file(img_path), api_name="/predict")
except Exception as e:
resnet_raw = {"error": f"resnet predict call failed: {repr(e)}"}
# 2) call yolo space
try:
yolo_raw = yolo_client.predict(image=handle_file(img_path), api_name="/predict")
except Exception as e:
yolo_raw = {"error": f"yolo predict call failed: {repr(e)}"}
# parse responses
resnet_dict, resnet_img, resnet_rawstore = parse_model_response(resnet_raw)
yolo_dict, yolo_img, yolo_rawstore = parse_model_response(yolo_raw)
# extract bests
r_label, r_conf = extract_best_prediction(resnet_dict)
y_label, y_conf = extract_best_prediction(yolo_dict)
debug = {
"resnet_raw": resnet_rawstore,
"resnet_parsed_dict": resnet_dict,
"resnet_best": {"label": r_label, "confidence": r_conf},
"yolo_raw": yolo_rawstore,
"yolo_parsed_dict": yolo_dict,
"yolo_best": {"label": y_label, "confidence": y_conf},
}
# Choose winner
if r_conf >= y_conf:
chosen = {
"chosen_model": "ResNet (crop_disease_detection)",
"label": r_label,
"confidence": r_conf,
"full_output": resnet_dict
}
text = f"Model Selected: ResNet\nPrediction: {r_label}\nConfidence: {r_conf:.4f}"
else:
chosen = {
"chosen_model": "YOLO (cv_first)",
"label": y_label,
"confidence": y_conf,
"full_output": yolo_dict
}
text = f"Model Selected: YOLO\nPrediction: {y_label}\nConfidence: {y_conf:.4f}"
# return text and a combined JSON containing debug + chosen
out_json = {"chosen": chosen, "debug": debug}
return text, out_json
except Exception as e:
tb = traceback.format_exc()
# Show the exception and stack trace in the UI for debugging
return ("❌ Internal error: " + str(e),
{"error": str(e), "traceback": tb})
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# 🌿 Crop Disease Classifier (PIL)")
gr.Markdown("Uploads an image (PIL). Robust parsing & debug info included.")
img = gr.Image(type="pil")
text_out = gr.Textbox(label="Final Prediction", lines=2)
json_out = gr.JSON(label="Raw Output (debug)")
btn = gr.Button("Run Prediction")
btn.click(fn=combined_predict, inputs=img, outputs=[text_out, json_out])
demo.launch()