courte commited on
Commit
416498a
·
1 Parent(s): 263deba
Files changed (1) hide show
  1. app.py +38 -6
app.py CHANGED
@@ -2,12 +2,25 @@ import tensorflow as tf
2
  import numpy as np
3
  from PIL import Image
4
  import gradio as gr
 
5
 
6
  # Load Keras model
7
  model = tf.keras.models.load_model("car_brand_classifier_final.h5")
8
 
 
 
 
 
 
 
 
 
 
9
  # Define the preprocessing function
10
  def preprocess_image(image):
 
 
 
11
  if isinstance(image, np.ndarray): # If Gradio gives a NumPy array, convert it to PIL Image
12
  image = Image.fromarray(image)
13
 
@@ -19,24 +32,43 @@ def preprocess_image(image):
19
 
20
  # Define the prediction function
21
  def predict(image):
 
 
 
22
  # Preprocess the image
23
  processed_image = preprocess_image(image)
24
 
25
  # Make a prediction
26
  predictions = model.predict(processed_image)
27
- predicted_class = np.argmax(predictions, axis=1)[0]
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Return the result
30
- return f"Predicted class: {predicted_class}"
 
 
 
31
 
32
  # Create the Gradio interface
33
  iface = gr.Interface(
34
  fn=predict,
35
  inputs=gr.Image(type="numpy"), # Ensure Gradio passes a NumPy array
36
- outputs="text",
 
 
 
37
  title="Car Vision",
38
- description="Upload an image of a car to classify its brand."
39
  )
40
 
41
  # Launch the app
42
- iface.launch()
 
2
  import numpy as np
3
  from PIL import Image
4
  import gradio as gr
5
+ import os
6
 
7
  # Load Keras model
8
  model = tf.keras.models.load_model("car_brand_classifier_final.h5")
9
 
10
+ # Define the path to the dataset containing car brand folders
11
+ DATASET_PATH = "./Car_Sales_vision_ai_project"
12
+
13
+ # Get the list of car brands (folder names) in the dataset
14
+ CAR_BRANDS = sorted(os.listdir(DATASET_PATH))
15
+
16
+ # Ensure the CAR_BRANDS list contains only valid directories
17
+ CAR_BRANDS = [brand for brand in CAR_BRANDS if os.path.isdir(os.path.join(DATASET_PATH, brand))]
18
+
19
  # Define the preprocessing function
20
  def preprocess_image(image):
21
+ """
22
+ Preprocess the input image for the model.
23
+ """
24
  if isinstance(image, np.ndarray): # If Gradio gives a NumPy array, convert it to PIL Image
25
  image = Image.fromarray(image)
26
 
 
32
 
33
  # Define the prediction function
34
  def predict(image):
35
+ """
36
+ Predict the car brand and return sample images from the corresponding folder.
37
+ """
38
  # Preprocess the image
39
  processed_image = preprocess_image(image)
40
 
41
  # Make a prediction
42
  predictions = model.predict(processed_image)
43
+ predicted_class_index = np.argmax(predictions, axis=1)[0]
44
+ predicted_brand = CAR_BRANDS[predicted_class_index]
45
+
46
+ # Get sample images from the predicted brand folder
47
+ brand_folder = os.path.join(DATASET_PATH, predicted_brand)
48
+ sample_images = []
49
+ if os.path.exists(brand_folder):
50
+ for filename in os.listdir(brand_folder)[:5]: # Limit to 5 sample images
51
+ img_path = os.path.join(brand_folder, filename)
52
+ if os.path.isfile(img_path):
53
+ sample_images.append(Image.open(img_path).resize((200, 200))) # Resize for consistency
54
 
55
+ # Return the predicted brand and sample images
56
+ return {
57
+ "Predicted Brand": predicted_brand,
58
+ "Sample Images": sample_images or ["No images found for this brand."]
59
+ }
60
 
61
  # Create the Gradio interface
62
  iface = gr.Interface(
63
  fn=predict,
64
  inputs=gr.Image(type="numpy"), # Ensure Gradio passes a NumPy array
65
+ outputs=[
66
+ gr.Textbox(label="Predicted Brand"),
67
+ gr.Gallery(label="Sample Images").style(grid=[5], height="auto") # Display images in a gallery
68
+ ],
69
  title="Car Vision",
70
+ description="Upload an image of a car to classify its brand and view sample images."
71
  )
72
 
73
  # Launch the app
74
+ iface.launch()