File size: 8,805 Bytes
ee39f59
 
 
c46eb4d
ee39f59
 
 
c46eb4d
 
 
ee39f59
 
 
c46eb4d
ee39f59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c46eb4d
ee39f59
 
 
 
 
 
 
8ecbf61
ee39f59
 
 
 
 
8ecbf61
 
 
ee39f59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ecbf61
 
 
ee39f59
8ecbf61
 
 
4901a7e
ee39f59
 
 
 
 
 
 
 
 
 
 
 
 
8ecbf61
ee39f59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ecbf61
 
ee39f59
a1c0906
 
 
 
 
 
 
 
 
ee39f59
 
 
 
 
 
 
 
 
 
 
c46eb4d
ee39f59
 
 
 
 
8ecbf61
 
ee39f59
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import time
import torch

from contextlib import nullcontext
from typing import Any, Dict, Optional, Union
from pathlib import Path

from apogee.tokenizer import Tokenizer
from apogee.model import GPT, ModelConfig

torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

class EndpointHandler:
    """
    Handler class.
    """

    def __init__(self, base_path: Optional[Union[str, Path]] = None, device: Optional[str] = None):
        if base_path is None:
            base_path = Path(__file__).parent
        self.base_path = Path(base_path)
        # Get the device
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = device
        print(f"Handler spwaned on device {self.device} 🚀")
        ckpt_path = self.base_path / "ckpt.pt"
        print(f"Loading model from {ckpt_path} 🤖")
        checkpoint = torch.load(ckpt_path, map_location=device)
        self.config = ModelConfig(**checkpoint["model_config"])
        self.tokenizer = Tokenizer()
        self.model = GPT(self.config)
        state_dict = checkpoint['model']
        unwanted_prefix = '_orig_mod.'
        for k in list(state_dict.keys()):
            if k.startswith(unwanted_prefix):
                state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
        self.model.load_state_dict(state_dict)
        self.model.eval()
        self.model.to(self.device)
        self.model = torch.compile(self.model)
        dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and torch.cuda.get_device_capability()[0] >= 8 else 'float16' # 'float32' or 'bfloat16' or 'float16'
        ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
        self.ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=ptdtype)
        print("Warming up hardware 🔥")
        with torch.no_grad(), self.ctx:
            self.model(torch.randint(0, self.tokenizer.vocabulary_size, (1, self.config.block_size), device=self.device))
        print("Model ready ! ✅")
        # Precompute useful values
        self.max_candles = (self.config.block_size - self.config.meta_size) // self.tokenizer.tokens_per_candle

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Args:
            data (Dict[str, Any]):
                inputs: Dict[str, Union[str, List[float]]] with keys:
                    pair: Pair symbol
                    frequency: Frequency of the time serie (1m, 5m, 30m, 2h, 8h, 1d)
                    timestamps: Timestamps of the time serie
                    open: Open prices
                    high: High prices
                    low: Low prices
                    close: Close prices
                    volume: Volumes
                steps: int = 4 | Number of sampling steps
                n_scenarios: int = 32 | Number of scenarios to generate
                seed: Optional[int] = None | Seed for the random number generator
        Return:
            Dict[str, Any] Generated scenarios with keys:
                timestamps: Timestamps of the time serie
                open: Open prices
                high: High prices
                low: Low prices
                close: Close prices
                volume: Volumes
        """
        t_start = time.time() # Start the timer
        # Unpack input data
        inputs = data.pop("inputs", data)
        # Validate the inputs
        assert "pair" in inputs and "frequency" in inputs and "timestamps" in inputs and "open" in inputs and "high" in inputs and "low" in inputs and "close" in inputs and "volume" in inputs, "Required keys: pair, frequency, timestamps, open, high, low, close, volume"
        assert isinstance(inputs["pair"], str) and isinstance(inputs["frequency"], str) and isinstance(inputs["timestamps"], list) and isinstance(inputs["open"], list) and isinstance(inputs["high"], list) and isinstance(inputs["low"], list) and isinstance(inputs["close"], list) and isinstance(inputs["volume"], list), "Inputs must be lists"
        assert inputs["frequency"] in ["1m", "5m", "30m", "2h", "8h", "1d"], "Invalid frequency"
        assert len(inputs["timestamps"]) == len(inputs["open"]) == len(inputs["high"]) == len(inputs["low"]) == len(inputs["close"]) == len(inputs["volume"]), "Inputs must have the same length"
        pair, freq = inputs["pair"], inputs["frequency"]
        pair = "".join(pair.split("/"))
        pair = f"binance.{pair.upper()}" if "." not in pair else pair
        timestamps = torch.tensor(list(map(int, inputs["timestamps"])))
        samples = torch.tensor([inputs["open"], inputs["high"], inputs["low"], inputs["close"], inputs["volume"]], dtype=torch.float32).T.contiguous()
        steps = data.pop("steps", 4)
        n_scenarios = data.pop("n_scenarios", 32)
        seed = data.pop("seed", None)
        # Validate the params
        assert isinstance(steps, int) and steps > 0, "steps must be a positive integer"
        assert isinstance(n_scenarios, int) and n_scenarios > 0, "n_scenarios must be a positive integer"
        if seed is not None:
            assert isinstance(seed, int), "seed must be an integer"
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
        # Generate scenarios
        samples = samples[-self.max_candles + steps:] # Keep only the last candles that fit in the model's context
        tokens = self.tokenizer.encode(pair, freq, samples) # Encode the samples into tokens
        tokens = tokens.to(self.device).unsqueeze(0).long() # Add a batch dimension
        with torch.no_grad(), self.ctx:
            for _ in range(steps * self.tokenizer.tokens_per_candle):
                assert tokens.shape[1] <= self.config.block_size, "Too many tokens in the sequence"
                logits = self.model(tokens) # forward the model to get the logits for the index in the sequence
                logits = logits[:, -1, :] # pluck the logits at the final step
                # apply softmax to convert logits to (normalized) probabilities
                probs = torch.nn.functional.softmax(logits, dim=-1)
                # sample from the distribution
                if probs.shape[0] != n_scenarios:
                    next_tokens = torch.multinomial(probs, num_samples=n_scenarios, replacement=True).T
                    tokens = tokens.expand(n_scenarios, -1)
                else:
                    next_tokens = torch.multinomial(probs, num_samples=1)
                # append sampled index to the running sequence and continue
                tokens = torch.cat((tokens, next_tokens), dim=1)
        # Decode the tokens back into samples
        _, _, scenarios = self.tokenizer.decode(tokens)
        scenarios = scenarios[:, -steps:]
        print(f"Generated {n_scenarios} scenarios in {time.time() - t_start:.2f} seconds ⏱")
        print("Nans:", torch.isnan(scenarios).sum().item())
        print("Infs:", torch.isinf(scenarios).sum().item())
        high_not_highest = (scenarios[..., :4].max(-1).values > scenarios[..., 1])
        low_not_lowest = (scenarios[..., :4].min(-1).values < scenarios[..., 2])
        invalid_candle = high_not_highest | low_not_lowest
        print("Highest not high rate:", high_not_highest.float().mean().item())
        print("Lowest not low rate:", low_not_lowest.float().mean().item())
        print("Invalid candles rate:", invalid_candle.float().mean().item())
        print("Invalid scenario rate:", invalid_candle.any(dim=-1).float().mean().item())
        return {
            "timestamps": (timestamps[-1] + torch.arange(1, steps+1) * torch.median(torch.diff(timestamps)).item()).tolist(),
            "open": scenarios[:, :, 0].tolist(),
            "high": scenarios[:, :, 1].tolist(),
            "low": scenarios[:, :, 2].tolist(),
            "close": scenarios[:, :, 3].tolist(),
            "volume": scenarios[:, :, 4].tolist()
        }

if __name__ == "__main__":
    import pandas as pd
    handler = EndpointHandler()
    test_path = Path(__file__).parents[2] / "tests" / "assets" / "BTCUSDT-1m-2019-03.csv"
    with open(test_path, "r") as f:
        data = pd.read_csv(f)
    y = handler({
        "inputs": {
            "pair": "binance.BTCUSDT",
            "frequency": "1m",
            "timestamps": data[data.columns[0]].tolist(),
            "open": data[data.columns[1]].tolist(),
            "high": data[data.columns[2]].tolist(),
            "low": data[data.columns[3]].tolist(),
            "close": data[data.columns[4]].tolist(),
            "volume": data[data.columns[5]].tolist()
        },
        "steps": 4,
        "n_scenarios": 64,
        "seed": 42
    })