# tools/models.py import torch import logging import onnxruntime as ort from time import time from typing import Union from configs import ModelConfig, InferenceConfig from visualization import draw_text_on_image from pipelines import VideoClassificationPipeline # Nếu cần thiết import numpy as np class Predictions: def __init__( self, predictions: list[dict] = None, inference_time: float = 0, start_time: float = 0, end_time: float = 0, ) -> None: self.predictions = predictions self.inference_time = inference_time self.start_time = start_time self.end_time = end_time def visualize( self, frame: np.ndarray, position: tuple = (20, 100), prefix: str = "Predictions", color: tuple = (0, 0, 255), ) -> np.ndarray: text = prefix + ": " + self.get_pred_message() return draw_text_on_image( image=frame, text=text, position=position, color=color, font_size=20, ) def get_pred_message(self) -> str: if not any(( self.start_time, self.end_time, self.inference_time, self.predictions )): return "" return ', '.join( [ f"{pred['gloss']} ({pred['score']*100:.2f}%)" for pred in self.predictions ] ) def __str__(self) -> str: if not any(( self.start_time, self.end_time, self.inference_time, self.predictions )): return "" predictions = self.get_pred_message() message = "Sample start: {:.2f}s - end: {:.2f}s | Runtime: {:.2f}s | Predictions: {}" return message.format(self.start_time, self.end_time, self.inference_time, predictions) def merge_results(self, results: dict = None) -> dict: if results is None: results = { "start_time": [], "end_time": [], "inference_time": [], "prediction": [], } results["start_time"].append(self.start_time) results["end_time"].append(self.end_time) results["inference_time"].append(self.inference_time) results["prediction"].append(self.predictions) return results def load_model( model_config: ModelConfig, inference_config: InferenceConfig, label2id: dict = None, id2label: dict = None, ) -> ort.InferenceSession: ''' Tải mô hình ONNX sử dụng onnxruntime. ''' try: session = ort.InferenceSession(model_config.pretrained) logging.info(f"ONNX model loaded from {model_config.pretrained}") except Exception as e: logging.error(f"Failed to load ONNX model: {e}") raise e return session def load_pipeline( model_config: ModelConfig, inference_config: InferenceConfig, ) -> ort.InferenceSession: ''' Tải onnxruntime session dựa trên cấu hình mô hình. ''' session = load_model(model_config, inference_config) return session def preprocess_inputs_onnx(inputs: np.ndarray, processor=None) -> dict: ''' Chuyển đổi đầu vào cho mô hình ONNX nếu cần. Bạn có thể thêm các bước tiền xử lý cụ thể ở đây nếu cần. ''' # Ví dụ: Đảm bảo rằng đầu vào có định dạng phù hợp # inputs = processor(inputs) # Nếu cần thiết return {"pixel_values": inputs.astype(np.float32)} # Điều chỉnh tùy thuộc vào yêu cầu của mô hình def get_predictions( inputs: np.ndarray, model: ort.InferenceSession, id2gloss: dict, k: int = 3, ) -> Predictions: ''' Lấy top-k dự đoán từ mô hình ONNX. Parameters ---------- inputs : np.ndarray Dữ liệu đầu vào đã được tiền xử lý. model : ort.InferenceSession Mô hình ONNX đã được tải. id2gloss : dict Bản đồ từ ID lớp sang gloss. k : int, optional Số lượng dự đoán cần trả về, mặc định là 3. Returns ------- Predictions Đối tượng chứa các dự đoán và thời gian suy luận. ''' if inputs is None: return Predictions() # Tiền xử lý đầu vào cho ONNX preprocessed_inputs = preprocess_inputs_onnx(inputs) # Lấy logits start_time = time() try: logits = model.run(None, preprocessed_inputs)[0] except Exception as e: logging.error(f"Error during ONNX inference: {e}") raise e inference_time = time() - start_time logits = torch.from_numpy(logits) # Lấy top-k dự đoán topk_scores, topk_indices = torch.topk(logits, k, dim=1) topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().detach().numpy() topk_indices = topk_indices.squeeze().detach().numpy() predictions = [] for i in range(k): class_idx = str(topk_indices[i]) gloss = id2gloss.get(class_idx, "Unknown") score = topk_scores[i] predictions.append({ 'gloss': gloss, 'score': score, }) return Predictions(predictions=predictions, inference_time=inference_time)