File size: 8,120 Bytes
89b4b0c
 
 
 
 
 
 
 
 
 
 
 
4702f0f
89b4b0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fac47b5
89b4b0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c342354
89b4b0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81df173
 
 
 
 
 
 
 
 
89b4b0c
 
 
 
 
 
 
 
 
 
 
4702f0f
89b4b0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from contextlib import nullcontext
import time
import torch
from apogee.tokenizer import Tokenizer
from apogee.model import GPT, ModelConfig

from typing import Any, Dict, Optional, Union
from pathlib import Path

torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

class EndpointHandler:
    """
    Handler class.
    """

    def __init__(self, base_path: Optional[Union[str, Path]] = None, device: Optional[str] = None):
        if base_path is None:
            base_path = Path(__file__).parent
        self.base_path = Path(base_path)
        # Get the device
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = device
        print(f"Handler spwaned on device {self.device} 🚀")
        ckpt_path = self.base_path / "ckpt.pt"
        print(f"Loading model from {ckpt_path} 🤖")
        checkpoint = torch.load(ckpt_path, map_location=device)
        self.config = ModelConfig(**checkpoint["model_config"])
        self.tokenizer = Tokenizer()
        self.model = GPT(self.config)
        state_dict = checkpoint['model']
        unwanted_prefix = '_orig_mod.'
        for k in list(state_dict.keys()):
            if k.startswith(unwanted_prefix):
                state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
        self.model.load_state_dict(state_dict)
        self.model.eval()
        self.model.to(self.device)
        self.model = torch.compile(self.model)
        dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and torch.cuda.get_device_capability()[0] >= 8 else 'float16' # 'float32' or 'bfloat16' or 'float16'
        ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
        self.ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=ptdtype)
        print("Warming up hardware 🔥")
        with torch.no_grad(), self.ctx:
            self.model(torch.randint(0, self.tokenizer.vocabulary_size, (1, self.config.block_size), device=self.device))
        print("Model ready ! ✅")
        # Precompute useful values
        self.max_candles = self.config.block_size // self.tokenizer.tokens_per_candle

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Args:
            data (Dict[str, Any]):
                inputs: Dict[str, List[float]] with keys:
                    timestamps: Timestamps of the time serie
                    open: Open prices
                    high: High prices
                    low: Low prices
                    close: Close prices
                    volume: Volumes
                steps: int = 4 | Number of sampling steps
                n_scenarios: int = 32 | Number of scenarios to generate
                seed: Optional[int] = None | Seed for the random number generator
        Return:
            Dict[str, Any] Generated scenarios with keys:
                timestamps: Timestamps of the time serie
                open: Open prices
                high: High prices
                low: Low prices
                close: Close prices
                volume: Volumes
        """
        t_start = time.time() # Start the timer
        # Unpack input data
        inputs = data.pop("inputs", data)
        # Validate the inputs
        assert "timestamps" in inputs and "open" in inputs and "high" in inputs and "low" in inputs and "close" in inputs and "volume" in inputs, "Required keys: timestamps, open, high, low, close, volume"
        assert isinstance(inputs["timestamps"], list) and isinstance(inputs["open"], list) and isinstance(inputs["high"], list) and isinstance(inputs["low"], list) and isinstance(inputs["close"], list) and isinstance(inputs["volume"], list), "Inputs must be lists"
        assert len(inputs["timestamps"]) == len(inputs["open"]) == len(inputs["high"]) == len(inputs["low"]) == len(inputs["close"]) == len(inputs["volume"]), "Inputs must have the same length"
        timestamps = torch.tensor(list(map(int, inputs["timestamps"])))
        samples = torch.tensor([inputs["open"], inputs["high"], inputs["low"], inputs["close"], inputs["volume"]], dtype=torch.float32).T.contiguous()
        steps = data.pop("steps", 4)
        n_scenarios = data.pop("n_scenarios", 32)
        seed = data.pop("seed", None)
        # Validate the params
        assert isinstance(steps, int) and steps > 0, "steps must be a positive integer"
        assert isinstance(n_scenarios, int) and n_scenarios > 0, "n_scenarios must be a positive integer"
        if seed is not None:
            assert isinstance(seed, int), "seed must be an integer"
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
        # Generate scenarios
        samples = samples[-self.max_candles + steps:] # Keep only the last candles that fit in the model's context
        tokens = self.tokenizer.encode(samples) # Encode the samples into tokens
        tokens = tokens.to(self.device).unsqueeze(0).long() # Add a batch dimension
        with torch.no_grad(), self.ctx:
            for _ in range(steps * self.tokenizer.tokens_per_candle):
                assert tokens.shape[1] <= self.config.block_size, "Too many tokens in the sequence"
                logits = self.model(tokens) # forward the model to get the logits for the index in the sequence
                logits = logits[:, -1, :] # pluck the logits at the final step
                # apply softmax to convert logits to (normalized) probabilities
                probs = torch.nn.functional.softmax(logits, dim=-1)
                # sample from the distribution
                if probs.shape[0] != n_scenarios:
                    next_tokens = torch.multinomial(probs, num_samples=n_scenarios, replacement=True).T
                    tokens = tokens.expand(n_scenarios, -1)
                else:
                    next_tokens = torch.multinomial(probs, num_samples=1)
                # append sampled index to the running sequence and continue
                tokens = torch.cat((tokens, next_tokens), dim=1)
        # Decode the tokens back into samples
        scenarios = self.tokenizer.decode(tokens)[:, -steps:]
        print(f"Generated {n_scenarios} scenarios in {time.time() - t_start:.2f} seconds ⏱")
        print("Nans:", torch.isnan(scenarios).sum().item())
        print("Infs:", torch.isinf(scenarios).sum().item())
        high_not_highest = (scenarios[..., :4].max(-1).values > scenarios[..., 1])
        low_not_lowest = (scenarios[..., :4].min(-1).values < scenarios[..., 2])
        invalid_candle = high_not_highest | low_not_lowest
        print("Highest not high rate:", high_not_highest.float().mean().item())
        print("Lowest not low rate:", low_not_lowest.float().mean().item())
        print("Invalid candles rate:", invalid_candle.float().mean().item())
        print("Invalid scenario rate:", invalid_candle.any(dim=-1).float().mean().item())
        return {
            "timestamps": (timestamps[-1] + torch.arange(1, steps+1) * torch.median(torch.diff(timestamps)).item()).tolist(),
            "open": scenarios[:, :, 0].tolist(),
            "high": scenarios[:, :, 1].tolist(),
            "low": scenarios[:, :, 2].tolist(),
            "close": scenarios[:, :, 3].tolist(),
            "volume": scenarios[:, :, 4].tolist()
        }

if __name__ == "__main__":
    import pandas as pd
    handler = EndpointHandler()
    test_path = Path(__file__).parents[2] / "tests" / "assets" / "BTCUSDT-1m-2019-03.csv"
    with open(test_path, "r") as f:
        data = pd.read_csv(f)
    y = handler({
        "inputs": {
            "timestamps": data[data.columns[0]].tolist(),
            "open": data[data.columns[1]].tolist(),
            "high": data[data.columns[2]].tolist(),
            "low": data[data.columns[3]].tolist(),
            "close": data[data.columns[4]].tolist(),
            "volume": data[data.columns[5]].tolist()
        },
        "steps": 4,
        "n_scenarios": 64,
        "seed": 42
    })