File size: 8,305 Bytes
1802f47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
"""
BIST Predictor — TimesFM 2.5 Tahmin Motoru
Google TimesFM foundation model ile hisse fiyat tahmini.
NVIDIA 4060Ti (CUDA) desteği ile GPU üzerinde çalışır.
"""

import logging
from datetime import datetime, timedelta
from typing import Optional
import numpy as np

from config import (
    MODEL_NAME, MAX_CONTEXT, NORMALIZE_INPUTS,
    USE_QUANTILE_HEAD, HORIZONS, QUANTILE_LEVELS, DEFAULT_HORIZON
)

logger = logging.getLogger(__name__)

# Global model instance (singleton — model büyük olduğu için bir kez yüklenir)
_model = None
_model_loaded = False


def _load_model():
    """TimesFM 2.5 modelini yükle (ilk çağrıda)."""
    global _model, _model_loaded
    
    if _model_loaded:
        return _model
    
    try:
        import torch
        torch.set_float32_matmul_precision("high")
        
        device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"PyTorch cihaz: {device}")
        if device == "cuda":
            logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
            logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
        
        import timesfm
        
        logger.info(f"TimesFM modeli yükleniyor: {MODEL_NAME}")
        _model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(MODEL_NAME)
        
        # Forecast konfigürasyonu — horizon'u en büyük değere ayarla
        max_horizon = max(HORIZONS)
        _model.compile(
            timesfm.ForecastConfig(
                max_context=MAX_CONTEXT,
                max_horizon=max_horizon,
                normalize_inputs=NORMALIZE_INPUTS,
                use_continuous_quantile_head=USE_QUANTILE_HEAD,
            )
        )
        
        _model_loaded = True
        logger.info(f"TimesFM modeli başarıyla yüklendi. Max horizon: {max_horizon}")
        return _model
    except Exception as e:
        logger.error(f"Model yükleme hatası: {e}")
        raise


def predict_stock(closing_prices: list[float], horizon: int = DEFAULT_HORIZON) -> Optional[dict]:
    """
    Tek bir hisse için tahmin üret.
    
    Args:
        closing_prices: Geçmiş kapanış fiyatları (kronolojik sırada)
        horizon: Tahmin edilecek gün sayısı
    
    Returns:
        dict: {
            "point_forecast": [float],      # Nokta tahminleri
            "quantiles": {                   # Quantile tahminleri
                "p10": [float], "p20": [float], ... "p90": [float]
            },
            "horizon": int,
            "context_length": int,
        }
    """
    model = _load_model()
    
    if model is None:
        logger.error("Model yüklenemedi, tahmin yapılamıyor.")
        return None
    
    try:
        # Context window'u sınırla
        context = closing_prices[-MAX_CONTEXT:] if len(closing_prices) > MAX_CONTEXT else closing_prices
        input_array = np.array(context, dtype=np.float32)
        
        # Tahmin üret
        point_forecast, quantile_forecast = model.forecast(
            horizon=horizon,
            inputs=[input_array],
        )
        
        # Sonuçları düzenle
        result = {
            "point_forecast": point_forecast[0].tolist()[:horizon],
            "quantiles": {},
            "horizon": horizon,
            "context_length": len(context),
            "last_known_price": float(closing_prices[-1]),
        }
        
        # Quantile sonuçlarını işle
        if quantile_forecast is not None and len(quantile_forecast.shape) >= 3:
            q_data = quantile_forecast[0]  # İlk (tek) seri
            num_quantiles = q_data.shape[-1]
            
            quantile_keys = ["p10", "p20", "p30", "p40", "p50", "p60", "p70", "p80", "p90"]
            for i, key in enumerate(quantile_keys):
                if i < num_quantiles:
                    result["quantiles"][key] = q_data[:horizon, i].tolist()
        
        logger.info(f"Tahmin üretildi: horizon={horizon}, context={len(context)}")
        return result
        
    except Exception as e:
        logger.error(f"Tahmin hatası: {e}")
        return None


