cryptogold-backend / model_handler.py
omniverse1's picture
Update model_handler.py
9927daa verified
raw
history blame
3.05 kB
import numpy as np
import torch
# Menggunakan ChronosPipeline untuk pemuatan dan inferensi yang efisien
from chronos import ChronosPipeline
class ModelHandler:
def __init__(self):
# Mengganti model lama dengan Chronos-2 yang lebih canggih
self.model_name = "amazon/chronos-2"
self.pipeline = None
# Penentuan device: "cuda" jika ada GPU, jika tidak "cpu"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.load_model()
def load_model(self):
"""Load Chronos-2 model using the official ChronosPipeline"""
try:
print(f"Loading {self.model_name} on {self.device}...")
# ChronosPipeline menangani semua proses tokenisasi dan pemuatan arsitektur dengan benar
self.pipeline = ChronosPipeline.from_pretrained(
self.model_name,
device_map=self.device,
)
print("Chronos-2 pipeline loaded successfully.")
except Exception as e:
# Jika gagal, pipeline akan tetap None, dan fallback akan digunakan
print(f"Error loading Chronos-2 model: {e}")
print("Using fallback prediction method")
self.pipeline = None
def predict(self, data, horizon=10):
"""Generate predictions using Chronos-2 or fallback. 'data' must be the dict from data_processor.prepare_for_chronos."""
try:
# Cek data: memastikan data yang masuk adalah dictionary yang valid
if data is None or not isinstance(data, dict) or 'original' not in data or len(data['original']) < 20:
return np.array([0] * horizon)
if self.pipeline is None:
# --- Fallback Logic (Menggunakan data['original']) ---
values = data['original']
recent_trend = np.polyfit(range(len(values[-20:])), values[-20:], 1)[0]
predictions = []
last_value = values[-1]
for i in range(horizon):
# Add trend with some noise
next_value = last_value + recent_trend * (i + 1)
# Use .get('std', 1.0) for safety
noise = np.random.normal(0, data.get('std', 1.0) * 0.1)
predictions.append(next_value + noise)
return np.array(predictions)
# --- Chronos-2 Inference ---
predictions_samples = self.pipeline.predict(
data['original'],
prediction_length=horizon,
num_samples=20
)
# Mengambil nilai rata-rata (mean) dari semua sampel untuk plot garis tunggal
mean_predictions = np.mean(predictions_samples, axis=0)
return mean_predictions
except Exception as e:
print(f"Prediction error: {e}")
return np.array([0] * horizon)