Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import tensorflow as tf | |
| from typing import List, Dict | |
| import json | |
| import os | |
| # Get the absolute path of the directory containing the script | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| # Construct the absolute path to the class_names.json file | |
| labels_path = os.path.join('class_names.json') | |
| # Load labels from the JSON file | |
| with open(labels_path, 'r') as f: | |
| LABELS: List[str] = json.load(f) | |
| def _load_image_to_rgb(image: Image.Image) -> np.ndarray: | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| return np.asarray(image) | |
| def _resize_image(img_rgb: np.ndarray) -> np.ndarray: | |
| im = Image.fromarray(img_rgb) | |
| im = im.resize((256, 256), Image.NEAREST) | |
| return np.asarray(im) | |
| def _preprocess(image: Image.Image) -> np.ndarray: | |
| rgb = _load_image_to_rgb(image) | |
| rgb_resized = _resize_image(rgb) | |
| # shape [1,256,256,3], float32 in 0..255 | |
| arr = rgb_resized.astype("float32") | |
| return np.expand_dims(arr, axis=0) | |
| class PreTrainedModel: | |
| def __init__(self, model_path: str = "model/model_final_saved.keras") -> None: | |
| # Construct the absolute path to the model file | |
| abs_model_path = os.path.join(script_dir, model_path) | |
| self.model = tf.keras.models.load_model(abs_model_path) | |
| def predict_image(self, image: Image.Image) -> Dict[str, float]: | |
| x = _preprocess(image) | |
| preds = self.model.predict(x) | |
| if isinstance(preds, (list, tuple)): | |
| preds = preds[0] | |
| probs = np.asarray(preds).squeeze().tolist() | |
| return {label: score for label, score in zip(LABELS, probs)} | |
| model = PreTrainedModel() | |
| def predict(image): | |
| predictions = model.predict_image(image) | |
| probs_percent = {label: round(p * 100, 2) | |
| for label, p in predictions.items()} | |
| max_label = max(probs_percent, key=probs_percent.get) | |
| return { | |
| "label": max_label, | |
| "percentage": probs_percent[max_label], | |
| "probabilities": probs_percent, | |
| } | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.JSON(), | |
| title="Flower Classification", | |
| description="Upload an image of a flower to classify it.", | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |