Update README.md
Browse files
README.md
CHANGED
|
@@ -26,6 +26,11 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 26 |
batch_size, lookback_length, channels = 1, 2880, 7
|
| 27 |
time_series = torch.randn(batch_size, lookback_length, channels)
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
# Generate forecasts
|
| 30 |
forecast_length = 96
|
| 31 |
predictions = model.generate(time_series, max_new_tokens=forecast_length)
|
|
|
|
| 26 |
batch_size, lookback_length, channels = 1, 2880, 7
|
| 27 |
time_series = torch.randn(batch_size, lookback_length, channels)
|
| 28 |
|
| 29 |
+
# Load the model and data to the same device
|
| 30 |
+
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
|
| 31 |
+
model = model.to(device)
|
| 32 |
+
time_series = time_series.to(device)
|
| 33 |
+
|
| 34 |
# Generate forecasts
|
| 35 |
forecast_length = 96
|
| 36 |
predictions = model.generate(time_series, max_new_tokens=forecast_length)
|