import gradio as gr from gradio_client import Client, handle_file from PIL import Image import tempfile import json import base64 import io import traceback # Clients for the two external Spaces resnet_client = Client("raqiat123/crop_disease_detection") yolo_client = Client("SoraRyuu/cv_first") def safe_load_json(maybe_json_str): """Try to parse a JSON string, otherwise return original.""" if not isinstance(maybe_json_str, str): return maybe_json_str try: return json.loads(maybe_json_str) except Exception: return maybe_json_str def parse_model_response(resp): """ Normalize a response from gradio_client.predict into: - primary: dict of {label: confidence} (or None) - optional_image: a PIL.Image (or None) - raw: original response (kept for debug) Handles: - dict - JSON strings - [dict, image], (dict, image) - list where first element is dict - base64 image strings (attempt decode) """ primary = None optional_image = None raw = resp # If response is tuple/list, prioritize first element for dict, second for image if isinstance(resp, (list, tuple)) and len(resp) > 0: # try first element as dict-like first = resp[0] first = safe_load_json(first) if isinstance(first, dict): primary = first # attempt to parse second element as image (base64 / bytes / PIL) if len(resp) > 1: second = resp[1] # If second is already a PIL Image if isinstance(second, Image.Image): optional_image = second # if second is bytes-like, try to open elif isinstance(second, (bytes, bytearray)): try: optional_image = Image.open(io.BytesIO(second)).convert("RGB") except Exception: optional_image = None # if second is base64 string elif isinstance(second, str): try: # some Gradio endpoints return data URLs e.g. "data:image/png;base64,...." if second.startswith("data:"): header, b64 = second.split(",", 1) decoded = base64.b64decode(b64) optional_image = Image.open(io.BytesIO(decoded)).convert("RGB") else: decoded = base64.b64decode(second) optional_image = Image.open(io.BytesIO(decoded)).convert("RGB") except Exception: optional_image = None # If still no primary, maybe the first element was image and second is dict if primary is None and len(resp) > 1: candidate = safe_load_json(resp[1]) if isinstance(candidate, dict): primary = candidate # If resp itself is a dict if primary is None: r = safe_load_json(resp) if isinstance(r, dict): primary = r # If still nothing, attempt to find a dict nested inside resp if primary is None: try: # if it's a string that contains a JSON object somewhere if isinstance(resp, str): # try to find first "{" and parse idx = resp.find("{") if idx != -1: candidate = safe_load_json(resp[idx:]) if isinstance(candidate, dict): primary = candidate except Exception: pass return primary, optional_image, raw def extract_best_prediction(result_dict): """Return (label, confidence) or (None, 0.0)""" if not result_dict or not isinstance(result_dict, dict): return None, 0.0 try: best_label = max(result_dict, key=result_dict.get) best_conf = float(result_dict[best_label]) return best_label, best_conf except Exception: # maybe values are strings that look like floats try: converted = {k: float(v) for k, v in result_dict.items()} best_label = max(converted, key=converted.get) return best_label, float(converted[best_label]) except Exception: return None, 0.0 def combined_predict(image_pil): """ image_pil: PIL.Image from Gradio Returns: (text, json) where json contains debug info if error happened """ try: # save to temp file for gradio_client with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: image_pil.save(tmp.name) img_path = tmp.name # 1) call resnet space try: resnet_raw = resnet_client.predict(image=handle_file(img_path), api_name="/predict") except Exception as e: resnet_raw = {"error": f"resnet predict call failed: {repr(e)}"} # 2) call yolo space try: yolo_raw = yolo_client.predict(image=handle_file(img_path), api_name="/predict") except Exception as e: yolo_raw = {"error": f"yolo predict call failed: {repr(e)}"} # parse responses resnet_dict, resnet_img, resnet_rawstore = parse_model_response(resnet_raw) yolo_dict, yolo_img, yolo_rawstore = parse_model_response(yolo_raw) # extract bests r_label, r_conf = extract_best_prediction(resnet_dict) y_label, y_conf = extract_best_prediction(yolo_dict) debug = { "resnet_raw": resnet_rawstore, "resnet_parsed_dict": resnet_dict, "resnet_best": {"label": r_label, "confidence": r_conf}, "yolo_raw": yolo_rawstore, "yolo_parsed_dict": yolo_dict, "yolo_best": {"label": y_label, "confidence": y_conf}, } # Choose winner if r_conf >= y_conf: chosen = { "chosen_model": "ResNet (crop_disease_detection)", "label": r_label, "confidence": r_conf, "full_output": resnet_dict } text = f"Model Selected: ResNet\nPrediction: {r_label}\nConfidence: {r_conf:.4f}" else: chosen = { "chosen_model": "YOLO (cv_first)", "label": y_label, "confidence": y_conf, "full_output": yolo_dict } text = f"Model Selected: YOLO\nPrediction: {y_label}\nConfidence: {y_conf:.4f}" # return text and a combined JSON containing debug + chosen out_json = {"chosen": chosen, "debug": debug} return text, out_json except Exception as e: tb = traceback.format_exc() # Show the exception and stack trace in the UI for debugging return ("❌ Internal error: " + str(e), {"error": str(e), "traceback": tb}) # Gradio UI with gr.Blocks() as demo: gr.Markdown("# 🌿 Crop Disease Classifier (PIL)") gr.Markdown("Uploads an image (PIL). Robust parsing & debug info included.") img = gr.Image(type="pil") text_out = gr.Textbox(label="Final Prediction", lines=2) json_out = gr.JSON(label="Raw Output (debug)") btn = gr.Button("Run Prediction") btn.click(fn=combined_predict, inputs=img, outputs=[text_out, json_out]) demo.launch()