SoraRyuu commited on
Commit
fff128f
·
verified ·
1 Parent(s): 2415591

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -62
app.py CHANGED
@@ -1,88 +1,196 @@
1
  import gradio as gr
2
  from gradio_client import Client, handle_file
3
  from PIL import Image
4
- import json
5
  import tempfile
 
 
 
 
6
 
7
- # Load both external Spaces
8
  resnet_client = Client("raqiat123/crop_disease_detection")
9
  yolo_client = Client("SoraRyuu/cv_first")
10
 
 
 
 
 
 
 
 
 
11
 
12
- def extract_best_prediction(result_dict):
13
  """
14
- Extracts the best label + best confidence from:
15
- {
16
- "label1": 0.82,
17
- "label2": 0.13,
18
- ...
19
- }
 
 
 
 
20
  """
21
- if not result_dict:
22
- return None, 0.0
23
-
24
- best_label = max(result_dict, key=result_dict.get)
25
- best_conf = result_dict[best_label]
26
- return best_label, best_conf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def combined_predict(image_pil):
30
  """
31
- Input = PIL image from Gradio
 
32
  """
33
-
34
- # Save the PIL image to a temp file for HF client
35
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
36
- image_pil.save(tmp.name)
37
- img_path = tmp.name
38
-
39
- # Run both external models
40
- resnet_output = resnet_client.predict(
41
- image=handle_file(img_path),
42
- api_name="/predict"
43
- )
44
-
45
- # YOLO space returns (dict, image)
46
- yolo_output, _ = yolo_client.predict(
47
- image=handle_file(img_path),
48
- api_name="/predict"
49
- )
50
-
51
- # Extract best predictions
52
- resnet_label, resnet_conf = extract_best_prediction(resnet_output)
53
- yolo_label, yolo_conf = extract_best_prediction(yolo_output)
54
-
55
- # Choose best model
56
- if resnet_conf >= yolo_conf:
57
- final = {
58
- "chosen_model": "ResNet (crop_disease_detection)",
59
- "label": resnet_label,
60
- "confidence": resnet_conf,
61
- "full_output": resnet_output
 
 
 
62
  }
63
- text = f"Model Selected: ResNet\nPrediction: {resnet_label}\nConfidence: {resnet_conf:.4f}"
64
- else:
65
- final = {
66
- "chosen_model": "YOLO (cv_first)",
67
- "label": yolo_label,
68
- "confidence": yolo_conf,
69
- "full_output": yolo_output
70
- }
71
- text = f"Model Selected: YOLO\nPrediction: {yolo_label}\nConfidence: {yolo_conf:.4f}"
72
-
73
- return text, final
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  # Gradio UI
77
  with gr.Blocks() as demo:
78
- gr.Markdown("# 🌿 Crop Disease Classifier")
79
- gr.Markdown("Give an image to detect disease in crop, if any.")
80
 
81
  img = gr.Image(type="pil")
82
- text_output = gr.Textbox(label="Prediction")
83
- json_output = gr.JSON(label="Raw JSON Output")
84
 
85
  btn = gr.Button("Run Prediction")
86
- btn.click(fn=combined_predict, inputs=img, outputs=[text_output, json_output])
87
 
88
- demo.launch()
 
1
  import gradio as gr
2
  from gradio_client import Client, handle_file
3
  from PIL import Image
 
4
  import tempfile
5
+ import json
6
+ import base64
7
+ import io
8
+ import traceback
9
 
10
+ # Clients for the two external Spaces
11
  resnet_client = Client("raqiat123/crop_disease_detection")
12
  yolo_client = Client("SoraRyuu/cv_first")
13
 
14
+ def safe_load_json(maybe_json_str):
15
+ """Try to parse a JSON string, otherwise return original."""
16
+ if not isinstance(maybe_json_str, str):
17
+ return maybe_json_str
18
+ try:
19
+ return json.loads(maybe_json_str)
20
+ except Exception:
21
+ return maybe_json_str
22
 
23
+ def parse_model_response(resp):
24
  """
25
+ Normalize a response from gradio_client.predict into:
26
+ - primary: dict of {label: confidence} (or None)
27
+ - optional_image: a PIL.Image (or None)
28
+ - raw: original response (kept for debug)
29
+ Handles:
30
+ - dict
31
+ - JSON strings
32
+ - [dict, image], (dict, image)
33
+ - list where first element is dict
34
+ - base64 image strings (attempt decode)
35
  """
