change ckpt.pt path
Browse files- handler.py +6 -1
- pipeline.py +3 -1
- requirements.txt +1 -0
handler.py
CHANGED
|
@@ -1,11 +1,16 @@
|
|
| 1 |
from typing import Dict, List, Any
|
| 2 |
import pipeline
|
|
|
|
| 3 |
|
| 4 |
class EndpointHandler():
|
| 5 |
def __init__(self, path=""):
|
| 6 |
# Preload all the elements you are going to need at inference.
|
| 7 |
# pseudo:
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 11 |
"""
|
|
|
|
| 1 |
from typing import Dict, List, Any
|
| 2 |
import pipeline
|
| 3 |
+
import gdown
|
| 4 |
|
| 5 |
class EndpointHandler():
|
| 6 |
def __init__(self, path=""):
|
| 7 |
# Preload all the elements you are going to need at inference.
|
| 8 |
# pseudo:
|
| 9 |
+
file_name = 'ckpt.pt'
|
| 10 |
+
url = 'https://drive.google.com/file/d/1jt5zyFcyVUOd5kC_yMrcj3Wqk7kAzuPU/view?usp=sharing'
|
| 11 |
+
gdown.download(url, file_name, quiet=False)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
|
| 15 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 16 |
"""
|
pipeline.py
CHANGED
|
@@ -4,6 +4,7 @@ import torch
|
|
| 4 |
import tiktoken
|
| 5 |
from model import GPTConfig, GPT
|
| 6 |
|
|
|
|
| 7 |
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
|
| 8 |
out_dir = 'out-stinfo' # ignored if init_from is not 'resume'
|
| 9 |
start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
|
|
@@ -29,7 +30,8 @@ def infer():
|
|
| 29 |
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(
|
| 30 |
device_type=device_type, dtype=ptdtype)
|
| 31 |
|
| 32 |
-
ckpt_path = os.path.join(os.getcwd(), out_dir, 'ckpt.pt')
|
|
|
|
| 33 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
| 34 |
gptconf = GPTConfig(**checkpoint['model_args'])
|
| 35 |
model = GPT(gptconf)
|
|
|
|
| 4 |
import tiktoken
|
| 5 |
from model import GPTConfig, GPT
|
| 6 |
|
| 7 |
+
|
| 8 |
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
|
| 9 |
out_dir = 'out-stinfo' # ignored if init_from is not 'resume'
|
| 10 |
start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
|
|
|
|
| 30 |
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(
|
| 31 |
device_type=device_type, dtype=ptdtype)
|
| 32 |
|
| 33 |
+
# ckpt_path = os.path.join(os.getcwd(), out_dir, 'ckpt.pt')
|
| 34 |
+
ckpt_path = 'ckpt.pt'
|
| 35 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
| 36 |
gptconf = GPTConfig(**checkpoint['model_args'])
|
| 37 |
model = GPT(gptconf)
|
requirements.txt
CHANGED
|
@@ -2,3 +2,4 @@
|
|
| 2 |
torch
|
| 3 |
tiktoken
|
| 4 |
numpy
|
|
|
|
|
|
| 2 |
torch
|
| 3 |
tiktoken
|
| 4 |
numpy
|
| 5 |
+
gdown
|