File size: 3,046 Bytes
b8086d5
a90bc0e
 
9927daa
b8086d5
 
 
a90bc0e
 
 
 
 
b8086d5
 
 
9927daa
b8086d5
a90bc0e
b8086d5
9927daa
a90bc0e
 
 
b8086d5
a90bc0e
b8086d5
 
9927daa
a90bc0e
b8086d5
a90bc0e
b8086d5
 
9927daa
b8086d5
9927daa
 
b8086d5
 
a90bc0e
9927daa
b8086d5
 
 
 
 
 
 
9927daa
b8086d5
9927daa
 
b8086d5
 
 
 
a90bc0e
 
 
 
 
 
 
9927daa
a90bc0e
 
 
b8086d5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
import torch
# Menggunakan ChronosPipeline untuk pemuatan dan inferensi yang efisien
from chronos import ChronosPipeline 

class ModelHandler:
    def __init__(self):
        # Mengganti model lama dengan Chronos-2 yang lebih canggih
        self.model_name = "amazon/chronos-2" 
        self.pipeline = None
        # Penentuan device: "cuda" jika ada GPU, jika tidak "cpu"
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.load_model()
    
    def load_model(self):
        """Load Chronos-2 model using the official ChronosPipeline"""
        try:
            print(f"Loading {self.model_name} on {self.device}...")
            
            # ChronosPipeline menangani semua proses tokenisasi dan pemuatan arsitektur dengan benar
            self.pipeline = ChronosPipeline.from_pretrained(
                self.model_name,
                device_map=self.device,
            )
            print("Chronos-2 pipeline loaded successfully.")
            
        except Exception as e:
            # Jika gagal, pipeline akan tetap None, dan fallback akan digunakan
            print(f"Error loading Chronos-2 model: {e}")
            print("Using fallback prediction method")
            self.pipeline = None
    
    def predict(self, data, horizon=10):
        """Generate predictions using Chronos-2 or fallback. 'data' must be the dict from data_processor.prepare_for_chronos."""
        try:
            # Cek data: memastikan data yang masuk adalah dictionary yang valid
            if data is None or not isinstance(data, dict) or 'original' not in data or len(data['original']) < 20:
                return np.array([0] * horizon)
            
            if self.pipeline is None:
                # --- Fallback Logic (Menggunakan data['original']) ---
                values = data['original']
                recent_trend = np.polyfit(range(len(values[-20:])), values[-20:], 1)[0]
                
                predictions = []
                last_value = values[-1]
                
                for i in range(horizon):
                    # Add trend with some noise
                    next_value = last_value + recent_trend * (i + 1)
                    # Use .get('std', 1.0) for safety
                    noise = np.random.normal(0, data.get('std', 1.0) * 0.1)
                    predictions.append(next_value + noise)
                
                return np.array(predictions)
            
            # --- Chronos-2 Inference ---
            predictions_samples = self.pipeline.predict(
                data['original'],
                prediction_length=horizon,
                num_samples=20 
            )
            
            # Mengambil nilai rata-rata (mean) dari semua sampel untuk plot garis tunggal
            mean_predictions = np.mean(predictions_samples, axis=0)
            
            return mean_predictions
            
        except Exception as e:
            print(f"Prediction error: {e}")
            return np.array([0] * horizon)