sidd-harth011 commited on
Commit
4030d18
·
1 Parent(s): 214cd0e
Files changed (1) hide show
  1. app.py +23 -13
app.py CHANGED
@@ -4,10 +4,15 @@ from PIL import Image
4
  import numpy as np
5
  import tensorflow as tf
6
  import gradio as gr
 
7
 
8
- # Paths
9
- working_dir = os.path.dirname(os.path.abspath(__file__))
10
- model_path = f"{working_dir}/trained_model/plant_disease_model.tflite"
 
 
 
 
11
 
12
  # Load TFLite model
13
  interpreter = tf.lite.Interpreter(model_path=model_path)
@@ -17,10 +22,16 @@ interpreter.allocate_tensors()
17
  input_details = interpreter.get_input_details()
18
  output_details = interpreter.get_output_details()
19
 
20
- # Load class indices
21
- class_indices = json.load(open(f"{working_dir}/class_indices.json"))
 
 
 
 
22
 
23
- # Function to preprocess the image
 
 
24
  def load_and_preprocess_image(image, target_size=(224, 224)):
25
  img = image.resize(target_size)
26
  img_array = np.array(img, dtype=np.float32)
@@ -28,29 +39,28 @@ def load_and_preprocess_image(image, target_size=(224, 224)):
28
  img_array = img_array / 255.0
29
  return img_array
30
 
 
31
  # Prediction function
 
32
  def predict_image_class(image):
33
  preprocessed_img = load_and_preprocess_image(image)
34
-
35
- # Set input tensor
36
  interpreter.set_tensor(input_details[0]['index'], preprocessed_img)
37
  interpreter.invoke()
38
-
39
- # Get predictions
40
  predictions = interpreter.get_tensor(output_details[0]['index'])
41
  predicted_class_index = np.argmax(predictions, axis=1)[0]
42
  predicted_class_name = class_indices[str(predicted_class_index)]
43
-
44
  return f"Prediction: {predicted_class_name}"
45
 
 
46
  # Gradio Interface
 
47
  interface = gr.Interface(
48
  fn=predict_image_class,
49
  inputs=gr.Image(type="pil", label="Upload an Image"),
50
  outputs=gr.Textbox(label="Prediction"),
51
  title="🌱 Plant Disease Classifier (TFLite)",
52
- description="Upload a plant leaf image to classify its disease using a compressed TFLite model."
53
  )
54
 
55
  if __name__ == "__main__":
56
- interface.launch()
 
4
  import numpy as np
5
  import tensorflow as tf
6
  import gradio as gr
7
+ from huggingface_hub import hf_hub_download
8
 
9
+ # -----------------------------
10
+ # Download TFLite model from Hugging Face model repo
11
+ # -----------------------------
12
+ model_path = hf_hub_download(
13
+ repo_id="sidd-harth011/checkingPDRMod", # your model repo
14
+ filename="plant_disease_model.tflite"
15
+ )
16
 
17
  # Load TFLite model
18
  interpreter = tf.lite.Interpreter(model_path=model_path)
 
22
  input_details = interpreter.get_input_details()
23
  output_details = interpreter.get_output_details()
24
 
25
+ # Download/load class indices
26
+ class_indices_path = hf_hub_download(
27
+ repo_id="sidd-harth011/checkingPDRMod",
28
+ filename="class_indices.json"
29
+ )
30
+ class_indices = json.load(open(class_indices_path))
31
 
32
+ # -----------------------------
33
+ # Preprocessing function
34
+ # -----------------------------
35
  def load_and_preprocess_image(image, target_size=(224, 224)):
36
  img = image.resize(target_size)
37
  img_array = np.array(img, dtype=np.float32)
 
39
  img_array = img_array / 255.0
40
  return img_array
41
 
42
+ # -----------------------------
43
  # Prediction function
44
+ # -----------------------------
45
  def predict_image_class(image):
46
  preprocessed_img = load_and_preprocess_image(image)
 
 
47
  interpreter.set_tensor(input_details[0]['index'], preprocessed_img)
48
  interpreter.invoke()
 
 
49
  predictions = interpreter.get_tensor(output_details[0]['index'])
50
  predicted_class_index = np.argmax(predictions, axis=1)[0]
51
  predicted_class_name = class_indices[str(predicted_class_index)]
 
52
  return f"Prediction: {predicted_class_name}"
53
 
54
+ # -----------------------------
55
  # Gradio Interface
56
+ # -----------------------------
57
  interface = gr.Interface(
58
  fn=predict_image_class,
59
  inputs=gr.Image(type="pil", label="Upload an Image"),
60
  outputs=gr.Textbox(label="Prediction"),
61
  title="🌱 Plant Disease Classifier (TFLite)",
62
+ description="Upload a plant leaf image to classify its disease using a compressed TFLite model hosted on Hugging Face."
63
  )
64
 
65
  if __name__ == "__main__":
66
+ interface.launch()