wang-sy commited on
Commit
a85dd96
·
1 Parent(s): 69c984c

change ckpt.pt path

Browse files
Files changed (3) hide show
  1. handler.py +6 -1
  2. pipeline.py +3 -1
  3. requirements.txt +1 -0
handler.py CHANGED
@@ -1,11 +1,16 @@
1
  from typing import Dict, List, Any
2
  import pipeline
 
3
 
4
  class EndpointHandler():
5
  def __init__(self, path=""):
6
  # Preload all the elements you are going to need at inference.
7
  # pseudo:
8
- a = 1
 
 
 
 
9
 
10
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
11
  """
 
1
  from typing import Dict, List, Any
2
  import pipeline
3
+ import gdown
4
 
5
  class EndpointHandler():
6
  def __init__(self, path=""):
7
  # Preload all the elements you are going to need at inference.
8
  # pseudo:
9
+ file_name = 'ckpt.pt'
10
+ url = 'https://drive.google.com/file/d/1jt5zyFcyVUOd5kC_yMrcj3Wqk7kAzuPU/view?usp=sharing'
11
+ gdown.download(url, file_name, quiet=False)
12
+
13
+
14
 
15
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
  """
pipeline.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  import tiktoken
5
  from model import GPTConfig, GPT
6
 
 
7
  init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
8
  out_dir = 'out-stinfo' # ignored if init_from is not 'resume'
9
  start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
@@ -29,7 +30,8 @@ def infer():
29
  ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(
30
  device_type=device_type, dtype=ptdtype)
31
 
32
- ckpt_path = os.path.join(os.getcwd(), out_dir, 'ckpt.pt')
 
33
  checkpoint = torch.load(ckpt_path, map_location=device)
34
  gptconf = GPTConfig(**checkpoint['model_args'])
35
  model = GPT(gptconf)
 
4
  import tiktoken
5
  from model import GPTConfig, GPT
6
 
7
+
8
  init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
9
  out_dir = 'out-stinfo' # ignored if init_from is not 'resume'
10
  start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
 
30
  ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(
31
  device_type=device_type, dtype=ptdtype)
32
 
33
+ # ckpt_path = os.path.join(os.getcwd(), out_dir, 'ckpt.pt')
34
+ ckpt_path = 'ckpt.pt'
35
  checkpoint = torch.load(ckpt_path, map_location=device)
36
  gptconf = GPTConfig(**checkpoint['model_args'])
37
  model = GPT(gptconf)
requirements.txt CHANGED
@@ -2,3 +2,4 @@
2
  torch
3
  tiktoken
4
  numpy
 
 
2
  torch
3
  tiktoken
4
  numpy
5
+ gdown