Spaces:
Running
Running
File size: 8,305 Bytes
1802f47 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 | """
BIST Predictor — TimesFM 2.5 Tahmin Motoru
Google TimesFM foundation model ile hisse fiyat tahmini.
NVIDIA 4060Ti (CUDA) desteği ile GPU üzerinde çalışır.
"""
import logging
from datetime import datetime, timedelta
from typing import Optional
import numpy as np
from config import (
MODEL_NAME, MAX_CONTEXT, NORMALIZE_INPUTS,
USE_QUANTILE_HEAD, HORIZONS, QUANTILE_LEVELS, DEFAULT_HORIZON
)
logger = logging.getLogger(__name__)
# Global model instance (singleton — model büyük olduğu için bir kez yüklenir)
_model = None
_model_loaded = False
def _load_model():
"""TimesFM 2.5 modelini yükle (ilk çağrıda)."""
global _model, _model_loaded
if _model_loaded:
return _model
try:
import torch
torch.set_float32_matmul_precision("high")
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"PyTorch cihaz: {device}")
if device == "cuda":
logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
import timesfm
logger.info(f"TimesFM modeli yükleniyor: {MODEL_NAME}")
_model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(MODEL_NAME)
# Forecast konfigürasyonu — horizon'u en büyük değere ayarla
max_horizon = max(HORIZONS)
_model.compile(
timesfm.ForecastConfig(
max_context=MAX_CONTEXT,
max_horizon=max_horizon,
normalize_inputs=NORMALIZE_INPUTS,
use_continuous_quantile_head=USE_QUANTILE_HEAD,
)
)
_model_loaded = True
logger.info(f"TimesFM modeli başarıyla yüklendi. Max horizon: {max_horizon}")
return _model
except Exception as e:
logger.error(f"Model yükleme hatası: {e}")
raise
def predict_stock(closing_prices: list[float], horizon: int = DEFAULT_HORIZON) -> Optional[dict]:
"""
Tek bir hisse için tahmin üret.
Args:
closing_prices: Geçmiş kapanış fiyatları (kronolojik sırada)
horizon: Tahmin edilecek gün sayısı
Returns:
dict: {
"point_forecast": [float], # Nokta tahminleri
"quantiles": { # Quantile tahminleri
"p10": [float], "p20": [float], ... "p90": [float]
},
"horizon": int,
"context_length": int,
}
"""
model = _load_model()
if model is None:
logger.error("Model yüklenemedi, tahmin yapılamıyor.")
return None
try:
# Context window'u sınırla
context = closing_prices[-MAX_CONTEXT:] if len(closing_prices) > MAX_CONTEXT else closing_prices
input_array = np.array(context, dtype=np.float32)
# Tahmin üret
point_forecast, quantile_forecast = model.forecast(
horizon=horizon,
inputs=[input_array],
)
# Sonuçları düzenle
result = {
"point_forecast": point_forecast[0].tolist()[:horizon],
"quantiles": {},
"horizon": horizon,
"context_length": len(context),
"last_known_price": float(closing_prices[-1]),
}
# Quantile sonuçlarını işle
if quantile_forecast is not None and len(quantile_forecast.shape) >= 3:
q_data = quantile_forecast[0] # İlk (tek) seri
num_quantiles = q_data.shape[-1]
quantile_keys = ["p10", "p20", "p30", "p40", "p50", "p60", "p70", "p80", "p90"]
for i, key in enumerate(quantile_keys):
if i < num_quantiles:
result["quantiles"][key] = q_data[:horizon, i].tolist()
logger.info(f"Tahmin üretildi: horizon={horizon}, context={len(context)}")
return result
except Exception as e:
logger.error(f"Tahmin hatası: {e}")
return None
def predict_stock_multi_horizon(closing_prices: list[float],
horizons: list[int] = None) -> dict:
"""
Birden fazla horizon için tahmin üret.
Args:
closing_prices: Geçmiş kapanış fiyatları
horizons: Tahmin horizon listesi [10, 30, 90]
Returns:
dict: {horizon: prediction_result}
"""
if horizons is None:
horizons = HORIZONS
results = {}
for h in horizons:
result = predict_stock(closing_prices, horizon=h)
if result:
results[h] = result
return results
def predict_batch(stocks_data: dict, horizon: int = DEFAULT_HORIZON) -> dict:
"""
Birden fazla hisse için toplu tahmin üret.
Args:
stocks_data: {symbol: closing_prices_list}
horizon: Tahmin horizon'u
Returns:
dict: {symbol: prediction_result}
"""
model = _load_model()
if model is None:
return {}
try:
# Tüm serileri hazırla
symbols = list(stocks_data.keys())
inputs = []
for symbol in symbols:
prices = stocks_data[symbol]
context = prices[-MAX_CONTEXT:] if len(prices) > MAX_CONTEXT else prices
inputs.append(np.array(context, dtype=np.float32))
# Toplu tahmin
point_forecasts, quantile_forecasts = model.forecast(
horizon=horizon,
inputs=inputs,
)
# Sonuçları düzenle
results = {}
for i, symbol in enumerate(symbols):
result = {
"point_forecast": point_forecasts[i].tolist()[:horizon],
"quantiles": {},
"horizon": horizon,
"context_length": len(inputs[i]),
"last_known_price": float(stocks_data[symbol][-1]),
}
if quantile_forecasts is not None and len(quantile_forecasts.shape) >= 3:
q_data = quantile_forecasts[i]
num_quantiles = q_data.shape[-1]
quantile_keys = ["p10", "p20", "p30", "p40", "p50", "p60", "p70", "p80", "p90"]
for qi, key in enumerate(quantile_keys):
if qi < num_quantiles:
result["quantiles"][key] = q_data[:horizon, qi].tolist()
results[symbol] = result
logger.info(f"Toplu tahmin tamamlandı: {len(results)} hisse, horizon={horizon}")
return results
except Exception as e:
logger.error(f"Toplu tahmin hatası: {e}")
# Fallback: tek tek tahmin
results = {}
for symbol, prices in stocks_data.items():
try:
r = predict_stock(prices, horizon)
if r:
results[symbol] = r
except Exception:
pass
return results
def generate_target_dates(from_date: str, horizon: int) -> list[str]:
"""
İş günlerini baz alarak hedef tarih listesi üret.
Args:
from_date: Başlangıç tarihi (YYYY-MM-DD)
horizon: Kaç iş günü
Returns:
list[str]: Hedef tarih listesi
"""
start = datetime.strptime(from_date, "%Y-%m-%d")
target_dates = []
current = start
while len(target_dates) < horizon:
current += timedelta(days=1)
# Hafta içi kontrolü (Pazartesi=0 ... Cuma=4)
if current.weekday() < 5:
target_dates.append(current.strftime("%Y-%m-%d"))
return target_dates
def is_model_loaded() -> bool:
"""Model yüklü mü kontrol et."""
return _model_loaded
def get_model_info() -> dict:
"""Model bilgilerini getir."""
import torch
return {
"model_name": MODEL_NAME,
"loaded": _model_loaded,
"cuda_available": torch.cuda.is_available(),
"device": "cuda" if torch.cuda.is_available() else "cpu",
"gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
"max_context": MAX_CONTEXT,
"horizons": HORIZONS,
"quantile_levels": QUANTILE_LEVELS,
}
|