import gradio as gr import numpy as np from PIL import Image import tensorflow as tf from typing import List, Dict import json import os # Get the absolute path of the directory containing the script script_dir = os.path.dirname(os.path.abspath(__file__)) # Construct the absolute path to the class_names.json file labels_path = os.path.join('class_names.json') # Load labels from the JSON file with open(labels_path, 'r') as f: LABELS: List[str] = json.load(f) def _load_image_to_rgb(image: Image.Image) -> np.ndarray: if image.mode != "RGB": image = image.convert("RGB") return np.asarray(image) def _resize_image(img_rgb: np.ndarray) -> np.ndarray: im = Image.fromarray(img_rgb) im = im.resize((256, 256), Image.NEAREST) return np.asarray(im) def _preprocess(image: Image.Image) -> np.ndarray: rgb = _load_image_to_rgb(image) rgb_resized = _resize_image(rgb) # shape [1,256,256,3], float32 in 0..255 arr = rgb_resized.astype("float32") return np.expand_dims(arr, axis=0) class PreTrainedModel: def __init__(self, model_path: str = "model/model_final_saved.keras") -> None: # Construct the absolute path to the model file abs_model_path = os.path.join(script_dir, model_path) self.model = tf.keras.models.load_model(abs_model_path) def predict_image(self, image: Image.Image) -> Dict[str, float]: x = _preprocess(image) preds = self.model.predict(x) if isinstance(preds, (list, tuple)): preds = preds[0] probs = np.asarray(preds).squeeze().tolist() return {label: score for label, score in zip(LABELS, probs)} model = PreTrainedModel() def predict(image): predictions = model.predict_image(image) probs_percent = {label: round(p * 100, 2) for label, p in predictions.items()} max_label = max(probs_percent, key=probs_percent.get) return { "label": max_label, "percentage": probs_percent[max_label], "probabilities": probs_percent, } iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.JSON(), title="Flower Classification", description="Upload an image of a flower to classify it.", ) if __name__ == "__main__": iface.launch()