Spaces:
Runtime error
Runtime error
add app
Browse files
app.py
CHANGED
|
@@ -2,12 +2,25 @@ import tensorflow as tf
|
|
| 2 |
import numpy as np
|
| 3 |
from PIL import Image
|
| 4 |
import gradio as gr
|
|
|
|
| 5 |
|
| 6 |
# Load Keras model
|
| 7 |
model = tf.keras.models.load_model("car_brand_classifier_final.h5")
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
# Define the preprocessing function
|
| 10 |
def preprocess_image(image):
|
|
|
|
|
|
|
|
|
|
| 11 |
if isinstance(image, np.ndarray): # If Gradio gives a NumPy array, convert it to PIL Image
|
| 12 |
image = Image.fromarray(image)
|
| 13 |
|
|
@@ -19,24 +32,43 @@ def preprocess_image(image):
|
|
| 19 |
|
| 20 |
# Define the prediction function
|
| 21 |
def predict(image):
|
|
|
|
|
|
|
|
|
|
| 22 |
# Preprocess the image
|
| 23 |
processed_image = preprocess_image(image)
|
| 24 |
|
| 25 |
# Make a prediction
|
| 26 |
predictions = model.predict(processed_image)
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
# Return the
|
| 30 |
-
return
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
# Create the Gradio interface
|
| 33 |
iface = gr.Interface(
|
| 34 |
fn=predict,
|
| 35 |
inputs=gr.Image(type="numpy"), # Ensure Gradio passes a NumPy array
|
| 36 |
-
outputs=
|
|
|
|
|
|
|
|
|
|
| 37 |
title="Car Vision",
|
| 38 |
-
description="Upload an image of a car to classify its brand."
|
| 39 |
)
|
| 40 |
|
| 41 |
# Launch the app
|
| 42 |
-
iface.launch()
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
from PIL import Image
|
| 4 |
import gradio as gr
|
| 5 |
+
import os
|
| 6 |
|
| 7 |
# Load Keras model
|
| 8 |
model = tf.keras.models.load_model("car_brand_classifier_final.h5")
|
| 9 |
|
| 10 |
+
# Define the path to the dataset containing car brand folders
|
| 11 |
+
DATASET_PATH = "./Car_Sales_vision_ai_project"
|
| 12 |
+
|
| 13 |
+
# Get the list of car brands (folder names) in the dataset
|
| 14 |
+
CAR_BRANDS = sorted(os.listdir(DATASET_PATH))
|
| 15 |
+
|
| 16 |
+
# Ensure the CAR_BRANDS list contains only valid directories
|
| 17 |
+
CAR_BRANDS = [brand for brand in CAR_BRANDS if os.path.isdir(os.path.join(DATASET_PATH, brand))]
|
| 18 |
+
|
| 19 |
# Define the preprocessing function
|
| 20 |
def preprocess_image(image):
|
| 21 |
+
"""
|
| 22 |
+
Preprocess the input image for the model.
|
| 23 |
+
"""
|
| 24 |
if isinstance(image, np.ndarray): # If Gradio gives a NumPy array, convert it to PIL Image
|
| 25 |
image = Image.fromarray(image)
|
| 26 |
|
|
|
|
| 32 |
|
| 33 |
# Define the prediction function
|
| 34 |
def predict(image):
|
| 35 |
+
"""
|
| 36 |
+
Predict the car brand and return sample images from the corresponding folder.
|
| 37 |
+
"""
|
| 38 |
# Preprocess the image
|
| 39 |
processed_image = preprocess_image(image)
|
| 40 |
|
| 41 |
# Make a prediction
|
| 42 |
predictions = model.predict(processed_image)
|
| 43 |
+
predicted_class_index = np.argmax(predictions, axis=1)[0]
|
| 44 |
+
predicted_brand = CAR_BRANDS[predicted_class_index]
|
| 45 |
+
|
| 46 |
+
# Get sample images from the predicted brand folder
|
| 47 |
+
brand_folder = os.path.join(DATASET_PATH, predicted_brand)
|
| 48 |
+
sample_images = []
|
| 49 |
+
if os.path.exists(brand_folder):
|
| 50 |
+
for filename in os.listdir(brand_folder)[:5]: # Limit to 5 sample images
|
| 51 |
+
img_path = os.path.join(brand_folder, filename)
|
| 52 |
+
if os.path.isfile(img_path):
|
| 53 |
+
sample_images.append(Image.open(img_path).resize((200, 200))) # Resize for consistency
|
| 54 |
|
| 55 |
+
# Return the predicted brand and sample images
|
| 56 |
+
return {
|
| 57 |
+
"Predicted Brand": predicted_brand,
|
| 58 |
+
"Sample Images": sample_images or ["No images found for this brand."]
|
| 59 |
+
}
|
| 60 |
|
| 61 |
# Create the Gradio interface
|
| 62 |
iface = gr.Interface(
|
| 63 |
fn=predict,
|
| 64 |
inputs=gr.Image(type="numpy"), # Ensure Gradio passes a NumPy array
|
| 65 |
+
outputs=[
|
| 66 |
+
gr.Textbox(label="Predicted Brand"),
|
| 67 |
+
gr.Gallery(label="Sample Images").style(grid=[5], height="auto") # Display images in a gallery
|
| 68 |
+
],
|
| 69 |
title="Car Vision",
|
| 70 |
+
description="Upload an image of a car to classify its brand and view sample images."
|
| 71 |
)
|
| 72 |
|
| 73 |
# Launch the app
|
| 74 |
+
iface.launch()
|