Andreas Varvarigos commited on
Commit
cb37d55
·
verified ·
1 Parent(s): c71f098

Update src/train.py

Browse files
Files changed (1) hide show
  1. src/train.py +2 -8
src/train.py CHANGED
@@ -6,7 +6,7 @@ import networkx as nx
6
  from tqdm import tqdm
7
  from peft import (LoraConfig, get_peft_model,
8
  prepare_model_for_kbit_training)
9
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
10
 
11
 
12
 
@@ -30,17 +30,11 @@ class QloraTrainer_CS:
30
  model_id = self.config['inference']["base_model"]
31
  print(model_id)
32
 
33
- bnb_config = BitsAndBytesConfig(
34
- load_in_8bit=True,
35
- bnb_8bit_use_double_quant=True,
36
- bnb_8bit_quant_type="nf8",
37
- bnb_8bit_compute_dtype=torch.bfloat16
38
- )
39
  tokenizer = AutoTokenizer.from_pretrained(model_id)
40
  tokenizer.model_max_length = self.config['training']['tokenizer']["max_length"]
41
  if not tokenizer.pad_token:
42
  tokenizer.pad_token = tokenizer.eos_token
43
- model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, torch_dtype=torch.bfloat16)
44
  if model.device.type != 'cuda':
45
  model.to('cuda')
46
 
 
6
  from tqdm import tqdm
7
  from peft import (LoraConfig, get_peft_model,
8
  prepare_model_for_kbit_training)
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
 
11
 
12
 
 
30
  model_id = self.config['inference']["base_model"]
31
  print(model_id)
32
 
 
 
 
 
 
 
33
  tokenizer = AutoTokenizer.from_pretrained(model_id)
34
  tokenizer.model_max_length = self.config['training']['tokenizer']["max_length"]
35
  if not tokenizer.pad_token:
36
  tokenizer.pad_token = tokenizer.eos_token
37
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
38
  if model.device.type != 'cuda':
39
  model.to('cuda')
40