Spaces:
Runtime error
Runtime error
Commit
·
725dc81
1
Parent(s):
b43e55e
Update generate.py
Browse files- generate.py +5 -5
generate.py
CHANGED
|
@@ -21,17 +21,17 @@ TYPE_WRITER=1 # whether output streamly
|
|
| 21 |
|
| 22 |
args = parser.parse_args()
|
| 23 |
print(args)
|
| 24 |
-
tokenizer = LlamaTokenizer.from_pretrained(
|
| 25 |
|
| 26 |
LOAD_8BIT = True
|
| 27 |
|
| 28 |
|
| 29 |
|
| 30 |
# fix the path for local checkpoint
|
| 31 |
-
lora_bin_path = os.path.join(
|
| 32 |
print(lora_bin_path)
|
| 33 |
-
if not os.path.exists(lora_bin_path) and
|
| 34 |
-
pytorch_bin_path = os.path.join(
|
| 35 |
print(pytorch_bin_path)
|
| 36 |
if os.path.exists(pytorch_bin_path):
|
| 37 |
os.rename(pytorch_bin_path, lora_bin_path)
|
|
@@ -140,7 +140,7 @@ def evaluate(
|
|
| 140 |
**kwargs,
|
| 141 |
)
|
| 142 |
with torch.no_grad():
|
| 143 |
-
if
|
| 144 |
for generation_output in model.stream_generate(
|
| 145 |
input_ids=input_ids,
|
| 146 |
generation_config=generation_config,
|
|
|
|
| 21 |
|
| 22 |
args = parser.parse_args()
|
| 23 |
print(args)
|
| 24 |
+
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODE)
|
| 25 |
|
| 26 |
LOAD_8BIT = True
|
| 27 |
|
| 28 |
|
| 29 |
|
| 30 |
# fix the path for local checkpoint
|
| 31 |
+
lora_bin_path = os.path.join(LORA_PATH, "adapter_model.bin")
|
| 32 |
print(lora_bin_path)
|
| 33 |
+
if not os.path.exists(lora_bin_path) and USE_LOCAL:
|
| 34 |
+
pytorch_bin_path = os.path.join(LORA_PATH, "pytorch_model.bin")
|
| 35 |
print(pytorch_bin_path)
|
| 36 |
if os.path.exists(pytorch_bin_path):
|
| 37 |
os.rename(pytorch_bin_path, lora_bin_path)
|
|
|
|
| 140 |
**kwargs,
|
| 141 |
)
|
| 142 |
with torch.no_grad():
|
| 143 |
+
if TYPE_WRITER:
|
| 144 |
for generation_output in model.stream_generate(
|
| 145 |
input_ids=input_ids,
|
| 146 |
generation_config=generation_config,
|