clean up default args
Browse files- src/retool_trainer.py +12 -7
src/retool_trainer.py
CHANGED
|
@@ -65,12 +65,18 @@ class ReToolTrainer(Trainer): # Change this line
|
|
| 65 |
# Store processing_class for compatibility
|
| 66 |
self.processing_class = processing_class or self.tokenizer
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
# Add reward function handling (since Trainer doesn't have this)
|
| 69 |
self.reward_funcs = reward_funcs or [self._binary_reward_function]
|
| 70 |
-
|
| 71 |
-
# Rest of the ReTool-specific code stays exactly the same!
|
| 72 |
-
self.eos_id = eos_id or self.processing_class.eos_token_id
|
| 73 |
-
|
| 74 |
|
| 75 |
# ReTool specific attributes
|
| 76 |
self.eos_id = eos_id or self.processing_class.eos_token_id
|
|
@@ -99,16 +105,15 @@ class ReToolTrainer(Trainer): # Change this line
|
|
| 99 |
do_sample=True,
|
| 100 |
pad_token_id=self.processing_class.pad_token_id,
|
| 101 |
bos_token_id=self.processing_class.bos_token_id,
|
| 102 |
-
eos_token_id=
|
| 103 |
temperature=self.temperature,
|
| 104 |
top_p=self.top_p,
|
| 105 |
top_k=self.top_k,
|
| 106 |
min_p=self.min_p,
|
| 107 |
return_dict_in_generate=True,
|
| 108 |
use_cache=True,
|
|
|
|
| 109 |
)
|
| 110 |
-
|
| 111 |
-
|
| 112 |
def _get_interpreter_token_ids(self) -> list[int]:
|
| 113 |
"""Get token IDs for <interpreter> and </interpreter> tags."""
|
| 114 |
start_token = self.processing_class.encode("<interpreter>", add_special_tokens=False)[0]
|
|
|
|
| 65 |
# Store processing_class for compatibility
|
| 66 |
self.processing_class = processing_class or self.tokenizer
|
| 67 |
|
| 68 |
+
# Processing class
|
| 69 |
+
if processing_class is None:
|
| 70 |
+
self.processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
|
| 71 |
+
else:
|
| 72 |
+
# Store processing_class for compatibility
|
| 73 |
+
self.processing_class = processing_class or self.tokenizer
|
| 74 |
+
if processing_class.pad_token is None:
|
| 75 |
+
self.processing_class.pad_token = processing_class.eos_token
|
| 76 |
+
|
| 77 |
+
|
| 78 |
# Add reward function handling (since Trainer doesn't have this)
|
| 79 |
self.reward_funcs = reward_funcs or [self._binary_reward_function]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
# ReTool specific attributes
|
| 82 |
self.eos_id = eos_id or self.processing_class.eos_token_id
|
|
|
|
| 105 |
do_sample=True,
|
| 106 |
pad_token_id=self.processing_class.pad_token_id,
|
| 107 |
bos_token_id=self.processing_class.bos_token_id,
|
| 108 |
+
eos_token_id=self.eos_id, # default stop on EOS
|
| 109 |
temperature=self.temperature,
|
| 110 |
top_p=self.top_p,
|
| 111 |
top_k=self.top_k,
|
| 112 |
min_p=self.min_p,
|
| 113 |
return_dict_in_generate=True,
|
| 114 |
use_cache=True,
|
| 115 |
+
cache_implementation=args.cache_implementation, #args.cache_implementation = 'Offloaded Cache'
|
| 116 |
)
|
|
|
|
|
|
|
| 117 |
def _get_interpreter_token_ids(self) -> list[int]:
|
| 118 |
"""Get token IDs for <interpreter> and </interpreter> tags."""
|
| 119 |
start_token = self.processing_class.encode("<interpreter>", add_special_tokens=False)[0]
|