File size: 8,120 Bytes
89b4b0c 4702f0f 89b4b0c fac47b5 89b4b0c c342354 89b4b0c 81df173 89b4b0c 4702f0f 89b4b0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from contextlib import nullcontext
import time
import torch
from apogee.tokenizer import Tokenizer
from apogee.model import GPT, ModelConfig
from typing import Any, Dict, Optional, Union
from pathlib import Path
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
class EndpointHandler:
"""
Handler class.
"""
def __init__(self, base_path: Optional[Union[str, Path]] = None, device: Optional[str] = None):
if base_path is None:
base_path = Path(__file__).parent
self.base_path = Path(base_path)
# Get the device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
print(f"Handler spwaned on device {self.device} 🚀")
ckpt_path = self.base_path / "ckpt.pt"
print(f"Loading model from {ckpt_path} 🤖")
checkpoint = torch.load(ckpt_path, map_location=device)
self.config = ModelConfig(**checkpoint["model_config"])
self.tokenizer = Tokenizer()
self.model = GPT(self.config)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
self.model.load_state_dict(state_dict)
self.model.eval()
self.model.to(self.device)
self.model = torch.compile(self.model)
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and torch.cuda.get_device_capability()[0] >= 8 else 'float16' # 'float32' or 'bfloat16' or 'float16'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
self.ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=ptdtype)
print("Warming up hardware 🔥")
with torch.no_grad(), self.ctx:
self.model(torch.randint(0, self.tokenizer.vocabulary_size, (1, self.config.block_size), device=self.device))
print("Model ready ! ✅")
# Precompute useful values
self.max_candles = self.config.block_size // self.tokenizer.tokens_per_candle
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data (Dict[str, Any]):
inputs: Dict[str, List[float]] with keys:
timestamps: Timestamps of the time serie
open: Open prices
high: High prices
low: Low prices
close: Close prices
volume: Volumes
steps: int = 4 | Number of sampling steps
n_scenarios: int = 32 | Number of scenarios to generate
seed: Optional[int] = None | Seed for the random number generator
Return:
Dict[str, Any] Generated scenarios with keys:
timestamps: Timestamps of the time serie
open: Open prices
high: High prices
low: Low prices
close: Close prices
volume: Volumes
"""
t_start = time.time() # Start the timer
# Unpack input data
inputs = data.pop("inputs", data)
# Validate the inputs
assert "timestamps" in inputs and "open" in inputs and "high" in inputs and "low" in inputs and "close" in inputs and "volume" in inputs, "Required keys: timestamps, open, high, low, close, volume"
assert isinstance(inputs["timestamps"], list) and isinstance(inputs["open"], list) and isinstance(inputs["high"], list) and isinstance(inputs["low"], list) and isinstance(inputs["close"], list) and isinstance(inputs["volume"], list), "Inputs must be lists"
assert len(inputs["timestamps"]) == len(inputs["open"]) == len(inputs["high"]) == len(inputs["low"]) == len(inputs["close"]) == len(inputs["volume"]), "Inputs must have the same length"
timestamps = torch.tensor(list(map(int, inputs["timestamps"])))
samples = torch.tensor([inputs["open"], inputs["high"], inputs["low"], inputs["close"], inputs["volume"]], dtype=torch.float32).T.contiguous()
steps = data.pop("steps", 4)
n_scenarios = data.pop("n_scenarios", 32)
seed = data.pop("seed", None)
# Validate the params
assert isinstance(steps, int) and steps > 0, "steps must be a positive integer"
assert isinstance(n_scenarios, int) and n_scenarios > 0, "n_scenarios must be a positive integer"
if seed is not None:
assert isinstance(seed, int), "seed must be an integer"
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Generate scenarios
samples = samples[-self.max_candles + steps:] # Keep only the last candles that fit in the model's context
tokens = self.tokenizer.encode(samples) # Encode the samples into tokens
tokens = tokens.to(self.device).unsqueeze(0).long() # Add a batch dimension
with torch.no_grad(), self.ctx:
for _ in range(steps * self.tokenizer.tokens_per_candle):
assert tokens.shape[1] <= self.config.block_size, "Too many tokens in the sequence"
logits = self.model(tokens) # forward the model to get the logits for the index in the sequence
logits = logits[:, -1, :] # pluck the logits at the final step
# apply softmax to convert logits to (normalized) probabilities
probs = torch.nn.functional.softmax(logits, dim=-1)
# sample from the distribution
if probs.shape[0] != n_scenarios:
next_tokens = torch.multinomial(probs, num_samples=n_scenarios, replacement=True).T
tokens = tokens.expand(n_scenarios, -1)
else:
next_tokens = torch.multinomial(probs, num_samples=1)
# append sampled index to the running sequence and continue
tokens = torch.cat((tokens, next_tokens), dim=1)
# Decode the tokens back into samples
scenarios = self.tokenizer.decode(tokens)[:, -steps:]
print(f"Generated {n_scenarios} scenarios in {time.time() - t_start:.2f} seconds ⏱")
print("Nans:", torch.isnan(scenarios).sum().item())
print("Infs:", torch.isinf(scenarios).sum().item())
high_not_highest = (scenarios[..., :4].max(-1).values > scenarios[..., 1])
low_not_lowest = (scenarios[..., :4].min(-1).values < scenarios[..., 2])
invalid_candle = high_not_highest | low_not_lowest
print("Highest not high rate:", high_not_highest.float().mean().item())
print("Lowest not low rate:", low_not_lowest.float().mean().item())
print("Invalid candles rate:", invalid_candle.float().mean().item())
print("Invalid scenario rate:", invalid_candle.any(dim=-1).float().mean().item())
return {
"timestamps": (timestamps[-1] + torch.arange(1, steps+1) * torch.median(torch.diff(timestamps)).item()).tolist(),
"open": scenarios[:, :, 0].tolist(),
"high": scenarios[:, :, 1].tolist(),
"low": scenarios[:, :, 2].tolist(),
"close": scenarios[:, :, 3].tolist(),
"volume": scenarios[:, :, 4].tolist()
}
if __name__ == "__main__":
import pandas as pd
handler = EndpointHandler()
test_path = Path(__file__).parents[2] / "tests" / "assets" / "BTCUSDT-1m-2019-03.csv"
with open(test_path, "r") as f:
data = pd.read_csv(f)
y = handler({
"inputs": {
"timestamps": data[data.columns[0]].tolist(),
"open": data[data.columns[1]].tolist(),
"high": data[data.columns[2]].tolist(),
"low": data[data.columns[3]].tolist(),
"close": data[data.columns[4]].tolist(),
"volume": data[data.columns[5]].tolist()
},
"steps": 4,
"n_scenarios": 64,
"seed": 42
}) |