Upload trainer.py
Browse files- LLAVA-Cherry/trainer.py +37 -3
LLAVA-Cherry/trainer.py
CHANGED
|
@@ -17,6 +17,11 @@
|
|
| 17 |
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
|
| 18 |
"""
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
import contextlib
|
| 21 |
import copy
|
| 22 |
import functools
|
|
@@ -2762,6 +2767,13 @@ class Trainer:
|
|
| 2762 |
Return:
|
| 2763 |
`torch.Tensor`: The tensor with training loss on this batch.
|
| 2764 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2765 |
model.train()
|
| 2766 |
inputs = self._prepare_inputs(inputs)
|
| 2767 |
|
|
@@ -2776,16 +2788,31 @@ class Trainer:
|
|
| 2776 |
del inputs['dataset_id']
|
| 2777 |
del inputs['data_info']
|
| 2778 |
#######################################################
|
|
|
|
| 2779 |
|
| 2780 |
with self.compute_loss_context_manager():
|
| 2781 |
-
loss = self.compute_loss(model, inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2782 |
|
| 2783 |
#######################################################
|
| 2784 |
import json
|
| 2785 |
for i in range(len(data_info_temp)):
|
| 2786 |
-
data_info_temp[i]['loss'] = float(loss[0][i])
|
|
|
|
|
|
|
| 2787 |
|
| 2788 |
-
|
|
|
|
|
|
|
|
|
|
| 2789 |
with open(file_path, 'a', encoding='utf-8') as file:
|
| 2790 |
# json.dump(data_info_temp[0], file, ensure_ascii=False, indent=4)
|
| 2791 |
for content in data_info_temp:
|
|
@@ -2825,6 +2852,13 @@ class Trainer:
|
|
| 2825 |
else:
|
| 2826 |
labels = None
|
| 2827 |
outputs = model(**inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2828 |
# Save past state if it exists
|
| 2829 |
# TODO: this needs to be fixed and made cleaner later.
|
| 2830 |
if self.args.past_index >= 0:
|
|
|
|
| 17 |
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
|
| 18 |
"""
|
| 19 |
|
| 20 |
+
#########################################################
|
| 21 |
+
from datetime import datetime
|
| 22 |
+
has_run = False
|
| 23 |
+
#########################################################
|
| 24 |
+
|
| 25 |
import contextlib
|
| 26 |
import copy
|
| 27 |
import functools
|
|
|
|
| 2767 |
Return:
|
| 2768 |
`torch.Tensor`: The tensor with training loss on this batch.
|
| 2769 |
"""
|
| 2770 |
+
|
| 2771 |
+
# #######################################################
|
| 2772 |
+
# # import pdb; pdb.set_trace()
|
| 2773 |
+
# import pprint
|
| 2774 |
+
# pprint.pprint(inputs)
|
| 2775 |
+
# #######################################################
|
| 2776 |
+
|
| 2777 |
model.train()
|
| 2778 |
inputs = self._prepare_inputs(inputs)
|
| 2779 |
|
|
|
|
| 2788 |
del inputs['dataset_id']
|
| 2789 |
del inputs['data_info']
|
| 2790 |
#######################################################
|
| 2791 |
+
|
| 2792 |
|
| 2793 |
with self.compute_loss_context_manager():
|
| 2794 |
+
# loss = self.compute_loss(model, inputs)
|
| 2795 |
+
(loss, outputs) = self.compute_loss(model, inputs,return_outputs=True)
|
| 2796 |
+
|
| 2797 |
+
|
| 2798 |
+
import pprint
|
| 2799 |
+
# pprint.pprint(outputs)
|
| 2800 |
+
# import pdb; pdb.set_trace()
|
| 2801 |
+
last_token_logits_yes = outputs.logits[:, -1, :]
|
| 2802 |
+
yes_target_token_id = 4874
|
| 2803 |
+
yes_target_logprob = torch.log_softmax(last_token_logits_yes, dim=-1)[0, yes_target_token_id].item()
|
| 2804 |
|
| 2805 |
#######################################################
|
| 2806 |
import json
|
| 2807 |
for i in range(len(data_info_temp)):
|
| 2808 |
+
# data_info_temp[i]['loss'] = float(loss[0][i])
|
| 2809 |
+
data_info_temp[i]['yes_target_logprob'] = yes_target_logprob
|
| 2810 |
+
data_info_temp[i]['logits_shape'] = outputs.logits.shape
|
| 2811 |
|
| 2812 |
+
from datetime import datetime
|
| 2813 |
+
current_time = datetime.now().strftime('%Y_%m_%d')
|
| 2814 |
+
|
| 2815 |
+
file_path = '/data/zbz5349/ICLR_2024/ACL_2025/LLaVA_Fliter/inference_demo/cherry_AskLLM_infer_result_' + current_time + '.jsonl'
|
| 2816 |
with open(file_path, 'a', encoding='utf-8') as file:
|
| 2817 |
# json.dump(data_info_temp[0], file, ensure_ascii=False, indent=4)
|
| 2818 |
for content in data_info_temp:
|
|
|
|
| 2852 |
else:
|
| 2853 |
labels = None
|
| 2854 |
outputs = model(**inputs)
|
| 2855 |
+
|
| 2856 |
+
# #######################################################
|
| 2857 |
+
# import pdb; pdb.set_trace()
|
| 2858 |
+
# import pprint
|
| 2859 |
+
# pprint.pprint(outputs)
|
| 2860 |
+
# #######################################################
|
| 2861 |
+
|
| 2862 |
# Save past state if it exists
|
| 2863 |
# TODO: this needs to be fixed and made cleaner later.
|
| 2864 |
if self.args.past_index >= 0:
|