File size: 1,781 Bytes
217c3af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10d4547
217c3af
 
 
 
 
 
 
10d4547
 
 
 
 
 
 
 
 
 
 
217c3af
 
 
 
10d4547
217c3af
 
 
 
 
 
 
 
 
10d4547
 
217c3af
 
 
 
10d4547
 
217c3af
10d4547
217c3af
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from __future__ import annotations

import io
from typing import Any, Dict, List

import numpy as np
from PIL import Image
import tensorflow as tf

LABELS: List[str] = [
    "Heart",
    "Oblong",
    "Oval",
    "Round",
    "Square",
]

TARGET_SIZE = 244


def _load_image(image_bytes: bytes) -> Image.Image:
    image = Image.open(io.BytesIO(image_bytes))
    if image.mode != "RGB":
        image = image.convert("RGB")
    return image


def _ensure_three_channels(array: np.ndarray) -> np.ndarray:
    if array.ndim == 2:
        array = np.stack([array] * 3, axis=-1)
    elif array.ndim == 3:
        if array.shape[-1] == 1:
            array = np.repeat(array, 3, axis=-1)
        elif array.shape[-1] > 3:
            array = array[..., :3]
    return array


def _preprocess(image_bytes: bytes) -> np.ndarray:
    image = _load_image(image_bytes)
    resized = image.resize((TARGET_SIZE, TARGET_SIZE), Image.BILINEAR)
    array = np.asarray(resized, dtype="float32")
    array = _ensure_three_channels(array)
    array /= 255.0
    return np.expand_dims(array, axis=0)


class PreTrainedModel:
    def __init__(self, model_path: str = "model/best_model.keras") -> None:
        self.model = tf.keras.models.load_model(model_path)

    def predict(self, inputs: bytes) -> List[Dict[str, Any]]:
        x = _preprocess(inputs)
        preds = self.model.predict(x, verbose=0)

        if isinstance(preds, (list, tuple)):
            preds = preds[0]

        probs = np.asarray(preds).squeeze().tolist()
        idx = int(np.argmax(probs))
        return [
            {"label": LABELS[idx], "score": float(probs[idx])},
        ]


def load_model(model_dir: str = ".") -> PreTrainedModel:
    return PreTrainedModel(model_path=f"{model_dir}/model/best_model.keras")