DDDano333 commited on
Commit
22c6032
·
1 Parent(s): 936fde4

cleaned out line

Browse files
Files changed (1) hide show
  1. train.py +1 -1
train.py CHANGED
@@ -47,7 +47,7 @@ def train(rank, world_size):
47
  "decapoda-research/llama-7b-hf",
48
  load_in_8bit=True,
49
  device_map="auto",
50
- ).to(device)
51
  model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank)
52
  tokenizer = LLaMATokenizer.from_pretrained(
53
  "decapoda-research/llama-7b-hf", add_eos_token=True
 
47
  "decapoda-research/llama-7b-hf",
48
  load_in_8bit=True,
49
  device_map="auto",
50
+ )
51
  model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank)
52
  tokenizer = LLaMATokenizer.from_pretrained(
53
  "decapoda-research/llama-7b-hf", add_eos_token=True