omniverse1 commited on
Commit
ec4d569
·
verified ·
1 Parent(s): 8ce0053

update app

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -3,7 +3,7 @@ import yfinance as yf
3
  import pandas as pd
4
  import numpy as np
5
  import torch
6
- from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from datetime import datetime, timedelta
8
  import plotly.graph_objects as go
9
  import plotly.express as px
@@ -30,10 +30,11 @@ from config import IDX_STOCKS, TECHNICAL_INDICATORS, PREDICTION_CONFIG
30
  @spaces.GPU(duration=120)
31
  def load_model():
32
  """Load the Amazon Chronos-Bolt model for time series forecasting"""
33
- model = AutoModelForCausalLM.from_pretrained(
34
  "amazon/chronos-bolt-base",
35
  torch_dtype=torch.bfloat16,
36
- device_map="auto"
 
37
  )
38
  tokenizer = AutoTokenizer.from_pretrained("amazon/chronos-bolt-base")
39
  return model, tokenizer
 
3
  import pandas as pd
4
  import numpy as np
5
  import torch
6
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
  from datetime import datetime, timedelta
8
  import plotly.graph_objects as go
9
  import plotly.express as px
 
30
  @spaces.GPU(duration=120)
31
  def load_model():
32
  """Load the Amazon Chronos-Bolt model for time series forecasting"""
33
+ model = AutoModelForSeq2SeqLM.from_pretrained(
34
  "amazon/chronos-bolt-base",
35
  torch_dtype=torch.bfloat16,
36
+ device_map="auto",
37
+ trust_remote_code=True
38
  )
39
  tokenizer = AutoTokenizer.from_pretrained("amazon/chronos-bolt-base")
40
  return model, tokenizer