36
+ primary = None
37
+ optional_image = None
38
+ raw = resp
39
+
40
+ # If response is tuple/list, prioritize first element for dict, second for image
41
+ if isinstance(resp, (list, tuple)) and len(resp) > 0:
42
+ # try first element as dict-like
43
+ first = resp[0]
44
+ first = safe_load_json(first)
45
+ if isinstance(first, dict):
46
+ primary = first
47
+ # attempt to parse second element as image (base64 / bytes / PIL)
48
+ if len(resp) > 1:
49
+ second = resp[1]
50
+ # If second is already a PIL Image
51
+ if isinstance(second, Image.Image):
52
+ optional_image = second
53
+ # if second is bytes-like, try to open
54
+ elif isinstance(second, (bytes, bytearray)):
55
+ try:
56
+ optional_image = Image.open(io.BytesIO(second)).convert("RGB")
57
+ except Exception:
58
+ optional_image = None
59
+ # if second is base64 string
60
+ elif isinstance(second, str):
61
+ try:
62
+ # some Gradio endpoints return data URLs e.g. "data:image/png;base64,...."
63
+ if second.startswith("data:"):
64
+ header, b64 = second.split(",", 1)
65
+ decoded = base64.b64decode(b64)
66
+ optional_image = Image.open(io.BytesIO(decoded)).convert("RGB")
67
+ else:
68
+ decoded = base64.b64decode(second)
69
+ optional_image = Image.open(io.BytesIO(decoded)).convert("RGB")
70
+ except Exception:
71
+ optional_image = None
72
+ # If still no primary, maybe the first element was image and second is dict
73
+ if primary is None and len(resp) > 1:
74
+ candidate = safe_load_json(resp[1])
75
+ if isinstance(candidate, dict):
76
+ primary = candidate
77
+
78
+ # If resp itself is a dict
79
+ if primary is None:
80
+ r = safe_load_json(resp)
81
+ if isinstance(r, dict):
82
+ primary = r
83
+
84
+ # If still nothing, attempt to find a dict nested inside resp
85
+ if primary is None:
86
+ try:
87
+ # if it's a string that contains a JSON object somewhere
88
+ if isinstance(resp, str):
89
+ # try to find first "{" and parse
90
+ idx = resp.find("{")
91
+ if idx != -1:
92
+ candidate = safe_load_json(resp[idx:])
93
+ if isinstance(candidate, dict):
94
+ primary = candidate
95
+ except Exception:
96
+ pass
97
+
98
+ return primary, optional_image, raw
99
 
100
+ def extract_best_prediction(result_dict):
101
+ """Return (label, confidence) or (None, 0.0)"""
102
+ if not result_dict or not isinstance(result_dict, dict):
103
+ return None, 0.0
104
+ try:
105
+ best_label = max(result_dict, key=result_dict.get)
106
+ best_conf = float(result_dict[best_label])
107
+ return best_label, best_conf
108
+ except Exception:
109
+ # maybe values are strings that look like floats
110
+ try:
111
+ converted = {k: float(v) for k, v in result_dict.items()}
112
+ best_label = max(converted, key=converted.get)
113
+ return best_label, float(converted[best_label])
114
+ except Exception:
115
+ return None, 0.0
116
 
117
  def combined_predict(image_pil):
118
  """
119
+ image_pil: PIL.Image from Gradio
120
+ Returns: (text, json) where json contains debug info if error happened
121
  """
122
+ try:
123
+ # save to temp file for gradio_client
124
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
125
+ image_pil.save(tmp.name)
126
+ img_path = tmp.name
127
+
128
+ # 1) call resnet space
129
+ try:
130
+ resnet_raw = resnet_client.predict(image=handle_file(img_path), api_name="/predict")
131
+ except Exception as e:
132
+ resnet_raw = {"error": f"resnet predict call failed: {repr(e)}"}
133
+ # 2) call yolo space
134
+ try:
135
+ yolo_raw = yolo_client.predict(image=handle_file(img_path), api_name="/predict")
136
+ except Exception as e:
137
+ yolo_raw = {"error": f"yolo predict call failed: {repr(e)}"}
138
+
139
+ # parse responses
140
+ resnet_dict, resnet_img, resnet_rawstore = parse_model_response(resnet_raw)
141
+ yolo_dict, yolo_img, yolo_rawstore = parse_model_response(yolo_raw)
142
+
143
+ # extract bests
144
+ r_label, r_conf = extract_best_prediction(resnet_dict)
145
+ y_label, y_conf = extract_best_prediction(yolo_dict)
146
+
147
+ debug = {
148
+ "resnet_raw": resnet_rawstore,
149
+ "resnet_parsed_dict": resnet_dict,
150
+ "resnet_best": {"label": r_label, "confidence": r_conf},
151
+ "yolo_raw": yolo_rawstore,
152
+ "yolo_parsed_dict": yolo_dict,
153
+ "yolo_best": {"label": y_label, "confidence": y_conf},
154
  }
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ # Choose winner
157
+ if r_conf >= y_conf:
158
+ chosen = {
159
+ "chosen_model": "ResNet (crop_disease_detection)",
160
+ "label": r_label,
161
+ "confidence": r_conf,
162
+ "full_output": resnet_dict
163
+ }
164
+ text = f"Model Selected: ResNet\nPrediction: {r_label}\nConfidence: {r_conf:.4f}"
165
+ else:
166
+ chosen = {
167
+ "chosen_model": "YOLO (cv_first)",
168
+ "label": y_label,
169
+ "confidence": y_conf,
170
+ "full_output": yolo_dict
171
+ }
172
+ text = f"Model Selected: YOLO\nPrediction: {y_label}\nConfidence: {y_conf:.4f}"
173
+
174
+ # return text and a combined JSON containing debug + chosen
175
+ out_json = {"chosen": chosen, "debug": debug}
176
+ return text, out_json
177
+
178
+ except Exception as e:
179
+ tb = traceback.format_exc()
180
+ # Show the exception and stack trace in the UI for debugging
181
+ return ("❌ Internal error: " + str(e),
182
+ {"error": str(e), "traceback": tb})
183
 
184
  # Gradio UI
185
  with gr.Blocks() as demo:
186
+ gr.Markdown("# 🌿 Crop Disease Classifier (PIL)")
187
+ gr.Markdown("Uploads an image (PIL). Robust parsing & debug info included.")
188
 
189
  img = gr.Image(type="pil")
190
+ text_out = gr.Textbox(label="Final Prediction", lines=2)
191
+ json_out = gr.JSON(label="Raw Output (debug)")
192
 
193
  btn = gr.Button("Run Prediction")
194
+ btn.click(fn=combined_predict, inputs=img, outputs=[text_out, json_out])
195
 
196
+ demo.launch()