Spaces:
Runtime error
Runtime error
File size: 5,395 Bytes
74050d9 5e98889 a7eca0b 5e98889 a7eca0b 5e98889 74050d9 5e98889 74050d9 5e98889 74050d9 5e98889 74050d9 5e98889 74050d9 5e98889 74050d9 5e98889 74050d9 5e98889 74050d9 5e98889 74050d9 5e98889 74050d9 5e98889 74050d9 5e98889 74050d9 5e98889 74050d9 5e98889 74050d9 5e98889 74050d9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | # tools/models.py
import torch
import logging
import onnxruntime as ort
from time import time
from typing import Union
from configs import ModelConfig, InferenceConfig
from visualization import draw_text_on_image
from pipelines import VideoClassificationPipeline # Nếu cần thiết
import numpy as np
class Predictions:
def __init__(
self,
predictions: list[dict] = None,
inference_time: float = 0,
start_time: float = 0,
end_time: float = 0,
) -> None:
self.predictions = predictions
self.inference_time = inference_time
self.start_time = start_time
self.end_time = end_time
def visualize(
self,
frame: np.ndarray,
position: tuple = (20, 100),
prefix: str = "Predictions",
color: tuple = (0, 0, 255),
) -> np.ndarray:
text = prefix + ": " + self.get_pred_message()
return draw_text_on_image(
image=frame,
text=text,
position=position,
color=color,
font_size=20,
)
def get_pred_message(self) -> str:
if not any((
self.start_time,
self.end_time,
self.inference_time,
self.predictions
)):
return ""
return ', '.join(
[
f"{pred['gloss']} ({pred['score']*100:.2f}%)"
for pred in self.predictions
]
)
def __str__(self) -> str:
if not any((
self.start_time,
self.end_time,
self.inference_time,
self.predictions
)):
return ""
predictions = self.get_pred_message()
message = "Sample start: {:.2f}s - end: {:.2f}s | Runtime: {:.2f}s | Predictions: {}"
return message.format(self.start_time, self.end_time, self.inference_time, predictions)
def merge_results(self, results: dict = None) -> dict:
if results is None:
results = {
"start_time": [],
"end_time": [],
"inference_time": [],
"prediction": [],
}
results["start_time"].append(self.start_time)
results["end_time"].append(self.end_time)
results["inference_time"].append(self.inference_time)
results["prediction"].append(self.predictions)
return results
def load_model(
model_config: ModelConfig,
inference_config: InferenceConfig,
label2id: dict = None,
id2label: dict = None,
) -> ort.InferenceSession:
'''
Tải mô hình ONNX sử dụng onnxruntime.
'''
try:
session = ort.InferenceSession(model_config.pretrained)
logging.info(f"ONNX model loaded from {model_config.pretrained}")
except Exception as e:
logging.error(f"Failed to load ONNX model: {e}")
raise e
return session
def load_pipeline(
model_config: ModelConfig,
inference_config: InferenceConfig,
) -> ort.InferenceSession:
'''
Tải onnxruntime session dựa trên cấu hình mô hình.
'''
session = load_model(model_config, inference_config)
return session
def preprocess_inputs_onnx(inputs: np.ndarray, processor=None) -> dict:
'''
Chuyển đổi đầu vào cho mô hình ONNX nếu cần.
Bạn có thể thêm các bước tiền xử lý cụ thể ở đây nếu cần.
'''
# Ví dụ: Đảm bảo rằng đầu vào có định dạng phù hợp
# inputs = processor(inputs) # Nếu cần thiết
return {"pixel_values": inputs.astype(np.float32)} # Điều chỉnh tùy thuộc vào yêu cầu của mô hình
def get_predictions(
inputs: np.ndarray,
model: ort.InferenceSession,
id2gloss: dict,
k: int = 3,
) -> Predictions:
'''
Lấy top-k dự đoán từ mô hình ONNX.
Parameters
----------
inputs : np.ndarray
Dữ liệu đầu vào đã được tiền xử lý.
model : ort.InferenceSession
Mô hình ONNX đã được tải.
id2gloss : dict
Bản đồ từ ID lớp sang gloss.
k : int, optional
Số lượng dự đoán cần trả về, mặc định là 3.
Returns
-------
Predictions
Đối tượng chứa các dự đoán và thời gian suy luận.
'''
if inputs is None:
return Predictions()
# Tiền xử lý đầu vào cho ONNX
preprocessed_inputs = preprocess_inputs_onnx(inputs)
# Lấy logits
start_time = time()
try:
logits = model.run(None, preprocessed_inputs)[0]
except Exception as e:
logging.error(f"Error during ONNX inference: {e}")
raise e
inference_time = time() - start_time
logits = torch.from_numpy(logits)
# Lấy top-k dự đoán
topk_scores, topk_indices = torch.topk(logits, k, dim=1)
topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().detach().numpy()
topk_indices = topk_indices.squeeze().detach().numpy()
predictions = []
for i in range(k):
class_idx = str(topk_indices[i])
gloss = id2gloss.get(class_idx, "Unknown")
score = topk_scores[i]
predictions.append({
'gloss': gloss,
'score': score,
})
return Predictions(predictions=predictions, inference_time=inference_time)
|