OlejnikM commited on
Commit
5966cb5
·
verified ·
1 Parent(s): 74afaa6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ import gradio as gr
4
+ import json
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ import os
7
+
8
+ def load_model(model_path):
9
+ # Load your TensorFlow model
10
+ return tf.keras.models.load_model(model_path)
11
+
12
+ def preprocess_image(image, target_size):
13
+ # Resize and normalize the image
14
+ image = image.resize(target_size)
15
+ image = np.array(image) / 255.0 # Normalize to [0, 1]
16
+ image = np.expand_dims(image, axis=0) # Add batch dimension
17
+ return image
18
+
19
+ def predict_top_10(image, model, class_indices):
20
+ # Preprocess the image
21
+ input_image = preprocess_image(image, target_size=(224, 224))
22
+
23
+ # Get predictions
24
+ predictions = model.predict(input_image)[0] # Assuming single image batch
25
+ # print("Model predictions:", predictions) # Debugging line
26
+
27
+ # Get top 10 predictions
28
+ top_indices = np.argsort(predictions)[::-1][:10] # Sort descending, take top 10
29
+ top_probs = predictions[top_indices]
30
+ top_classes = [class_indices[i] for i in top_indices]
31
+
32
+ output = [(top_classes[i], top_probs[i]) for i in range(10)]
33
+ print("Top 10 predictions:", output) # Debugging line
34
+
35
+ return output
36
+
37
+ def visualize_predictions(predictions, class_image_dir):
38
+ print("Received predictions:", predictions) # Debugging to check what's coming in
39
+ output_images = []
40
+ # Load the font once outside the loop to avoid redundant loading
41
+ font = ImageFont.load_default() # or ImageFont.truetype("./arial.ttf", 20)
42
+
43
+ for class_name, prob in predictions:
44
+ class_image_path = os.path.join(class_image_dir, f"{class_name}.jpg")
45
+ try:
46
+ class_image = Image.open(class_image_path).convert("RGB")
47
+ # Resize image if necessary
48
+ class_image = class_image.resize((300, 300))
49
+ draw = ImageDraw.Draw(class_image)
50
+ text = f"{class_name}: {prob:.2%}"
51
+ draw.text((10, 10), text, fill="black", font=font)
52
+ output_images.append(class_image)
53
+ except FileNotFoundError:
54
+ blank_image = Image.new("RGB", (300, 300), "black")
55
+ draw = ImageDraw.Draw(blank_image)
56
+ draw.text((10, 10), f"{class_name}: {prob:.2%} (Image not found)", fill="black", font=font)
57
+ output_images.append(blank_image)
58
+ except Exception as e:
59
+ print(f"Error processing {class_name}: {e}")
60
+ blank_image = Image.new("RGB", (300, 300), "black")
61
+ draw = ImageDraw.Draw(blank_image)
62
+ draw.text((10, 10), f"{class_name}: {prob:.2%} (Error: {str(e)})", fill="black", font=font)
63
+ output_images.append(blank_image)
64
+ return output_images
65
+
66
+ # Load model and class indices
67
+ model = load_model("owned_coins.keras") # Update with your model's file path
68
+ with open("class_indices.json", "r") as f:
69
+ class_indices = json.load(f)
70
+
71
+
72
+ # Reverse the dictionary
73
+ class_indices = {v: k for k, v in class_indices.items()}
74
+
75
+ # Specify the directory containing class images
76
+ class_image_dir = "downloaded_images" # Update with the actual path to your image directory
77
+
78
+ def gradio_predict(image):
79
+ # Get top 10 predictions
80
+ predictions = predict_top_10(image, model, class_indices)
81
+ print("Predictions:", predictions) # Debugging line
82
+
83
+ # Visualize predictions with class images
84
+ result_images = visualize_predictions(predictions, class_image_dir)
85
+ return result_images
86
+
87
+
88
+
89
+ # Create Gradio interface
90
+ interface = gr.Interface(
91
+ fn=gradio_predict,
92
+ inputs=gr.Image(type="pil", label="Upload an Image"),
93
+ outputs=gr.Gallery(label="Top 10 Predictions"),
94
+ title="Multiclass Image Classifier",
95
+ description="Upload an image to see the top 10 predictions with class images and probabilities."
96
+ )
97
+
98
+ # Launch locally and push to Hugging Face
99
+ if __name__ == "__main__":
100
+ interface.launch(share=True)
101
+ # Uncomment the following line to push to Hugging Face (requires Hugging Face credentials)
102
+ # interface.launch(share=True)