File size: 6,375 Bytes
6d5d850
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import logging
from typing import Iterable, Optional

import tensorflow as tf

from . import config
from .preprocessing import VideoPreprocessor

logger = logging.getLogger(__name__)


def _configure_tensorflow() -> None:
    """
    Apply lightweight TensorFlow runtime tweaks to avoid noisy logs and GPU OOMs.
    """
    try:
        tf.get_logger().setLevel(logging.ERROR)
        gpus = tf.config.list_physical_devices("GPU")
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except Exception as exc:
        logger.debug("TensorFlow runtime configuration skipped: %s", exc)


class LipReadingModel:
    def __init__(self, model_path: str = str(config.MODEL_PATH)):
        # Initialize character mappings before loading the model
        vocab_chars = (
            "aa\u0192bcdde\u02c6ghiklmno\u201copqrstuuvxy\u00a0\u2026?"
            "a???????????\u201a\u0160????????\u00a1\u008d?i?\u00a2\u2022?o???????????\u00a3\u2014?u??????y????'?!123456789 "
        )
        vocab = []
        seen = set()
        for ch in vocab_chars:
            if ch not in seen:
                seen.add(ch)
                vocab.append(ch)
        self.char_to_num = tf.keras.layers.StringLookup(vocabulary=vocab, oov_token="")
        self.num_to_char = tf.keras.layers.StringLookup(
            vocabulary=self.char_to_num.get_vocabulary(), oov_token="", invert=True
        )

        _configure_tensorflow()

        try:
            self.model = tf.keras.models.load_model(
                model_path,
                custom_objects={"CTCLoss": self.CTCLoss},
            )
            logger.info("Model loaded successfully from %s", model_path)
        except Exception as exc:
            logger.error("Error loading model from %s: %s", model_path, exc)
            self.model = self.build_model()

    @staticmethod
    def CTCLoss(y_true, y_pred):
        batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
        input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
        label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
        input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
        label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
        return tf.keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)

    def build_model(self):
        model = tf.keras.Sequential()
        model.add(tf.keras.layers.Conv3D(64, (3, 3, 3), strides=(1, 2, 2), input_shape=(None, config.TARGET_SIZE, config.TARGET_SIZE, 1), padding="same"))
        model.add(tf.keras.layers.BatchNormalization())
        model.add(tf.keras.layers.Activation("relu"))
        model.add(tf.keras.layers.MaxPool3D((1, 2, 2), padding="same"))

        model.add(tf.keras.layers.Conv3D(128, (3, 3, 3), strides=(1, 2, 2), padding="same"))
        model.add(tf.keras.layers.BatchNormalization())
        model.add(tf.keras.layers.Activation("relu"))
        model.add(tf.keras.layers.MaxPool3D((1, 2, 2), padding="same"))

        model.add(tf.keras.layers.Conv3D(256, (3, 3, 3), strides=(1, 2, 2), padding="same"))
        model.add(tf.keras.layers.LayerNormalization())
        model.add(tf.keras.layers.Activation("relu"))
        model.add(tf.keras.layers.MaxPool3D((1, 2, 2), padding="same"))

        model.add(tf.keras.layers.Conv3D(256, (3, 3, 3), padding="same"))
        model.add(tf.keras.layers.BatchNormalization())
        model.add(tf.keras.layers.Activation("relu"))
        model.add(tf.keras.layers.MaxPool3D((1, 2, 2), padding="same"))

        model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Flatten()))

        model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(512, kernel_initializer="Orthogonal", return_sequences=True)))
        model.add(tf.keras.layers.Dropout(0.4))

        model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(256, kernel_initializer="Orthogonal", return_sequences=True)))
        model.add(tf.keras.layers.Dropout(0.4))

        model.add(tf.keras.layers.Dense(self.char_to_num.vocabulary_size() + 1, kernel_initializer="he_normal", activation="softmax"))

        logger.info("Built the fallback model architecture.")
        return model

    def predict(self, normalized_frames: Optional[tf.Tensor]):
        if self.model is None:
            return "? Model not loaded. Please check the model path and ensure the model file is accessible."

        if normalized_frames is None:
            return "? No frames extracted from the video. Please ensure the video contains a clear view of the face and lips."

        if int(tf.size(normalized_frames)) == 0:
            return "? No frames extracted from the video. Please ensure the video contains a clear view of the face and lips."

        try:
            frames = tf.expand_dims(normalized_frames, axis=0)
            yhat = self.model.predict(frames, verbose=0)

            input_length = [yhat.shape[1]]
            decoded_tf = tf.keras.backend.ctc_decode(yhat, input_length=input_length, greedy=True)[0][0]
            decoded = decoded_tf.numpy().flatten()

            prediction = "".join(
                [
                    self.num_to_char(int(num)).numpy().decode("utf-8")
                    for num in decoded
                    if int(num) != -1
                ]
            )
            return prediction.strip()
        except Exception as exc:
            logger.error("Error during prediction: %s", exc)
            return f"? An error occurred during prediction: {exc}"


def predict_from_video(
    video_path: Optional[str] = None,
    frames: Optional[Iterable] = None,
    model: Optional[LipReadingModel] = None,
    preprocessor: Optional[VideoPreprocessor] = None,
):
    """
    Predicts the text from a video file or webcam frames using the provided model.
    """
    if model is None:
        model = LipReadingModel()

    if preprocessor is None:
        preprocessor = VideoPreprocessor()

    if video_path:
        normalized_frames = preprocessor.preprocess_video(video_path)
    elif frames is not None:
        normalized_frames = preprocessor.preprocess_frames(frames)
    else:
        return "? No video or frames provided for prediction."

    if normalized_frames is None:
        return "? Unable to extract frames from the provided video."

    return model.predict(normalized_frames)