omniverse1 commited on
Commit
88fcfd5
·
verified ·
1 Parent(s): 9cf1928

Update model_handler.py

Browse files
Files changed (1) hide show
  1. model_handler.py +8 -4
model_handler.py CHANGED
@@ -1,21 +1,23 @@
1
  import numpy as np
2
  import torch
3
- from chronos import BaseChronosPipeline
 
4
 
5
  class ModelHandler:
6
  def __init__(self):
7
  self.model_name = "amazon/chronos-2"
8
  self.pipeline = None
 
9
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
  self.load_model()
11
 
12
  def load_model(self):
13
- """Load Chronos-2 model using the official ChronosPipeline"""
14
  try:
15
  print(f"Loading {self.model_name} on {self.device}...")
16
 
17
- # FIX UTAMA: Pemuatan otomatis oleh pipeline
18
- self.pipeline = ChronosPipeline.from_pretrained(
19
  self.model_name,
20
  device_map=self.device,
21
  )
@@ -49,12 +51,14 @@ class ModelHandler:
49
  return np.array(predictions)
50
 
51
  # --- Chronos-2 Inference ---
 
52
  predictions_samples = self.pipeline.predict(
53
  data['original'],
54
  prediction_length=horizon,
55
  num_samples=20
56
  )
57
 
 
58
  mean_predictions = np.mean(predictions_samples, axis=0)
59
 
60
  return mean_predictions
 
1
  import numpy as np
2
  import torch
3
+ # PENTING: Mengganti ChronosPipeline dengan BaseChronosPipeline sesuai referensi terbaru
4
+ from chronos import BaseChronosPipeline
5
 
6
  class ModelHandler:
7
  def __init__(self):
8
  self.model_name = "amazon/chronos-2"
9
  self.pipeline = None
10
+ # Penentuan device
11
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
  self.load_model()
13
 
14
  def load_model(self):
15
+ """Load Chronos-2 model using the BaseChronosPipeline"""
16
  try:
17
  print(f"Loading {self.model_name} on {self.device}...")
18
 
19
+ # Perhatikan: Menggunakan BaseChronosPipeline.from_pretrained
20
+ self.pipeline = BaseChronosPipeline.from_pretrained(
21
  self.model_name,
22
  device_map=self.device,
23
  )
 
51
  return np.array(predictions)
52
 
53
  # --- Chronos-2 Inference ---
54
+ # NOTE: BaseChronosPipeline.predict mengembalikan array of arrays (sampel)
55
  predictions_samples = self.pipeline.predict(
56
  data['original'],
57
  prediction_length=horizon,
58
  num_samples=20
59
  )
60
 
61
+ # Mengambil nilai rata-rata (mean) dari semua sampel
62
  mean_predictions = np.mean(predictions_samples, axis=0)
63
 
64
  return mean_predictions