File size: 2,111 Bytes
e916c8e
 
 
 
 
 
 
 
 
 
 
ec56169
e916c8e
 
 
 
 
 
 
 
 
 
 
ec56169
e916c8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import numpy as np
import tensorflow as tf
from huggingface_hub import hf_hub_download
from nets import get_model_from_name
from utils.utils import cvtColor, get_classes, letterbox_image, preprocess_input
import tempfile

class Classification:
    def __init__(self, model_choice):
        self.model_choice = model_choice
        self.classes_path = "src/cls_classes.txt"
        self.input_shape = (224, 224)
        self.alpha = 0.25

        cache_dir = os.path.join(tempfile.gettempdir(), "hf_cache")
        os.makedirs(cache_dir, exist_ok=True)
        self.model_path = hf_hub_download(
            repo_id="sudo-paras-shah/micro-expression-casme2",
            filename="ep097.weights.h5" if self.model_choice is "mobilenet" else "ep089.weights.h5",
            cache_dir=cache_dir
        )

        # Load class names and model
        self.class_names, self.num_classes = get_classes(self.classes_path)
        self.load_model()

    def load_model(self):
        if self.model_choice == "mobilenet":
            self.model = get_model_from_name[self.model_choice](
                input_shape=[self.input_shape[0], self.input_shape[1], 3],
                classes=self.num_classes,
                alpha=self.alpha
            )
        else:
            self.model = get_model_from_name[self.model_choice](
                input_shape=[self.input_shape[0], self.input_shape[1], 3],
                classes=self.num_classes
            )

        self.model.load_weights(self.model_path)
        print("Model loaded from", self.model_path)
        print("Classes:", self.class_names)

    def detect_image(self, image):
        image = cvtColor(image)
        image = letterbox_image(image, [self.input_shape[1], self.input_shape[0]])
        image = np.array(image, dtype=np.float32)
        image = preprocess_input(image)
        image = np.expand_dims(image, axis=0)

        preds = self.model.predict(image)[0]
        class_index = np.argmax(preds)
        class_name = self.class_names[class_index]
        probability = preds[class_index]

        return class_name, probability