Update utilities/modeling.py
Browse files- utilities/modeling.py +6 -2
utilities/modeling.py
CHANGED
|
@@ -37,7 +37,9 @@ def get_peft(model, peft, max_seq_length, random_seed):
|
|
| 37 |
return model
|
| 38 |
|
| 39 |
|
| 40 |
-
def get_trainer(model, tokenizer, dataset, sft,
|
|
|
|
|
|
|
| 41 |
|
| 42 |
trainer = SFTTrainer(
|
| 43 |
model = model,
|
|
@@ -68,6 +70,7 @@ def get_trainer(model, tokenizer, dataset, sft, data_field, max_seq_length, rand
|
|
| 68 |
|
| 69 |
|
| 70 |
def prepare_trainer(model_name, max_seq_length, random_seed,
|
|
|
|
| 71 |
peft, sft, dataset, data_field):
|
| 72 |
|
| 73 |
print("Loading Model")
|
|
@@ -77,7 +80,8 @@ def prepare_trainer(model_name, max_seq_length, random_seed,
|
|
| 77 |
model = get_peft(model, peft, max_seq_length, random_seed)
|
| 78 |
|
| 79 |
print("Getting Trainer Model")
|
| 80 |
-
trainer = get_trainer(model, tokenizer, dataset, data_field, max_seq_length, random_seed
|
|
|
|
| 81 |
|
| 82 |
return trainer
|
| 83 |
|
|
|
|
| 37 |
return model
|
| 38 |
|
| 39 |
|
| 40 |
+
def get_trainer(model, tokenizer, dataset, sft,
|
| 41 |
+
data_field, max_seq_length, random_seed,
|
| 42 |
+
num_epochs, max_steps):
|
| 43 |
|
| 44 |
trainer = SFTTrainer(
|
| 45 |
model = model,
|
|
|
|
| 70 |
|
| 71 |
|
| 72 |
def prepare_trainer(model_name, max_seq_length, random_seed,
|
| 73 |
+
num_epochs, max_steps,
|
| 74 |
peft, sft, dataset, data_field):
|
| 75 |
|
| 76 |
print("Loading Model")
|
|
|
|
| 80 |
model = get_peft(model, peft, max_seq_length, random_seed)
|
| 81 |
|
| 82 |
print("Getting Trainer Model")
|
| 83 |
+
trainer = get_trainer(model, tokenizer, dataset, data_field, max_seq_length, random_seed,
|
| 84 |
+
num_epochs, max_steps)
|
| 85 |
|
| 86 |
return trainer
|
| 87 |
|