omniverse1 commited on
Commit
fca6922
·
verified ·
1 Parent(s): e045d56

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +9 -5
utils.py CHANGED
@@ -271,14 +271,19 @@ def predict_prices(data, model, tokenizer, prediction_days=30):
271
  input_sequence = prices[-context_length:]
272
 
273
  # Create prediction input
274
- # CRITICAL FIX: Cast input to LongTensor to satisfy model embedding layer expectation.
275
- prediction_input = torch.tensor(input_sequence).unsqueeze(0).long().to(model.device)
 
 
276
 
277
  # Generate predictions
278
  with torch.no_grad():
279
- # NOTE: Removed prediction_length from generate call as it was causing failures.
 
280
  forecast = model.generate(
281
  prediction_input,
 
 
282
  )
283
 
284
  # Handle complex Chronos output: [batch_size, num_samples, prediction_length]
@@ -287,8 +292,7 @@ def predict_prices(data, model, tokenizer, prediction_days=30):
287
  output_tensor = forecast[0] if isinstance(forecast, (tuple, list)) else forecast
288
 
289
  # 2. Average across the samples (dim=1) and convert to a simple 1D numpy array
290
- # predictions = output_tensor.mean(dim=1).squeeze().cpu().numpy()
291
- # To avoid unexpected dimensions, let's simplify averaging:
292
  predictions = output_tensor.float().mean(dim=1).squeeze().cpu().numpy()
293
 
294
  # Handle case where predictions is a single scalar (convert to array for safety)
 
271
  input_sequence = prices[-context_length:]
272
 
273
  # Create prediction input
274
+ # CRITICAL FIX: Revert to float, but use max_new_tokens for generation.
275
+ # This forces the model's custom generation logic to handle the raw floats,
276
+ # as manual quantization/token mapping is impossible without the Chronos tokenizer.
277
+ prediction_input = torch.tensor(input_sequence).unsqueeze(0).float().to(model.device)
278
 
279
  # Generate predictions
280
  with torch.no_grad():
281
+ # CRITICAL FIX: Use max_new_tokens, which is standard for Seq2Seq generation length.
282
+ # Removed prediction_length keyword and added do_sample to align with typical Chronos usage.
283
  forecast = model.generate(
284
  prediction_input,
285
+ max_new_tokens=prediction_days,
286
+ do_sample=True
287
  )
288
 
289
  # Handle complex Chronos output: [batch_size, num_samples, prediction_length]
 
292
  output_tensor = forecast[0] if isinstance(forecast, (tuple, list)) else forecast
293
 
294
  # 2. Average across the samples (dim=1) and convert to a simple 1D numpy array
295
+ # output_tensor shape: [1, num_samples, prediction_length] (If sampling, the float values are the predictions)
 
296
  predictions = output_tensor.float().mean(dim=1).squeeze().cpu().numpy()
297
 
298
  # Handle case where predictions is a single scalar (convert to array for safety)