neuralbroker commited on
Commit
cd53253
·
verified ·
1 Parent(s): 5812ada

Update clean backend-only project docs and eval

Browse files
Files changed (1) hide show
  1. scripts/test_inference.py +7 -3
scripts/test_inference.py CHANGED
@@ -7,6 +7,7 @@ from __future__ import annotations
7
 
8
  import argparse
9
  from pathlib import Path
 
10
 
11
  REPO_ROOT = Path(__file__).resolve().parents[1]
12
  CHECKPOINT_CANDIDATES = [
@@ -59,10 +60,13 @@ def main() -> None:
59
  tokenizer = AutoTokenizer.from_pretrained(str(checkpoint_path), trust_remote_code=True)
60
 
61
  peft_config = PeftConfig.from_pretrained(str(checkpoint_path))
62
- print(f"Loading base model: {peft_config.base_model_name_or_path}")
 
 
 
63
  base_model = AutoModelForCausalLM.from_pretrained(
64
- peft_config.base_model_name_or_path,
65
- torch_dtype=torch.bfloat16,
66
  device_map="auto",
67
  trust_remote_code=True,
68
  )
 
7
 
8
  import argparse
9
  from pathlib import Path
10
+ from typing import Any, cast
11
 
12
  REPO_ROOT = Path(__file__).resolve().parents[1]
13
  CHECKPOINT_CANDIDATES = [
 
60
  tokenizer = AutoTokenizer.from_pretrained(str(checkpoint_path), trust_remote_code=True)
61
 
62
  peft_config = PeftConfig.from_pretrained(str(checkpoint_path))
63
+ base_model_name = peft_config.base_model_name_or_path
64
+ if not base_model_name:
65
+ raise SystemExit("Checkpoint PEFT config is missing base_model_name_or_path")
66
+ print(f"Loading base model: {base_model_name}")
67
  base_model = AutoModelForCausalLM.from_pretrained(
68
+ str(base_model_name),
69
+ torch_dtype=cast(Any, torch).bfloat16,
70
  device_map="auto",
71
  trust_remote_code=True,
72
  )