baonathor commited on
Commit
e540b53
·
verified ·
1 Parent(s): f910de1

config weights_only=False

Browse files
Files changed (1) hide show
  1. public_inference_extreme.py +333 -333
public_inference_extreme.py CHANGED
@@ -1,333 +1,333 @@
1
- import torch
2
- import torch.nn as nn
3
- import numpy as np
4
- import pandas as pd
5
- import math
6
- import os
7
- import joblib
8
- import time
9
- from typing import List, Dict, Optional
10
- from transformers import AutoModel, AutoTokenizer, AutoConfig
11
-
12
- # ==========================================
13
- # 1. CONFIGURATION
14
- # ==========================================
15
- class PublicConfig:
16
- # Model Architecture
17
- max_length = 256
18
- num_labels_3m = 3
19
- num_labels_30m = 3
20
-
21
- # Feature settings
22
- feature_cols = [
23
- "feat_ret_1m", "feat_ret_5m", "feat_ret_15m",
24
- "feat_volatility_60m", "feat_num_trades_60m", "feat_volume_60m",
25
- "feat_tweet_freq_24h", "feat_time_since_prev_tweet",
26
- "feat_btc_ret_60m", "feat_btc_ret_24h",
27
- "feat_fear_greed_index", "feat_btc_dominance", "feat_altseason_index"
28
- ]
29
-
30
- # Inference Settings
31
- device = "cuda" if torch.cuda.is_available() else "cpu"
32
-
33
- # Paths (Relative to this script or defined by user)
34
- checkpoint_dir = "checkpoints_market_event_multitask"
35
- model_filename = "best_model_extreme.pt"
36
- scaler_filename = "scaler.pkl"
37
-
38
- cfg = PublicConfig()
39
-
40
- # ==========================================
41
- # 2. MODEL ARCHITECTURE
42
- # ==========================================
43
- class MarketConditionedEventMultiTask(nn.Module):
44
- """
45
- The Extreme Signal Model Architecture.
46
- Combines BERT (Text) + MLP (Market Data) + Attention Mechanism.
47
- """
48
- def __init__(self, num_features: int,
49
- num_labels_3m: int, num_labels_30m: int,
50
- bert_config, device: str = "cpu"):
51
- super().__init__()
52
-
53
- # Load BERT structure from config (Offline mode)
54
- self.bert = AutoModel.from_config(bert_config)
55
-
56
- hidden_size = self.bert.config.hidden_size
57
-
58
- # MLP to encode numeric market features
59
- self.market_mlp = nn.Sequential(
60
- nn.Linear(num_features, hidden_size),
61
- nn.ReLU(),
62
- nn.LayerNorm(hidden_size),
63
- )
64
-
65
- # Linear projections for market-conditioned attention
66
- self.query_proj = nn.Linear(hidden_size, hidden_size)
67
- self.key_proj = nn.Linear(hidden_size, hidden_size)
68
-
69
- combined_size = hidden_size * 3 # [CLS] + context + market_emb
70
-
71
- # Classification head for 3m horizon
72
- self.classifier_3m = nn.Sequential(
73
- nn.Linear(combined_size, hidden_size),
74
- nn.ReLU(),
75
- nn.Dropout(0.2),
76
- nn.Linear(hidden_size, num_labels_3m),
77
- )
78
-
79
- # Classification head for 30m horizon
80
- self.classifier_30m = nn.Sequential(
81
- nn.Linear(combined_size, hidden_size),
82
- nn.ReLU(),
83
- nn.Dropout(0.2),
84
- nn.Linear(hidden_size, num_labels_30m),
85
- )
86
-
87
- def forward(self, input_ids, attention_mask, market_features):
88
- # Encode tweet with BERT
89
- outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
90
- last_hidden_state = outputs.last_hidden_state
91
- pooled_output = outputs.pooler_output
92
-
93
- # Encode market state
94
- market_emb = self.market_mlp(market_features)
95
-
96
- # Market-conditioned attention
97
- Q = self.query_proj(market_emb).unsqueeze(1)
98
- K = self.key_proj(last_hidden_state)
99
- scores = torch.matmul(Q, K.transpose(1, 2)) / math.sqrt(K.size(-1))
100
-
101
- extended_mask = attention_mask.unsqueeze(1)
102
- scores = scores.masked_fill(extended_mask == 0, float("-inf"))
103
- attn_weights = torch.softmax(scores, dim=-1)
104
-
105
- context = torch.matmul(attn_weights, last_hidden_state).squeeze(1)
106
-
107
- # Combine and Classify
108
- combined = torch.cat([pooled_output, context, market_emb], dim=-1)
109
- logits_3m = self.classifier_3m(combined)
110
- logits_30m = self.classifier_30m(combined)
111
-
112
- return logits_3m, logits_30m
113
-
114
- # ==========================================
115
- # 3. INFERENCE CLASS
116
- # ==========================================
117
- class ExtremeModelPredictor:
118
- def __init__(self, model_dir: str):
119
- self.device = cfg.device
120
- self.model_dir = model_dir
121
-
122
- print(f"Loading Extreme Model from {model_dir}...")
123
- print(f"Using device: {self.device.upper()}")
124
-
125
- if self.device == 'cpu':
126
- print("Note: Running on CPU. If you have an NVIDIA GPU, please install PyTorch with CUDA support.")
127
-
128
- # 1. Load Tokenizer
129
- try:
130
- self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
131
- except Exception as e:
132
- raise FileNotFoundError(f"Could not load tokenizer from {model_dir}. Please ensure tokenizer files exist. Error: {e}")
133
-
134
- # 2. Load Scaler
135
- scaler_path = os.path.join(model_dir, cfg.scaler_filename)
136
- if not os.path.exists(scaler_path):
137
- # Fallback for development environment
138
- scaler_path = cfg.scaler_filename
139
-
140
- if os.path.exists(scaler_path):
141
- self.scaler = joblib.load(scaler_path)
142
- else:
143
- raise FileNotFoundError(f"Scaler not found at {scaler_path}. Please include scaler.pkl.")
144
-
145
- # 3. Load Model Config & Weights
146
- config_path = os.path.join(model_dir, "config.json")
147
- if not os.path.exists(config_path):
148
- raise FileNotFoundError(f"Config not found at {config_path}. Please ensure config.json exists.")
149
-
150
- bert_config = AutoConfig.from_pretrained(model_dir)
151
-
152
- self.model = MarketConditionedEventMultiTask(
153
- num_features=len(cfg.feature_cols),
154
- num_labels_3m=cfg.num_labels_3m,
155
- num_labels_30m=cfg.num_labels_30m,
156
- device=self.device,
157
- bert_config=bert_config
158
- )
159
-
160
- # Load State Dict
161
- weight_path = os.path.join(model_dir, cfg.model_filename)
162
- if not os.path.exists(weight_path):
163
- # Fallback name
164
- weight_path = os.path.join(model_dir, "best_model.pt")
165
-
166
- state_dict = torch.load(weight_path, map_location="cpu")
167
- self.model.load_state_dict(state_dict)
168
- self.model.to(self.device)
169
- self.model.eval()
170
- print("Model loaded successfully.")
171
-
172
- def preprocess_features(self, raw_feats: Dict[str, float]) -> np.ndarray:
173
- # Ensure correct order and fill missing with 0
174
- vals = [raw_feats.get(col, 0.0) for col in cfg.feature_cols]
175
-
176
- # Create DataFrame with feature names to avoid sklearn warning
177
- df = pd.DataFrame([vals], columns=cfg.feature_cols)
178
- df = df.fillna(0)
179
-
180
- return self.scaler.transform(df)
181
-
182
- def predict(self, project_name: str, symbol: str, tweet_text: str, market_features: Dict[str, float]):
183
- """
184
- Returns probabilities for 3m and 30m horizons.
185
- Classes: 0 (Down), 1 (Neutral), 2 (Up)
186
- """
187
-
188
- full_text = f"{project_name} ({symbol}): {tweet_text}"
189
- start_time = time.perf_counter()
190
- # Tokenize
191
- encoded = self.tokenizer(
192
- full_text,
193
- padding="max_length",
194
- truncation=True,
195
- max_length=cfg.max_length,
196
- return_tensors="pt"
197
- )
198
-
199
- input_ids = encoded["input_ids"].to(self.device)
200
- attention_mask = encoded["attention_mask"].to(self.device)
201
-
202
- # Features
203
- feats_scaled = self.preprocess_features(market_features)
204
- feats_tensor = torch.tensor(feats_scaled, dtype=torch.float32).to(self.device)
205
-
206
- with torch.no_grad():
207
- logits_3m, logits_30m = self.model(input_ids, attention_mask, feats_tensor)
208
-
209
- probs_3m = torch.softmax(logits_3m, dim=-1).cpu().numpy()[0]
210
- probs_30m = torch.softmax(logits_30m, dim=-1).cpu().numpy()[0]
211
-
212
- inference_time = time.perf_counter() - start_time
213
-
214
- return {
215
- "3m_probs": {"Down": probs_3m[0], "Neutral": probs_3m[1], "Up": probs_3m[2]},
216
- "30m_probs": {"Down": probs_30m[0], "Neutral": probs_30m[1], "Up": probs_30m[2]},
217
- "extreme_signal": self._check_extreme(probs_3m),
218
- "inference_time": inference_time
219
- }
220
-
221
- def _check_extreme(self, probs):
222
- # Logic: If Down or Up > 0.7
223
- threshold = 0.7
224
- if probs[0] > threshold:
225
- return "EXTREME DOWN"
226
- elif probs[2] > threshold:
227
- return "EXTREME UP"
228
- return "NORMAL"
229
-
230
- # ==========================================
231
- # 4. EXAMPLE USAGE
232
- # ==========================================
233
- if __name__ == "__main__":
234
- # Example of how to use this script
235
-
236
- # Path to the folder containing: config.json, tokenizer files, scaler.pkl, best_model_extreme.pt
237
- checkpoint_folder = "checkpoints_market_event_multitask"
238
-
239
- # Check if folder exists before running
240
- if os.path.exists(checkpoint_folder):
241
- predictor = ExtremeModelPredictor(checkpoint_folder)
242
-
243
- # Define a helper to run scenarios
244
- def run_scenario(name, tweet, features):
245
- print(f"\n--- Scenario: {name} ---")
246
- print(f"Tweet: {tweet}")
247
- # Print key features for context
248
- print(f"Key Feats: 1m={features.get('feat_ret_1m')}, Vol={features.get('feat_volume_60m')}, F&G={features.get('feat_fear_greed_index')}")
249
-
250
- res = predictor.predict("Test", "TST", tweet, features)
251
- print(f"Signal: {res['extreme_signal']}")
252
- print(f"3m Probs: Down={res['3m_probs']['Down']:.3f}, Neutral={res['3m_probs']['Neutral']:.3f}, Up={res['3m_probs']['Up']:.3f}")
253
- print(f"Time: {res['inference_time']*1000:.2f} ms")
254
- return res
255
-
256
- # Warm-up
257
- print("\n--- Warm-up Run ---")
258
- predictor.predict("Warmup", "WP", "Warmup", {
259
- "feat_ret_1m": 0.0, "feat_ret_5m": 0.0, "feat_ret_15m": 0.0,
260
- "feat_volatility_60m": 0.0, "feat_num_trades_60m": 0, "feat_volume_60m": 0,
261
- "feat_tweet_freq_24h": 0, "feat_time_since_prev_tweet": 0,
262
- "feat_btc_ret_60m": 0.0, "feat_btc_ret_24h": 0.0,
263
- "feat_fear_greed_index": 50, "feat_btc_dominance": 50, "feat_altseason_index": 0
264
- })
265
-
266
- # 1. FOMO / Strong Uptrend
267
- run_scenario(
268
- "FOMO / Strong Uptrend",
269
- "BREAKING: Major exchange listing confirmed! 🚀 #ToTheMoon",
270
- {
271
- "feat_ret_1m": 0.02, "feat_ret_5m": 0.05, "feat_ret_15m": 0.08,
272
- "feat_volatility_60m": 0.05, "feat_num_trades_60m": 1000, "feat_volume_60m": 2000000,
273
- "feat_tweet_freq_24h": 100, "feat_time_since_prev_tweet": 5,
274
- "feat_btc_ret_60m": 0.01, "feat_btc_ret_24h": 0.05,
275
- "feat_fear_greed_index": 80, "feat_btc_dominance": 45, "feat_altseason_index": 80
276
- }
277
- )
278
-
279
- # 2. Panic Dump / Crash (The "Mean Reversion" Test)
280
- run_scenario(
281
- "Panic Dump / Crash",
282
- "URGENT: Security breach detected. Do not interact with contracts.",
283
- {
284
- "feat_ret_1m": -0.05, "feat_ret_5m": -0.08, "feat_ret_15m": -0.10,
285
- "feat_volatility_60m": 0.15, "feat_num_trades_60m": 2000, "feat_volume_60m": 5000000,
286
- "feat_tweet_freq_24h": 200, "feat_time_since_prev_tweet": 1,
287
- "feat_btc_ret_60m": -0.03, "feat_btc_ret_24h": -0.08,
288
- "feat_fear_greed_index": 10, "feat_btc_dominance": 60, "feat_altseason_index": 5
289
- }
290
- )
291
-
292
- # 3. Slow Bleed / Bear Market
293
- run_scenario(
294
- "Slow Bleed / Bear Market",
295
- "Weekly development update. Progress is slow but steady.",
296
- {
297
- "feat_ret_1m": -0.001, "feat_ret_5m": -0.002, "feat_ret_15m": -0.005,
298
- "feat_volatility_60m": 0.01, "feat_num_trades_60m": 50, "feat_volume_60m": 100000,
299
- "feat_tweet_freq_24h": 10, "feat_time_since_prev_tweet": 60,
300
- "feat_btc_ret_60m": -0.001, "feat_btc_ret_24h": -0.01,
301
- "feat_fear_greed_index": 30, "feat_btc_dominance": 55, "feat_altseason_index": 10
302
- }
303
- )
304
-
305
- # 4. Sideways / Stable
306
- run_scenario(
307
- "Sideways / Stable",
308
- "Just a normal day building. #Crypto",
309
- {
310
- "feat_ret_1m": 0.000, "feat_ret_5m": 0.001, "feat_ret_15m": -0.001,
311
- "feat_volatility_60m": 0.005, "feat_num_trades_60m": 20, "feat_volume_60m": 50000,
312
- "feat_tweet_freq_24h": 5, "feat_time_since_prev_tweet": 120,
313
- "feat_btc_ret_60m": 0.000, "feat_btc_ret_24h": 0.002,
314
- "feat_fear_greed_index": 50, "feat_btc_dominance": 50, "feat_altseason_index": 20
315
- }
316
- )
317
-
318
- # 5. Divergence: Good News + Bad Price (Opportunity?)
319
- run_scenario(
320
- "Divergence: Good News + Bad Price",
321
- "Partnership with Google Cloud announced!",
322
- {
323
- "feat_ret_1m": -0.02, "feat_ret_5m": -0.03, "feat_ret_15m": -0.03,
324
- "feat_volatility_60m": 0.04, "feat_num_trades_60m": 150, "feat_volume_60m": 300000,
325
- "feat_tweet_freq_24h": 20, "feat_time_since_prev_tweet": 10,
326
- "feat_btc_ret_60m": -0.01, "feat_btc_ret_24h": -0.02,
327
- "feat_fear_greed_index": 40, "feat_btc_dominance": 52, "feat_altseason_index": 15
328
- }
329
- )
330
-
331
- else:
332
- print(f"Error: Checkpoint folder '{checkpoint_folder}' not found.")
333
- print("Please place this script next to your model folder.")
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import pandas as pd
5
+ import math
6
+ import os
7
+ import joblib
8
+ import time
9
+ from typing import List, Dict, Optional
10
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
11
+
12
+ # ==========================================
13
+ # 1. CONFIGURATION
14
+ # ==========================================
15
+ class PublicConfig:
16
+ # Model Architecture
17
+ max_length = 256
18
+ num_labels_3m = 3
19
+ num_labels_30m = 3
20
+
21
+ # Feature settings
22
+ feature_cols = [
23
+ "feat_ret_1m", "feat_ret_5m", "feat_ret_15m",
24
+ "feat_volatility_60m", "feat_num_trades_60m", "feat_volume_60m",
25
+ "feat_tweet_freq_24h", "feat_time_since_prev_tweet",
26
+ "feat_btc_ret_60m", "feat_btc_ret_24h",
27
+ "feat_fear_greed_index", "feat_btc_dominance", "feat_altseason_index"
28
+ ]
29
+
30
+ # Inference Settings
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+
33
+ # Paths (Relative to this script or defined by user)
34
+ checkpoint_dir = "checkpoints_market_event_multitask"
35
+ model_filename = "best_model_extreme.pt"
36
+ scaler_filename = "scaler.pkl"
37
+
38
+ cfg = PublicConfig()
39
+
40
+ # ==========================================
41
+ # 2. MODEL ARCHITECTURE
42
+ # ==========================================
43
+ class MarketConditionedEventMultiTask(nn.Module):
44
+ """
45
+ The Extreme Signal Model Architecture.
46
+ Combines BERT (Text) + MLP (Market Data) + Attention Mechanism.
47
+ """
48
+ def __init__(self, num_features: int,
49
+ num_labels_3m: int, num_labels_30m: int,
50
+ bert_config, device: str = "cpu"):
51
+ super().__init__()
52
+
53
+ # Load BERT structure from config (Offline mode)
54
+ self.bert = AutoModel.from_config(bert_config)
55
+
56
+ hidden_size = self.bert.config.hidden_size
57
+
58
+ # MLP to encode numeric market features
59
+ self.market_mlp = nn.Sequential(
60
+ nn.Linear(num_features, hidden_size),
61
+ nn.ReLU(),
62
+ nn.LayerNorm(hidden_size),
63
+ )
64
+
65
+ # Linear projections for market-conditioned attention
66
+ self.query_proj = nn.Linear(hidden_size, hidden_size)
67
+ self.key_proj = nn.Linear(hidden_size, hidden_size)
68
+
69
+ combined_size = hidden_size * 3 # [CLS] + context + market_emb
70
+
71
+ # Classification head for 3m horizon
72
+ self.classifier_3m = nn.Sequential(
73
+ nn.Linear(combined_size, hidden_size),
74
+ nn.ReLU(),
75
+ nn.Dropout(0.2),
76
+ nn.Linear(hidden_size, num_labels_3m),
77
+ )
78
+
79
+ # Classification head for 30m horizon
80
+ self.classifier_30m = nn.Sequential(
81
+ nn.Linear(combined_size, hidden_size),
82
+ nn.ReLU(),
83
+ nn.Dropout(0.2),
84
+ nn.Linear(hidden_size, num_labels_30m),
85
+ )
86
+
87
+ def forward(self, input_ids, attention_mask, market_features):
88
+ # Encode tweet with BERT
89
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
90
+ last_hidden_state = outputs.last_hidden_state
91
+ pooled_output = outputs.pooler_output
92
+
93
+ # Encode market state
94
+ market_emb = self.market_mlp(market_features)
95
+
96
+ # Market-conditioned attention
97
+ Q = self.query_proj(market_emb).unsqueeze(1)
98
+ K = self.key_proj(last_hidden_state)
99
+ scores = torch.matmul(Q, K.transpose(1, 2)) / math.sqrt(K.size(-1))
100
+
101
+ extended_mask = attention_mask.unsqueeze(1)
102
+ scores = scores.masked_fill(extended_mask == 0, float("-inf"))
103
+ attn_weights = torch.softmax(scores, dim=-1)
104
+
105
+ context = torch.matmul(attn_weights, last_hidden_state).squeeze(1)
106
+
107
+ # Combine and Classify
108
+ combined = torch.cat([pooled_output, context, market_emb], dim=-1)
109
+ logits_3m = self.classifier_3m(combined)
110
+ logits_30m = self.classifier_30m(combined)
111
+
112
+ return logits_3m, logits_30m
113
+
114
+ # ==========================================
115
+ # 3. INFERENCE CLASS
116
+ # ==========================================
117
+ class ExtremeModelPredictor:
118
+ def __init__(self, model_dir: str):
119
+ self.device = cfg.device
120
+ self.model_dir = model_dir
121
+
122
+ print(f"Loading Extreme Model from {model_dir}...")
123
+ print(f"Using device: {self.device.upper()}")
124
+
125
+ if self.device == 'cpu':
126
+ print("Note: Running on CPU. If you have an NVIDIA GPU, please install PyTorch with CUDA support.")
127
+
128
+ # 1. Load Tokenizer
129
+ try:
130
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
131
+ except Exception as e:
132
+ raise FileNotFoundError(f"Could not load tokenizer from {model_dir}. Please ensure tokenizer files exist. Error: {e}")
133
+
134
+ # 2. Load Scaler
135
+ scaler_path = os.path.join(model_dir, cfg.scaler_filename)
136
+ if not os.path.exists(scaler_path):
137
+ # Fallback for development environment
138
+ scaler_path = cfg.scaler_filename
139
+
140
+ if os.path.exists(scaler_path):
141
+ self.scaler = joblib.load(scaler_path)
142
+ else:
143
+ raise FileNotFoundError(f"Scaler not found at {scaler_path}. Please include scaler.pkl.")
144
+
145
+ # 3. Load Model Config & Weights
146
+ config_path = os.path.join(model_dir, "config.json")
147
+ if not os.path.exists(config_path):
148
+ raise FileNotFoundError(f"Config not found at {config_path}. Please ensure config.json exists.")
149
+
150
+ bert_config = AutoConfig.from_pretrained(model_dir)
151
+
152
+ self.model = MarketConditionedEventMultiTask(
153
+ num_features=len(cfg.feature_cols),
154
+ num_labels_3m=cfg.num_labels_3m,
155
+ num_labels_30m=cfg.num_labels_30m,
156
+ device=self.device,
157
+ bert_config=bert_config
158
+ )
159
+
160
+ # Load State Dict
161
+ weight_path = os.path.join(model_dir, cfg.model_filename)
162
+ if not os.path.exists(weight_path):
163
+ # Fallback name
164
+ weight_path = os.path.join(model_dir, "best_model.pt")
165
+
166
+ state_dict = torch.load(weight_path, map_location="cpu",weights_only=False)
167
+ self.model.load_state_dict(state_dict)
168
+ self.model.to(self.device)
169
+ self.model.eval()
170
+ print("Model loaded successfully.")
171
+
172
+ def preprocess_features(self, raw_feats: Dict[str, float]) -> np.ndarray:
173
+ # Ensure correct order and fill missing with 0
174
+ vals = [raw_feats.get(col, 0.0) for col in cfg.feature_cols]
175
+
176
+ # Create DataFrame with feature names to avoid sklearn warning
177
+ df = pd.DataFrame([vals], columns=cfg.feature_cols)
178
+ df = df.fillna(0)
179
+
180
+ return self.scaler.transform(df)
181
+
182
+ def predict(self, project_name: str, symbol: str, tweet_text: str, market_features: Dict[str, float]):
183
+ """
184
+ Returns probabilities for 3m and 30m horizons.
185
+ Classes: 0 (Down), 1 (Neutral), 2 (Up)
186
+ """
187
+
188
+ full_text = f"{project_name} ({symbol}): {tweet_text}"
189
+ start_time = time.perf_counter()
190
+ # Tokenize
191
+ encoded = self.tokenizer(
192
+ full_text,
193
+ padding="max_length",
194
+ truncation=True,
195
+ max_length=cfg.max_length,
196
+ return_tensors="pt"
197
+ )
198
+
199
+ input_ids = encoded["input_ids"].to(self.device)
200
+ attention_mask = encoded["attention_mask"].to(self.device)
201
+
202
+ # Features
203
+ feats_scaled = self.preprocess_features(market_features)
204
+ feats_tensor = torch.tensor(feats_scaled, dtype=torch.float32).to(self.device)
205
+
206
+ with torch.no_grad():
207
+ logits_3m, logits_30m = self.model(input_ids, attention_mask, feats_tensor)
208
+
209
+ probs_3m = torch.softmax(logits_3m, dim=-1).cpu().numpy()[0]
210
+ probs_30m = torch.softmax(logits_30m, dim=-1).cpu().numpy()[0]
211
+
212
+ inference_time = time.perf_counter() - start_time
213
+
214
+ return {
215
+ "3m_probs": {"Down": probs_3m[0], "Neutral": probs_3m[1], "Up": probs_3m[2]},
216
+ "30m_probs": {"Down": probs_30m[0], "Neutral": probs_30m[1], "Up": probs_30m[2]},
217
+ "extreme_signal": self._check_extreme(probs_3m),
218
+ "inference_time": inference_time
219
+ }
220
+
221
+ def _check_extreme(self, probs):
222
+ # Logic: If Down or Up > 0.7
223
+ threshold = 0.7
224
+ if probs[0] > threshold:
225
+ return "EXTREME DOWN"
226
+ elif probs[2] > threshold:
227
+ return "EXTREME UP"
228
+ return "NORMAL"
229
+
230
+ # ==========================================
231
+ # 4. EXAMPLE USAGE
232
+ # ==========================================
233
+ if __name__ == "__main__":
234
+ # Example of how to use this script
235
+
236
+ # Path to the folder containing: config.json, tokenizer files, scaler.pkl, best_model_extreme.pt
237
+ checkpoint_folder = "checkpoints_market_event_multitask"
238
+
239
+ # Check if folder exists before running
240
+ if os.path.exists(checkpoint_folder):
241
+ predictor = ExtremeModelPredictor(checkpoint_folder)
242
+
243
+ # Define a helper to run scenarios
244
+ def run_scenario(name, tweet, features):
245
+ print(f"\n--- Scenario: {name} ---")
246
+ print(f"Tweet: {tweet}")
247
+ # Print key features for context
248
+ print(f"Key Feats: 1m={features.get('feat_ret_1m')}, Vol={features.get('feat_volume_60m')}, F&G={features.get('feat_fear_greed_index')}")
249
+
250
+ res = predictor.predict("Test", "TST", tweet, features)
251
+ print(f"Signal: {res['extreme_signal']}")
252
+ print(f"3m Probs: Down={res['3m_probs']['Down']:.3f}, Neutral={res['3m_probs']['Neutral']:.3f}, Up={res['3m_probs']['Up']:.3f}")
253
+ print(f"Time: {res['inference_time']*1000:.2f} ms")
254
+ return res
255
+
256
+ # Warm-up
257
+ print("\n--- Warm-up Run ---")
258
+ predictor.predict("Warmup", "WP", "Warmup", {
259
+ "feat_ret_1m": 0.0, "feat_ret_5m": 0.0, "feat_ret_15m": 0.0,
260
+ "feat_volatility_60m": 0.0, "feat_num_trades_60m": 0, "feat_volume_60m": 0,
261
+ "feat_tweet_freq_24h": 0, "feat_time_since_prev_tweet": 0,
262
+ "feat_btc_ret_60m": 0.0, "feat_btc_ret_24h": 0.0,
263
+ "feat_fear_greed_index": 50, "feat_btc_dominance": 50, "feat_altseason_index": 0
264
+ })
265
+
266
+ # 1. FOMO / Strong Uptrend
267
+ run_scenario(
268
+ "FOMO / Strong Uptrend",
269
+ "BREAKING: Major exchange listing confirmed! 🚀 #ToTheMoon",
270
+ {
271
+ "feat_ret_1m": 0.02, "feat_ret_5m": 0.05, "feat_ret_15m": 0.08,
272
+ "feat_volatility_60m": 0.05, "feat_num_trades_60m": 1000, "feat_volume_60m": 2000000,
273
+ "feat_tweet_freq_24h": 100, "feat_time_since_prev_tweet": 5,
274
+ "feat_btc_ret_60m": 0.01, "feat_btc_ret_24h": 0.05,
275
+ "feat_fear_greed_index": 80, "feat_btc_dominance": 45, "feat_altseason_index": 80
276
+ }
277
+ )
278
+
279
+ # 2. Panic Dump / Crash (The "Mean Reversion" Test)
280
+ run_scenario(
281
+ "Panic Dump / Crash",
282
+ "URGENT: Security breach detected. Do not interact with contracts.",
283
+ {
284
+ "feat_ret_1m": -0.05, "feat_ret_5m": -0.08, "feat_ret_15m": -0.10,
285
+ "feat_volatility_60m": 0.15, "feat_num_trades_60m": 2000, "feat_volume_60m": 5000000,
286
+ "feat_tweet_freq_24h": 200, "feat_time_since_prev_tweet": 1,
287
+ "feat_btc_ret_60m": -0.03, "feat_btc_ret_24h": -0.08,
288
+ "feat_fear_greed_index": 10, "feat_btc_dominance": 60, "feat_altseason_index": 5
289
+ }
290
+ )
291
+
292
+ # 3. Slow Bleed / Bear Market
293
+ run_scenario(
294
+ "Slow Bleed / Bear Market",
295
+ "Weekly development update. Progress is slow but steady.",
296
+ {
297
+ "feat_ret_1m": -0.001, "feat_ret_5m": -0.002, "feat_ret_15m": -0.005,
298
+ "feat_volatility_60m": 0.01, "feat_num_trades_60m": 50, "feat_volume_60m": 100000,
299
+ "feat_tweet_freq_24h": 10, "feat_time_since_prev_tweet": 60,
300
+ "feat_btc_ret_60m": -0.001, "feat_btc_ret_24h": -0.01,
301
+ "feat_fear_greed_index": 30, "feat_btc_dominance": 55, "feat_altseason_index": 10
302
+ }
303
+ )
304
+
305
+ # 4. Sideways / Stable
306
+ run_scenario(
307
+ "Sideways / Stable",
308
+ "Just a normal day building. #Crypto",
309
+ {
310
+ "feat_ret_1m": 0.000, "feat_ret_5m": 0.001, "feat_ret_15m": -0.001,
311
+ "feat_volatility_60m": 0.005, "feat_num_trades_60m": 20, "feat_volume_60m": 50000,
312
+ "feat_tweet_freq_24h": 5, "feat_time_since_prev_tweet": 120,
313
+ "feat_btc_ret_60m": 0.000, "feat_btc_ret_24h": 0.002,
314
+ "feat_fear_greed_index": 50, "feat_btc_dominance": 50, "feat_altseason_index": 20
315
+ }
316
+ )
317
+
318
+ # 5. Divergence: Good News + Bad Price (Opportunity?)
319
+ run_scenario(
320
+ "Divergence: Good News + Bad Price",
321
+ "Partnership with Google Cloud announced!",
322
+ {
323
+ "feat_ret_1m": -0.02, "feat_ret_5m": -0.03, "feat_ret_15m": -0.03,
324
+ "feat_volatility_60m": 0.04, "feat_num_trades_60m": 150, "feat_volume_60m": 300000,
325
+ "feat_tweet_freq_24h": 20, "feat_time_since_prev_tweet": 10,
326
+ "feat_btc_ret_60m": -0.01, "feat_btc_ret_24h": -0.02,
327
+ "feat_fear_greed_index": 40, "feat_btc_dominance": 52, "feat_altseason_index": 15
328
+ }
329
+ )
330
+
331
+ else:
332
+ print(f"Error: Checkpoint folder '{checkpoint_folder}' not found.")
333
+ print("Please place this script next to your model folder.")