courte commited on
Commit
f7da8d7
·
verified ·
1 Parent(s): d40745f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -21
app.py CHANGED
@@ -2,43 +2,37 @@ import tensorflow as tf
2
  import numpy as np
3
  from PIL import Image
4
  import gradio as gr
5
- import os
6
 
7
- # Load the model
8
- model = tf.keras.models.load_model("car_brand_classifier_final.h5", compile=False)
9
-
10
- # Define image directory where car brand images are stored
11
- IMAGE_DIR = "car_brands" # Make sure this folder exists with images named as class labels
12
 
13
  # Define the preprocessing function
14
- def preprocess_image(image_path):
15
- image = Image.open(image_path) # Open image from path
16
- image = image.resize((299, 299)) # Resize to model input size
17
  image = np.array(image) / 255.0 # Normalize pixel values
18
  image = np.expand_dims(image, axis=0) # Add batch dimension
19
  return image
20
 
21
- # Prediction function
22
  def predict(image):
 
23
  processed_image = preprocess_image(image)
24
- predictions = model.predict(processed_image)
25
- predicted_class = np.argmax(predictions, axis=1)[0] # Get predicted class index
26
 
27
- # Find corresponding image for predicted class
28
- matching_image_path = os.path.join(IMAGE_DIR, f"{predicted_class}.jpg")
 
29
 
30
- if os.path.exists(matching_image_path):
31
- return matching_image_path # Return image file path
32
- else:
33
- return "No matching image found"
34
 
35
- # Gradio Interface
36
  iface = gr.Interface(
37
  fn=predict,
38
  inputs="image",
39
- outputs="image", # Change output to "image"
40
  title="Car Vision",
41
- description="Upload an image of a car, and get a matching brand image.",
42
  )
43
 
 
44
  iface.launch()
 
2
  import numpy as np
3
  from PIL import Image
4
  import gradio as gr
 
5
 
6
+ # Load the Keras model
7
+ model = tf.keras.models.load_model("car_brand_classifier_final.h5")
 
 
 
8
 
9
  # Define the preprocessing function
10
+ def preprocess_image(image):
11
+ image = image.resize((299, 299)) # Resize to match model input size
 
12
  image = np.array(image) / 255.0 # Normalize pixel values
13
  image = np.expand_dims(image, axis=0) # Add batch dimension
14
  return image
15
 
16
+ # Define the prediction function
17
  def predict(image):
18
+ # Preprocess the image
19
  processed_image = preprocess_image(image)
 
 
20
 
21
+ # Make a prediction
22
+ predictions = model.predict(processed_image)
23
+ predicted_class = np.argmax(predictions, axis=1)[0]
24
 
25
+ # Return the result
26
+ return f"Predicted class: {predicted_class}"
 
 
27
 
28
+ # Create the Gradio interface
29
  iface = gr.Interface(
30
  fn=predict,
31
  inputs="image",
32
+ outputs="text",
33
  title="Car Vision",
34
+ description="Upload an image of a car to classify its brand."
35
  )
36
 
37
+ # Launch the app
38
  iface.launch()