File size: 3,956 Bytes
5966cb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import tensorflow as tf
import numpy as np
import gradio as gr
import json
from PIL import Image, ImageDraw, ImageFont
import os

def load_model(model_path):
    # Load your TensorFlow model
    return tf.keras.models.load_model(model_path)

def preprocess_image(image, target_size):
    # Resize and normalize the image
    image = image.resize(target_size)
    image = np.array(image) / 255.0  # Normalize to [0, 1]
    image = np.expand_dims(image, axis=0)  # Add batch dimension
    return image

def predict_top_10(image, model, class_indices):
    # Preprocess the image
    input_image = preprocess_image(image, target_size=(224, 224))

    # Get predictions
    predictions = model.predict(input_image)[0]  # Assuming single image batch
   # print("Model predictions:", predictions)  # Debugging line

    # Get top 10 predictions
    top_indices = np.argsort(predictions)[::-1][:10]  # Sort descending, take top 10
    top_probs = predictions[top_indices]
    top_classes = [class_indices[i] for i in top_indices]

    output = [(top_classes[i], top_probs[i]) for i in range(10)]
    print("Top 10 predictions:", output)  # Debugging line

    return output

def visualize_predictions(predictions, class_image_dir):
    print("Received predictions:", predictions)  # Debugging to check what's coming in
    output_images = []
    # Load the font once outside the loop to avoid redundant loading
    font = ImageFont.load_default()  # or ImageFont.truetype("./arial.ttf", 20)

    for class_name, prob in predictions:
        class_image_path = os.path.join(class_image_dir, f"{class_name}.jpg")
        try:
            class_image = Image.open(class_image_path).convert("RGB")
            # Resize image if necessary
            class_image = class_image.resize((300, 300))  
            draw = ImageDraw.Draw(class_image)
            text = f"{class_name}: {prob:.2%}"
            draw.text((10, 10), text, fill="black", font=font)
            output_images.append(class_image)
        except FileNotFoundError:
            blank_image = Image.new("RGB", (300, 300), "black")
            draw = ImageDraw.Draw(blank_image)
            draw.text((10, 10), f"{class_name}: {prob:.2%} (Image not found)", fill="black", font=font)
            output_images.append(blank_image)
        except Exception as e:
            print(f"Error processing {class_name}: {e}")
            blank_image = Image.new("RGB", (300, 300), "black")
            draw = ImageDraw.Draw(blank_image)
            draw.text((10, 10), f"{class_name}: {prob:.2%} (Error: {str(e)})", fill="black", font=font)
            output_images.append(blank_image)
    return output_images

# Load model and class indices
model = load_model("owned_coins.keras")  # Update with your model's file path
with open("class_indices.json", "r") as f:
    class_indices = json.load(f)


# Reverse the dictionary
class_indices = {v: k for k, v in class_indices.items()}

# Specify the directory containing class images
class_image_dir = "downloaded_images"   # Update with the actual path to your image directory

def gradio_predict(image):
    # Get top 10 predictions
    predictions = predict_top_10(image, model, class_indices)
    print("Predictions:", predictions)  # Debugging line

    # Visualize predictions with class images
    result_images = visualize_predictions(predictions, class_image_dir)
    return result_images



# Create Gradio interface
interface = gr.Interface(
    fn=gradio_predict,
    inputs=gr.Image(type="pil", label="Upload an Image"),
    outputs=gr.Gallery(label="Top 10 Predictions"),
    title="Multiclass Image Classifier",
    description="Upload an image to see the top 10 predictions with class images and probabilities."
)

# Launch locally and push to Hugging Face
if __name__ == "__main__":
    interface.launch(share=True)
    # Uncomment the following line to push to Hugging Face (requires Hugging Face credentials)
    # interface.launch(share=True)