File size: 3,658 Bytes
120f728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from pycoral.utils import edgetpu
import time
from abc import ABC, abstractmethod
from pathlib import Path
import numpy as np

DEFAULT_MODEL_PATH = str(Path(__file__).parent / "models/portiloop_model_quant.tflite")
print(DEFAULT_MODEL_PATH)

class AbstractQuantizedModelForInference(ABC):
    @abstractmethod
    def add_datapoints(self, input_float):
        return NotImplemented

class QuantizedModelForInference(AbstractQuantizedModelForInference):
    def __init__(self, num_models_parallel=8, window_size=54, seq_stride=42, model_path=None, verbose=False, channel=2):
        model_path = DEFAULT_MODEL_PATH if model_path is None else model_path
        self.verbose = verbose
        self.channel = channel
        self.num_models_parallel = num_models_parallel
        
        self.interpreters = []
        for i in range(self.num_models_parallel):
            self.interpreters.append(edgetpu.make_interpreter(model_path))
            self.interpreters[i].allocate_tensors()
        self.interpreter_counter = 0
        
        self.input_details = self.interpreters[0].get_input_details()
        self.output_details = self.interpreters[0].get_output_details()
        
        self.buffer = []
        self.seq_stride = seq_stride
        self.window_size = window_size 
        
        self.stride_counters = [np.floor((self.seq_stride / self.num_models_parallel) * i) for i in range(self.num_models_parallel)]
        for idx, i in enumerate(self.stride_counters[1:]):
            self.stride_counters[idx+1] = i - self.stride_counters[idx]
        self.current_stride_counter = self.stride_counters[0] - 1
        
        
    def add_datapoints(self, inputs_float):
        res = []
        for inp in inputs_float:
            result = self.add_datapoint(inp)
            if result is not None:
                res.append(result)
        return res
    
        
    def add_datapoint(self, input_float):
        input_float = input_float[self.channel-1]
        result = None
        self.buffer.append(input_float)
        if len(self.buffer) > self.window_size:
            self.buffer = self.buffer[1:]
            self.current_stride_counter += 1
            if self.current_stride_counter == self.stride_counter[self.interpreter_counter]:
                result = self.call_model(self.interpreter_counter, self.buffer)
                self.interpreter_counter += 1
                self.interpreter_counter %= self.num_model_parallel
                self.current_stride_counter = 0
        return result
            
                
        
    def call_model(self, idx, input_float=None):
        if input_float is None:
            # For debuggin purposes
            input_shape = input_details[0]['shape']
            input = np.array(np.random.random_sample(input_shape), dtype=np.int8)
        else:
            # Convert float input to Int
            input_scale, input_zero_point = input_details[0]["quantization"]
            input = np.asarray(input_float) / input_scale + input_zero_point
            input = input.astype(input_details[0]["dtype"])

        interpreter.set_tensor(input_details[0]['index'], input)
        if self.verbose:
            start_time = time.time()

        interpreter.invoke()

        if self.verbose:
            end_time = time.time()

        output = interpreter.get_tensor(output_details[0]['index'])
        output_scale, output_zero_point = input_details[0]["quantization"]
        output = float(output - output_zero_point) * output_scale

        if self.verbose:
            print(f"Computed output {output} in {end_time - start_time} seconds")

        return output