File size: 4,035 Bytes
b176a89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python3

import os
from pathlib import Path

import numpy as np
import onnxruntime as ort
import torch


class EmbeddingOnnx:
    def __init__(self, filename: str, device: str = "cpu"):
        if not os.path.exists(filename):
            raise FileNotFoundError(f"ONNX embedding model file not found: {filename}")

        filename_path = Path(filename)
        external_candidates = [
            filename_path.parent / f"{filename_path.name}.data",
            filename_path.parent / f"{filename_path.stem}.onnx_data",
        ]
        self.external_data_file = None
        for p in external_candidates:
            if p.exists():
                self.external_data_file = p
                break

        session_opts = ort.SessionOptions()
        session_opts.inter_op_num_threads = 1
        session_opts.intra_op_num_threads = 1
        self.session_opts = session_opts

        try:
            if device == "cpu":
                use_providers = ["CPUExecutionProvider"]
            elif device == "cuda":
                providers = ort.get_available_providers()
                if "CUDAExecutionProvider" in providers:
                    use_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
                else:
                    use_providers = ["CPUExecutionProvider"]
            else:
                use_providers = ["CPUExecutionProvider"]

            self.model = ort.InferenceSession(
                filename,
                sess_options=session_opts,
                providers=use_providers,
            )
        except Exception as e:
            raise

        meta = self.model.get_modelmeta().custom_metadata_map
        self.vocab_size = int(meta.get("vocab_size", 151936))
        self.hidden_size = int(meta.get("hidden_size", 1024))
        self.model_type = meta.get("model_type", "embedding_layer")

        model_inputs = self.model.get_inputs()
        model_outputs = self.model.get_outputs()

        if len(model_inputs) == 0:
            raise RuntimeError("ONNX embedding model has no inputs")
        if len(model_outputs) == 0:
            raise RuntimeError("ONNX embedding model has no outputs")

        self.input_name = model_inputs[0].name
        self.output_name = model_outputs[0].name

        first_output_type = str(model_outputs[0].type).lower()
        self.input_dtype = np.int64

        is_int8_model = "int8" in filename.lower()
        if is_int8_model:
            self.output_dtype = np.float32
        elif "float16" in first_output_type or "fp16" in first_output_type:
            self.output_dtype = np.float16
        elif "float32" in first_output_type or "float" in first_output_type:
            self.output_dtype = np.float32
        else:
            self.output_dtype = np.float32
    
    def __call__(self, input_ids):
        import torch

        if isinstance(input_ids, torch.Tensor):
            input_ids = input_ids.detach().cpu().numpy()
        elif isinstance(input_ids, list):
            input_ids = np.array(input_ids, dtype=np.int64)
            if input_ids.ndim == 1:
                input_ids = input_ids[None, :]
        elif not isinstance(input_ids, np.ndarray):
            input_ids = np.array(input_ids)

        input_ids = input_ids.astype(np.int64, copy=False)

        if input_ids.ndim == 1:
            input_ids = input_ids[None, :]

        if input_ids.ndim != 2:
            raise ValueError(
                f"input_ids must be 2-D (batch_size, seq_length), got shape {input_ids.shape}"
            )

        input_ids = np.clip(input_ids, 0, self.vocab_size - 1)

        ort_inputs = {self.input_name: input_ids}
        embeddings = self.model.run([self.output_name], ort_inputs)[0]

        if embeddings.dtype != self.output_dtype:
            embeddings = embeddings.astype(self.output_dtype, copy=False)

        if np.any(np.isnan(embeddings)) or np.any(np.isinf(embeddings)):
            embeddings = np.where(np.isfinite(embeddings), embeddings, 0.0)

        return embeddings