| | from __future__ import annotations |
| |
|
| | import logging |
| | import os |
| | from typing import Dict, List, Optional |
| |
|
| | import gradio as gr |
| | import numpy as np |
| | import tensorflow as tf |
| | from PIL import Image |
| |
|
| | from config import settings |
| |
|
| | os.environ.setdefault("GRADIO_SERVER_QUEUE_ENABLED", "0") |
| |
|
| | logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s %(message)s") |
| | logger = logging.getLogger(__name__) |
| |
|
| | IMG_SIZE = 244 |
| |
|
| |
|
| | def _ensure_three_channels(array: np.ndarray) -> np.ndarray: |
| | if array.ndim == 2: |
| | array = np.stack([array] * 3, axis=-1) |
| | elif array.ndim == 3: |
| | if array.shape[-1] == 1: |
| | array = np.repeat(array, 3, axis=-1) |
| | elif array.shape[-1] > 3: |
| | array = array[..., :3] |
| | return array |
| |
|
| |
|
| | class FaceShapeModel: |
| | def __init__(self, model_path: str, labels: List[str]): |
| | if not os.path.exists(model_path): |
| | raise FileNotFoundError(f"Model file not found at: {model_path}") |
| |
|
| | self.labels = labels |
| | logger.info("Loading TensorFlow model from %s", model_path) |
| | self.model = tf.keras.models.load_model(model_path) |
| | logger.info("Model loaded successfully with %d labels", len(labels)) |
| |
|
| | @staticmethod |
| | def _preprocess(image: Image.Image) -> np.ndarray: |
| | if image.mode != "RGB": |
| | image = image.convert("RGB") |
| |
|
| | resized = image.resize((IMG_SIZE, IMG_SIZE), Image.BILINEAR) |
| | array = np.asarray(resized, dtype="float32") |
| | array = _ensure_three_channels(array) |
| | array /= 255.0 |
| | array = np.expand_dims(array, axis=0) |
| | return array |
| |
|
| | def predict_image(self, image: Image.Image) -> Dict[str, float]: |
| | batch = self._preprocess(image) |
| | preds = self.model.predict(batch, verbose=0) |
| |
|
| | if isinstance(preds, (list, tuple)): |
| | preds = preds[0] |
| |
|
| | scores = np.asarray(preds).squeeze() |
| |
|
| | if scores.ndim == 0: |
| | scores = np.array([float(scores)]) |
| |
|
| | if len(scores) != len(self.labels): |
| | raise ValueError( |
| | "Model output length does not match labels. " |
| | f"Expected {len(self.labels)} values, got {len(scores)}." |
| | ) |
| |
|
| | return {label: float(score) for label, score in zip(self.labels, scores.tolist())} |
| |
|
| |
|
| | _model: Optional[FaceShapeModel] = None |
| |
|
| |
|
| | def get_model() -> FaceShapeModel: |
| | global _model |
| | if _model is None: |
| | _model = FaceShapeModel(settings.model_path, settings.labels) |
| | return _model |
| |
|
| |
|
| | def predict(image: Image.Image) -> Dict[str, float]: |
| | try: |
| | model = get_model() |
| | except Exception as exc: |
| | logger.exception("Failed to load model") |
| | raise gr.Error(f"Model gagal dimuat: {exc}") from exc |
| |
|
| | try: |
| | return model.predict_image(image) |
| | except Exception as exc: |
| | logger.exception("Prediction failed") |
| | raise gr.Error(f"Prediksi gagal: {exc}") from exc |
| |
|
| |
|
| | def build_interface() -> gr.Interface: |
| | return gr.Interface( |
| | fn=predict, |
| | inputs=gr.Image(type="pil", image_mode="RGB"), |
| | outputs=gr.Label(num_top_classes=3), |
| | title="Face Shape Detection", |
| | description="Unggah foto wajah untuk mendeteksi bentuk wajah Anda menggunakan model TensorFlow.", |
| | allow_flagging="never", |
| | ) |
| |
|
| |
|
| | def launch_app(): |
| | iface = build_interface() |
| |
|
| | launch_args: Dict[str, object] = { |
| | "server_name": "0.0.0.0", |
| | "server_port": settings.port, |
| | "share": settings.share, |
| | "show_api": True, |
| | } |
| |
|
| | if settings.gradio_username and settings.gradio_password: |
| | launch_args["auth"] = (settings.gradio_username, settings.gradio_password) |
| | launch_args["auth_message"] = "Masukkan kredensial untuk mengakses demo" |
| |
|
| | iface.launch(**launch_args) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | try: |
| | launch_app() |
| | except Exception: |
| | logger.exception("Gradio application terminated due to an error") |
| | raise |
| |
|