imeesam commited on
Commit
03e6b2e
·
verified ·
1 Parent(s): b956dea

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +12 -4
inference.py CHANGED
@@ -1,24 +1,32 @@
1
  from PIL import Image
2
  import tensorflow as tf
3
  import numpy as np
 
 
4
 
5
- # Load the TFLite model
6
  interpreter = tf.lite.Interpreter(model_path="leaf_model_85_percent.tflite")
7
  interpreter.allocate_tensors()
8
  input_details = interpreter.get_input_details()
9
  output_details = interpreter.get_output_details()
10
 
11
- # Define preprocess and inference
12
  def preprocess(image: Image.Image):
13
  img = image.resize((224, 224)).convert("RGB")
14
  img_array = np.array(img) / 255.0
15
  img_array = np.expand_dims(img_array, axis=0).astype(np.float32)
16
  return img_array
17
 
18
- def predict(image: Image.Image):
 
19
  input_tensor = preprocess(image)
20
  interpreter.set_tensor(input_details[0]['index'], input_tensor)
21
  interpreter.invoke()
22
  output = interpreter.get_tensor(output_details[0]['index'])[0][0]
23
  label = "Unhealthy" if output > 0.5 else "Healthy"
24
- return [{"label": label, "score": float(output)}]
 
 
 
 
 
 
 
1
  from PIL import Image
2
  import tensorflow as tf
3
  import numpy as np
4
+ import io
5
+ import base64
6
 
7
+ # Load the TFLite model once
8
  interpreter = tf.lite.Interpreter(model_path="leaf_model_85_percent.tflite")
9
  interpreter.allocate_tensors()
10
  input_details = interpreter.get_input_details()
11
  output_details = interpreter.get_output_details()
12
 
 
13
  def preprocess(image: Image.Image):
14
  img = image.resize((224, 224)).convert("RGB")
15
  img_array = np.array(img) / 255.0
16
  img_array = np.expand_dims(img_array, axis=0).astype(np.float32)
17
  return img_array
18
 
19
+ def classify(image_bytes: bytes):
20
+ image = Image.open(io.BytesIO(image_bytes))
21
  input_tensor = preprocess(image)
22
  interpreter.set_tensor(input_details[0]['index'], input_tensor)
23
  interpreter.invoke()
24
  output = interpreter.get_tensor(output_details[0]['index'])[0][0]
25
  label = "Unhealthy" if output > 0.5 else "Healthy"
26
+ return {"label": label, "score": float(output)}
27
+
28
+ # This is the function HF Spaces calls to do inference
29
+ def inference(payload):
30
+ # payload will have image data in base64 or bytes format, depends on your API input
31
+ image_bytes = base64.b64decode(payload["data"][0].split(",")[1]) # Assuming base64 input
32
+ return classify(image_bytes)