File size: 5,395 Bytes
74050d9
 
5e98889
 
a7eca0b
5e98889
a7eca0b
 
5e98889
74050d9
 
5e98889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74050d9
5e98889
 
 
74050d9
5e98889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74050d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e98889
74050d9
 
5e98889
 
 
 
74050d9
 
5e98889
 
74050d9
 
 
 
5e98889
74050d9
5e98889
74050d9
 
5e98889
 
74050d9
 
5e98889
 
 
 
74050d9
 
 
 
5e98889
74050d9
 
 
 
 
5e98889
 
74050d9
 
5e98889
 
 
 
74050d9
 
 
 
 
 
 
 
 
5e98889
74050d9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# tools/models.py

import torch
import logging
import onnxruntime as ort
from time import time
from typing import Union
from configs import ModelConfig, InferenceConfig
from visualization import draw_text_on_image
from pipelines import VideoClassificationPipeline  # Nếu cần thiết
import numpy as np


class Predictions:
    def __init__(
        self,
        predictions: list[dict] = None,
        inference_time: float = 0,
        start_time: float = 0,
        end_time: float = 0,
    ) -> None:
        self.predictions = predictions
        self.inference_time = inference_time
        self.start_time = start_time
        self.end_time = end_time

    def visualize(
        self,
        frame: np.ndarray,
        position: tuple = (20, 100),
        prefix: str = "Predictions",
        color: tuple = (0, 0, 255),
    ) -> np.ndarray:
        text = prefix + ": " + self.get_pred_message()
        return draw_text_on_image(
            image=frame,
            text=text,
            position=position,
            color=color,
            font_size=20,
        )

    def get_pred_message(self) -> str:
        if not any((
            self.start_time,
            self.end_time,
            self.inference_time,
            self.predictions
        )):
            return ""

        return ', '.join(
            [
                f"{pred['gloss']} ({pred['score']*100:.2f}%)"
                for pred in self.predictions
            ]
        )

    def __str__(self) -> str:
        if not any((
            self.start_time,
            self.end_time,
            self.inference_time,
            self.predictions
        )):
            return ""

        predictions = self.get_pred_message()
        message = "Sample start: {:.2f}s - end: {:.2f}s | Runtime: {:.2f}s | Predictions: {}"
        return message.format(self.start_time, self.end_time, self.inference_time, predictions)

    def merge_results(self, results: dict = None) -> dict:
        if results is None:
            results = {
                "start_time": [],
                "end_time": [],
                "inference_time": [],
                "prediction": [],
            }
        results["start_time"].append(self.start_time)
        results["end_time"].append(self.end_time)
        results["inference_time"].append(self.inference_time)
        results["prediction"].append(self.predictions)
        return results


def load_model(
    model_config: ModelConfig,
    inference_config: InferenceConfig,
    label2id: dict = None,
    id2label: dict = None,
) -> ort.InferenceSession:
    '''
    Tải mô hình ONNX sử dụng onnxruntime.
    '''
    try:
        session = ort.InferenceSession(model_config.pretrained)
        logging.info(f"ONNX model loaded from {model_config.pretrained}")
    except Exception as e:
        logging.error(f"Failed to load ONNX model: {e}")
        raise e
    return session


def load_pipeline(
    model_config: ModelConfig,
    inference_config: InferenceConfig,
) -> ort.InferenceSession:
    '''
    Tải onnxruntime session dựa trên cấu hình mô hình.
    '''
    session = load_model(model_config, inference_config)
    return session


def preprocess_inputs_onnx(inputs: np.ndarray, processor=None) -> dict:
    '''
    Chuyển đổi đầu vào cho mô hình ONNX nếu cần.
    Bạn có thể thêm các bước tiền xử lý cụ thể ở đây nếu cần.
    '''
    # Ví dụ: Đảm bảo rằng đầu vào có định dạng phù hợp
    # inputs = processor(inputs)  # Nếu cần thiết
    return {"pixel_values": inputs.astype(np.float32)}  # Điều chỉnh tùy thuộc vào yêu cầu của mô hình


def get_predictions(
    inputs: np.ndarray,
    model: ort.InferenceSession,
    id2gloss: dict,
    k: int = 3,
) -> Predictions:
    '''
    Lấy top-k dự đoán từ mô hình ONNX.

    Parameters
    ----------
    inputs : np.ndarray
        Dữ liệu đầu vào đã được tiền xử lý.
    model : ort.InferenceSession
        Mô hình ONNX đã được tải.
    id2gloss : dict
        Bản đồ từ ID lớp sang gloss.
    k : int, optional
        Số lượng dự đoán cần trả về, mặc định là 3.

    Returns
    -------
    Predictions
        Đối tượng chứa các dự đoán và thời gian suy luận.
    '''
    if inputs is None:
        return Predictions()

    # Tiền xử lý đầu vào cho ONNX
    preprocessed_inputs = preprocess_inputs_onnx(inputs)

    # Lấy logits
    start_time = time()
    try:
        logits = model.run(None, preprocessed_inputs)[0]
    except Exception as e:
        logging.error(f"Error during ONNX inference: {e}")
        raise e
    inference_time = time() - start_time

    logits = torch.from_numpy(logits)
    # Lấy top-k dự đoán
    topk_scores, topk_indices = torch.topk(logits, k, dim=1)
    topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().detach().numpy()
    topk_indices = topk_indices.squeeze().detach().numpy()

    predictions = []
    for i in range(k):
        class_idx = str(topk_indices[i])
        gloss = id2gloss.get(class_idx, "Unknown")
        score = topk_scores[i]
        predictions.append({
            'gloss': gloss,
            'score': score,
        })

    return Predictions(predictions=predictions, inference_time=inference_time)