Kresna commited on
Commit
a0e0898
·
1 Parent(s): f3b95cd

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +48 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from PIL import Image
5
+ import requests
6
+ from io import BytesIO
7
+
8
+ # Load the model
9
+ model = tf.keras.models.load_model('Nutrient-Model.h5')
10
+
11
+ # Define the class names
12
+ class_names = ['Iron', 'Magnesium', 'Nitrogen', 'Potassium', 'Zinc']
13
+
14
+ def classify_image(image):
15
+ # Convert the numpy array to a PIL Image object
16
+ pil_image = Image.fromarray(np.uint8(image)).convert('RGB')
17
+
18
+ # Resize the image
19
+ pil_image = pil_image.resize((224, 224))
20
+
21
+ # Convert the PIL Image object to a numpy array
22
+ image_array = np.array(pil_image)
23
+
24
+ # Normalize the image
25
+ normalized_image_array = (image_array.astype(np.float32) / 255.0)
26
+
27
+ # Reshape the image
28
+ data = normalized_image_array.reshape((1, 224, 224, 3))
29
+
30
+ # Make the prediction
31
+ prediction = model.predict(data)[0]
32
+
33
+ # Get the predicted class name
34
+ predicted_class = class_names[np.argmax(prediction)]
35
+
36
+ # Get the confidence score for the predicted class
37
+ confidence_score = np.max(prediction)
38
+
39
+ # Return the predicted class and confidence score
40
+ return f"{predicted_class} ({confidence_score*100:.2f}%)"
41
+
42
+ # Define the Gradio interface
43
+ image_input = gr.inputs.Image()
44
+ output_text = gr.outputs.Textbox()
45
+ gr.Interface(fn=classify_image, inputs=image_input, outputs=output_text, title="Image Classification", description="Classify an image into one of five classes: Iron, Magnesium, Nitrogen, Potassium, or Zinc."
46
+ ).launch(share=True)
47
+
48
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==3.28.1
2
+ numpy==1.23.4
3
+ Pillow==9.1.1
4
+ Pillow==9.5.0
5
+ Requests==2.29.0
6
+ tensorflow==2.10.0
7
+ tensorflow_gpu==2.10.1