def predict_stock_multi_horizon(closing_prices: list[float],
                                 horizons: list[int] = None) -> dict:
    """
    Birden fazla horizon için tahmin üret.
    
    Args:
        closing_prices: Geçmiş kapanış fiyatları
        horizons: Tahmin horizon listesi [10, 30, 90]
    
    Returns:
        dict: {horizon: prediction_result}
    """
    if horizons is None:
        horizons = HORIZONS
    
    results = {}
    for h in horizons:
        result = predict_stock(closing_prices, horizon=h)
        if result:
            results[h] = result
    
    return results


def predict_batch(stocks_data: dict, horizon: int = DEFAULT_HORIZON) -> dict:
    """
    Birden fazla hisse için toplu tahmin üret.
    
    Args:
        stocks_data: {symbol: closing_prices_list}
        horizon: Tahmin horizon'u
    
    Returns:
        dict: {symbol: prediction_result}
    """
    model = _load_model()
    
    if model is None:
        return {}
    
    try:
        # Tüm serileri hazırla
        symbols = list(stocks_data.keys())
        inputs = []
        
        for symbol in symbols:
            prices = stocks_data[symbol]
            context = prices[-MAX_CONTEXT:] if len(prices) > MAX_CONTEXT else prices
            inputs.append(np.array(context, dtype=np.float32))
        
        # Toplu tahmin
        point_forecasts, quantile_forecasts = model.forecast(
            horizon=horizon,
            inputs=inputs,
        )
        
        # Sonuçları düzenle
        results = {}
        for i, symbol in enumerate(symbols):
            result = {
                "point_forecast": point_forecasts[i].tolist()[:horizon],
                "quantiles": {},
                "horizon": horizon,
                "context_length": len(inputs[i]),
                "last_known_price": float(stocks_data[symbol][-1]),
            }
            
            if quantile_forecasts is not None and len(quantile_forecasts.shape) >= 3:
                q_data = quantile_forecasts[i]
                num_quantiles = q_data.shape[-1]
                
                quantile_keys = ["p10", "p20", "p30", "p40", "p50", "p60", "p70", "p80", "p90"]
                for qi, key in enumerate(quantile_keys):
                    if qi < num_quantiles:
                        result["quantiles"][key] = q_data[:horizon, qi].tolist()
            
            results[symbol] = result
        
        logger.info(f"Toplu tahmin tamamlandı: {len(results)} hisse, horizon={horizon}")
        return results
        
    except Exception as e:
        logger.error(f"Toplu tahmin hatası: {e}")
        # Fallback: tek tek tahmin
        results = {}
        for symbol, prices in stocks_data.items():
            try:
                r = predict_stock(prices, horizon)
                if r:
                    results[symbol] = r
            except Exception:
                pass
        return results


def generate_target_dates(from_date: str, horizon: int) -> list[str]:
    """
    İş günlerini baz alarak hedef tarih listesi üret.
    
    Args:
        from_date: Başlangıç tarihi (YYYY-MM-DD)
        horizon: Kaç iş günü
    
    Returns:
        list[str]: Hedef tarih listesi
    """
    start = datetime.strptime(from_date, "%Y-%m-%d")
    target_dates = []
    current = start
    
    while len(target_dates) < horizon:
        current += timedelta(days=1)
        # Hafta içi kontrolü (Pazartesi=0 ... Cuma=4)
        if current.weekday() < 5:
            target_dates.append(current.strftime("%Y-%m-%d"))
    
    return target_dates


def is_model_loaded() -> bool:
    """Model yüklü mü kontrol et."""
    return _model_loaded


def get_model_info() -> dict:
    """Model bilgilerini getir."""
    import torch
    return {
        "model_name": MODEL_NAME,
        "loaded": _model_loaded,
        "cuda_available": torch.cuda.is_available(),
        "device": "cuda" if torch.cuda.is_available() else "cpu",
        "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
        "max_context": MAX_CONTEXT,
        "horizons": HORIZONS,
        "quantile_levels": QUANTILE_LEVELS,
    }