JayceAnova commited on
Commit
65c947a
·
verified ·
1 Parent(s): ff8573a

Update test_ddp.py

Browse files
Files changed (1) hide show
  1. test_ddp.py +9 -8
test_ddp.py CHANGED
@@ -11,7 +11,8 @@ from torch.nn.parallel import DistributedDataParallel
11
  from peft import PeftModel
12
  from torch.utils.data import DataLoader
13
  from tqdm import tqdm
14
- from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
 
15
 
16
  from utils import *
17
  from collator import TestCollator
@@ -33,12 +34,12 @@ def test_ddp(args):
33
  device_map = {"": local_rank}
34
  device = torch.device("cuda",local_rank)
35
 
36
- tokenizer = LlamaTokenizer.from_pretrained(args.ckpt_path)
37
- args.lora=True
38
  if args.lora:
39
- model = LlamaForCausalLM.from_pretrained(
40
  args.base_model,
41
- torch_dtype=torch.float16,
42
  low_cpu_mem_usage=True,
43
  device_map=device_map,
44
  )
@@ -46,13 +47,13 @@ def test_ddp(args):
46
  model = PeftModel.from_pretrained(
47
  model,
48
  args.ckpt_path,
49
- torch_dtype=torch.float16,
50
  device_map=device_map,
51
  )
52
  else:
53
- model = LlamaForCausalLM.from_pretrained(
54
  args.ckpt_path,
55
- torch_dtype=torch.float16,
56
  low_cpu_mem_usage=True,
57
  device_map=device_map,
58
  )
 
11
  from peft import PeftModel
12
  from torch.utils.data import DataLoader
13
  from tqdm import tqdm
14
+ # from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
16
 
17
  from utils import *
18
  from collator import TestCollator
 
34
  device_map = {"": local_rank}
35
  device = torch.device("cuda",local_rank)
36
 
37
+ tokenizer = AutoTokenizer.from_pretrained(args.ckpt_path)
38
+ #args.lora=True
39
  if args.lora:
40
+ model = AutoModelForCausalLM.from_pretrained(
41
  args.base_model,
42
+ torch_dtype=torch.bfloat16,
43
  low_cpu_mem_usage=True,
44
  device_map=device_map,
45
  )
 
47
  model = PeftModel.from_pretrained(
48
  model,
49
  args.ckpt_path,
50
+ torch_dtype=torch.bfloat16,
51
  device_map=device_map,
52
  )
53
  else:
54
+ model = AutoModelForCausalLM.from_pretrained(
55
  args.ckpt_path,
56
+ torch_dtype=torch.bfloat16,
57
  low_cpu_mem_usage=True,
58
  device_map=device_map,
59
  )