pchen182224 commited on
Commit
2a7112d
·
verified ·
1 Parent(s): 3f3fabd

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +21 -5
README.md CHANGED
@@ -26,21 +26,37 @@ tags:
26
  pip install transformers==4.30.2 # Use this version for stable compatibility
27
  ```
28
 
 
29
  ```
 
 
30
  import torch
31
  from transformers import AutoModelForCausalLM
 
 
32
  # load pretrain model
33
- # supports different lookback/forecast lengths
34
- model = AutoModelForCausalLM.from_pretrained('DecisionIntelligence/LightGTS', trust_remote_code=True)
 
 
 
 
 
 
 
 
35
  # prepare input
36
- batch_size, lookback_length = 1, 528
37
  seqs = torch.randn(batch_size, lookback_length).unsqueeze(-1).float()
38
- # Note that Sundial can generate multiple probable predictions
39
  forecast_length = 192
40
  outputs = model.generate(seqs, patch_len = 48, stride_len=48, max_output_length=forecast_length, inference_patch_len=48)
41
- print(output.shape)
 
42
  ```
43
 
 
 
44
 
45
  ## Citation
46
 
 
26
  pip install transformers==4.30.2 # Use this version for stable compatibility
27
  ```
28
 
29
+ ### Zero-Shot
30
  ```
31
+ from configuration_LightGTS import LightGTSConfig
32
+ from modeling_LightGTS import LightGTSForPrediction
33
  import torch
34
  from transformers import AutoModelForCausalLM
35
+ from transformers import AutoModelForCausalLM, MODEL_MAPPING
36
+ from transformers import AutoConfig
37
  # load pretrain model
38
+ LightGTS_config = LightGTSConfig(context_points=528, c_in=1, target_dim=192, patch_len=48, stride=48)
39
+ LightGTS_config.save_pretrained("LightGTS-huggingface")
40
+
41
+ AutoConfig.register("LightGTS",LightGTSConfig)
42
+ AutoModelForCausalLM.register(LightGTSConfig, LightGTSForPrediction)
43
+
44
+ model = AutoModelForCausalLM.from_pretrained(
45
+ "./LightGTS-huggingface",
46
+ trust_remote_code=True
47
+ )
48
  # prepare input
49
+ batch_size, lookback_length = 1, 576
50
  seqs = torch.randn(batch_size, lookback_length).unsqueeze(-1).float()
51
+ # generate forecasting results
52
  forecast_length = 192
53
  outputs = model.generate(seqs, patch_len = 48, stride_len=48, max_output_length=forecast_length, inference_patch_len=48)
54
+ print(outputs.shape)
55
+
56
  ```
57
 
58
+ ### Fine-tune
59
+ For usage examples, please see test_finetune.py
60
 
61
  ## Citation
62