cryptogold-backend / model_handler.py
omniverse1's picture
Update model_handler.py
a90bc0e verified
raw
history blame
3.01 kB
import numpy as np
import torch
# Menggunakan ChronosPipeline untuk pemuatan dan inferensi yang efisien
from chronos import ChronosPipeline
class ModelHandler:
def __init__(self):
# Mengganti model lama dengan Chronos-2 yang lebih canggih
self.model_name = "amazon/chronos-2"
self.pipeline = None
# Penentuan device: "cuda" jika ada GPU, jika tidak "cpu"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.load_model()
def load_model(self):
"""Load Chronos-2 model optimized for CPU/GPU"""
try:
print(f"Loading {self.model_name} on {self.device}...")
# ChronosPipeline menangani semua proses tokenisasi dan pemuatan arsitektur
self.pipeline = ChronosPipeline.from_pretrained(
self.model_name,
device_map=self.device,
)
print("Chronos-2 pipeline loaded successfully.")
except Exception as e:
print(f"Error loading Chronos-2 model: {e}")
print("Using fallback prediction method")
self.pipeline = None
def predict(self, data, horizon=10):
"""Generate predictions using Chronos-2 or fallback"""
try:
# Menggunakan data['original'] yang merupakan harga aktual riil
if data is None or len(data['original']) < 20:
return np.array([0] * horizon)
if self.pipeline is None:
# --- Fallback Logic ---
# Logic ekstrapolasi tren lama tetap dipertahankan jika model Deep Learning gagal dimuat
values = data['original']
recent_trend = np.polyfit(range(len(values[-20:])), values[-20:], 1)[0]
predictions = []
last_value = values[-1]
for i in range(horizon):
next_value = last_value + recent_trend * (i + 1)
noise = np.random.normal(0, data['std'] * 0.1)
predictions.append(next_value + noise)
return np.array(predictions)
# --- Chronos-2 Inference ---
# Input: numpy array dari harga Close historis yang riil
predictions_samples = self.pipeline.predict(
data['original'],
prediction_length=horizon,
# Mengambil 20 sampel prediksi untuk mendapatkan prediksi probablistik
num_samples=20
)
# Untuk chart (garis tunggal), ambil nilai rata-rata (mean) dari semua sampel.
mean_predictions = np.mean(predictions_samples, axis=0)
return mean_predictions
except Exception as e:
print(f"Prediction error: {e}")
# Mengembalikan array nol jika ada error saat inferensi Chronos
return np.array([0] * horizon)