Add files using upload-large-folder tool
Browse files- r1-a/final_dataset/preference_relative/dataset_info.json +84 -0
- r1-a/final_dataset/preference_relative/state.json +33 -0
- r1-a/final_dataset/preference_relative_processed_shards/logs/shard_0_gpu_0.log +16 -0
- r1-a/final_dataset/preference_relative_processed_shards/shard_0/dataset_info.json +100 -0
- r1-a/final_dataset/preference_relative_processed_shards/shard_0/state.json +37 -0
- r1-a/final_dataset/prompt_only_relative_paths/dataset_info.json +24 -0
- r1-a/final_dataset/prompt_only_relative_paths/state.json +13 -0
- r1-a/response_generation/glm4voice.py +579 -0
- r1-a/response_generation/gpt4o.py +464 -0
- r1-a/response_generation/gpt4o_mini.py +464 -0
- r1-a/response_generation/gpt5o_retry.py +461 -0
- r1-a/response_generation/kimi.py +532 -0
- r1-a/response_generation/minicpm.py +519 -0
- r1-a/response_generation/minicpm/MiniCPM-o/.gitignore +3 -0
- r1-a/response_generation/minicpm/MiniCPM-o/LICENSE +201 -0
- r1-a/response_generation/minicpm/MiniCPM-o/README.md +0 -0
- r1-a/response_generation/minicpm/MiniCPM-o/README_zh.md +2524 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/dataset/cgbench.py +1760 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/dataset/cmmmu.py +354 -0
- r1-a/response_generation/qwenomni.py +451 -0
r1-a/final_dataset/preference_relative/dataset_info.json
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"citation": "",
|
| 3 |
+
"description": "",
|
| 4 |
+
"features": {
|
| 5 |
+
"question_text": {
|
| 6 |
+
"dtype": "string",
|
| 7 |
+
"_type": "Value"
|
| 8 |
+
},
|
| 9 |
+
"question_audio": {
|
| 10 |
+
"dtype": "string",
|
| 11 |
+
"_type": "Value"
|
| 12 |
+
},
|
| 13 |
+
"source_dataset": {
|
| 14 |
+
"dtype": "string",
|
| 15 |
+
"_type": "Value"
|
| 16 |
+
},
|
| 17 |
+
"metadata": {
|
| 18 |
+
"dtype": "string",
|
| 19 |
+
"_type": "Value"
|
| 20 |
+
},
|
| 21 |
+
"model_1": {
|
| 22 |
+
"dtype": "string",
|
| 23 |
+
"_type": "Value"
|
| 24 |
+
},
|
| 25 |
+
"prompt_name_1": {
|
| 26 |
+
"dtype": "string",
|
| 27 |
+
"_type": "Value"
|
| 28 |
+
},
|
| 29 |
+
"prompt_text_1": {
|
| 30 |
+
"dtype": "string",
|
| 31 |
+
"_type": "Value"
|
| 32 |
+
},
|
| 33 |
+
"response_text_1": {
|
| 34 |
+
"dtype": "string",
|
| 35 |
+
"_type": "Value"
|
| 36 |
+
},
|
| 37 |
+
"response_audio_path_1": {
|
| 38 |
+
"dtype": "string",
|
| 39 |
+
"_type": "Value"
|
| 40 |
+
},
|
| 41 |
+
"model_2": {
|
| 42 |
+
"dtype": "string",
|
| 43 |
+
"_type": "Value"
|
| 44 |
+
},
|
| 45 |
+
"prompt_name_2": {
|
| 46 |
+
"dtype": "string",
|
| 47 |
+
"_type": "Value"
|
| 48 |
+
},
|
| 49 |
+
"prompt_text_2": {
|
| 50 |
+
"dtype": "string",
|
| 51 |
+
"_type": "Value"
|
| 52 |
+
},
|
| 53 |
+
"response_text_2": {
|
| 54 |
+
"dtype": "string",
|
| 55 |
+
"_type": "Value"
|
| 56 |
+
},
|
| 57 |
+
"response_audio_path_2": {
|
| 58 |
+
"dtype": "string",
|
| 59 |
+
"_type": "Value"
|
| 60 |
+
},
|
| 61 |
+
"model_3": {
|
| 62 |
+
"dtype": "string",
|
| 63 |
+
"_type": "Value"
|
| 64 |
+
},
|
| 65 |
+
"prompt_name_3": {
|
| 66 |
+
"dtype": "string",
|
| 67 |
+
"_type": "Value"
|
| 68 |
+
},
|
| 69 |
+
"prompt_text_3": {
|
| 70 |
+
"dtype": "string",
|
| 71 |
+
"_type": "Value"
|
| 72 |
+
},
|
| 73 |
+
"response_text_3": {
|
| 74 |
+
"dtype": "string",
|
| 75 |
+
"_type": "Value"
|
| 76 |
+
},
|
| 77 |
+
"response_audio_path_3": {
|
| 78 |
+
"dtype": "string",
|
| 79 |
+
"_type": "Value"
|
| 80 |
+
}
|
| 81 |
+
},
|
| 82 |
+
"homepage": "",
|
| 83 |
+
"license": ""
|
| 84 |
+
}
|
r1-a/final_dataset/preference_relative/state.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "6ada4a1b690526f4",
|
| 8 |
+
"_format_columns": [
|
| 9 |
+
"question_text",
|
| 10 |
+
"question_audio",
|
| 11 |
+
"source_dataset",
|
| 12 |
+
"metadata",
|
| 13 |
+
"model_1",
|
| 14 |
+
"prompt_name_1",
|
| 15 |
+
"prompt_text_1",
|
| 16 |
+
"response_text_1",
|
| 17 |
+
"response_audio_path_1",
|
| 18 |
+
"model_2",
|
| 19 |
+
"prompt_name_2",
|
| 20 |
+
"prompt_text_2",
|
| 21 |
+
"response_text_2",
|
| 22 |
+
"response_audio_path_2",
|
| 23 |
+
"model_3",
|
| 24 |
+
"prompt_name_3",
|
| 25 |
+
"prompt_text_3",
|
| 26 |
+
"response_text_3",
|
| 27 |
+
"response_audio_path_3"
|
| 28 |
+
],
|
| 29 |
+
"_format_kwargs": {},
|
| 30 |
+
"_format_type": null,
|
| 31 |
+
"_output_all_columns": false,
|
| 32 |
+
"_split": null
|
| 33 |
+
}
|
r1-a/final_dataset/preference_relative_processed_shards/logs/shard_0_gpu_0.log
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2025-05-04 14:41:37,640 - INFO - [Shard 0] - Process started for Shard 0 on GPU 0 (logical device cuda:0)
|
| 2 |
+
2025-05-04 14:41:37,640 - INFO - [Shard 0] - Arguments: Namespace(shard_index=0, gpu_id=0, wer_threshold=0.4, pipeline_batch_size=16, map_batch_size=16, num_check_workers=4)
|
| 3 |
+
2025-05-04 14:41:37,640 - INFO - [Shard 0] - Loading dataset from /home/chenyifu/audio-r1/r1-a/final_dataset/preference_relative
|
| 4 |
+
2025-05-04 16:30:18,197 - ERROR - [Shard 0] - Failed to load dataset:
|
| 5 |
+
Traceback (most recent call last):
|
| 6 |
+
File "/home/chenyifu/audio-r1/r1-a/dataset/retts.py", line 380, in main
|
| 7 |
+
logger.info(f"Full dataset loaded with {full_ds.num_rows} rows.")
|
| 8 |
+
File "/home/chenyifu/audio-r1/r1-a/dataset/retts.py", line 380, in main
|
| 9 |
+
logger.info(f"Full dataset loaded with {full_ds.num_rows} rows.")
|
| 10 |
+
File "/home/chenyifu/miniconda3/envs/cosyvoice/lib/python3.10/bdb.py", line 90, in trace_dispatch
|
| 11 |
+
return self.dispatch_line(frame)
|
| 12 |
+
File "/home/chenyifu/miniconda3/envs/cosyvoice/lib/python3.10/bdb.py", line 115, in dispatch_line
|
| 13 |
+
if self.quitting: raise BdbQuit
|
| 14 |
+
bdb.BdbQuit
|
| 15 |
+
2025-05-04 16:30:18,199 - WARNING - [Shard 0] - Processing did not complete or failed early. No statistics to log.
|
| 16 |
+
2025-05-04 16:30:18,199 - INFO - [Shard 0] - Process for Shard 0 on GPU 0 finished.
|
r1-a/final_dataset/preference_relative_processed_shards/shard_0/dataset_info.json
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"citation": "",
|
| 3 |
+
"description": "",
|
| 4 |
+
"features": {
|
| 5 |
+
"question_text": {
|
| 6 |
+
"dtype": "string",
|
| 7 |
+
"_type": "Value"
|
| 8 |
+
},
|
| 9 |
+
"question_audio": {
|
| 10 |
+
"dtype": "string",
|
| 11 |
+
"_type": "Value"
|
| 12 |
+
},
|
| 13 |
+
"source_dataset": {
|
| 14 |
+
"dtype": "string",
|
| 15 |
+
"_type": "Value"
|
| 16 |
+
},
|
| 17 |
+
"metadata": {
|
| 18 |
+
"dtype": "string",
|
| 19 |
+
"_type": "Value"
|
| 20 |
+
},
|
| 21 |
+
"model_1": {
|
| 22 |
+
"dtype": "string",
|
| 23 |
+
"_type": "Value"
|
| 24 |
+
},
|
| 25 |
+
"prompt_name_1": {
|
| 26 |
+
"dtype": "string",
|
| 27 |
+
"_type": "Value"
|
| 28 |
+
},
|
| 29 |
+
"prompt_text_1": {
|
| 30 |
+
"dtype": "string",
|
| 31 |
+
"_type": "Value"
|
| 32 |
+
},
|
| 33 |
+
"response_text_1": {
|
| 34 |
+
"dtype": "string",
|
| 35 |
+
"_type": "Value"
|
| 36 |
+
},
|
| 37 |
+
"response_audio_path_1": {
|
| 38 |
+
"dtype": "string",
|
| 39 |
+
"_type": "Value"
|
| 40 |
+
},
|
| 41 |
+
"model_2": {
|
| 42 |
+
"dtype": "string",
|
| 43 |
+
"_type": "Value"
|
| 44 |
+
},
|
| 45 |
+
"prompt_name_2": {
|
| 46 |
+
"dtype": "string",
|
| 47 |
+
"_type": "Value"
|
| 48 |
+
},
|
| 49 |
+
"prompt_text_2": {
|
| 50 |
+
"dtype": "string",
|
| 51 |
+
"_type": "Value"
|
| 52 |
+
},
|
| 53 |
+
"response_text_2": {
|
| 54 |
+
"dtype": "string",
|
| 55 |
+
"_type": "Value"
|
| 56 |
+
},
|
| 57 |
+
"response_audio_path_2": {
|
| 58 |
+
"dtype": "string",
|
| 59 |
+
"_type": "Value"
|
| 60 |
+
},
|
| 61 |
+
"model_3": {
|
| 62 |
+
"dtype": "string",
|
| 63 |
+
"_type": "Value"
|
| 64 |
+
},
|
| 65 |
+
"prompt_name_3": {
|
| 66 |
+
"dtype": "string",
|
| 67 |
+
"_type": "Value"
|
| 68 |
+
},
|
| 69 |
+
"prompt_text_3": {
|
| 70 |
+
"dtype": "string",
|
| 71 |
+
"_type": "Value"
|
| 72 |
+
},
|
| 73 |
+
"response_text_3": {
|
| 74 |
+
"dtype": "string",
|
| 75 |
+
"_type": "Value"
|
| 76 |
+
},
|
| 77 |
+
"response_audio_path_3": {
|
| 78 |
+
"dtype": "string",
|
| 79 |
+
"_type": "Value"
|
| 80 |
+
},
|
| 81 |
+
"asr_transcription": {
|
| 82 |
+
"dtype": "string",
|
| 83 |
+
"_type": "Value"
|
| 84 |
+
},
|
| 85 |
+
"wer": {
|
| 86 |
+
"dtype": "float32",
|
| 87 |
+
"_type": "Value"
|
| 88 |
+
},
|
| 89 |
+
"is_bad_tts": {
|
| 90 |
+
"dtype": "bool",
|
| 91 |
+
"_type": "Value"
|
| 92 |
+
},
|
| 93 |
+
"error_message": {
|
| 94 |
+
"dtype": "string",
|
| 95 |
+
"_type": "Value"
|
| 96 |
+
}
|
| 97 |
+
},
|
| 98 |
+
"homepage": "",
|
| 99 |
+
"license": ""
|
| 100 |
+
}
|
r1-a/final_dataset/preference_relative_processed_shards/shard_0/state.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "00775948802a271b",
|
| 8 |
+
"_format_columns": [
|
| 9 |
+
"asr_transcription",
|
| 10 |
+
"error_message",
|
| 11 |
+
"is_bad_tts",
|
| 12 |
+
"metadata",
|
| 13 |
+
"model_1",
|
| 14 |
+
"model_2",
|
| 15 |
+
"model_3",
|
| 16 |
+
"prompt_name_1",
|
| 17 |
+
"prompt_name_2",
|
| 18 |
+
"prompt_name_3",
|
| 19 |
+
"prompt_text_1",
|
| 20 |
+
"prompt_text_2",
|
| 21 |
+
"prompt_text_3",
|
| 22 |
+
"question_audio",
|
| 23 |
+
"question_text",
|
| 24 |
+
"response_audio_path_1",
|
| 25 |
+
"response_audio_path_2",
|
| 26 |
+
"response_audio_path_3",
|
| 27 |
+
"response_text_1",
|
| 28 |
+
"response_text_2",
|
| 29 |
+
"response_text_3",
|
| 30 |
+
"source_dataset",
|
| 31 |
+
"wer"
|
| 32 |
+
],
|
| 33 |
+
"_format_kwargs": {},
|
| 34 |
+
"_format_type": null,
|
| 35 |
+
"_output_all_columns": false,
|
| 36 |
+
"_split": null
|
| 37 |
+
}
|
r1-a/final_dataset/prompt_only_relative_paths/dataset_info.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"citation": "",
|
| 3 |
+
"description": "",
|
| 4 |
+
"features": {
|
| 5 |
+
"source_dataset": {
|
| 6 |
+
"dtype": "string",
|
| 7 |
+
"_type": "Value"
|
| 8 |
+
},
|
| 9 |
+
"question_text": {
|
| 10 |
+
"dtype": "string",
|
| 11 |
+
"_type": "Value"
|
| 12 |
+
},
|
| 13 |
+
"question_audio": {
|
| 14 |
+
"dtype": "string",
|
| 15 |
+
"_type": "Value"
|
| 16 |
+
},
|
| 17 |
+
"metadata": {
|
| 18 |
+
"dtype": "string",
|
| 19 |
+
"_type": "Value"
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"homepage": "",
|
| 23 |
+
"license": ""
|
| 24 |
+
}
|
r1-a/final_dataset/prompt_only_relative_paths/state.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "419ded2384418f0a",
|
| 8 |
+
"_format_columns": null,
|
| 9 |
+
"_format_kwargs": {},
|
| 10 |
+
"_format_type": null,
|
| 11 |
+
"_output_all_columns": false,
|
| 12 |
+
"_split": null
|
| 13 |
+
}
|
r1-a/response_generation/glm4voice.py
ADDED
|
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import base64
|
| 4 |
+
import uuid
|
| 5 |
+
import time
|
| 6 |
+
import re
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
import concurrent.futures
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import threading
|
| 11 |
+
import itertools
|
| 12 |
+
import traceback # For detailed error logging
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import soundfile as sf
|
| 16 |
+
from zhipuai import ZhipuAI
|
| 17 |
+
# Import specific error type if available and helpful
|
| 18 |
+
# Attempt to import specific error, handle if it doesn't exist
|
| 19 |
+
try:
|
| 20 |
+
from zhipuai.core._errors import APIStatusError
|
| 21 |
+
except ImportError:
|
| 22 |
+
# Define a dummy class if the specific error isn't available
|
| 23 |
+
# This allows the except block to still catch general exceptions
|
| 24 |
+
# that might represent API status issues if the SDK changes.
|
| 25 |
+
print("Warning: zhipuai.core._errors.APIStatusError not found. Using generic Exception for status errors.")
|
| 26 |
+
class APIStatusError(Exception):
|
| 27 |
+
def __init__(self, message, status_code=None, body=None):
|
| 28 |
+
super().__init__(message)
|
| 29 |
+
self.status_code = status_code
|
| 30 |
+
self.body = body
|
| 31 |
+
self.message = message # Add message attribute for consistency
|
| 32 |
+
|
| 33 |
+
from datasets import load_from_disk, Dataset
|
| 34 |
+
from dotenv import load_dotenv
|
| 35 |
+
|
| 36 |
+
# --- Configuration (User's Original Settings) ---
|
| 37 |
+
load_dotenv()
|
| 38 |
+
|
| 39 |
+
# 1. API Client Setup
|
| 40 |
+
GLM_MODEL_NAME = "glm-4-voice" # <<< User's original model name
|
| 41 |
+
|
| 42 |
+
# --- API Key Rotation Setup (User's Original Keys & Logic) ---
|
| 43 |
+
ZHIPUAI_API_KEYS = [
|
| 44 |
+
"14a67189b8bc4ee489e83b6247c36d0e.AIPUNrII50wREvsh",
|
| 45 |
+
"72120787822c4123a9654965ff90e4e6.JS1nuey9MncQscPa",
|
| 46 |
+
"d41b3b5bb49f4c8680b3836e7fc49bbf.u0jGxYc5sYPeRr5p",
|
| 47 |
+
"bc9bccd6ddd145fc844a014521c26868.JwsZXHzA3l32dDwz",
|
| 48 |
+
"0e5a05d709794737923ebd122e07d491.sL67ALh6BiLYaaGW", # New key
|
| 49 |
+
"db87c1fda8af4eb8b505f36e791d700d.w5M0Q3ZssT55tvlW", # New key
|
| 50 |
+
"1594ac60fbca4973809f4da425238e0c.ZMMfchqbok992Dmu", # New key
|
| 51 |
+
"469c0fa3b14e4913b1d14bc5d6f0c858.0KdQjFqdi66VPMnb",
|
| 52 |
+
"b9b538bb0e134438bacaf922b023d1fd.sogFUUp57UJ8YSd6",
|
| 53 |
+
"50bb382993a345cfa35833fc89caaa52.oR921jSW8iwzCV22",
|
| 54 |
+
"44512bbede5940f7964db7694bfc04df.yhDEQyPOXQCqh1Mn",
|
| 55 |
+
"99aba409b55c432696b9d5f1ff565d30.GmfRNngBOo8qDUbf"
|
| 56 |
+
] # <<< User's original keys
|
| 57 |
+
|
| 58 |
+
if not ZHIPUAI_API_KEYS:
|
| 59 |
+
print("FATAL: No ZHIPUAI_API_KEYS provided in the list.")
|
| 60 |
+
exit(1)
|
| 61 |
+
|
| 62 |
+
# Make sure keys are unique if duplicates were accidental
|
| 63 |
+
unique_keys = list(dict.fromkeys(ZHIPUAI_API_KEYS))
|
| 64 |
+
if len(unique_keys) != len(ZHIPUAI_API_KEYS):
|
| 65 |
+
print(f"Warning: Duplicate API keys found and removed. Using {len(unique_keys)} unique keys.")
|
| 66 |
+
ZHIPUAI_API_KEYS = unique_keys
|
| 67 |
+
|
| 68 |
+
key_cycler = itertools.cycle(ZHIPUAI_API_KEYS)
|
| 69 |
+
key_lock = threading.Lock()
|
| 70 |
+
disabled_keys = set() # Shared set to store disabled keys
|
| 71 |
+
|
| 72 |
+
class AllKeysDisabledError(Exception):
|
| 73 |
+
"""Custom exception raised when all API keys are disabled."""
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
def get_next_active_key():
|
| 77 |
+
"""
|
| 78 |
+
Thread-safely gets the next API key from the cycle, skipping disabled keys.
|
| 79 |
+
Raises AllKeysDisabledError if all keys are disabled.
|
| 80 |
+
(User's Original Logic)
|
| 81 |
+
"""
|
| 82 |
+
with key_lock:
|
| 83 |
+
initial_key_count = len(ZHIPUAI_API_KEYS)
|
| 84 |
+
checked_count = 0
|
| 85 |
+
while checked_count < initial_key_count:
|
| 86 |
+
potential_key = next(key_cycler)
|
| 87 |
+
if potential_key not in disabled_keys:
|
| 88 |
+
return potential_key
|
| 89 |
+
checked_count += 1
|
| 90 |
+
# Prevent infinite loop if somehow cycle changes mid-operation (shouldn't happen)
|
| 91 |
+
if checked_count > initial_key_count * 2:
|
| 92 |
+
print("Warning: Potential issue in get_next_active_key cycle detection.")
|
| 93 |
+
break
|
| 94 |
+
# If we exit the loop, all keys have been checked and are disabled
|
| 95 |
+
if len(disabled_keys) == initial_key_count:
|
| 96 |
+
raise AllKeysDisabledError("All API keys have been disabled.")
|
| 97 |
+
else:
|
| 98 |
+
# This case should ideally not be reached if logic is sound
|
| 99 |
+
# but indicates a potential problem finding an active key
|
| 100 |
+
print(f"Warning: Could not find an active key after checking {checked_count}. Disabled: {len(disabled_keys)}/{initial_key_count}")
|
| 101 |
+
raise RuntimeError("Failed to find an active API key.")
|
| 102 |
+
# --- End API Key Rotation Setup ---
|
| 103 |
+
|
| 104 |
+
# 2. Dataset Paths (User's Original Paths)
|
| 105 |
+
INPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_sampling_tasks" # <<< User's original path
|
| 106 |
+
OUTPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_tasks_with_glm" # <<< User's original path
|
| 107 |
+
|
| 108 |
+
# 3. Output Audio Configuration (User's Original Settings)
|
| 109 |
+
OUTPUT_AUDIO_ROOT_DIR = "/root/autodl-tmp/audio-r1/r1-a/generated_audio/glm_voice" # <<< User's original path
|
| 110 |
+
OUTPUT_AUDIO_FORMAT = "wav" # <<< User's original setting
|
| 111 |
+
OUTPUT_AUDIO_SAMPLERATE = 44100 # <<< User's original setting
|
| 112 |
+
|
| 113 |
+
# 4. API Call Settings (User's Original Settings)
|
| 114 |
+
API_RETRY_DELAY = 5 # <<< User's original setting
|
| 115 |
+
API_MAX_RETRIES = 3 # <<< User's original setting
|
| 116 |
+
MAX_WORKERS = 10 # <<< User's original setting
|
| 117 |
+
|
| 118 |
+
# --- Helper Functions (User's Original Functions) ---
|
| 119 |
+
def encode_audio_base64(audio_path):
|
| 120 |
+
# ... (implementation unchanged from user's script) ...
|
| 121 |
+
if not audio_path or not os.path.exists(audio_path):
|
| 122 |
+
print(f"Warning: Input audio file not found or path is empty: {audio_path}")
|
| 123 |
+
return None
|
| 124 |
+
try:
|
| 125 |
+
with open(audio_path, "rb") as audio_file:
|
| 126 |
+
return base64.b64encode(audio_file.read()).decode("utf-8")
|
| 127 |
+
except Exception as e:
|
| 128 |
+
print(f"Error encoding audio file {audio_path}: {e}")
|
| 129 |
+
return None
|
| 130 |
+
|
| 131 |
+
def parse_ultra_history(history_str):
|
| 132 |
+
# ... (implementation unchanged from user's script) ...
|
| 133 |
+
messages = []
|
| 134 |
+
pattern = re.compile(r"\[(USER|ASSISTANT)\]\s*([\s\S]*?)(?=\s*\[(?:USER|ASSISTANT)\]|$)")
|
| 135 |
+
matches = pattern.findall(history_str)
|
| 136 |
+
if not matches:
|
| 137 |
+
return [] # Return empty list if no matches, as per user's original code
|
| 138 |
+
for role_tag, content in matches:
|
| 139 |
+
role = role_tag.lower()
|
| 140 |
+
cleaned_content = content.strip()
|
| 141 |
+
if cleaned_content:
|
| 142 |
+
messages.append({"role": role, "content": cleaned_content})
|
| 143 |
+
return messages
|
| 144 |
+
|
| 145 |
+
# --- Modified API Call Worker Function (Handles Key Disabling & History Flattening) ---
|
| 146 |
+
def call_glm_voice_api_worker(task_info):
|
| 147 |
+
"""
|
| 148 |
+
Worker function to call GLM Voice API, handling key disabling for error 1113,
|
| 149 |
+
and flattening history into the user prompt with clear markers.
|
| 150 |
+
(Incorporates Method 2 flattening into user's worker structure)
|
| 151 |
+
"""
|
| 152 |
+
row_idx = task_info["row_idx"]
|
| 153 |
+
slot_idx = task_info["slot_idx"]
|
| 154 |
+
current_api_key = task_info["api_key"]
|
| 155 |
+
history_messages = task_info["history_messages"] # Original parsed history
|
| 156 |
+
prompt_text = task_info["prompt_text"] # The user's current text request
|
| 157 |
+
question_audio_path = task_info["question_audio_path"]
|
| 158 |
+
output_audio_filepath = task_info["output_audio_filepath"]
|
| 159 |
+
|
| 160 |
+
retries = 0
|
| 161 |
+
local_glm_client = None
|
| 162 |
+
|
| 163 |
+
while retries < API_MAX_RETRIES:
|
| 164 |
+
# --- Initialize or Re-initialize client (User's Original Logic) ---
|
| 165 |
+
if local_glm_client is None or getattr(local_glm_client, 'api_key', None) != current_api_key:
|
| 166 |
+
try:
|
| 167 |
+
with key_lock:
|
| 168 |
+
if current_api_key in disabled_keys:
|
| 169 |
+
print(f"Info (Row {row_idx}, Slot {slot_idx}): Assigned key ...{current_api_key[-6:]} was disabled before use, getting new key.")
|
| 170 |
+
current_api_key = get_next_active_key()
|
| 171 |
+
task_info["api_key"] = current_api_key # Update task_info potentially for logging?
|
| 172 |
+
print(f" [Thread-{threading.get_ident()}] Initializing client for Row {row_idx}, Slot {slot_idx} (Key: ...{current_api_key[-6:]})")
|
| 173 |
+
local_glm_client = ZhipuAI(api_key=current_api_key)
|
| 174 |
+
except AllKeysDisabledError:
|
| 175 |
+
print(f"FATAL (Row {row_idx}, Slot {slot_idx}): All API keys are disabled. Cannot proceed with task.")
|
| 176 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: All Keys Disabled]", "saved_audio_path": None}
|
| 177 |
+
except Exception as client_init_e:
|
| 178 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Failed to initialize ZhipuAI client with key ...{current_api_key[-6:]}: {client_init_e}")
|
| 179 |
+
retries += 1
|
| 180 |
+
time.sleep(API_RETRY_DELAY)
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
# --- Attempt API Call ---
|
| 184 |
+
try:
|
| 185 |
+
# 1. Prepare Input Audio (User's Original Logic)
|
| 186 |
+
base64_audio_data = encode_audio_base64(question_audio_path)
|
| 187 |
+
if not base64_audio_data:
|
| 188 |
+
# This is a data error, not an API error, fail the task immediately (User's Original Logic)
|
| 189 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Skipping GLM API call - missing input audio.")
|
| 190 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: Missing input audio]", "saved_audio_path": None}
|
| 191 |
+
input_audio_format = os.path.splitext(question_audio_path)[1].lstrip('.') or 'wav'
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# 2. *** Flatten History and Construct Combined User Text Prompt (Method 2 Implementation) ***
|
| 195 |
+
text_parts = []
|
| 196 |
+
if history_messages:
|
| 197 |
+
print(f" (Row {row_idx}, Slot {slot_idx}) Flattening history ({len(history_messages)} turns) into prompt.")
|
| 198 |
+
text_parts.append("--- Start of Conversation History ---")
|
| 199 |
+
for msg in history_messages:
|
| 200 |
+
role_tag = "[User]" if msg['role'] == 'user' else "[Assistant]"
|
| 201 |
+
# Ensure content is string, handle potential non-string data defensively
|
| 202 |
+
content_str = str(msg.get('content', '')).strip()
|
| 203 |
+
if content_str: # Avoid adding empty messages
|
| 204 |
+
text_parts.append(f"{role_tag}: {content_str}")
|
| 205 |
+
text_parts.append("--- End of Conversation History ---")
|
| 206 |
+
text_parts.append("\n--- Current Task ---") # Clear separator
|
| 207 |
+
# Explicit instruction referencing history and audio
|
| 208 |
+
text_parts.append("Based on the conversation history above and the accompanying audio input, please respond to the following request:")
|
| 209 |
+
else:
|
| 210 |
+
# No history, just provide the current prompt directly
|
| 211 |
+
print(f" (Row {row_idx}, Slot {slot_idx}) No history found. Using prompt directly.")
|
| 212 |
+
text_parts.append("--- Current Task ---")
|
| 213 |
+
text_parts.append("Please respond to the following request based on the accompanying audio input:")
|
| 214 |
+
|
| 215 |
+
# Add the user's actual current request text
|
| 216 |
+
if prompt_text: # Only add if not empty
|
| 217 |
+
text_parts.append(prompt_text.strip())
|
| 218 |
+
|
| 219 |
+
combined_user_text = "\n".join(text_parts)
|
| 220 |
+
# --- End Flattening Logic ---
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# 3. Construct User Message Content List (Text + Audio)
|
| 224 |
+
user_content_list = [
|
| 225 |
+
{"type": "text", "text": combined_user_text}, # Use the combined text
|
| 226 |
+
{"type": "input_audio", "input_audio": {"data": base64_audio_data, "format": input_audio_format}}
|
| 227 |
+
]
|
| 228 |
+
|
| 229 |
+
# 4. Construct Final Messages List (Only the single combined user message)
|
| 230 |
+
# This replaces the user's original 'messages = history_messages + [{"role": "user", "content": user_content_list}]'
|
| 231 |
+
messages = [{"role": "user", "content": user_content_list}]
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# 5. Make API Call (User's Original Logic)
|
| 235 |
+
# Optional: print(f"Debug (Row {row_idx}, Slot {slot_idx}): Sending messages structure:\n{json.dumps(messages, indent=2, ensure_ascii=False)}")
|
| 236 |
+
response = local_glm_client.chat.completions.create(
|
| 237 |
+
model=GLM_MODEL_NAME,
|
| 238 |
+
messages=messages, # Send the single, combined user message
|
| 239 |
+
stream=False
|
| 240 |
+
# Add other parameters like temperature if the user had them originally (they didn't)
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# 6. Process SUCCESSFUL Response (User's Original Logic -unchanged-)
|
| 244 |
+
if response and response.choices:
|
| 245 |
+
message = response.choices[0].message
|
| 246 |
+
collected_text = message.content
|
| 247 |
+
audio_info = getattr(message, 'audio', None) # Use getattr for safety as per user's original code
|
| 248 |
+
if audio_info and 'data' in audio_info:
|
| 249 |
+
audio_base64_string = audio_info['data']
|
| 250 |
+
try:
|
| 251 |
+
decoded_data = base64.b64decode(audio_base64_string)
|
| 252 |
+
if len(decoded_data) == 0: # Check after decode (User's Original Check)
|
| 253 |
+
print(f"Warning (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): GLM returned empty audio data.")
|
| 254 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None}
|
| 255 |
+
|
| 256 |
+
os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True)
|
| 257 |
+
# Soundfile saving logic (User's Original Logic -unchanged-)
|
| 258 |
+
with BytesIO(decoded_data) as bio:
|
| 259 |
+
try:
|
| 260 |
+
audio_data, samplerate = sf.read(bio, dtype='int16')
|
| 261 |
+
except Exception:
|
| 262 |
+
bio.seek(0) # Rewind buffer before trying float
|
| 263 |
+
try:
|
| 264 |
+
audio_data_float, samplerate = sf.read(bio, dtype='float32')
|
| 265 |
+
# Convert float to int16
|
| 266 |
+
audio_data = (audio_data_float * 32767).astype(np.int16)
|
| 267 |
+
except Exception as sf_read_err_float:
|
| 268 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Soundfile failed to read audio data: {sf_read_err_float}")
|
| 269 |
+
# Return text, audio failed
|
| 270 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None}
|
| 271 |
+
|
| 272 |
+
# Use detected samplerate, fallback to configured rate if detection failed
|
| 273 |
+
write_samplerate = samplerate if samplerate > 0 else OUTPUT_AUDIO_SAMPLERATE
|
| 274 |
+
sf.write(output_audio_filepath, audio_data, write_samplerate)
|
| 275 |
+
|
| 276 |
+
# TASK SUCCEEDED!
|
| 277 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": output_audio_filepath}
|
| 278 |
+
|
| 279 |
+
except base64.binascii.Error as b64_e:
|
| 280 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): GLM b64 decode failed: {b64_e}")
|
| 281 |
+
# Return text, audio failed
|
| 282 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None}
|
| 283 |
+
except Exception as e:
|
| 284 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Saving GLM audio failed: {e}")
|
| 285 |
+
# Return text, audio failed
|
| 286 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None}
|
| 287 |
+
else: # No audio in successful text response (User's Original Logic)
|
| 288 |
+
print(f"Warning (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): No audio data in GLM response.")
|
| 289 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None}
|
| 290 |
+
else: # Invalid/empty successful response (User's Original Logic)
|
| 291 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Invalid/empty GLM API response. Response: {response}")
|
| 292 |
+
# Treat as a retryable error for the task
|
| 293 |
+
retries += 1
|
| 294 |
+
time.sleep(API_RETRY_DELAY)
|
| 295 |
+
continue # Go to next iteration of while loop
|
| 296 |
+
|
| 297 |
+
# --- Handle API Errors (User's Original Logic -unchanged-) ---
|
| 298 |
+
except APIStatusError as e:
|
| 299 |
+
# --- Log the error details ---
|
| 300 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): APIStatusError Encountered")
|
| 301 |
+
print(f" Status Code: {getattr(e, 'status_code', 'N/A')}") # Use getattr for safety
|
| 302 |
+
error_details = getattr(e, 'body', getattr(e, 'message', str(e)))
|
| 303 |
+
print(f" Error Details: {error_details}")
|
| 304 |
+
# --- End Logging ---
|
| 305 |
+
|
| 306 |
+
# Check for the specific "account overdue" error (User's Original Logic)
|
| 307 |
+
is_overdue_error = False
|
| 308 |
+
status_code = getattr(e, 'status_code', None)
|
| 309 |
+
# Adjust check to handle both 429 and potential 400 errors with code 1113 in body
|
| 310 |
+
if status_code == 429 or (status_code == 400 and '1113' in str(error_details)):
|
| 311 |
+
try:
|
| 312 |
+
error_body = {}
|
| 313 |
+
# Try parsing if details look like JSON
|
| 314 |
+
if isinstance(error_details, (str, bytes)) and error_details.strip().startswith('{'):
|
| 315 |
+
error_body = json.loads(error_details)
|
| 316 |
+
elif isinstance(error_details, dict):
|
| 317 |
+
error_body = error_details # If body is already a dict
|
| 318 |
+
|
| 319 |
+
if isinstance(error_body, dict) and str(error_body.get("error", {}).get("code", "")) == "1113":
|
| 320 |
+
is_overdue_error = True
|
| 321 |
+
except (json.JSONDecodeError, AttributeError):
|
| 322 |
+
# Can't parse body or access attributes, assume not the specific error for safety
|
| 323 |
+
pass
|
| 324 |
+
except Exception as parse_err:
|
| 325 |
+
print(f"Warning: Error parsing API error body: {parse_err}")
|
| 326 |
+
|
| 327 |
+
if is_overdue_error:
|
| 328 |
+
key_to_disable = current_api_key
|
| 329 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Account overdue (1113) for Key ...{key_to_disable[-6:]}. Disabling key.")
|
| 330 |
+
with key_lock:
|
| 331 |
+
disabled_keys.add(key_to_disable)
|
| 332 |
+
print(f" Disabled keys count: {len(disabled_keys)}/{len(ZHIPUAI_API_KEYS)}")
|
| 333 |
+
|
| 334 |
+
# Don't increment retries here, try getting a new key immediately
|
| 335 |
+
try:
|
| 336 |
+
current_api_key = get_next_active_key() # Get a new key
|
| 337 |
+
print(f" (Row {row_idx}, Slot {slot_idx}) Switched to new key ...{current_api_key[-6:]} for next attempt.")
|
| 338 |
+
local_glm_client = None # Force re-initialization with new key
|
| 339 |
+
continue # Go immediately to the next iteration of the while loop with the new key
|
| 340 |
+
except AllKeysDisabledError:
|
| 341 |
+
print(f"FATAL (Row {row_idx}, Slot {slot_idx}): All API keys are disabled after key ...{key_to_disable[-6:]} failed. Cannot retry task.")
|
| 342 |
+
# Return failure for this task as no keys are left
|
| 343 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: All Keys Disabled]", "saved_audio_path": None}
|
| 344 |
+
|
| 345 |
+
else:
|
| 346 |
+
# Other APIStatusError (rate limit, server error, etc.) - treat as retryable
|
| 347 |
+
retries += 1
|
| 348 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): GLM API Call Attempt {retries}/{API_MAX_RETRIES} failed: HTTP {status_code}, {error_details}")
|
| 349 |
+
if retries < API_MAX_RETRIES:
|
| 350 |
+
time.sleep(API_RETRY_DELAY)
|
| 351 |
+
# Continue loop to retry with the *same* key (unless it was just disabled above)
|
| 352 |
+
continue
|
| 353 |
+
else:
|
| 354 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Max retries reached after API error.")
|
| 355 |
+
# Return failure for the task
|
| 356 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Max retries after status error]", "saved_audio_path": None}
|
| 357 |
+
|
| 358 |
+
except Exception as e:
|
| 359 |
+
# Handle other unexpected errors during API call or processing (User's Original Logic)
|
| 360 |
+
retries += 1
|
| 361 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Unexpected Error Attempt {retries}/{API_MAX_RETRIES}: {type(e).__name__} - {e}")
|
| 362 |
+
print(traceback.format_exc()) # Print traceback for unexpected errors
|
| 363 |
+
if retries < API_MAX_RETRIES:
|
| 364 |
+
time.sleep(API_RETRY_DELAY)
|
| 365 |
+
continue # Continue loop to retry
|
| 366 |
+
else:
|
| 367 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Max retries reached after unexpected error.")
|
| 368 |
+
# Return failure for the task
|
| 369 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Max retries after unexpected error]", "saved_audio_path": None}
|
| 370 |
+
|
| 371 |
+
# If loop finishes without returning, max retries were hit (User's Original Logic)
|
| 372 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Task failed after {API_MAX_RETRIES} attempts (may include key switches).")
|
| 373 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Max retries reached]", "saved_audio_path": None}
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
# --- Main Processing Logic (User's Original Logic -unchanged-) ---
|
| 377 |
+
|
| 378 |
+
print("Loading dataset...")
|
| 379 |
+
try:
|
| 380 |
+
dataset = load_from_disk(INPUT_DATASET_DIR)
|
| 381 |
+
print(f"Dataset loaded successfully with {len(dataset)} rows from {INPUT_DATASET_DIR}.")
|
| 382 |
+
except Exception as e:
|
| 383 |
+
print(f"FATAL: Error loading dataset from {INPUT_DATASET_DIR}: {e}")
|
| 384 |
+
exit(1)
|
| 385 |
+
|
| 386 |
+
os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True)
|
| 387 |
+
|
| 388 |
+
# --- Pre-calculation Step for GLM (User's Original Logic -unchanged-) ---
|
| 389 |
+
print("Pre-calculating GLM tasks and assigning initial API keys...")
|
| 390 |
+
tasks_to_process = []
|
| 391 |
+
original_data = list(dataset) # Convert to list for easier updates later
|
| 392 |
+
initial_keys_available = True
|
| 393 |
+
|
| 394 |
+
for idx, row in enumerate(tqdm(original_data, desc="Scanning dataset for GLM tasks")):
|
| 395 |
+
if not initial_keys_available:
|
| 396 |
+
# Stop scanning if we know no keys are left
|
| 397 |
+
print("Stopping task scanning as no active keys are available.")
|
| 398 |
+
break
|
| 399 |
+
|
| 400 |
+
for i in range(1, 4):
|
| 401 |
+
model_key = f"model_{i}"
|
| 402 |
+
response_text_key = f"response_text_{i}"
|
| 403 |
+
prompt_text_key = f"prompt_text_{i}"
|
| 404 |
+
model_assigned = row.get(model_key)
|
| 405 |
+
# Check if response exists and is not empty string (User's original check was just existence)
|
| 406 |
+
response_text_exists = row.get(response_text_key) is not None and str(row.get(response_text_key)).strip() != ""
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
if model_assigned == "glm_voice" and not response_text_exists: # Check using configured model name
|
| 410 |
+
question_audio_path = row.get('question_audio')
|
| 411 |
+
# Add check if audio path exists on disk
|
| 412 |
+
if not question_audio_path or not os.path.exists(question_audio_path):
|
| 413 |
+
print(f"Warning (Row {idx}, Slot {i}): Skipping GLM task - Missing or invalid 'question_audio' path: {question_audio_path}")
|
| 414 |
+
continue # Skip this slot if audio is missing
|
| 415 |
+
|
| 416 |
+
# --- Get initial active API key (User's Original Logic) ---
|
| 417 |
+
try:
|
| 418 |
+
assigned_key = get_next_active_key()
|
| 419 |
+
except AllKeysDisabledError:
|
| 420 |
+
print("FATAL: All API keys are disabled during initial task scanning. Cannot proceed.")
|
| 421 |
+
initial_keys_available = False
|
| 422 |
+
break # Stop processing this row
|
| 423 |
+
except Exception as key_err:
|
| 424 |
+
print(f"FATAL: Error getting initial API key: {key_err}. Stopping.")
|
| 425 |
+
initial_keys_available = False
|
| 426 |
+
break
|
| 427 |
+
# ---
|
| 428 |
+
|
| 429 |
+
metadata_str = row.get('metadata', "{}")
|
| 430 |
+
source_dataset = row.get('source_dataset')
|
| 431 |
+
metadata = {}
|
| 432 |
+
try:
|
| 433 |
+
# Handle case where metadata might already be a dict or is a JSON string
|
| 434 |
+
if metadata_str and isinstance(metadata_str, str): metadata = json.loads(metadata_str)
|
| 435 |
+
elif isinstance(metadata_str, dict): metadata = metadata_str
|
| 436 |
+
except json.JSONDecodeError:
|
| 437 |
+
print(f"Warning (Row {idx}): Could not parse metadata string: {metadata_str[:100]}...")
|
| 438 |
+
pass # Continue with empty metadata
|
| 439 |
+
|
| 440 |
+
# Parse history here - it will be flattened later in the worker
|
| 441 |
+
history_messages = []
|
| 442 |
+
if source_dataset == 'ultra':
|
| 443 |
+
history_str = metadata.get('history', '')
|
| 444 |
+
if history_str: history_messages = parse_ultra_history(history_str)
|
| 445 |
+
|
| 446 |
+
unique_id = str(uuid.uuid4()).replace("-", "")
|
| 447 |
+
output_audio_filename = f"glm_r{idx}_s{i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}"
|
| 448 |
+
output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename)
|
| 449 |
+
|
| 450 |
+
task_info = {
|
| 451 |
+
"row_idx": idx,
|
| 452 |
+
"slot_idx": i,
|
| 453 |
+
"api_key": assigned_key, # Initial key
|
| 454 |
+
"history_messages": history_messages, # Pass the original parsed history
|
| 455 |
+
"prompt_text": row.get(prompt_text_key, ""),
|
| 456 |
+
"question_audio_path": question_audio_path,
|
| 457 |
+
"output_audio_filepath": output_audio_filepath,
|
| 458 |
+
}
|
| 459 |
+
tasks_to_process.append(task_info)
|
| 460 |
+
# Process only the first unfilled GLM slot found per row (User's Implicit Logic)
|
| 461 |
+
break # Stop checking slots for this row
|
| 462 |
+
|
| 463 |
+
if not initial_keys_available: break # Exit outer loop too
|
| 464 |
+
|
| 465 |
+
total_tasks = len(tasks_to_process)
|
| 466 |
+
if total_tasks == 0:
|
| 467 |
+
if not initial_keys_available:
|
| 468 |
+
print("No tasks processed because all initial keys were disabled.")
|
| 469 |
+
else:
|
| 470 |
+
print("No GLM Voice tasks found needing processing.")
|
| 471 |
+
exit(0)
|
| 472 |
+
|
| 473 |
+
print(f"Found {total_tasks} GLM Voice tasks to process using initially {len(ZHIPUAI_API_KEYS)} API keys.")
|
| 474 |
+
if len(disabled_keys) > 0: # Should be 0 here, but for safety
|
| 475 |
+
print(f"Note: {len(disabled_keys)} keys already marked as disabled (should not happen at this stage).")
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
# --- Threaded Execution for GLM (User's Original Logic -unchanged-) ---
|
| 479 |
+
print(f"Starting GLM processing with up to {MAX_WORKERS} worker threads...")
|
| 480 |
+
start_total_time = time.time()
|
| 481 |
+
results = {}
|
| 482 |
+
tasks_completed = 0
|
| 483 |
+
tasks_failed = 0
|
| 484 |
+
executor_shutdown = False # Flag to stop submitting new tasks if all keys die
|
| 485 |
+
|
| 486 |
+
# Use context manager for ThreadPoolExecutor
|
| 487 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
| 488 |
+
# Create futures mapping back to task info for easier result merging
|
| 489 |
+
future_to_task = {executor.submit(call_glm_voice_api_worker, task): task for task in tasks_to_process}
|
| 490 |
+
|
| 491 |
+
for future in tqdm(concurrent.futures.as_completed(future_to_task), total=total_tasks, desc="Processing GLM tasks"):
|
| 492 |
+
task = future_to_task[future] # Get the original task info associated with this future
|
| 493 |
+
row_idx = task["row_idx"]
|
| 494 |
+
slot_idx = task["slot_idx"]
|
| 495 |
+
try:
|
| 496 |
+
result = future.result() # Get the result from the worker
|
| 497 |
+
results[(row_idx, slot_idx)] = result # Store result using (row, slot) tuple as key
|
| 498 |
+
|
| 499 |
+
# Check if the task failed because all keys got disabled during its execution
|
| 500 |
+
if result["response_text"] == "[ERROR: All Keys Disabled]" and not executor_shutdown:
|
| 501 |
+
print("\n--- CRITICAL: All Keys Disabled detected during execution. Stopping submission of new tasks. ---")
|
| 502 |
+
# Potentially cancel remaining futures if possible/desired
|
| 503 |
+
# Note: Standard ThreadPoolExecutor doesn't easily support cancelling submitted tasks
|
| 504 |
+
# We will just let running tasks finish but won't submit new ones if we had that logic.
|
| 505 |
+
# For now, just set flag and log.
|
| 506 |
+
executor_shutdown = True # Prevent theoretical resubmission logic
|
| 507 |
+
tasks_failed += 1 # Count this task as failed
|
| 508 |
+
# Check for other errors in the result text or missing audio path
|
| 509 |
+
elif result["saved_audio_path"] is None or "[ERROR" in result["response_text"]:
|
| 510 |
+
tasks_failed += 1
|
| 511 |
+
tasks_completed += 1
|
| 512 |
+
|
| 513 |
+
except Exception as exc: # Catch exceptions raised *by* the future (e.g., if worker itself crashes)
|
| 514 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): GLM Task generated an unhandled exception: {exc}")
|
| 515 |
+
print(traceback.format_exc())
|
| 516 |
+
# Store an error result
|
| 517 |
+
results[(row_idx, slot_idx)] = {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": f"[ERROR: Worker Crash - {type(exc).__name__}]", "saved_audio_path": None}
|
| 518 |
+
tasks_failed += 1
|
| 519 |
+
tasks_completed += 1
|
| 520 |
+
# No finally block needed here unless cleaning up future_to_task is desired
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
end_total_time = time.time()
|
| 524 |
+
print("\n--- GLM Processing Complete ---")
|
| 525 |
+
print(f"Total GLM tasks attempted: {tasks_completed} (Succeeded: {tasks_completed - tasks_failed}, Failed: {tasks_failed})")
|
| 526 |
+
print(f"Final disabled key count: {len(disabled_keys)}/{len(ZHIPUAI_API_KEYS)}")
|
| 527 |
+
print(f"Total GLM processing time: {(end_total_time - start_total_time)/60:.2f} minutes")
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
# --- Merge Results back into the dataset structure (User's Original Logic -unchanged-) ---
|
| 531 |
+
print("Merging GLM results...")
|
| 532 |
+
updated_data = original_data # Use the list created earlier
|
| 533 |
+
for (row_idx, slot_idx), result in tqdm(results.items(), desc="Merging GLM results"):
|
| 534 |
+
response_text_key = f"response_text_{slot_idx}"
|
| 535 |
+
response_audio_key = f"response_audio_path_{slot_idx}"
|
| 536 |
+
# Check index validity before updating
|
| 537 |
+
if 0 <= row_idx < len(updated_data):
|
| 538 |
+
# Ensure the item at the index is a dictionary (it should be if loaded from dataset)
|
| 539 |
+
if isinstance(updated_data[row_idx], dict):
|
| 540 |
+
updated_data[row_idx][response_text_key] = result["response_text"]
|
| 541 |
+
updated_data[row_idx][response_audio_key] = result["saved_audio_path"]
|
| 542 |
+
else:
|
| 543 |
+
print(f"Warning: Item at index {row_idx} is not a dictionary. Skipping merge for Slot {slot_idx}.")
|
| 544 |
+
else:
|
| 545 |
+
print(f"Warning: Invalid row index {row_idx} encountered during GLM result merge.")
|
| 546 |
+
|
| 547 |
+
# --- Save the final updated dataset (User's Original Logic -unchanged, including fallback) ---
|
| 548 |
+
if updated_data:
|
| 549 |
+
print(f"\nSaving updated dataset with GLM results to {OUTPUT_DATASET_DIR}...")
|
| 550 |
+
try:
|
| 551 |
+
# Use the features from the original loaded dataset if available
|
| 552 |
+
updated_dataset = Dataset.from_list(updated_data, features=dataset.features if dataset else None)
|
| 553 |
+
updated_dataset.save_to_disk(OUTPUT_DATASET_DIR)
|
| 554 |
+
print("Updated dataset saved successfully.")
|
| 555 |
+
except Exception as final_save_e:
|
| 556 |
+
print(f"Error saving final dataset using datasets lib: {final_save_e}")
|
| 557 |
+
print(f"Final disabled key count at save: {len(disabled_keys)}/{len(ZHIPUAI_API_KEYS)}")
|
| 558 |
+
print("Attempting to save as JSON lines as fallback...")
|
| 559 |
+
# Fallback to JSON Lines (User's original fallback logic)
|
| 560 |
+
output_jsonl_path = OUTPUT_DATASET_DIR.rstrip('/') + ".jsonl" # Ensure no trailing slash before adding extension
|
| 561 |
+
try:
|
| 562 |
+
with open(output_jsonl_path, 'w', encoding='utf-8') as f:
|
| 563 |
+
for item in updated_data:
|
| 564 |
+
# Attempt to make item JSON serializable
|
| 565 |
+
serializable_item = {}
|
| 566 |
+
for k, v in item.items():
|
| 567 |
+
if isinstance(v, (str, int, float, bool, list, dict)) or v is None:
|
| 568 |
+
serializable_item[k] = v
|
| 569 |
+
elif isinstance(v, np.ndarray):
|
| 570 |
+
serializable_item[k] = v.tolist() # Convert numpy arrays
|
| 571 |
+
else:
|
| 572 |
+
serializable_item[k] = str(v) # Convert other types to string as fallback
|
| 573 |
+
f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n')
|
| 574 |
+
print(f"Fallback save successful to {output_jsonl_path}")
|
| 575 |
+
except Exception as json_save_e:
|
| 576 |
+
print(f"Error saving as JSON lines: {json_save_e}")
|
| 577 |
+
|
| 578 |
+
else:
|
| 579 |
+
print("No data was available to save (potentially all keys disabled early or no tasks processed).")
|
r1-a/response_generation/gpt4o.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import base64
|
| 4 |
+
import uuid
|
| 5 |
+
import time
|
| 6 |
+
import re
|
| 7 |
+
import random
|
| 8 |
+
import concurrent.futures
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import threading
|
| 11 |
+
import traceback # For detailed error logging
|
| 12 |
+
|
| 13 |
+
import requests # Use requests library for HTTP calls
|
| 14 |
+
# Make sure numpy is imported if needed for potential fallback serialization
|
| 15 |
+
import numpy as np
|
| 16 |
+
from datasets import load_from_disk, Dataset
|
| 17 |
+
from dotenv import load_dotenv
|
| 18 |
+
|
| 19 |
+
# --- Configuration ---
|
| 20 |
+
load_dotenv()
|
| 21 |
+
|
| 22 |
+
# 1. API Client Setup
|
| 23 |
+
GPT4O_MODEL_NAME = "gpt4o" # How it's identified in your dataset's model columns
|
| 24 |
+
API_MODEL_NAME = "gpt-4o-audio-preview" # Actual model name for the API call
|
| 25 |
+
API_ENDPOINT = "https://api.vansai.cn/v1/chat/completions"
|
| 26 |
+
try:
|
| 27 |
+
# Assuming a single key for this service based on the original script
|
| 28 |
+
API_TOKEN = "sk-uOJ27X9jNsYh1PDx1e665b0f92434bEc9bD53bE6D3BaD29a"
|
| 29 |
+
if not API_TOKEN:
|
| 30 |
+
raise ValueError("AIGCBEST_API_KEY environment variable not set.")
|
| 31 |
+
print("AIGCBEST API Key loaded.")
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"FATAL: Error getting API Key: {e}")
|
| 34 |
+
exit(1)
|
| 35 |
+
|
| 36 |
+
# 2. Dataset Paths
|
| 37 |
+
INPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_sampling_tasks"
|
| 38 |
+
OUTPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_tasks_with_gpt4o"
|
| 39 |
+
|
| 40 |
+
# 3. Output Audio Configuration
|
| 41 |
+
OUTPUT_AUDIO_ROOT_DIR = "/root/autodl-tmp/audio-r1/r1-a/generated_audio/gpt4o_2"
|
| 42 |
+
OUTPUT_AUDIO_FORMAT = "wav" # API will be requested to return wav
|
| 43 |
+
AVAILABLE_VOICES = ['alloy', 'ash', 'ballad', 'coral', 'echo', 'sage', 'shimmer', 'verse']
|
| 44 |
+
|
| 45 |
+
# 4. API Call Settings
|
| 46 |
+
API_TIMEOUT = 120
|
| 47 |
+
API_RETRY_DELAY = 5
|
| 48 |
+
API_MAX_RETRIES = 3 # Max attempts *for the task*
|
| 49 |
+
MAX_WORKERS = 8 # Adjust based on API rate limits and system resources
|
| 50 |
+
|
| 51 |
+
# 5. Checkpoint Saving Configuration # <-- NEW
|
| 52 |
+
CHECKPOINT_INTERVAL = 500 # Save every 500 completed tasks
|
| 53 |
+
|
| 54 |
+
# --- Helper Functions (encode_audio_base64 and parse_ultra_history remain the same) ---
|
| 55 |
+
|
| 56 |
+
def encode_audio_base64(audio_path):
|
| 57 |
+
if not audio_path or not os.path.exists(audio_path):
|
| 58 |
+
print(f"Warning: Input audio file not found or path is empty: {audio_path}")
|
| 59 |
+
return None
|
| 60 |
+
try:
|
| 61 |
+
with open(audio_path, "rb") as audio_file:
|
| 62 |
+
return base64.b64encode(audio_file.read()).decode("utf-8")
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"Error encoding audio file {audio_path}: {e}")
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
def parse_ultra_history(history_str):
|
| 68 |
+
messages = []
|
| 69 |
+
pattern = re.compile(r"\[(USER|ASSISTANT)\]\s*([\s\S]*?)(?=\s*\[(?:USER|ASSISTANT)\]|$)")
|
| 70 |
+
matches = pattern.findall(history_str)
|
| 71 |
+
if not matches:
|
| 72 |
+
return []
|
| 73 |
+
for role_tag, content in matches:
|
| 74 |
+
role = role_tag.lower()
|
| 75 |
+
cleaned_content = content.strip()
|
| 76 |
+
if cleaned_content:
|
| 77 |
+
messages.append({"role": role, "content": cleaned_content})
|
| 78 |
+
return messages
|
| 79 |
+
|
| 80 |
+
# --- Modified API Call Worker Function for GPT-4o (Reduced Prints) ---
|
| 81 |
+
def call_gpt4o_api_worker(task_info):
|
| 82 |
+
"""
|
| 83 |
+
Worker function to call the custom GPT-4o API for a single task.
|
| 84 |
+
"""
|
| 85 |
+
row_idx = task_info["row_idx"]
|
| 86 |
+
slot_idx = task_info["slot_idx"]
|
| 87 |
+
history_messages = task_info["history_messages"]
|
| 88 |
+
prompt_text = task_info["prompt_text"]
|
| 89 |
+
question_text = task_info["question_text"]
|
| 90 |
+
question_audio_path = task_info["question_audio_path"]
|
| 91 |
+
output_audio_filepath = task_info["output_audio_filepath"]
|
| 92 |
+
|
| 93 |
+
retries = 0
|
| 94 |
+
headers = {
|
| 95 |
+
'Accept': 'application/json',
|
| 96 |
+
'Authorization': f'Bearer {API_TOKEN}', # Use the single loaded token
|
| 97 |
+
'Content-Type': 'application/json'
|
| 98 |
+
}
|
| 99 |
+
selected_voice = random.choice(AVAILABLE_VOICES)
|
| 100 |
+
# print(f" [Thread-{threading.get_ident()}] Processing Row {row_idx}, Slot {slot_idx} (GPT4o Voice: {selected_voice})") # Optional log
|
| 101 |
+
|
| 102 |
+
while retries < API_MAX_RETRIES:
|
| 103 |
+
try:
|
| 104 |
+
# 1. Prepare Input Audio
|
| 105 |
+
base64_audio_data = encode_audio_base64(question_audio_path)
|
| 106 |
+
if not base64_audio_data:
|
| 107 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Skipping GPT4o API call - missing input audio.")
|
| 108 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: Missing input audio]", "saved_audio_path": None}
|
| 109 |
+
|
| 110 |
+
input_audio_format = os.path.splitext(question_audio_path)[1].lstrip('.') or 'wav'
|
| 111 |
+
|
| 112 |
+
# 2. Construct User Message Content
|
| 113 |
+
combined_text = f"{prompt_text}"
|
| 114 |
+
user_content_list = [
|
| 115 |
+
{"type": "text", "text": combined_text},
|
| 116 |
+
{"type": "input_audio", "input_audio": {"data": base64_audio_data, "format": input_audio_format}}
|
| 117 |
+
]
|
| 118 |
+
messages = history_messages + [{"role": "user", "content": user_content_list}]
|
| 119 |
+
|
| 120 |
+
# 4. Construct Payload
|
| 121 |
+
payload = {
|
| 122 |
+
"model": API_MODEL_NAME,
|
| 123 |
+
"modalities": ["text", "audio"],
|
| 124 |
+
"audio": {"voice": selected_voice, "format": OUTPUT_AUDIO_FORMAT},
|
| 125 |
+
"messages": messages
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
# 5. Make API Call
|
| 129 |
+
response = requests.post(
|
| 130 |
+
API_ENDPOINT,
|
| 131 |
+
headers=headers,
|
| 132 |
+
json=payload,
|
| 133 |
+
timeout=API_TIMEOUT
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# 6. Process Response
|
| 137 |
+
if response.status_code == 200:
|
| 138 |
+
try:
|
| 139 |
+
response_data = response.json()
|
| 140 |
+
# Make parsing more robust
|
| 141 |
+
choices = response_data.get('choices')
|
| 142 |
+
if not choices or not isinstance(choices, list) or len(choices) == 0:
|
| 143 |
+
raise ValueError("Invalid or empty 'choices' field in response.")
|
| 144 |
+
|
| 145 |
+
message_content = choices[0].get('message', {})
|
| 146 |
+
if not message_content:
|
| 147 |
+
raise ValueError("Missing 'message' field in the first choice.")
|
| 148 |
+
|
| 149 |
+
audio_info = message_content.get('audio', {})
|
| 150 |
+
if not isinstance(audio_info, dict): audio_info = {} # Handle case where audio might be null or not a dict
|
| 151 |
+
|
| 152 |
+
audio_base64_string = audio_info.get('data', '')
|
| 153 |
+
# Try getting text from 'content' if 'transcript' is missing/empty in 'audio'
|
| 154 |
+
collected_text = audio_info.get('transcript', '').strip()
|
| 155 |
+
if not collected_text:
|
| 156 |
+
text_content_list = message_content.get('content', [])
|
| 157 |
+
if isinstance(text_content_list, list):
|
| 158 |
+
for item in text_content_list:
|
| 159 |
+
if isinstance(item, dict) and item.get("type") == "text":
|
| 160 |
+
collected_text = item.get("text", "").strip()
|
| 161 |
+
break # Take the first text part found
|
| 162 |
+
# Still no text? Try the top-level message content directly if it's a string
|
| 163 |
+
elif isinstance(message_content.get('content'), str):
|
| 164 |
+
collected_text = message_content['content'].strip()
|
| 165 |
+
|
| 166 |
+
if not collected_text: print(f"Warning (Row {row_idx}, Slot {slot_idx}): No text content found after checking multiple fields.")
|
| 167 |
+
if not audio_base64_string: print(f"Warning (Row {row_idx}, Slot {slot_idx}): No audio data found.")
|
| 168 |
+
|
| 169 |
+
saved_audio_path = None
|
| 170 |
+
if audio_base64_string:
|
| 171 |
+
try:
|
| 172 |
+
wav_bytes = base64.b64decode(audio_base64_string)
|
| 173 |
+
if len(wav_bytes) == 0:
|
| 174 |
+
print(f"Warning (Row {row_idx}, Slot {slot_idx}): Decoded audio bytes are empty.")
|
| 175 |
+
else:
|
| 176 |
+
os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True)
|
| 177 |
+
with open(output_audio_filepath, "wb") as f:
|
| 178 |
+
f.write(wav_bytes)
|
| 179 |
+
saved_audio_path = output_audio_filepath
|
| 180 |
+
# print(f" Audio saved to: {output_audio_filepath}") # Less verbose log
|
| 181 |
+
except base64.binascii.Error as b64_err:
|
| 182 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Decoding base64 audio data failed: {b64_err}")
|
| 183 |
+
except Exception as e:
|
| 184 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Saving audio file failed: {e}")
|
| 185 |
+
|
| 186 |
+
# TASK SUCCEEDED (even if audio saving failed, text might be valid)
|
| 187 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text, "saved_audio_path": saved_audio_path}
|
| 188 |
+
|
| 189 |
+
except (json.JSONDecodeError, IndexError, KeyError, TypeError, ValueError) as e:
|
| 190 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Parsing successful API response failed: {type(e).__name__} - {e}")
|
| 191 |
+
print(f" Response Text (start): {response.text[:500]}...")
|
| 192 |
+
retries += 1
|
| 193 |
+
print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
|
| 194 |
+
time.sleep(API_RETRY_DELAY)
|
| 195 |
+
continue
|
| 196 |
+
except Exception as e: # Catch-all for unexpected errors during processing
|
| 197 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Unexpected error processing response: {e}")
|
| 198 |
+
print(traceback.format_exc())
|
| 199 |
+
retries += 1
|
| 200 |
+
print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
|
| 201 |
+
time.sleep(API_RETRY_DELAY)
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
else: # Handle non-200 status codes
|
| 205 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): API returned status {response.status_code}. Response: {response.text[:500]}...")
|
| 206 |
+
retries += 1
|
| 207 |
+
if retries < API_MAX_RETRIES:
|
| 208 |
+
print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
|
| 209 |
+
time.sleep(API_RETRY_DELAY)
|
| 210 |
+
continue # Go to next iteration of while loop
|
| 211 |
+
else:
|
| 212 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after status {response.status_code}.")
|
| 213 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": f"[API ERROR: Status {response.status_code}]", "saved_audio_path": None}
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
except requests.exceptions.Timeout:
|
| 217 |
+
retries += 1
|
| 218 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): API Call Attempt {retries}/{API_MAX_RETRIES} timed out after {API_TIMEOUT}s.")
|
| 219 |
+
if retries < API_MAX_RETRIES:
|
| 220 |
+
time.sleep(API_RETRY_DELAY)
|
| 221 |
+
continue
|
| 222 |
+
else:
|
| 223 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after timeout.")
|
| 224 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Timeout]", "saved_audio_path": None}
|
| 225 |
+
|
| 226 |
+
except requests.exceptions.RequestException as e:
|
| 227 |
+
retries += 1
|
| 228 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Network/Request Error Attempt {retries}/{API_MAX_RETRIES}: {e}")
|
| 229 |
+
if retries < API_MAX_RETRIES:
|
| 230 |
+
time.sleep(API_RETRY_DELAY)
|
| 231 |
+
continue
|
| 232 |
+
else:
|
| 233 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after network error.")
|
| 234 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Network Error]", "saved_audio_path": None}
|
| 235 |
+
|
| 236 |
+
except Exception as e:
|
| 237 |
+
retries += 1
|
| 238 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Unexpected Error in Worker Loop Attempt {retries}/{API_MAX_RETRIES}: {type(e).__name__} - {e}")
|
| 239 |
+
print(traceback.format_exc())
|
| 240 |
+
if retries < API_MAX_RETRIES:
|
| 241 |
+
time.sleep(API_RETRY_DELAY)
|
| 242 |
+
continue
|
| 243 |
+
else:
|
| 244 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after unexpected error.")
|
| 245 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Unexpected Worker Error]", "saved_audio_path": None}
|
| 246 |
+
|
| 247 |
+
# If loop finishes without returning, max retries were hit
|
| 248 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Task failed after {API_MAX_RETRIES} attempts.")
|
| 249 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Max retries reached]", "saved_audio_path": None}
|
| 250 |
+
|
| 251 |
+
# --- Checkpoint Saving Function --- # <-- NEW (Copied from previous response)
|
| 252 |
+
def save_checkpoint(data_to_save, output_dir, dataset_features):
|
| 253 |
+
"""Saves the current state of the data to disk."""
|
| 254 |
+
if not data_to_save:
|
| 255 |
+
print("Checkpoint: No data available to save.")
|
| 256 |
+
return
|
| 257 |
+
|
| 258 |
+
# Ensure output directory exists before saving
|
| 259 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 260 |
+
|
| 261 |
+
print(f"\nCheckpoint: Saving {len(data_to_save)} rows to {output_dir}...")
|
| 262 |
+
try:
|
| 263 |
+
# Convert list of dicts back to Dataset object
|
| 264 |
+
checkpoint_dataset = Dataset.from_list(data_to_save, features=dataset_features)
|
| 265 |
+
checkpoint_dataset.save_to_disk(output_dir)
|
| 266 |
+
print(f"Checkpoint: Saved successfully to {output_dir}")
|
| 267 |
+
except Exception as ckpt_save_e:
|
| 268 |
+
print(f"Error saving checkpoint dataset using datasets lib: {ckpt_save_e}")
|
| 269 |
+
# Fallback to JSON Lines (optional, but good practice)
|
| 270 |
+
output_jsonl_path = os.path.join(output_dir, "checkpoint_data.jsonl") # Save inside the dir
|
| 271 |
+
print(f"Attempting to save checkpoint as JSON lines to {output_jsonl_path}...")
|
| 272 |
+
try:
|
| 273 |
+
with open(output_jsonl_path, 'w', encoding='utf-8') as f:
|
| 274 |
+
for item in data_to_save:
|
| 275 |
+
# Basic serialization handling for common types like numpy arrays
|
| 276 |
+
serializable_item = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in item.items()}
|
| 277 |
+
f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n')
|
| 278 |
+
print(f"Checkpoint: Fallback save successful to {output_jsonl_path}")
|
| 279 |
+
except Exception as json_save_e:
|
| 280 |
+
print(f"Error saving checkpoint as JSON lines: {json_save_e}")
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# --- Main Processing Logic ---
|
| 284 |
+
|
| 285 |
+
print("Checking for existing checkpoint/output dataset...")
|
| 286 |
+
dataset = None
|
| 287 |
+
original_features = None # Initialize
|
| 288 |
+
|
| 289 |
+
try:
|
| 290 |
+
# 检查输出目录是否存在,并且看起来像一个 Hugging Face datasets 目录
|
| 291 |
+
# (dataset_info.json 或 state.json 是常见的指示文件)
|
| 292 |
+
potential_checkpoint_info = os.path.join(OUTPUT_DATASET_DIR, "dataset_info.json")
|
| 293 |
+
potential_checkpoint_state = os.path.join(OUTPUT_DATASET_DIR, "state.json")
|
| 294 |
+
|
| 295 |
+
if os.path.exists(OUTPUT_DATASET_DIR) and \
|
| 296 |
+
(os.path.exists(potential_checkpoint_info) or os.path.exists(potential_checkpoint_state)):
|
| 297 |
+
|
| 298 |
+
print(f"Attempting to load existing data from output directory: {OUTPUT_DATASET_DIR}")
|
| 299 |
+
try:
|
| 300 |
+
dataset = load_from_disk(OUTPUT_DATASET_DIR)
|
| 301 |
+
original_features = dataset.features # 获取已保存数据集的特征
|
| 302 |
+
print(f"Successfully resumed from {OUTPUT_DATASET_DIR}. Loaded {len(dataset)} rows.")
|
| 303 |
+
except Exception as load_ckpt_e:
|
| 304 |
+
print(f"Warning: Failed to load from {OUTPUT_DATASET_DIR}: {load_ckpt_e}")
|
| 305 |
+
print("Falling back to loading original input dataset.")
|
| 306 |
+
dataset = None # Ensure we proceed to load original if checkpoint load failed
|
| 307 |
+
else:
|
| 308 |
+
print(f"No valid existing data found in {OUTPUT_DATASET_DIR}.")
|
| 309 |
+
# If no checkpoint, ensure dataset is None so original loading happens
|
| 310 |
+
|
| 311 |
+
# 如果 dataset 仍然是 None (因为没有找到 checkpoint 或加载失败)
|
| 312 |
+
if dataset is None:
|
| 313 |
+
print(f"Loading original dataset from {INPUT_DATASET_DIR}...")
|
| 314 |
+
dataset = load_from_disk(INPUT_DATASET_DIR)
|
| 315 |
+
original_features = dataset.features
|
| 316 |
+
print(f"Original dataset loaded successfully with {len(dataset)} rows.")
|
| 317 |
+
|
| 318 |
+
except Exception as initial_load_e:
|
| 319 |
+
print(f"FATAL: Error during initial dataset loading (original or checkpoint): {initial_load_e}")
|
| 320 |
+
print(traceback.format_exc()) # 打印详细错误
|
| 321 |
+
exit(1)
|
| 322 |
+
|
| 323 |
+
os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True)
|
| 324 |
+
|
| 325 |
+
# --- Pre-calculation Step for GPT-4o ---
|
| 326 |
+
print("Pre-calculating GPT-4o tasks...")
|
| 327 |
+
tasks_to_process = []
|
| 328 |
+
# Use a list of dictionaries, which is mutable and easier for direct updates
|
| 329 |
+
updated_data = list(dataset)
|
| 330 |
+
|
| 331 |
+
for idx, row in enumerate(tqdm(updated_data, desc="Scanning dataset for GPT-4o tasks")):
|
| 332 |
+
for i in range(1, 4):
|
| 333 |
+
model_key = f"model_{i}"
|
| 334 |
+
response_text_key = f"response_text_{i}"
|
| 335 |
+
prompt_text_key = f"prompt_text_{i}"
|
| 336 |
+
response_audio_key = f"response_audio_path_{i}" # Key for storing the *new* audio path
|
| 337 |
+
|
| 338 |
+
model_assigned = row.get(model_key)
|
| 339 |
+
response_text_exists = row.get(response_text_key) is not None
|
| 340 |
+
|
| 341 |
+
# Check for the specific model name used in the dataset
|
| 342 |
+
if model_assigned == GPT4O_MODEL_NAME and not response_text_exists:
|
| 343 |
+
question_audio_path = row.get('question_audio')
|
| 344 |
+
if not question_audio_path or not os.path.exists(question_audio_path): # Check path validity here
|
| 345 |
+
print(f"Warning (Row {idx}, Slot {i}): Skipping GPT-4o task - Missing or invalid 'question_audio' path: {question_audio_path}")
|
| 346 |
+
# Pre-fill error? Let's just skip task creation for now.
|
| 347 |
+
# If needed: updated_data[idx][response_text_key] = "[ERROR: Missing input audio]"
|
| 348 |
+
# If needed: updated_data[idx][response_audio_key] = None
|
| 349 |
+
continue # Skip this task
|
| 350 |
+
|
| 351 |
+
metadata_str = row.get('metadata', "{}")
|
| 352 |
+
source_dataset = row.get('source_dataset')
|
| 353 |
+
metadata = {}
|
| 354 |
+
try:
|
| 355 |
+
if metadata_str and isinstance(metadata_str, str): metadata = json.loads(metadata_str)
|
| 356 |
+
elif isinstance(metadata_str, dict): metadata = metadata_str
|
| 357 |
+
except json.JSONDecodeError: pass
|
| 358 |
+
|
| 359 |
+
history_messages = []
|
| 360 |
+
if source_dataset == 'ultra':
|
| 361 |
+
history_str = metadata.get('history', '')
|
| 362 |
+
if history_str: history_messages = parse_ultra_history(history_str)
|
| 363 |
+
|
| 364 |
+
unique_id = str(uuid.uuid4()).replace("-", "")
|
| 365 |
+
output_audio_filename = f"gpt4o_r{idx}_s{i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}"
|
| 366 |
+
output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename)
|
| 367 |
+
|
| 368 |
+
task_info = {
|
| 369 |
+
"row_idx": idx,
|
| 370 |
+
"slot_idx": i,
|
| 371 |
+
# No API key needed here as it's global/single
|
| 372 |
+
"history_messages": history_messages,
|
| 373 |
+
"prompt_text": row.get(prompt_text_key, ""),
|
| 374 |
+
"question_text": row.get('question_text', ""), # Pass question text
|
| 375 |
+
"question_audio_path": question_audio_path,
|
| 376 |
+
"output_audio_filepath": output_audio_filepath,
|
| 377 |
+
}
|
| 378 |
+
tasks_to_process.append(task_info)
|
| 379 |
+
# Decide if you process all slots or just the first unfilled one
|
| 380 |
+
# break # Uncomment this line if you only want the *first* unfilled gpt4o slot per row processed
|
| 381 |
+
|
| 382 |
+
total_tasks = len(tasks_to_process)
|
| 383 |
+
if total_tasks == 0:
|
| 384 |
+
print("No GPT-4o tasks found needing processing.")
|
| 385 |
+
exit(0)
|
| 386 |
+
|
| 387 |
+
print(f"Found {total_tasks} GPT-4o tasks to process.")
|
| 388 |
+
|
| 389 |
+
# --- Threaded Execution with Checkpointing for GPT-4o --- # <-- MODIFIED SECTION
|
| 390 |
+
print(f"Starting GPT-4o processing with up to {MAX_WORKERS} worker threads...")
|
| 391 |
+
start_total_time = time.time()
|
| 392 |
+
# results = {} # No longer needed
|
| 393 |
+
tasks_completed = 0
|
| 394 |
+
tasks_failed = 0
|
| 395 |
+
completed_since_last_save = 0 # <-- Counter for checkpointing
|
| 396 |
+
|
| 397 |
+
# Use context manager for ThreadPoolExecutor
|
| 398 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
| 399 |
+
future_to_task = {executor.submit(call_gpt4o_api_worker, task): task for task in tasks_to_process}
|
| 400 |
+
|
| 401 |
+
for future in tqdm(concurrent.futures.as_completed(future_to_task), total=total_tasks, desc="Processing GPT-4o tasks"):
|
| 402 |
+
task_info = future_to_task[future] # Get original task info
|
| 403 |
+
row_idx = task_info["row_idx"]
|
| 404 |
+
slot_idx = task_info["slot_idx"]
|
| 405 |
+
result = None # Define result scope
|
| 406 |
+
|
| 407 |
+
try:
|
| 408 |
+
result = future.result()
|
| 409 |
+
# --- Direct Update and Checkpointing Logic ---
|
| 410 |
+
response_text_key = f"response_text_{slot_idx}"
|
| 411 |
+
response_audio_key = f"response_audio_path_{slot_idx}"
|
| 412 |
+
|
| 413 |
+
if 0 <= row_idx < len(updated_data):
|
| 414 |
+
updated_data[row_idx][response_text_key] = result["response_text"]
|
| 415 |
+
updated_data[row_idx][response_audio_key] = result["saved_audio_path"]
|
| 416 |
+
if result["saved_audio_path"] is None or "[ERROR" in result["response_text"]: # Check for error marker
|
| 417 |
+
tasks_failed += 1
|
| 418 |
+
else:
|
| 419 |
+
print(f"Warning: Invalid row index {row_idx} encountered during result merge. Skipping update.")
|
| 420 |
+
tasks_failed += 1 # Count as failed if index is bad
|
| 421 |
+
|
| 422 |
+
tasks_completed += 1
|
| 423 |
+
completed_since_last_save += 1 # Increment checkpoint counter
|
| 424 |
+
|
| 425 |
+
# Check if it's time to save a checkpoint
|
| 426 |
+
if completed_since_last_save >= CHECKPOINT_INTERVAL:
|
| 427 |
+
save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
|
| 428 |
+
completed_since_last_save = 0 # Reset counter
|
| 429 |
+
|
| 430 |
+
except Exception as exc: # Catch exceptions raised *by* the future/worker if not handled inside
|
| 431 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): GPT-4o Task generated an unhandled exception: {exc}")
|
| 432 |
+
print(traceback.format_exc())
|
| 433 |
+
# Attempt to record error in the main data structure
|
| 434 |
+
response_text_key = f"response_text_{slot_idx}"
|
| 435 |
+
response_audio_key = f"response_audio_path_{slot_idx}"
|
| 436 |
+
if 0 <= row_idx < len(updated_data):
|
| 437 |
+
updated_data[row_idx][response_text_key] = f"[ERROR: Worker Crash - {exc}]"
|
| 438 |
+
updated_data[row_idx][response_audio_key] = None
|
| 439 |
+
else:
|
| 440 |
+
print(f"Warning: Invalid row index {row_idx} encountered during exception handling merge.")
|
| 441 |
+
|
| 442 |
+
tasks_failed += 1
|
| 443 |
+
tasks_completed += 1 # Count as completed (though failed)
|
| 444 |
+
completed_since_last_save += 1 # Also increment for checkpointing
|
| 445 |
+
|
| 446 |
+
# Check if it's time to save a checkpoint even after an error
|
| 447 |
+
if completed_since_last_save >= CHECKPOINT_INTERVAL:
|
| 448 |
+
save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
|
| 449 |
+
completed_since_last_save = 0 # Reset counter
|
| 450 |
+
|
| 451 |
+
end_total_time = time.time()
|
| 452 |
+
print("\n--- GPT-4o Processing Complete ---")
|
| 453 |
+
print(f"Total GPT-4o tasks processed: {tasks_completed} (Succeeded: {tasks_completed - tasks_failed}, Failed: {tasks_failed})")
|
| 454 |
+
print(f"Total GPT-4o processing time: {(end_total_time - start_total_time)/60:.2f} minutes")
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
# --- Final Save ---
|
| 458 |
+
# Save one last time to ensure any remaining processed items (< CHECKPOINT_INTERVAL) are saved
|
| 459 |
+
print("\nPerforming final save...")
|
| 460 |
+
save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
|
| 461 |
+
|
| 462 |
+
print("\nScript finished.")
|
| 463 |
+
|
| 464 |
+
# --- (Removed the old merging and saving logic as it's now handled by save_checkpoint) ---
|
r1-a/response_generation/gpt4o_mini.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import base64
|
| 4 |
+
import uuid
|
| 5 |
+
import time
|
| 6 |
+
import re
|
| 7 |
+
import random
|
| 8 |
+
import concurrent.futures
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import threading
|
| 11 |
+
import traceback # For detailed error logging
|
| 12 |
+
|
| 13 |
+
import requests # Use requests library for HTTP calls
|
| 14 |
+
# Make sure numpy is imported if needed for potential fallback serialization
|
| 15 |
+
import numpy as np
|
| 16 |
+
from datasets import load_from_disk, Dataset
|
| 17 |
+
from dotenv import load_dotenv
|
| 18 |
+
|
| 19 |
+
# --- Configuration ---
|
| 20 |
+
load_dotenv()
|
| 21 |
+
|
| 22 |
+
# 1. API Client Setup
|
| 23 |
+
GPT4O_MODEL_NAME = "freeze_omni" # How it's identified in your dataset's model columns
|
| 24 |
+
API_MODEL_NAME = "gpt-4o-mini-audio-preview" # Actual model name for the API call
|
| 25 |
+
API_ENDPOINT = "https://api2.aigcbest.top/v1/chat/completions"
|
| 26 |
+
try:
|
| 27 |
+
# Assuming a single key for this service based on the original script
|
| 28 |
+
API_TOKEN = "sk-J6Y4OBCEG0D75suEZoj22eFmiwO1DHzLCqvt4bRmyZRTMlTa"
|
| 29 |
+
if not API_TOKEN:
|
| 30 |
+
raise ValueError("AIGCBEST_API_KEY environment variable not set.")
|
| 31 |
+
print("AIGCBEST API Key loaded.")
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"FATAL: Error getting API Key: {e}")
|
| 34 |
+
exit(1)
|
| 35 |
+
|
| 36 |
+
# 2. Dataset Paths
|
| 37 |
+
INPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_sampling_tasks"
|
| 38 |
+
OUTPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_tasks_with_gpt4o_mini"
|
| 39 |
+
|
| 40 |
+
# 3. Output Audio Configuration
|
| 41 |
+
OUTPUT_AUDIO_ROOT_DIR = "/root/autodl-tmp/audio-r1/r1-a/generated_audio/gpt4o_mini"
|
| 42 |
+
OUTPUT_AUDIO_FORMAT = "wav" # API will be requested to return wav
|
| 43 |
+
AVAILABLE_VOICES = ['alloy', 'ash', 'ballad', 'coral', 'echo', 'sage', 'shimmer', 'verse']
|
| 44 |
+
|
| 45 |
+
# 4. API Call Settings
|
| 46 |
+
API_TIMEOUT = 240
|
| 47 |
+
API_RETRY_DELAY = 5
|
| 48 |
+
API_MAX_RETRIES = 3 # Max attempts *for the task*
|
| 49 |
+
MAX_WORKERS = 8 # Adjust based on API rate limits and system resources
|
| 50 |
+
|
| 51 |
+
# 5. Checkpoint Saving Configuration # <-- NEW
|
| 52 |
+
CHECKPOINT_INTERVAL = 500 # Save every 500 completed tasks
|
| 53 |
+
|
| 54 |
+
# --- Helper Functions (encode_audio_base64 and parse_ultra_history remain the same) ---
|
| 55 |
+
|
| 56 |
+
def encode_audio_base64(audio_path):
|
| 57 |
+
if not audio_path or not os.path.exists(audio_path):
|
| 58 |
+
print(f"Warning: Input audio file not found or path is empty: {audio_path}")
|
| 59 |
+
return None
|
| 60 |
+
try:
|
| 61 |
+
with open(audio_path, "rb") as audio_file:
|
| 62 |
+
return base64.b64encode(audio_file.read()).decode("utf-8")
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"Error encoding audio file {audio_path}: {e}")
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
def parse_ultra_history(history_str):
|
| 68 |
+
messages = []
|
| 69 |
+
pattern = re.compile(r"\[(USER|ASSISTANT)\]\s*([\s\S]*?)(?=\s*\[(?:USER|ASSISTANT)\]|$)")
|
| 70 |
+
matches = pattern.findall(history_str)
|
| 71 |
+
if not matches:
|
| 72 |
+
return []
|
| 73 |
+
for role_tag, content in matches:
|
| 74 |
+
role = role_tag.lower()
|
| 75 |
+
cleaned_content = content.strip()
|
| 76 |
+
if cleaned_content:
|
| 77 |
+
messages.append({"role": role, "content": cleaned_content})
|
| 78 |
+
return messages
|
| 79 |
+
|
| 80 |
+
# --- Modified API Call Worker Function for GPT-4o (Reduced Prints) ---
|
| 81 |
+
def call_gpt4o_api_worker(task_info):
|
| 82 |
+
"""
|
| 83 |
+
Worker function to call the custom GPT-4o API for a single task.
|
| 84 |
+
"""
|
| 85 |
+
row_idx = task_info["row_idx"]
|
| 86 |
+
slot_idx = task_info["slot_idx"]
|
| 87 |
+
history_messages = task_info["history_messages"]
|
| 88 |
+
prompt_text = task_info["prompt_text"]
|
| 89 |
+
question_text = task_info["question_text"]
|
| 90 |
+
question_audio_path = task_info["question_audio_path"]
|
| 91 |
+
output_audio_filepath = task_info["output_audio_filepath"]
|
| 92 |
+
|
| 93 |
+
retries = 0
|
| 94 |
+
headers = {
|
| 95 |
+
'Accept': 'application/json',
|
| 96 |
+
'Authorization': f'Bearer {API_TOKEN}', # Use the single loaded token
|
| 97 |
+
'Content-Type': 'application/json'
|
| 98 |
+
}
|
| 99 |
+
selected_voice = random.choice(AVAILABLE_VOICES)
|
| 100 |
+
# print(f" [Thread-{threading.get_ident()}] Processing Row {row_idx}, Slot {slot_idx} (GPT4o Voice: {selected_voice})") # Optional log
|
| 101 |
+
|
| 102 |
+
while retries < API_MAX_RETRIES:
|
| 103 |
+
try:
|
| 104 |
+
# 1. Prepare Input Audio
|
| 105 |
+
base64_audio_data = encode_audio_base64(question_audio_path)
|
| 106 |
+
if not base64_audio_data:
|
| 107 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Skipping GPT4o API call - missing input audio.")
|
| 108 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: Missing input audio]", "saved_audio_path": None}
|
| 109 |
+
|
| 110 |
+
input_audio_format = os.path.splitext(question_audio_path)[1].lstrip('.') or 'wav'
|
| 111 |
+
|
| 112 |
+
# 2. Construct User Message Content
|
| 113 |
+
combined_text = f"{prompt_text}"
|
| 114 |
+
user_content_list = [
|
| 115 |
+
{"type": "text", "text": combined_text},
|
| 116 |
+
{"type": "input_audio", "input_audio": {"data": base64_audio_data, "format": input_audio_format}}
|
| 117 |
+
]
|
| 118 |
+
messages = history_messages + [{"role": "user", "content": user_content_list}]
|
| 119 |
+
|
| 120 |
+
# 4. Construct Payload
|
| 121 |
+
payload = {
|
| 122 |
+
"model": API_MODEL_NAME,
|
| 123 |
+
"modalities": ["text", "audio"],
|
| 124 |
+
"audio": {"voice": selected_voice, "format": OUTPUT_AUDIO_FORMAT},
|
| 125 |
+
"messages": messages
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
# 5. Make API Call
|
| 129 |
+
response = requests.post(
|
| 130 |
+
API_ENDPOINT,
|
| 131 |
+
headers=headers,
|
| 132 |
+
json=payload,
|
| 133 |
+
timeout=API_TIMEOUT
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# 6. Process Response
|
| 137 |
+
if response.status_code == 200:
|
| 138 |
+
try:
|
| 139 |
+
response_data = response.json()
|
| 140 |
+
# Make parsing more robust
|
| 141 |
+
choices = response_data.get('choices')
|
| 142 |
+
if not choices or not isinstance(choices, list) or len(choices) == 0:
|
| 143 |
+
raise ValueError("Invalid or empty 'choices' field in response.")
|
| 144 |
+
|
| 145 |
+
message_content = choices[0].get('message', {})
|
| 146 |
+
if not message_content:
|
| 147 |
+
raise ValueError("Missing 'message' field in the first choice.")
|
| 148 |
+
|
| 149 |
+
audio_info = message_content.get('audio', {})
|
| 150 |
+
if not isinstance(audio_info, dict): audio_info = {} # Handle case where audio might be null or not a dict
|
| 151 |
+
|
| 152 |
+
audio_base64_string = audio_info.get('data', '')
|
| 153 |
+
# Try getting text from 'content' if 'transcript' is missing/empty in 'audio'
|
| 154 |
+
collected_text = audio_info.get('transcript', '').strip()
|
| 155 |
+
if not collected_text:
|
| 156 |
+
text_content_list = message_content.get('content', [])
|
| 157 |
+
if isinstance(text_content_list, list):
|
| 158 |
+
for item in text_content_list:
|
| 159 |
+
if isinstance(item, dict) and item.get("type") == "text":
|
| 160 |
+
collected_text = item.get("text", "").strip()
|
| 161 |
+
break # Take the first text part found
|
| 162 |
+
# Still no text? Try the top-level message content directly if it's a string
|
| 163 |
+
elif isinstance(message_content.get('content'), str):
|
| 164 |
+
collected_text = message_content['content'].strip()
|
| 165 |
+
|
| 166 |
+
if not collected_text: print(f"Warning (Row {row_idx}, Slot {slot_idx}): No text content found after checking multiple fields.")
|
| 167 |
+
if not audio_base64_string: print(f"Warning (Row {row_idx}, Slot {slot_idx}): No audio data found.")
|
| 168 |
+
|
| 169 |
+
saved_audio_path = None
|
| 170 |
+
if audio_base64_string:
|
| 171 |
+
try:
|
| 172 |
+
wav_bytes = base64.b64decode(audio_base64_string)
|
| 173 |
+
if len(wav_bytes) == 0:
|
| 174 |
+
print(f"Warning (Row {row_idx}, Slot {slot_idx}): Decoded audio bytes are empty.")
|
| 175 |
+
else:
|
| 176 |
+
os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True)
|
| 177 |
+
with open(output_audio_filepath, "wb") as f:
|
| 178 |
+
f.write(wav_bytes)
|
| 179 |
+
saved_audio_path = output_audio_filepath
|
| 180 |
+
# print(f" Audio saved to: {output_audio_filepath}") # Less verbose log
|
| 181 |
+
except base64.binascii.Error as b64_err:
|
| 182 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Decoding base64 audio data failed: {b64_err}")
|
| 183 |
+
except Exception as e:
|
| 184 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Saving audio file failed: {e}")
|
| 185 |
+
|
| 186 |
+
# TASK SUCCEEDED (even if audio saving failed, text might be valid)
|
| 187 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text, "saved_audio_path": saved_audio_path}
|
| 188 |
+
|
| 189 |
+
except (json.JSONDecodeError, IndexError, KeyError, TypeError, ValueError) as e:
|
| 190 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Parsing successful API response failed: {type(e).__name__} - {e}")
|
| 191 |
+
print(f" Response Text (start): {response.text[:500]}...")
|
| 192 |
+
retries += 1
|
| 193 |
+
print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
|
| 194 |
+
time.sleep(API_RETRY_DELAY)
|
| 195 |
+
continue
|
| 196 |
+
except Exception as e: # Catch-all for unexpected errors during processing
|
| 197 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Unexpected error processing response: {e}")
|
| 198 |
+
print(traceback.format_exc())
|
| 199 |
+
retries += 1
|
| 200 |
+
print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
|
| 201 |
+
time.sleep(API_RETRY_DELAY)
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
else: # Handle non-200 status codes
|
| 205 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): API returned status {response.status_code}. Response: {response.text[:500]}...")
|
| 206 |
+
retries += 1
|
| 207 |
+
if retries < API_MAX_RETRIES:
|
| 208 |
+
print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
|
| 209 |
+
time.sleep(API_RETRY_DELAY)
|
| 210 |
+
continue # Go to next iteration of while loop
|
| 211 |
+
else:
|
| 212 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after status {response.status_code}.")
|
| 213 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": f"[API ERROR: Status {response.status_code}]", "saved_audio_path": None}
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
except requests.exceptions.Timeout:
|
| 217 |
+
retries += 1
|
| 218 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): API Call Attempt {retries}/{API_MAX_RETRIES} timed out after {API_TIMEOUT}s.")
|
| 219 |
+
if retries < API_MAX_RETRIES:
|
| 220 |
+
time.sleep(API_RETRY_DELAY)
|
| 221 |
+
continue
|
| 222 |
+
else:
|
| 223 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after timeout.")
|
| 224 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Timeout]", "saved_audio_path": None}
|
| 225 |
+
|
| 226 |
+
except requests.exceptions.RequestException as e:
|
| 227 |
+
retries += 1
|
| 228 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Network/Request Error Attempt {retries}/{API_MAX_RETRIES}: {e}")
|
| 229 |
+
if retries < API_MAX_RETRIES:
|
| 230 |
+
time.sleep(API_RETRY_DELAY)
|
| 231 |
+
continue
|
| 232 |
+
else:
|
| 233 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after network error.")
|
| 234 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Network Error]", "saved_audio_path": None}
|
| 235 |
+
|
| 236 |
+
except Exception as e:
|
| 237 |
+
retries += 1
|
| 238 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Unexpected Error in Worker Loop Attempt {retries}/{API_MAX_RETRIES}: {type(e).__name__} - {e}")
|
| 239 |
+
print(traceback.format_exc())
|
| 240 |
+
if retries < API_MAX_RETRIES:
|
| 241 |
+
time.sleep(API_RETRY_DELAY)
|
| 242 |
+
continue
|
| 243 |
+
else:
|
| 244 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after unexpected error.")
|
| 245 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Unexpected Worker Error]", "saved_audio_path": None}
|
| 246 |
+
|
| 247 |
+
# If loop finishes without returning, max retries were hit
|
| 248 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Task failed after {API_MAX_RETRIES} attempts.")
|
| 249 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Max retries reached]", "saved_audio_path": None}
|
| 250 |
+
|
| 251 |
+
# --- Checkpoint Saving Function --- # <-- NEW (Copied from previous response)
|
| 252 |
+
def save_checkpoint(data_to_save, output_dir, dataset_features):
|
| 253 |
+
"""Saves the current state of the data to disk."""
|
| 254 |
+
if not data_to_save:
|
| 255 |
+
print("Checkpoint: No data available to save.")
|
| 256 |
+
return
|
| 257 |
+
|
| 258 |
+
# Ensure output directory exists before saving
|
| 259 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 260 |
+
|
| 261 |
+
print(f"\nCheckpoint: Saving {len(data_to_save)} rows to {output_dir}...")
|
| 262 |
+
try:
|
| 263 |
+
# Convert list of dicts back to Dataset object
|
| 264 |
+
checkpoint_dataset = Dataset.from_list(data_to_save, features=dataset_features)
|
| 265 |
+
checkpoint_dataset.save_to_disk(output_dir)
|
| 266 |
+
print(f"Checkpoint: Saved successfully to {output_dir}")
|
| 267 |
+
except Exception as ckpt_save_e:
|
| 268 |
+
print(f"Error saving checkpoint dataset using datasets lib: {ckpt_save_e}")
|
| 269 |
+
# Fallback to JSON Lines (optional, but good practice)
|
| 270 |
+
output_jsonl_path = os.path.join(output_dir, "checkpoint_data.jsonl") # Save inside the dir
|
| 271 |
+
print(f"Attempting to save checkpoint as JSON lines to {output_jsonl_path}...")
|
| 272 |
+
try:
|
| 273 |
+
with open(output_jsonl_path, 'w', encoding='utf-8') as f:
|
| 274 |
+
for item in data_to_save:
|
| 275 |
+
# Basic serialization handling for common types like numpy arrays
|
| 276 |
+
serializable_item = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in item.items()}
|
| 277 |
+
f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n')
|
| 278 |
+
print(f"Checkpoint: Fallback save successful to {output_jsonl_path}")
|
| 279 |
+
except Exception as json_save_e:
|
| 280 |
+
print(f"Error saving checkpoint as JSON lines: {json_save_e}")
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# --- Main Processing Logic ---
|
| 284 |
+
|
| 285 |
+
print("Checking for existing checkpoint/output dataset...")
|
| 286 |
+
dataset = None
|
| 287 |
+
original_features = None # Initialize
|
| 288 |
+
|
| 289 |
+
try:
|
| 290 |
+
# 检查输出目录是否存在,并且看起来像一个 Hugging Face datasets 目录
|
| 291 |
+
# (dataset_info.json 或 state.json 是常见的指示文件)
|
| 292 |
+
potential_checkpoint_info = os.path.join(OUTPUT_DATASET_DIR, "dataset_info.json")
|
| 293 |
+
potential_checkpoint_state = os.path.join(OUTPUT_DATASET_DIR, "state.json")
|
| 294 |
+
|
| 295 |
+
if os.path.exists(OUTPUT_DATASET_DIR) and \
|
| 296 |
+
(os.path.exists(potential_checkpoint_info) or os.path.exists(potential_checkpoint_state)):
|
| 297 |
+
|
| 298 |
+
print(f"Attempting to load existing data from output directory: {OUTPUT_DATASET_DIR}")
|
| 299 |
+
try:
|
| 300 |
+
dataset = load_from_disk(OUTPUT_DATASET_DIR)
|
| 301 |
+
original_features = dataset.features # 获取已保存数据集的特征
|
| 302 |
+
print(f"Successfully resumed from {OUTPUT_DATASET_DIR}. Loaded {len(dataset)} rows.")
|
| 303 |
+
except Exception as load_ckpt_e:
|
| 304 |
+
print(f"Warning: Failed to load from {OUTPUT_DATASET_DIR}: {load_ckpt_e}")
|
| 305 |
+
print("Falling back to loading original input dataset.")
|
| 306 |
+
dataset = None # Ensure we proceed to load original if checkpoint load failed
|
| 307 |
+
else:
|
| 308 |
+
print(f"No valid existing data found in {OUTPUT_DATASET_DIR}.")
|
| 309 |
+
# If no checkpoint, ensure dataset is None so original loading happens
|
| 310 |
+
|
| 311 |
+
# 如果 dataset 仍然是 None (因为没有找到 checkpoint 或加载失败)
|
| 312 |
+
if dataset is None:
|
| 313 |
+
print(f"Loading original dataset from {INPUT_DATASET_DIR}...")
|
| 314 |
+
dataset = load_from_disk(INPUT_DATASET_DIR)
|
| 315 |
+
original_features = dataset.features
|
| 316 |
+
print(f"Original dataset loaded successfully with {len(dataset)} rows.")
|
| 317 |
+
|
| 318 |
+
except Exception as initial_load_e:
|
| 319 |
+
print(f"FATAL: Error during initial dataset loading (original or checkpoint): {initial_load_e}")
|
| 320 |
+
print(traceback.format_exc()) # 打印详细错误
|
| 321 |
+
exit(1)
|
| 322 |
+
|
| 323 |
+
os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True)
|
| 324 |
+
|
| 325 |
+
# --- Pre-calculation Step for GPT-4o ---
|
| 326 |
+
print("Pre-calculating GPT-4o tasks...")
|
| 327 |
+
tasks_to_process = []
|
| 328 |
+
# Use a list of dictionaries, which is mutable and easier for direct updates
|
| 329 |
+
updated_data = list(dataset)
|
| 330 |
+
|
| 331 |
+
for idx, row in enumerate(tqdm(updated_data, desc="Scanning dataset for GPT-4o tasks")):
|
| 332 |
+
for i in range(1, 4):
|
| 333 |
+
model_key = f"model_{i}"
|
| 334 |
+
response_text_key = f"response_text_{i}"
|
| 335 |
+
prompt_text_key = f"prompt_text_{i}"
|
| 336 |
+
response_audio_key = f"response_audio_path_{i}" # Key for storing the *new* audio path
|
| 337 |
+
|
| 338 |
+
model_assigned = row.get(model_key)
|
| 339 |
+
response_text_exists = row.get(response_text_key) is not None
|
| 340 |
+
|
| 341 |
+
# Check for the specific model name used in the dataset
|
| 342 |
+
if model_assigned == GPT4O_MODEL_NAME and not response_text_exists:
|
| 343 |
+
question_audio_path = row.get('question_audio')
|
| 344 |
+
if not question_audio_path or not os.path.exists(question_audio_path): # Check path validity here
|
| 345 |
+
print(f"Warning (Row {idx}, Slot {i}): Skipping GPT-4o task - Missing or invalid 'question_audio' path: {question_audio_path}")
|
| 346 |
+
# Pre-fill error? Let's just skip task creation for now.
|
| 347 |
+
# If needed: updated_data[idx][response_text_key] = "[ERROR: Missing input audio]"
|
| 348 |
+
# If needed: updated_data[idx][response_audio_key] = None
|
| 349 |
+
continue # Skip this task
|
| 350 |
+
|
| 351 |
+
metadata_str = row.get('metadata', "{}")
|
| 352 |
+
source_dataset = row.get('source_dataset')
|
| 353 |
+
metadata = {}
|
| 354 |
+
try:
|
| 355 |
+
if metadata_str and isinstance(metadata_str, str): metadata = json.loads(metadata_str)
|
| 356 |
+
elif isinstance(metadata_str, dict): metadata = metadata_str
|
| 357 |
+
except json.JSONDecodeError: pass
|
| 358 |
+
|
| 359 |
+
history_messages = []
|
| 360 |
+
if source_dataset == 'ultra':
|
| 361 |
+
history_str = metadata.get('history', '')
|
| 362 |
+
if history_str: history_messages = parse_ultra_history(history_str)
|
| 363 |
+
|
| 364 |
+
unique_id = str(uuid.uuid4()).replace("-", "")
|
| 365 |
+
output_audio_filename = f"gpt4o_r{idx}_s{i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}"
|
| 366 |
+
output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename)
|
| 367 |
+
|
| 368 |
+
task_info = {
|
| 369 |
+
"row_idx": idx,
|
| 370 |
+
"slot_idx": i,
|
| 371 |
+
# No API key needed here as it's global/single
|
| 372 |
+
"history_messages": history_messages,
|
| 373 |
+
"prompt_text": row.get(prompt_text_key, ""),
|
| 374 |
+
"question_text": row.get('question_text', ""), # Pass question text
|
| 375 |
+
"question_audio_path": question_audio_path,
|
| 376 |
+
"output_audio_filepath": output_audio_filepath,
|
| 377 |
+
}
|
| 378 |
+
tasks_to_process.append(task_info)
|
| 379 |
+
# Decide if you process all slots or just the first unfilled one
|
| 380 |
+
# break # Uncomment this line if you only want the *first* unfilled gpt4o slot per row processed
|
| 381 |
+
|
| 382 |
+
total_tasks = len(tasks_to_process)
|
| 383 |
+
if total_tasks == 0:
|
| 384 |
+
print("No GPT-4o tasks found needing processing.")
|
| 385 |
+
exit(0)
|
| 386 |
+
|
| 387 |
+
print(f"Found {total_tasks} GPT-4o tasks to process.")
|
| 388 |
+
|
| 389 |
+
# --- Threaded Execution with Checkpointing for GPT-4o --- # <-- MODIFIED SECTION
|
| 390 |
+
print(f"Starting GPT-4o processing with up to {MAX_WORKERS} worker threads...")
|
| 391 |
+
start_total_time = time.time()
|
| 392 |
+
# results = {} # No longer needed
|
| 393 |
+
tasks_completed = 0
|
| 394 |
+
tasks_failed = 0
|
| 395 |
+
completed_since_last_save = 0 # <-- Counter for checkpointing
|
| 396 |
+
|
| 397 |
+
# Use context manager for ThreadPoolExecutor
|
| 398 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
| 399 |
+
future_to_task = {executor.submit(call_gpt4o_api_worker, task): task for task in tasks_to_process}
|
| 400 |
+
|
| 401 |
+
for future in tqdm(concurrent.futures.as_completed(future_to_task), total=total_tasks, desc="Processing GPT-4o tasks"):
|
| 402 |
+
task_info = future_to_task[future] # Get original task info
|
| 403 |
+
row_idx = task_info["row_idx"]
|
| 404 |
+
slot_idx = task_info["slot_idx"]
|
| 405 |
+
result = None # Define result scope
|
| 406 |
+
|
| 407 |
+
try:
|
| 408 |
+
result = future.result()
|
| 409 |
+
# --- Direct Update and Checkpointing Logic ---
|
| 410 |
+
response_text_key = f"response_text_{slot_idx}"
|
| 411 |
+
response_audio_key = f"response_audio_path_{slot_idx}"
|
| 412 |
+
|
| 413 |
+
if 0 <= row_idx < len(updated_data):
|
| 414 |
+
updated_data[row_idx][response_text_key] = result["response_text"]
|
| 415 |
+
updated_data[row_idx][response_audio_key] = result["saved_audio_path"]
|
| 416 |
+
if result["saved_audio_path"] is None or "[ERROR" in result["response_text"]: # Check for error marker
|
| 417 |
+
tasks_failed += 1
|
| 418 |
+
else:
|
| 419 |
+
print(f"Warning: Invalid row index {row_idx} encountered during result merge. Skipping update.")
|
| 420 |
+
tasks_failed += 1 # Count as failed if index is bad
|
| 421 |
+
|
| 422 |
+
tasks_completed += 1
|
| 423 |
+
completed_since_last_save += 1 # Increment checkpoint counter
|
| 424 |
+
|
| 425 |
+
# Check if it's time to save a checkpoint
|
| 426 |
+
if completed_since_last_save >= CHECKPOINT_INTERVAL:
|
| 427 |
+
save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
|
| 428 |
+
completed_since_last_save = 0 # Reset counter
|
| 429 |
+
|
| 430 |
+
except Exception as exc: # Catch exceptions raised *by* the future/worker if not handled inside
|
| 431 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): GPT-4o Task generated an unhandled exception: {exc}")
|
| 432 |
+
print(traceback.format_exc())
|
| 433 |
+
# Attempt to record error in the main data structure
|
| 434 |
+
response_text_key = f"response_text_{slot_idx}"
|
| 435 |
+
response_audio_key = f"response_audio_path_{slot_idx}"
|
| 436 |
+
if 0 <= row_idx < len(updated_data):
|
| 437 |
+
updated_data[row_idx][response_text_key] = f"[ERROR: Worker Crash - {exc}]"
|
| 438 |
+
updated_data[row_idx][response_audio_key] = None
|
| 439 |
+
else:
|
| 440 |
+
print(f"Warning: Invalid row index {row_idx} encountered during exception handling merge.")
|
| 441 |
+
|
| 442 |
+
tasks_failed += 1
|
| 443 |
+
tasks_completed += 1 # Count as completed (though failed)
|
| 444 |
+
completed_since_last_save += 1 # Also increment for checkpointing
|
| 445 |
+
|
| 446 |
+
# Check if it's time to save a checkpoint even after an error
|
| 447 |
+
if completed_since_last_save >= CHECKPOINT_INTERVAL:
|
| 448 |
+
save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
|
| 449 |
+
completed_since_last_save = 0 # Reset counter
|
| 450 |
+
|
| 451 |
+
end_total_time = time.time()
|
| 452 |
+
print("\n--- GPT-4o Processing Complete ---")
|
| 453 |
+
print(f"Total GPT-4o tasks processed: {tasks_completed} (Succeeded: {tasks_completed - tasks_failed}, Failed: {tasks_failed})")
|
| 454 |
+
print(f"Total GPT-4o processing time: {(end_total_time - start_total_time)/60:.2f} minutes")
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
# --- Final Save ---
|
| 458 |
+
# Save one last time to ensure any remaining processed items (< CHECKPOINT_INTERVAL) are saved
|
| 459 |
+
print("\nPerforming final save...")
|
| 460 |
+
save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
|
| 461 |
+
|
| 462 |
+
print("\nScript finished.")
|
| 463 |
+
|
| 464 |
+
# --- (Removed the old merging and saving logic as it's now handled by save_checkpoint) ---
|
r1-a/response_generation/gpt5o_retry.py
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import base64
|
| 4 |
+
import uuid
|
| 5 |
+
import time
|
| 6 |
+
import re
|
| 7 |
+
import random
|
| 8 |
+
import concurrent.futures
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import threading
|
| 11 |
+
import traceback # For detailed error logging
|
| 12 |
+
|
| 13 |
+
import requests # Use requests library for HTTP calls
|
| 14 |
+
import numpy as np # Import numpy for potential fallback serialization
|
| 15 |
+
from datasets import load_from_disk, Dataset
|
| 16 |
+
from dotenv import load_dotenv
|
| 17 |
+
|
| 18 |
+
# --- Configuration ---
|
| 19 |
+
load_dotenv()
|
| 20 |
+
|
| 21 |
+
# --- !!! KEY CONFIGURATION FOR RETRY SCRIPT !!! ---
|
| 22 |
+
|
| 23 |
+
# 1. Identify the model you are retrying
|
| 24 |
+
TARGET_MODEL_NAME = "gpt4o" # Or "qwen_omni" if retrying Qwen
|
| 25 |
+
|
| 26 |
+
# 2. Set the INPUT/OUTPUT dataset directory to the PREVIOUS script's OUTPUT directory
|
| 27 |
+
# This is where the partially processed data (with errors) resides.
|
| 28 |
+
# The script will LOAD from here and SAVE back to here.
|
| 29 |
+
DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_tasks_with_gpt4o" # Adjust if needed
|
| 30 |
+
|
| 31 |
+
# 3. Set the audio output directory (can be the same as before)
|
| 32 |
+
OUTPUT_AUDIO_ROOT_DIR = "/root/autodl-tmp/audio-r1/r1-a/generated_audio/gpt4o_2" # Adjust if needed
|
| 33 |
+
|
| 34 |
+
# 4. API Configuration (Specific to the model being retried)
|
| 35 |
+
API_MODEL_NAME = "gpt-4o-audio-preview" # Actual model name for the API call
|
| 36 |
+
API_ENDPOINT = "https://api2.aigcbest.top/v1/chat/completions"
|
| 37 |
+
try:
|
| 38 |
+
API_TOKEN = "sk-D6jMssP7AZw3ZU6LEZaljdNMO1zif6wzef6XVh4kOgZAhQzI" # Use the correct key
|
| 39 |
+
if not API_TOKEN:
|
| 40 |
+
raise ValueError("API_TOKEN environment variable not set.")
|
| 41 |
+
print(f"{TARGET_MODEL_NAME} API Key loaded.")
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(f"FATAL: Error getting API Key: {e}")
|
| 44 |
+
exit(1)
|
| 45 |
+
|
| 46 |
+
# 5. Output Audio Configuration (Specific to the model being retried)
|
| 47 |
+
OUTPUT_AUDIO_FORMAT = "wav"
|
| 48 |
+
AVAILABLE_VOICES = ['alloy', 'ash', 'ballad', 'coral', 'echo', 'sage', 'shimmer', 'verse'] # GPT-4o voices
|
| 49 |
+
|
| 50 |
+
# 6. API Call Settings
|
| 51 |
+
API_TIMEOUT = 120
|
| 52 |
+
API_RETRY_DELAY = 5
|
| 53 |
+
API_MAX_RETRIES = 3
|
| 54 |
+
MAX_WORKERS = 8
|
| 55 |
+
|
| 56 |
+
# 7. Checkpoint Saving Configuration
|
| 57 |
+
CHECKPOINT_INTERVAL = 50 # Save every 500 *retried* tasks completed
|
| 58 |
+
|
| 59 |
+
# --- Error Markers to Look For ---
|
| 60 |
+
# These prefixes indicate a failed task that needs retrying
|
| 61 |
+
ERROR_MARKERS = ("[API ERROR", "[ERROR")
|
| 62 |
+
|
| 63 |
+
# --- Helper Functions (encode_audio_base64, parse_ultra_history - unchanged) ---
|
| 64 |
+
|
| 65 |
+
def encode_audio_base64(audio_path):
|
| 66 |
+
if not audio_path or not os.path.exists(audio_path):
|
| 67 |
+
print(f"Warning: Input audio file not found or path is empty: {audio_path}")
|
| 68 |
+
return None
|
| 69 |
+
try:
|
| 70 |
+
with open(audio_path, "rb") as audio_file:
|
| 71 |
+
return base64.b64encode(audio_file.read()).decode("utf-8")
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f"Error encoding audio file {audio_path}: {e}")
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
def parse_ultra_history(history_str):
|
| 77 |
+
messages = []
|
| 78 |
+
pattern = re.compile(r"\[(USER|ASSISTANT)\]\s*([\s\S]*?)(?=\s*\[(?:USER|ASSISTANT)\]|$)")
|
| 79 |
+
matches = pattern.findall(history_str)
|
| 80 |
+
if not matches:
|
| 81 |
+
return []
|
| 82 |
+
for role_tag, content in matches:
|
| 83 |
+
role = role_tag.lower()
|
| 84 |
+
cleaned_content = content.strip()
|
| 85 |
+
if cleaned_content:
|
| 86 |
+
messages.append({"role": role, "content": cleaned_content})
|
| 87 |
+
return messages
|
| 88 |
+
|
| 89 |
+
# --- API Call Worker Function (Use the correct one for the target model - GPT-4o version shown) ---
|
| 90 |
+
# --- (This function call_gpt4o_api_worker is copied directly from the previous script) ---
|
| 91 |
+
def call_gpt4o_api_worker(task_info):
|
| 92 |
+
"""
|
| 93 |
+
Worker function to call the custom GPT-4o API for a single task.
|
| 94 |
+
(Identical to the function in the previous script)
|
| 95 |
+
"""
|
| 96 |
+
row_idx = task_info["row_idx"]
|
| 97 |
+
slot_idx = task_info["slot_idx"]
|
| 98 |
+
history_messages = task_info["history_messages"]
|
| 99 |
+
prompt_text = task_info["prompt_text"]
|
| 100 |
+
question_text = task_info["question_text"]
|
| 101 |
+
question_audio_path = task_info["question_audio_path"]
|
| 102 |
+
output_audio_filepath = task_info["output_audio_filepath"]
|
| 103 |
+
|
| 104 |
+
retries = 0
|
| 105 |
+
headers = {
|
| 106 |
+
'Accept': 'application/json',
|
| 107 |
+
'Authorization': f'Bearer {API_TOKEN}', # Use the single loaded token
|
| 108 |
+
'Content-Type': 'application/json'
|
| 109 |
+
}
|
| 110 |
+
selected_voice = random.choice(AVAILABLE_VOICES)
|
| 111 |
+
|
| 112 |
+
while retries < API_MAX_RETRIES:
|
| 113 |
+
try:
|
| 114 |
+
# 1. Prepare Input Audio
|
| 115 |
+
base64_audio_data = encode_audio_base64(question_audio_path)
|
| 116 |
+
if not base64_audio_data:
|
| 117 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Skipping GPT4o API call - missing input audio.")
|
| 118 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: Missing input audio]", "saved_audio_path": None}
|
| 119 |
+
|
| 120 |
+
input_audio_format = os.path.splitext(question_audio_path)[1].lstrip('.') or 'wav'
|
| 121 |
+
|
| 122 |
+
# 2. Construct User Message Content
|
| 123 |
+
combined_text = f"{prompt_text}"
|
| 124 |
+
user_content_list = [
|
| 125 |
+
{"type": "text", "text": combined_text},
|
| 126 |
+
{"type": "input_audio", "input_audio": {"data": base64_audio_data, "format": input_audio_format}}
|
| 127 |
+
]
|
| 128 |
+
messages = history_messages + [{"role": "user", "content": user_content_list}]
|
| 129 |
+
|
| 130 |
+
# 4. Construct Payload
|
| 131 |
+
payload = {
|
| 132 |
+
"model": API_MODEL_NAME,
|
| 133 |
+
"modalities": ["text", "audio"],
|
| 134 |
+
"audio": {"voice": selected_voice, "format": OUTPUT_AUDIO_FORMAT},
|
| 135 |
+
"messages": messages
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
# 5. Make API Call
|
| 139 |
+
response = requests.post(
|
| 140 |
+
API_ENDPOINT,
|
| 141 |
+
headers=headers,
|
| 142 |
+
json=payload,
|
| 143 |
+
timeout=API_TIMEOUT
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# 6. Process Response
|
| 147 |
+
if response.status_code == 200:
|
| 148 |
+
try:
|
| 149 |
+
response_data = response.json()
|
| 150 |
+
choices = response_data.get('choices')
|
| 151 |
+
if not choices or not isinstance(choices, list) or len(choices) == 0:
|
| 152 |
+
raise ValueError("Invalid or empty 'choices' field in response.")
|
| 153 |
+
message_content = choices[0].get('message', {})
|
| 154 |
+
if not message_content:
|
| 155 |
+
raise ValueError("Missing 'message' field in the first choice.")
|
| 156 |
+
audio_info = message_content.get('audio', {})
|
| 157 |
+
if not isinstance(audio_info, dict): audio_info = {}
|
| 158 |
+
|
| 159 |
+
audio_base64_string = audio_info.get('data', '')
|
| 160 |
+
collected_text = audio_info.get('transcript', '').strip()
|
| 161 |
+
if not collected_text:
|
| 162 |
+
text_content_list = message_content.get('content', [])
|
| 163 |
+
if isinstance(text_content_list, list):
|
| 164 |
+
for item in text_content_list:
|
| 165 |
+
if isinstance(item, dict) and item.get("type") == "text":
|
| 166 |
+
collected_text = item.get("text", "").strip()
|
| 167 |
+
break
|
| 168 |
+
elif isinstance(message_content.get('content'), str):
|
| 169 |
+
collected_text = message_content['content'].strip()
|
| 170 |
+
|
| 171 |
+
if not collected_text: print(f"Warning (Row {row_idx}, Slot {slot_idx}): No text content found after checking multiple fields.")
|
| 172 |
+
if not audio_base64_string: print(f"Warning (Row {row_idx}, Slot {slot_idx}): No audio data found.")
|
| 173 |
+
|
| 174 |
+
saved_audio_path = None
|
| 175 |
+
if audio_base64_string:
|
| 176 |
+
try:
|
| 177 |
+
wav_bytes = base64.b64decode(audio_base64_string)
|
| 178 |
+
if len(wav_bytes) == 0:
|
| 179 |
+
print(f"Warning (Row {row_idx}, Slot {slot_idx}): Decoded audio bytes are empty.")
|
| 180 |
+
else:
|
| 181 |
+
os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True)
|
| 182 |
+
with open(output_audio_filepath, "wb") as f:
|
| 183 |
+
f.write(wav_bytes)
|
| 184 |
+
saved_audio_path = output_audio_filepath
|
| 185 |
+
except base64.binascii.Error as b64_err:
|
| 186 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Decoding base64 audio data failed: {b64_err}")
|
| 187 |
+
except Exception as e:
|
| 188 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Saving audio file failed: {e}")
|
| 189 |
+
|
| 190 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text, "saved_audio_path": saved_audio_path}
|
| 191 |
+
|
| 192 |
+
except (json.JSONDecodeError, IndexError, KeyError, TypeError, ValueError) as e:
|
| 193 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Parsing successful API response failed: {type(e).__name__} - {e}")
|
| 194 |
+
print(f" Response Text (start): {response.text[:500]}...")
|
| 195 |
+
retries += 1
|
| 196 |
+
print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
|
| 197 |
+
time.sleep(API_RETRY_DELAY)
|
| 198 |
+
continue
|
| 199 |
+
except Exception as e:
|
| 200 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Unexpected error processing response: {e}")
|
| 201 |
+
print(traceback.format_exc())
|
| 202 |
+
retries += 1
|
| 203 |
+
print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
|
| 204 |
+
time.sleep(API_RETRY_DELAY)
|
| 205 |
+
continue
|
| 206 |
+
|
| 207 |
+
else: # Handle non-200 status codes
|
| 208 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): API returned status {response.status_code}. Response: {response.text[:500]}...")
|
| 209 |
+
retries += 1
|
| 210 |
+
if retries < API_MAX_RETRIES:
|
| 211 |
+
print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
|
| 212 |
+
time.sleep(API_RETRY_DELAY)
|
| 213 |
+
continue
|
| 214 |
+
else:
|
| 215 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after status {response.status_code}.")
|
| 216 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": f"[API ERROR: Status {response.status_code}]", "saved_audio_path": None}
|
| 217 |
+
|
| 218 |
+
except requests.exceptions.Timeout:
|
| 219 |
+
retries += 1
|
| 220 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): API Call Attempt {retries}/{API_MAX_RETRIES} timed out after {API_TIMEOUT}s.")
|
| 221 |
+
if retries < API_MAX_RETRIES:
|
| 222 |
+
time.sleep(API_RETRY_DELAY)
|
| 223 |
+
continue
|
| 224 |
+
else:
|
| 225 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after timeout.")
|
| 226 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Timeout]", "saved_audio_path": None}
|
| 227 |
+
except requests.exceptions.RequestException as e:
|
| 228 |
+
retries += 1
|
| 229 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Network/Request Error Attempt {retries}/{API_MAX_RETRIES}: {e}")
|
| 230 |
+
if retries < API_MAX_RETRIES:
|
| 231 |
+
time.sleep(API_RETRY_DELAY)
|
| 232 |
+
continue
|
| 233 |
+
else:
|
| 234 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after network error.")
|
| 235 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Network Error]", "saved_audio_path": None}
|
| 236 |
+
except Exception as e:
|
| 237 |
+
retries += 1
|
| 238 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Unexpected Error in Worker Loop Attempt {retries}/{API_MAX_RETRIES}: {type(e).__name__} - {e}")
|
| 239 |
+
print(traceback.format_exc())
|
| 240 |
+
if retries < API_MAX_RETRIES:
|
| 241 |
+
time.sleep(API_RETRY_DELAY)
|
| 242 |
+
continue
|
| 243 |
+
else:
|
| 244 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after unexpected error.")
|
| 245 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Unexpected Worker Error]", "saved_audio_path": None}
|
| 246 |
+
|
| 247 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Task failed after {API_MAX_RETRIES} attempts (Worker Loop Exited).")
|
| 248 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Max retries reached]", "saved_audio_path": None}
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# --- Checkpoint Saving Function (Unchanged) ---
|
| 252 |
+
def save_checkpoint(data_to_save, output_dir, dataset_features):
|
| 253 |
+
"""Saves the current state of the data to disk."""
|
| 254 |
+
if not data_to_save:
|
| 255 |
+
print("Checkpoint: No data available to save.")
|
| 256 |
+
return
|
| 257 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 258 |
+
print(f"\nCheckpoint: Saving {len(data_to_save)} rows to {output_dir}...")
|
| 259 |
+
try:
|
| 260 |
+
checkpoint_dataset = Dataset.from_list(data_to_save, features=dataset_features)
|
| 261 |
+
checkpoint_dataset.save_to_disk(output_dir)
|
| 262 |
+
print(f"Checkpoint: Saved successfully to {output_dir}")
|
| 263 |
+
except Exception as ckpt_save_e:
|
| 264 |
+
print(f"Error saving checkpoint dataset using datasets lib: {ckpt_save_e}")
|
| 265 |
+
output_jsonl_path = os.path.join(output_dir, "checkpoint_data.jsonl")
|
| 266 |
+
print(f"Attempting to save checkpoint as JSON lines to {output_jsonl_path}...")
|
| 267 |
+
try:
|
| 268 |
+
with open(output_jsonl_path, 'w', encoding='utf-8') as f:
|
| 269 |
+
for item in data_to_save:
|
| 270 |
+
serializable_item = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in item.items()}
|
| 271 |
+
f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n')
|
| 272 |
+
print(f"Checkpoint: Fallback save successful to {output_jsonl_path}")
|
| 273 |
+
except Exception as json_save_e:
|
| 274 |
+
print(f"Error saving checkpoint as JSON lines: {json_save_e}")
|
| 275 |
+
|
| 276 |
+
# --- Main Processing Logic (Retry Focus) ---
|
| 277 |
+
|
| 278 |
+
print(f"--- Starting Retry Script for {TARGET_MODEL_NAME} ---")
|
| 279 |
+
print(f"Loading dataset to retry from: {DATASET_DIR}")
|
| 280 |
+
|
| 281 |
+
try:
|
| 282 |
+
# Attempt to load the dataset from the specified directory
|
| 283 |
+
if not os.path.exists(DATASET_DIR) or \
|
| 284 |
+
not (os.path.exists(os.path.join(DATASET_DIR, "dataset_info.json")) or \
|
| 285 |
+
os.path.exists(os.path.join(DATASET_DIR, "state.json"))):
|
| 286 |
+
print(f"FATAL: Dataset directory not found or invalid: {DATASET_DIR}")
|
| 287 |
+
print("Please ensure this path points to the OUTPUT directory of the previous script run.")
|
| 288 |
+
exit(1)
|
| 289 |
+
|
| 290 |
+
dataset = load_from_disk(DATASET_DIR)
|
| 291 |
+
original_features = dataset.features # Store features for saving
|
| 292 |
+
print(f"Dataset loaded successfully with {len(dataset)} rows.")
|
| 293 |
+
|
| 294 |
+
except Exception as e:
|
| 295 |
+
print(f"FATAL: Error loading dataset from {DATASET_DIR}: {e}")
|
| 296 |
+
print(traceback.format_exc())
|
| 297 |
+
exit(1)
|
| 298 |
+
|
| 299 |
+
# Ensure audio output directory exists
|
| 300 |
+
os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True)
|
| 301 |
+
|
| 302 |
+
# --- Pre-calculation Step for Retrying Failed Tasks ---
|
| 303 |
+
print(f"Scanning dataset for failed {TARGET_MODEL_NAME} tasks to retry...")
|
| 304 |
+
tasks_to_process = []
|
| 305 |
+
# Use a list of dictionaries, which is mutable and easier for direct updates
|
| 306 |
+
updated_data = list(dataset) # Load data into memory for modification
|
| 307 |
+
|
| 308 |
+
for idx, row in enumerate(tqdm(updated_data, desc=f"Scanning for failed {TARGET_MODEL_NAME} tasks")):
|
| 309 |
+
for i in range(1, 4): # Check slots 1, 2, 3
|
| 310 |
+
model_key = f"model_{i}"
|
| 311 |
+
response_text_key = f"response_text_{i}"
|
| 312 |
+
prompt_text_key = f"prompt_text_{i}"
|
| 313 |
+
response_audio_key = f"response_audio_path_{i}"
|
| 314 |
+
|
| 315 |
+
model_assigned = row.get(model_key)
|
| 316 |
+
response_text_value = row.get(response_text_key)
|
| 317 |
+
|
| 318 |
+
# --- Core Retry Logic ---
|
| 319 |
+
# Check if the model assigned matches the one we are retrying
|
| 320 |
+
if model_assigned == TARGET_MODEL_NAME:
|
| 321 |
+
# Check if the response text indicates an error
|
| 322 |
+
is_error = False
|
| 323 |
+
if isinstance(response_text_value, str):
|
| 324 |
+
cleaned_text = response_text_value.strip()
|
| 325 |
+
if cleaned_text.startswith(ERROR_MARKERS): # Check if it starts with any error prefix
|
| 326 |
+
is_error = True
|
| 327 |
+
# Optional: You might also want to retry if text is None or empty,
|
| 328 |
+
# but the primary goal is retrying explicit errors.
|
| 329 |
+
# elif response_text_value is None or response_text_value == "":
|
| 330 |
+
# is_error = True # Uncomment if needed
|
| 331 |
+
|
| 332 |
+
if is_error:
|
| 333 |
+
print(f"\nInfo (Row {idx}, Slot {i}): Found failed task to retry. Current text: '{str(response_text_value)[:100]}...'") # Log finding
|
| 334 |
+
|
| 335 |
+
# --- Gather info needed for the task (same as original script) ---
|
| 336 |
+
question_audio_path = row.get('question_audio')
|
| 337 |
+
if not question_audio_path or not os.path.exists(question_audio_path):
|
| 338 |
+
print(f"Warning (Row {idx}, Slot {i}): Skipping retry - Missing or invalid 'question_audio' path: {question_audio_path}")
|
| 339 |
+
# Keep the old error message in updated_data for this case
|
| 340 |
+
continue # Skip this specific task retry
|
| 341 |
+
|
| 342 |
+
metadata_str = row.get('metadata', "{}")
|
| 343 |
+
source_dataset = row.get('source_dataset')
|
| 344 |
+
metadata = {}
|
| 345 |
+
try:
|
| 346 |
+
if metadata_str and isinstance(metadata_str, str): metadata = json.loads(metadata_str)
|
| 347 |
+
elif isinstance(metadata_str, dict): metadata = metadata_str
|
| 348 |
+
except json.JSONDecodeError: pass
|
| 349 |
+
|
| 350 |
+
history_messages = []
|
| 351 |
+
if source_dataset == 'ultra':
|
| 352 |
+
history_str = metadata.get('history', '')
|
| 353 |
+
if history_str: history_messages = parse_ultra_history(history_str)
|
| 354 |
+
|
| 355 |
+
unique_id = str(uuid.uuid4()).replace("-", "")
|
| 356 |
+
# Generate a *new* filename for the potential audio output
|
| 357 |
+
output_audio_filename = f"{TARGET_MODEL_NAME}_retry_r{idx}_s{i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}"
|
| 358 |
+
output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename)
|
| 359 |
+
|
| 360 |
+
task_info = {
|
| 361 |
+
"row_idx": idx,
|
| 362 |
+
"slot_idx": i,
|
| 363 |
+
"history_messages": history_messages,
|
| 364 |
+
"prompt_text": row.get(prompt_text_key, ""),
|
| 365 |
+
"question_text": row.get('question_text', ""),
|
| 366 |
+
"question_audio_path": question_audio_path,
|
| 367 |
+
"output_audio_filepath": output_audio_filepath,
|
| 368 |
+
}
|
| 369 |
+
tasks_to_process.append(task_info)
|
| 370 |
+
# Decide if you want to retry all failed slots in a row or just the first one found
|
| 371 |
+
# break # Uncomment if you only want to retry the FIRST failed slot per row
|
| 372 |
+
|
| 373 |
+
total_tasks = len(tasks_to_process)
|
| 374 |
+
if total_tasks == 0:
|
| 375 |
+
print(f"No failed {TARGET_MODEL_NAME} tasks found needing reprocessing in {DATASET_DIR}.")
|
| 376 |
+
exit(0)
|
| 377 |
+
|
| 378 |
+
print(f"Found {total_tasks} failed {TARGET_MODEL_NAME} tasks to retry.")
|
| 379 |
+
|
| 380 |
+
# --- Threaded Execution with Checkpointing (Identical structure to previous script) ---
|
| 381 |
+
print(f"Starting reprocessing with up to {MAX_WORKERS} worker threads...")
|
| 382 |
+
start_total_time = time.time()
|
| 383 |
+
tasks_completed = 0
|
| 384 |
+
tasks_failed_retries = 0 # Count failures during the *retry* attempt
|
| 385 |
+
completed_since_last_save = 0
|
| 386 |
+
|
| 387 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
| 388 |
+
# Ensure the correct worker function is called based on TARGET_MODEL_NAME
|
| 389 |
+
api_worker_function = call_gpt4o_api_worker # Default to GPT-4o
|
| 390 |
+
# Add logic here if TARGET_MODEL_NAME could be Qwen
|
| 391 |
+
# if TARGET_MODEL_NAME == "qwen_omni":
|
| 392 |
+
# api_worker_function = call_qwen_omni_api_worker # Assuming you have this function defined/imported
|
| 393 |
+
|
| 394 |
+
future_to_task = {executor.submit(api_worker_function, task): task for task in tasks_to_process}
|
| 395 |
+
|
| 396 |
+
for future in tqdm(concurrent.futures.as_completed(future_to_task), total=total_tasks, desc="Reprocessing tasks"):
|
| 397 |
+
task_info = future_to_task[future]
|
| 398 |
+
row_idx = task_info["row_idx"]
|
| 399 |
+
slot_idx = task_info["slot_idx"]
|
| 400 |
+
result = None
|
| 401 |
+
|
| 402 |
+
try:
|
| 403 |
+
result = future.result()
|
| 404 |
+
# --- Direct Update and Checkpointing Logic ---
|
| 405 |
+
response_text_key = f"response_text_{slot_idx}"
|
| 406 |
+
response_audio_key = f"response_audio_path_{slot_idx}"
|
| 407 |
+
|
| 408 |
+
if 0 <= row_idx < len(updated_data):
|
| 409 |
+
# Update the data in memory
|
| 410 |
+
updated_data[row_idx][response_text_key] = result["response_text"]
|
| 411 |
+
updated_data[row_idx][response_audio_key] = result["saved_audio_path"]
|
| 412 |
+
# Check if the *retry* attempt failed
|
| 413 |
+
if result["saved_audio_path"] is None or str(result["response_text"]).strip().startswith(ERROR_MARKERS):
|
| 414 |
+
tasks_failed_retries += 1
|
| 415 |
+
print(f"Warning (Row {row_idx}, Slot {i}): Retry attempt failed. Result: {str(result['response_text'])[:100]}...")
|
| 416 |
+
else:
|
| 417 |
+
print(f"Warning: Invalid row index {row_idx} encountered during result merge. Skipping update.")
|
| 418 |
+
tasks_failed_retries += 1
|
| 419 |
+
|
| 420 |
+
tasks_completed += 1
|
| 421 |
+
completed_since_last_save += 1
|
| 422 |
+
|
| 423 |
+
# Checkpoint saving
|
| 424 |
+
if completed_since_last_save >= CHECKPOINT_INTERVAL:
|
| 425 |
+
# Save the updated data back to the SAME directory
|
| 426 |
+
save_checkpoint(updated_data, DATASET_DIR, original_features)
|
| 427 |
+
completed_since_last_save = 0
|
| 428 |
+
|
| 429 |
+
except Exception as exc:
|
| 430 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Retry Task generated an unhandled exception: {exc}")
|
| 431 |
+
print(traceback.format_exc())
|
| 432 |
+
response_text_key = f"response_text_{slot_idx}"
|
| 433 |
+
response_audio_key = f"response_audio_path_{slot_idx}"
|
| 434 |
+
if 0 <= row_idx < len(updated_data):
|
| 435 |
+
updated_data[row_idx][response_text_key] = f"[ERROR: Retry Worker Crash - {exc}]" # Mark as worker crash during retry
|
| 436 |
+
updated_data[row_idx][response_audio_key] = None
|
| 437 |
+
else:
|
| 438 |
+
print(f"Warning: Invalid row index {row_idx} encountered during exception handling merge.")
|
| 439 |
+
|
| 440 |
+
tasks_failed_retries += 1
|
| 441 |
+
tasks_completed += 1
|
| 442 |
+
completed_since_last_save += 1
|
| 443 |
+
|
| 444 |
+
# Checkpoint saving after error
|
| 445 |
+
if completed_since_last_save >= CHECKPOINT_INTERVAL:
|
| 446 |
+
save_checkpoint(updated_data, DATASET_DIR, original_features)
|
| 447 |
+
completed_since_last_save = 0
|
| 448 |
+
|
| 449 |
+
end_total_time = time.time()
|
| 450 |
+
print("\n--- Reprocessing Complete ---")
|
| 451 |
+
print(f"Total tasks retried: {tasks_completed}")
|
| 452 |
+
print(f" Succeeded on retry: {tasks_completed - tasks_failed_retries}")
|
| 453 |
+
print(f" Failed on retry: {tasks_failed_retries}")
|
| 454 |
+
print(f"Total reprocessing time: {(end_total_time - start_total_time)/60:.2f} minutes")
|
| 455 |
+
|
| 456 |
+
# --- Final Save ---
|
| 457 |
+
# Save the final state of the updated data back to the original location
|
| 458 |
+
print("\nPerforming final save of the reprocessed dataset...")
|
| 459 |
+
save_checkpoint(updated_data, DATASET_DIR, original_features)
|
| 460 |
+
|
| 461 |
+
print(f"\nRetry script finished. Updated dataset saved in: {DATASET_DIR}")
|
r1-a/response_generation/kimi.py
ADDED
|
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
import re # For parsing history
|
| 5 |
+
import uuid # For generating unique filenames
|
| 6 |
+
import torch # Kimi might return tensors
|
| 7 |
+
import soundfile as sf # For saving Kimi audio output
|
| 8 |
+
import sys
|
| 9 |
+
from datasets import load_from_disk, Dataset, Features, Audio, Value
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
import datetime # For ETA formatting
|
| 12 |
+
from tqdm import tqdm # Import tqdm
|
| 13 |
+
import traceback # For detailed error printing
|
| 14 |
+
|
| 15 |
+
# --- Kimi-Audio Project Path Setup ---
|
| 16 |
+
# <--- *** IMPORTANT: Update this path to the PARENT directory containing the 'kimia_infer' folder *** --->
|
| 17 |
+
kimia_project_parent_dir = "/home/chenyifu/audio-r1/r1-a/response_generation/Kimi-Audio"
|
| 18 |
+
|
| 19 |
+
# Check if the path exists and add it to sys.path
|
| 20 |
+
if os.path.isdir(kimia_project_parent_dir):
|
| 21 |
+
if kimia_project_parent_dir not in sys.path:
|
| 22 |
+
sys.path.insert(0, kimia_project_parent_dir)
|
| 23 |
+
print(f"Added '{kimia_project_parent_dir}' to Python path.")
|
| 24 |
+
# Try importing KimiAudio only after potentially adding the path
|
| 25 |
+
try:
|
| 26 |
+
from kimia_infer.api.kimia import KimiAudio # Kimi model class
|
| 27 |
+
except ImportError as import_err:
|
| 28 |
+
print(f"Error: Could not import KimiAudio from '{kimia_project_parent_dir}'.")
|
| 29 |
+
print(f"ImportError: {import_err}")
|
| 30 |
+
print("Please ensure the 'kimia_infer' directory exists within the specified path and check dependencies.")
|
| 31 |
+
exit(1)
|
| 32 |
+
else:
|
| 33 |
+
print(f"Error: Kimi project parent directory not found: '{kimia_project_parent_dir}'")
|
| 34 |
+
print("Please update the 'kimia_project_parent_dir' variable in the script.")
|
| 35 |
+
exit(1)
|
| 36 |
+
|
| 37 |
+
# --- Configuration ---
|
| 38 |
+
load_dotenv() # Load environment variables if needed (e.g., API keys, though not typical for local Kimi)
|
| 39 |
+
|
| 40 |
+
# 1. Model & Tokenizer Setup (Kimi Specific)
|
| 41 |
+
KIMI_MODEL_NAME = "kimi_audio" # Identifier used in your dataset's model_N columns
|
| 42 |
+
KIMI_MODEL_PATH = "/home/chenyifu/audio-r1/r1-a/response_generation/Kimi-Audio/checkpoint/Kimi-Audio-7B-Instruct" # Path to your Kimi model checkpoint directory
|
| 43 |
+
# KIMI_DEVICE = 'cuda' # KimiAudio class likely handles device selection based on availability. Verify its internal logic if issues arise.
|
| 44 |
+
# KIMI_DTYPE = torch.bfloat16 # KimiAudio likely handles dtype internally.
|
| 45 |
+
|
| 46 |
+
# 2. Dataset Paths
|
| 47 |
+
INPUT_DATASET_DIR = "/home/chenyifu/audio-r1/r1-a/dataset/preference_sampling_tasks" # Original source
|
| 48 |
+
OUTPUT_DATASET_DIR = "/home/chenyifu/audio-r1/r1-a/dataset/preference_tasks_with_kimi" # Where Kimi processed data is saved/resumed from
|
| 49 |
+
|
| 50 |
+
# 3. Output Audio Configuration (Kimi Specific)
|
| 51 |
+
OUTPUT_AUDIO_ROOT_DIR = "/home/chenyifu/audio-r1/r1-a/generated_audio/kimi" # Where Kimi generated audio files are saved
|
| 52 |
+
OUTPUT_AUDIO_FORMAT = "wav"
|
| 53 |
+
OUTPUT_AUDIO_SAMPLERATE = 24000 # Kimi example uses 24kHz output. Confirm this matches your model's expected/native output SR.
|
| 54 |
+
|
| 55 |
+
# 4. Kimi Call Settings (Based on example, adjust as needed)
|
| 56 |
+
KIMI_SAMPLING_PARAMS = {
|
| 57 |
+
"audio_temperature": 0.8,
|
| 58 |
+
"audio_top_k": 10,
|
| 59 |
+
"text_temperature": 0.0, # 0.0 for deterministic text, increase for more variety
|
| 60 |
+
"text_top_k": 5, # Relevant if text_temperature > 0
|
| 61 |
+
"audio_repetition_penalty": 1.0,
|
| 62 |
+
"audio_repetition_window_size": 64,
|
| 63 |
+
"text_repetition_penalty": 1.0,
|
| 64 |
+
"text_repetition_window_size": 16,
|
| 65 |
+
# "max_new_tokens": 128 # Add if needed and supported by KimiAudio.generate
|
| 66 |
+
}
|
| 67 |
+
KIMI_OUTPUT_TYPE = "both" # Generate both audio and text
|
| 68 |
+
|
| 69 |
+
# 5. Periodic Save Settings
|
| 70 |
+
SAVE_EVERY_N_SAMPLES = 50 # Save after processing this many samples
|
| 71 |
+
|
| 72 |
+
# --- Helper Functions ---
|
| 73 |
+
|
| 74 |
+
def format_time(seconds):
|
| 75 |
+
"""Formats seconds into a human-readable string H:MM:SS"""
|
| 76 |
+
if seconds < 0:
|
| 77 |
+
return "N/A"
|
| 78 |
+
return str(datetime.timedelta(seconds=int(seconds)))
|
| 79 |
+
|
| 80 |
+
# REMOVED load_audio_minicpm - Kimi takes the path directly
|
| 81 |
+
|
| 82 |
+
def parse_ultra_history(history_str):
|
| 83 |
+
"""Parses the specific history string format from ultra metadata for Kimi."""
|
| 84 |
+
messages = []
|
| 85 |
+
# Relaxed pattern to capture content even if tags are slightly off or whitespace varies
|
| 86 |
+
pattern = re.compile(r"\[\s*(USER|ASSISTANT)\s*\]\s*([\s\S]*?)(?=\s*\[\s*(?:USER|ASSISTANT)\s*\]|$)")
|
| 87 |
+
matches = pattern.findall(history_str)
|
| 88 |
+
if not matches and history_str and history_str.strip():
|
| 89 |
+
# Simple fallback if standard pattern fails but there's content
|
| 90 |
+
if history_str.lower().startswith("user:") or history_str.lower().startswith("[user]"):
|
| 91 |
+
role = "user"
|
| 92 |
+
content = re.sub(r"^(user:|\[user\])\s*", "", history_str, flags=re.IGNORECASE).strip()
|
| 93 |
+
if content: messages.append({"role": role, "message_type": "text", "content": content}) # Add Kimi message_type
|
| 94 |
+
elif history_str.lower().startswith("assistant:") or history_str.lower().startswith("[assistant]"):
|
| 95 |
+
role = "assistant"
|
| 96 |
+
content = re.sub(r"^(assistant:|\[assistant\])\s*", "", history_str, flags=re.IGNORECASE).strip()
|
| 97 |
+
if content: messages.append({"role": role, "message_type": "text", "content": content}) # Add Kimi message_type
|
| 98 |
+
else:
|
| 99 |
+
print(f"Warning: Could not parse history string format: {history_str[:100]}...")
|
| 100 |
+
return messages # Return whatever was parsed, even if empty
|
| 101 |
+
|
| 102 |
+
for role_tag, content in matches:
|
| 103 |
+
role = role_tag.strip().lower()
|
| 104 |
+
cleaned_content = content.strip()
|
| 105 |
+
if cleaned_content:
|
| 106 |
+
# IMPORTANT: Add message_type='text' for Kimi history
|
| 107 |
+
messages.append({"role": role, "message_type": "text", "content": cleaned_content})
|
| 108 |
+
return messages
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# --- Kimi Model Interaction Function ---
|
| 112 |
+
def call_kimi_model(model, messages_input, sampling_params, output_audio_filepath, output_sample_rate):
|
| 113 |
+
"""Calls the Kimi-Audio model, saves audio, returns text and audio path."""
|
| 114 |
+
try:
|
| 115 |
+
# 1. Ensure Output Directory Exists
|
| 116 |
+
os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True)
|
| 117 |
+
|
| 118 |
+
# 2. Call Kimi's Generate Function
|
| 119 |
+
wav_output, text_output = model.generate(
|
| 120 |
+
messages_input,
|
| 121 |
+
**sampling_params,
|
| 122 |
+
output_type=KIMI_OUTPUT_TYPE # Use 'both'
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# 3. Process and Save Audio Output
|
| 126 |
+
saved_audio_path = None
|
| 127 |
+
if wav_output is not None and isinstance(wav_output, torch.Tensor) and wav_output.numel() > 0: # Check if tensor is not empty
|
| 128 |
+
try:
|
| 129 |
+
# Ensure tensor is on CPU, reshape (if needed, often view(-1)), convert to numpy
|
| 130 |
+
# Check KimiAudio output format - might already be 1D or need specific shape
|
| 131 |
+
audio_data = wav_output.detach().cpu().view(-1).numpy()
|
| 132 |
+
|
| 133 |
+
# Ensure data is float32 or int16 as supported by soundfile/WAV
|
| 134 |
+
if audio_data.dtype != 'float32':
|
| 135 |
+
# Attempt conversion, potentially scale if it's int
|
| 136 |
+
# print(f" Info: Converting Kimi audio output from {audio_data.dtype} to float32 for saving.")
|
| 137 |
+
if np.issubdtype(audio_data.dtype, np.integer):
|
| 138 |
+
# Scale integer types to [-1, 1] float range if necessary
|
| 139 |
+
# Example: if int16 -> audio_data = audio_data.astype(np.float32) / 32768.0
|
| 140 |
+
# Adjust scaling based on the actual integer range if known
|
| 141 |
+
audio_data = audio_data.astype(np.float32) # Simplest conversion, might need scaling
|
| 142 |
+
else:
|
| 143 |
+
audio_data = audio_data.astype(np.float32)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
sf.write(output_audio_filepath, audio_data, output_sample_rate)
|
| 147 |
+
|
| 148 |
+
# Check if file was actually created and has size
|
| 149 |
+
if os.path.exists(output_audio_filepath) and os.path.getsize(output_audio_filepath) > 100: # Check for a reasonable size threshold
|
| 150 |
+
saved_audio_path = output_audio_filepath
|
| 151 |
+
else:
|
| 152 |
+
print(f" Error: Kimi generate finished but output audio file seems empty or too small at {output_audio_filepath}")
|
| 153 |
+
if os.path.exists(output_audio_filepath):
|
| 154 |
+
try: os.remove(output_audio_filepath)
|
| 155 |
+
except OSError as rm_err: print(f" Warning: Could not remove empty/small file {output_audio_filepath}: {rm_err}")
|
| 156 |
+
except ImportError:
|
| 157 |
+
print("Error: NumPy library not found. Please install it (`pip install numpy`)")
|
| 158 |
+
return "[ERROR: NumPy Missing]", None # Indicate failure clearly
|
| 159 |
+
except Exception as sf_err:
|
| 160 |
+
print(f" Error saving Kimi audio output to {output_audio_filepath}: {sf_err}")
|
| 161 |
+
traceback.print_exc()
|
| 162 |
+
if os.path.exists(output_audio_filepath):
|
| 163 |
+
try: os.remove(output_audio_filepath)
|
| 164 |
+
except OSError as rm_err: print(f" Warning: Could not remove potentially corrupt file {output_audio_filepath}: {rm_err}")
|
| 165 |
+
elif wav_output is None:
|
| 166 |
+
print(" Warning: Kimi model did not return an audio tensor (wav_output is None).")
|
| 167 |
+
elif isinstance(wav_output, torch.Tensor) and wav_output.numel() == 0:
|
| 168 |
+
print(" Warning: Kimi model returned an empty audio tensor.")
|
| 169 |
+
else:
|
| 170 |
+
print(f" Warning: Kimi model returned unexpected audio output type: {type(wav_output)}. Expected torch.Tensor.")
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# 4. Process Text Output
|
| 174 |
+
response_text_cleaned = ""
|
| 175 |
+
if isinstance(text_output, str):
|
| 176 |
+
response_text_cleaned = text_output.strip()
|
| 177 |
+
elif text_output is not None:
|
| 178 |
+
response_text_cleaned = str(text_output).strip() # Convert just in case
|
| 179 |
+
else:
|
| 180 |
+
# If text is None but audio might exist, use a specific marker
|
| 181 |
+
if saved_audio_path:
|
| 182 |
+
response_text_cleaned = "[Audio Generated, No Text Output]"
|
| 183 |
+
else:
|
| 184 |
+
response_text_cleaned = "[ERROR: No Text Output]"
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# Return text (even if audio failed) and the path (or None)
|
| 188 |
+
return response_text_cleaned, saved_audio_path
|
| 189 |
+
|
| 190 |
+
except Exception as e:
|
| 191 |
+
print(f"\n --- Error during Kimi model call ---")
|
| 192 |
+
# Avoid printing potentially huge message list directly
|
| 193 |
+
first_message = messages_input[0] if messages_input else "N/A"
|
| 194 |
+
last_message_content = messages_input[-1]['content'] if messages_input else "N/A"
|
| 195 |
+
if isinstance(last_message_content, str) and len(last_message_content) > 100 :
|
| 196 |
+
last_message_preview = last_message_content[:100] + "..."
|
| 197 |
+
else:
|
| 198 |
+
last_message_preview = last_message_content
|
| 199 |
+
|
| 200 |
+
print(f" Input Messages Info: Count={len(messages_input)}, First={first_message}, Last Content Preview='{last_message_preview}'")
|
| 201 |
+
print(f" Exception Type: {type(e).__name__}")
|
| 202 |
+
print(f" Error Details: {e}")
|
| 203 |
+
print(" Traceback:")
|
| 204 |
+
traceback.print_exc()
|
| 205 |
+
print(" --- End Error Details ---")
|
| 206 |
+
|
| 207 |
+
# Attempt cleanup of potentially incomplete output file
|
| 208 |
+
if 'output_audio_filepath' in locals() and os.path.exists(output_audio_filepath):
|
| 209 |
+
try:
|
| 210 |
+
os.remove(output_audio_filepath)
|
| 211 |
+
except OSError as rm_err:
|
| 212 |
+
print(f" Warning: Could not remove file {output_audio_filepath} after error: {rm_err}")
|
| 213 |
+
# Return clear error markers
|
| 214 |
+
return "[ERROR: Kimi Model Call Failed]", None
|
| 215 |
+
|
| 216 |
+
# --- Dataset Saving Function (Modified for Kimi context) ---
|
| 217 |
+
def save_checkpoint(data_list, features, output_dir, fallback_dir=None):
|
| 218 |
+
"""Saves the current state of the data list as a Hugging Face Dataset."""
|
| 219 |
+
if not data_list:
|
| 220 |
+
print("\nSkipping checkpoint save: data list is empty.")
|
| 221 |
+
return
|
| 222 |
+
|
| 223 |
+
print(f"\nSaving checkpoint with {len(data_list)} rows to {output_dir}...")
|
| 224 |
+
try:
|
| 225 |
+
# Ensure the list contains dictionaries
|
| 226 |
+
data_to_save = [dict(item) for item in data_list]
|
| 227 |
+
|
| 228 |
+
# --- Feature Check/Adaptation (Optional but recommended) ---
|
| 229 |
+
# Sometimes saving fails if data types changed unexpectedly (e.g., None -> str)
|
| 230 |
+
# It's safer to create the Dataset *without* features first, then cast
|
| 231 |
+
temp_dataset = Dataset.from_list(data_to_save)
|
| 232 |
+
# Now cast to the original features, allowing potential None/type mismatches
|
| 233 |
+
# This might raise warnings but is often more robust than direct from_list with features
|
| 234 |
+
updated_dataset = temp_dataset.cast(features)
|
| 235 |
+
# --- End Feature Check ---
|
| 236 |
+
|
| 237 |
+
# Ensure output directory exists before saving
|
| 238 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 239 |
+
updated_dataset.save_to_disk(output_dir)
|
| 240 |
+
print("Checkpoint saved successfully.")
|
| 241 |
+
|
| 242 |
+
except Exception as e:
|
| 243 |
+
print(f"Error saving checkpoint dataset using save_to_disk to {output_dir}: {e}")
|
| 244 |
+
traceback.print_exc()
|
| 245 |
+
if fallback_dir:
|
| 246 |
+
# Use Kimi-specific name in fallback path
|
| 247 |
+
fallback_path = os.path.join(fallback_dir, f"updated_{KIMI_MODEL_NAME}_data_checkpoint_{int(time.time())}.jsonl")
|
| 248 |
+
print(f"Attempting to save data as JSON Lines fallback to: {fallback_path}")
|
| 249 |
+
try:
|
| 250 |
+
os.makedirs(fallback_dir, exist_ok=True)
|
| 251 |
+
with open(fallback_path, 'w', encoding='utf-8') as f:
|
| 252 |
+
# Reuse data_to_save which is already list of dicts
|
| 253 |
+
for item in data_to_save:
|
| 254 |
+
# Ensure all values are serializable
|
| 255 |
+
serializable_item = {}
|
| 256 |
+
for k, v in item.items():
|
| 257 |
+
if isinstance(v, (datetime.datetime, datetime.date)):
|
| 258 |
+
serializable_item[k] = v.isoformat()
|
| 259 |
+
elif isinstance(v, bytes):
|
| 260 |
+
serializable_item[k] = v.decode('utf-8', errors='ignore')
|
| 261 |
+
elif isinstance(v, torch.Tensor): # Handle potential tensors if not caught earlier
|
| 262 |
+
print(f" Warning: Found unexpected Tensor for key '{k}' in fallback save. Converting to list.")
|
| 263 |
+
serializable_item[k] = v.tolist()
|
| 264 |
+
elif not isinstance(v, (str, int, float, bool, list, dict, type(None))):
|
| 265 |
+
print(f" Warning: Converting non-standard type {type(v)} for key '{k}' to string for JSON fallback.")
|
| 266 |
+
serializable_item[k] = str(v)
|
| 267 |
+
else:
|
| 268 |
+
serializable_item[k] = v
|
| 269 |
+
try:
|
| 270 |
+
f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n')
|
| 271 |
+
except TypeError as json_type_err:
|
| 272 |
+
print(f" Skipping row due to JSON serialization error: {json_type_err} in item part: {k}={v}")
|
| 273 |
+
print("Fallback JSON Lines checkpoint saved successfully.")
|
| 274 |
+
except Exception as json_e:
|
| 275 |
+
print(f"Error saving fallback JSON Lines checkpoint: {json_e}")
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# =============================================
|
| 279 |
+
# --- Main Processing Logic ---
|
| 280 |
+
# =============================================
|
| 281 |
+
|
| 282 |
+
# --- STEP 1: Dataset Loading (Modified for Resumption) ---
|
| 283 |
+
print("="*30)
|
| 284 |
+
print("STEP 1: Loading Dataset")
|
| 285 |
+
print("="*30)
|
| 286 |
+
dataset = None
|
| 287 |
+
original_features = None # Initialize
|
| 288 |
+
|
| 289 |
+
# Check if the Kimi-specific output directory exists
|
| 290 |
+
if os.path.exists(OUTPUT_DATASET_DIR):
|
| 291 |
+
print(f"Found existing Kimi processed dataset directory at: {OUTPUT_DATASET_DIR}")
|
| 292 |
+
print("Attempting to load it to resume processing...")
|
| 293 |
+
try:
|
| 294 |
+
dataset = load_from_disk(OUTPUT_DATASET_DIR)
|
| 295 |
+
original_features = dataset.features # Get features from the loaded dataset
|
| 296 |
+
print(f"Resumed Kimi dataset loaded successfully with {len(dataset)} rows.")
|
| 297 |
+
print(f"Features from resumed dataset: {original_features}")
|
| 298 |
+
except Exception as e:
|
| 299 |
+
print(f"Warning: Error loading existing Kimi dataset from {OUTPUT_DATASET_DIR}: {e}")
|
| 300 |
+
traceback.print_exc()
|
| 301 |
+
print("Will attempt to load the original input dataset instead.")
|
| 302 |
+
dataset = None # Reset dataset variable
|
| 303 |
+
else:
|
| 304 |
+
print(f"No existing Kimi processed dataset found at {OUTPUT_DATASET_DIR}.")
|
| 305 |
+
print("Will attempt to load the original input dataset.")
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# If dataset is still None, load from the original input directory
|
| 309 |
+
if dataset is None:
|
| 310 |
+
print(f"\nLoading original input dataset from: {INPUT_DATASET_DIR}")
|
| 311 |
+
if not os.path.exists(INPUT_DATASET_DIR):
|
| 312 |
+
print(f"FATAL: Original input dataset directory not found at {INPUT_DATASET_DIR}")
|
| 313 |
+
exit(1)
|
| 314 |
+
try:
|
| 315 |
+
dataset = load_from_disk(INPUT_DATASET_DIR)
|
| 316 |
+
original_features = dataset.features # Get features from the input dataset
|
| 317 |
+
print(f"Original input dataset loaded successfully with {len(dataset)} rows.")
|
| 318 |
+
print(f"Features from input dataset: {original_features}")
|
| 319 |
+
except Exception as e:
|
| 320 |
+
print(f"FATAL: Error loading original input dataset from {INPUT_DATASET_DIR}: {e}")
|
| 321 |
+
traceback.print_exc()
|
| 322 |
+
exit(1)
|
| 323 |
+
|
| 324 |
+
# --- Ensure dataset and features were loaded ---
|
| 325 |
+
if dataset is None or original_features is None:
|
| 326 |
+
print("FATAL: Failed to load any dataset. Exiting.")
|
| 327 |
+
exit(1)
|
| 328 |
+
# --- End Dataset Loading ---
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# --- STEP 2: Pre-computation - Identify Kimi Tasks ---
|
| 332 |
+
print("\n" + "="*30)
|
| 333 |
+
print(f"STEP 2: Identifying '{KIMI_MODEL_NAME}' Tasks to Process")
|
| 334 |
+
print("="*30)
|
| 335 |
+
pkusafe_tasks_indices = []
|
| 336 |
+
other_tasks_indices = []
|
| 337 |
+
|
| 338 |
+
# Iterate through the loaded dataset structure
|
| 339 |
+
for idx, row in enumerate(dataset):
|
| 340 |
+
source_dataset = row.get('source_dataset')
|
| 341 |
+
processed_in_row = False # Flag to ensure we only pick one Kimi slot per row
|
| 342 |
+
for i in range(1, 4): # Check slots 1, 2, 3
|
| 343 |
+
model_key = f"model_{i}"
|
| 344 |
+
response_text_key = f"response_text_{i}"
|
| 345 |
+
# Check if the slot is assigned to Kimi and is NOT yet filled (text response missing)
|
| 346 |
+
is_target_model_task = row.get(model_key) == KIMI_MODEL_NAME
|
| 347 |
+
is_unfilled = not row.get(response_text_key) # True if None or empty string
|
| 348 |
+
|
| 349 |
+
if is_target_model_task and is_unfilled and not processed_in_row:
|
| 350 |
+
task_info = (idx, i) # Store tuple of (original_row_index, slot_index)
|
| 351 |
+
if source_dataset == 'pkusafe':
|
| 352 |
+
pkusafe_tasks_indices.append(task_info)
|
| 353 |
+
else:
|
| 354 |
+
other_tasks_indices.append(task_info)
|
| 355 |
+
processed_in_row = True # Mark row as having a task identified
|
| 356 |
+
|
| 357 |
+
# Combine lists, prioritizing pkusafe
|
| 358 |
+
tasks_to_process_indices = pkusafe_tasks_indices + other_tasks_indices
|
| 359 |
+
total_tasks_to_process = len(tasks_to_process_indices)
|
| 360 |
+
|
| 361 |
+
print(f"Found {len(pkusafe_tasks_indices)} 'pkusafe' tasks and {len(other_tasks_indices)} other tasks requiring '{KIMI_MODEL_NAME}' processing in the loaded dataset.")
|
| 362 |
+
print(f"Total tasks remaining to process: {total_tasks_to_process}")
|
| 363 |
+
|
| 364 |
+
if total_tasks_to_process == 0:
|
| 365 |
+
print(f"\nNo remaining tasks to process for {KIMI_MODEL_NAME} based on the loaded dataset.")
|
| 366 |
+
# Optionally, perform a final save for consistency
|
| 367 |
+
# print("Performing a final save to ensure consistency...")
|
| 368 |
+
# final_data_list = [dict(row) for row in dataset]
|
| 369 |
+
# fallback_save_dir_final = os.path.join(os.path.dirname(OUTPUT_DATASET_DIR), f"{KIMI_MODEL_NAME}_checkpoints_fallback")
|
| 370 |
+
# save_checkpoint(final_data_list, original_features, OUTPUT_DATASET_DIR, fallback_save_dir_final)
|
| 371 |
+
print("Exiting.")
|
| 372 |
+
exit(0)
|
| 373 |
+
# --- End Pre-computation Step ---
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
# --- STEP 3: Load Kimi Model ---
|
| 377 |
+
print("\n" + "="*30)
|
| 378 |
+
print(f"STEP 3: Loading {KIMI_MODEL_NAME} Model")
|
| 379 |
+
print("="*30)
|
| 380 |
+
try:
|
| 381 |
+
# Load Kimi model using the class imported earlier
|
| 382 |
+
model = KimiAudio(model_path=KIMI_MODEL_PATH, load_detokenizer=True) # Assuming detokenizer is needed based on example
|
| 383 |
+
print(f"{KIMI_MODEL_NAME} model loaded successfully from {KIMI_MODEL_PATH}.")
|
| 384 |
+
except NameError:
|
| 385 |
+
print("FATAL: KimiAudio class not defined. Import likely failed earlier.")
|
| 386 |
+
exit(1)
|
| 387 |
+
except Exception as e:
|
| 388 |
+
print(f"Error loading {KIMI_MODEL_NAME} model from {KIMI_MODEL_PATH}: {e}")
|
| 389 |
+
traceback.print_exc()
|
| 390 |
+
exit(1)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
# --- STEP 4: Prepare for Processing ---
|
| 394 |
+
print("\n" + "="*30)
|
| 395 |
+
print(f"STEP 4: Preparing for {KIMI_MODEL_NAME} Processing")
|
| 396 |
+
print("="*30)
|
| 397 |
+
# Create output directories if they don't exist
|
| 398 |
+
os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True)
|
| 399 |
+
os.makedirs(OUTPUT_DATASET_DIR, exist_ok=True)
|
| 400 |
+
# Define and create fallback directory for Kimi
|
| 401 |
+
fallback_save_dir = os.path.join(os.path.dirname(OUTPUT_DATASET_DIR), f"{KIMI_MODEL_NAME}_checkpoints_fallback")
|
| 402 |
+
os.makedirs(fallback_save_dir, exist_ok=True)
|
| 403 |
+
print(f"Audio outputs will be saved in: {OUTPUT_AUDIO_ROOT_DIR}")
|
| 404 |
+
print(f"Dataset checkpoints will be saved in: {OUTPUT_DATASET_DIR}")
|
| 405 |
+
print(f"Fallback checkpoints (JSONL) in: {fallback_save_dir}")
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
# Create a mutable list of dictionaries from the loaded dataset for updates
|
| 409 |
+
updated_data = [dict(row) for row in dataset] # Convert each row to a dictionary
|
| 410 |
+
|
| 411 |
+
tasks_processed_count = 0 # Count successful completions for average time calculation
|
| 412 |
+
start_total_time = time.time()
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
# --- STEP 5: Start Processing Loop ---
|
| 416 |
+
print("\n" + "="*30)
|
| 417 |
+
print(f"STEP 5: Starting {KIMI_MODEL_NAME} Processing Loop ({total_tasks_to_process} Tasks)")
|
| 418 |
+
print("="*30)
|
| 419 |
+
# Use tqdm for the progress bar, iterating over the identified task indices
|
| 420 |
+
pbar = tqdm(enumerate(tasks_to_process_indices), total=total_tasks_to_process, desc=f"Processing {KIMI_MODEL_NAME} Tasks")
|
| 421 |
+
for loop_idx, (row_idx, slot_i) in pbar:
|
| 422 |
+
# Get the row data *from our mutable list* using the original index
|
| 423 |
+
row = updated_data[row_idx] # This is already a dictionary
|
| 424 |
+
|
| 425 |
+
# Set description in tqdm dynamically
|
| 426 |
+
pbar.set_description(f"Processing Row {row_idx}, Slot {slot_i}")
|
| 427 |
+
|
| 428 |
+
prompt_text_key = f"prompt_text_{slot_i}"
|
| 429 |
+
response_text_key = f"response_text_{slot_i}"
|
| 430 |
+
response_audio_key = f"response_audio_path_{slot_i}"
|
| 431 |
+
model_key = f"model_{slot_i}"
|
| 432 |
+
|
| 433 |
+
# --- Sanity Check ---
|
| 434 |
+
if row.get(model_key) != KIMI_MODEL_NAME:
|
| 435 |
+
tqdm.write(f" Skipping Row {row_idx}, Slot {slot_i}: Model is '{row.get(model_key)}', not '{KIMI_MODEL_NAME}'.")
|
| 436 |
+
continue
|
| 437 |
+
if row.get(response_text_key):
|
| 438 |
+
tqdm.write(f" Skipping Row {row_idx}, Slot {slot_i}: Already has response text '{str(row.get(response_text_key))[:50]}...'.")
|
| 439 |
+
continue
|
| 440 |
+
|
| 441 |
+
# --- Prepare Kimi Model Inputs ---
|
| 442 |
+
prompt_text = row.get(prompt_text_key, "")
|
| 443 |
+
question_audio_path = row.get('question_audio')
|
| 444 |
+
metadata_str = row.get('metadata', "{}")
|
| 445 |
+
source_dataset = row.get('source_dataset')
|
| 446 |
+
|
| 447 |
+
# Check for essential input audio path validity
|
| 448 |
+
if not question_audio_path or not os.path.exists(question_audio_path):
|
| 449 |
+
tqdm.write(f" Error: Input audio path missing or invalid for Row {row_idx}: '{question_audio_path}'. Skipping model call.")
|
| 450 |
+
updated_data[row_idx][response_text_key] = "[ERROR: Missing Input Audio]"
|
| 451 |
+
updated_data[row_idx][response_audio_key] = None
|
| 452 |
+
continue # Move to the next task in the loop
|
| 453 |
+
|
| 454 |
+
# --- Construct Kimi `messages` list ---
|
| 455 |
+
kimi_messages = []
|
| 456 |
+
|
| 457 |
+
# 1. Parse History (if any)
|
| 458 |
+
if source_dataset == 'ultra' and metadata_str:
|
| 459 |
+
try:
|
| 460 |
+
metadata = json.loads(metadata_str)
|
| 461 |
+
history_str = metadata.get('history', '')
|
| 462 |
+
if history_str:
|
| 463 |
+
# Ensure history messages have 'message_type': 'text'
|
| 464 |
+
history_messages_parsed = parse_ultra_history(history_str)
|
| 465 |
+
kimi_messages.extend(history_messages_parsed)
|
| 466 |
+
except json.JSONDecodeError:
|
| 467 |
+
tqdm.write(f" Warning: Could not parse metadata JSON for row {row_idx}")
|
| 468 |
+
except Exception as hist_e:
|
| 469 |
+
tqdm.write(f" Warning: Error processing history for row {row_idx}: {hist_e}")
|
| 470 |
+
# Add elif blocks here for history parsing from other datasets if needed
|
| 471 |
+
|
| 472 |
+
# 2. Add Current User Turn (Text Prompt + Audio Path)
|
| 473 |
+
# Add text prompt first, if it exists and is not empty
|
| 474 |
+
if prompt_text and prompt_text.strip():
|
| 475 |
+
kimi_messages.append({"role": "user", "message_type": "text", "content": prompt_text.strip()})
|
| 476 |
+
# Add the user audio query using its path
|
| 477 |
+
kimi_messages.append({"role": "user", "message_type": "audio", "content": question_audio_path})
|
| 478 |
+
|
| 479 |
+
# Generate unique output audio filename
|
| 480 |
+
unique_id = str(uuid.uuid4())
|
| 481 |
+
output_audio_filename = f"{KIMI_MODEL_NAME}_row{row_idx}_slot{slot_i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}"
|
| 482 |
+
output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename)
|
| 483 |
+
|
| 484 |
+
# --- Call Kimi Model ---
|
| 485 |
+
# tqdm.write(f" Calling {KIMI_MODEL_NAME} for Row {row_idx}, Slot {slot_i}...") # Less verbose log
|
| 486 |
+
call_start_time = time.time()
|
| 487 |
+
response_text, saved_audio_path = call_kimi_model(
|
| 488 |
+
model,
|
| 489 |
+
kimi_messages,
|
| 490 |
+
KIMI_SAMPLING_PARAMS,
|
| 491 |
+
output_audio_filepath,
|
| 492 |
+
OUTPUT_AUDIO_SAMPLERATE
|
| 493 |
+
)
|
| 494 |
+
call_end_time = time.time()
|
| 495 |
+
audio_basename = os.path.basename(str(saved_audio_path)) if saved_audio_path else "None"
|
| 496 |
+
tqdm.write(f" Row {row_idx}, Slot {slot_i}: Finished in {call_end_time - call_start_time:.2f}s. Text: '{str(response_text)[:50]}...', Audio: {audio_basename}")
|
| 497 |
+
|
| 498 |
+
# Store results back into the main data list (updated_data)
|
| 499 |
+
updated_data[row_idx][response_text_key] = response_text # Store text/error marker
|
| 500 |
+
updated_data[row_idx][response_audio_key] = saved_audio_path # Store path or None
|
| 501 |
+
|
| 502 |
+
# Increment success counter based on successful generation (e.g., text isn't an error marker)
|
| 503 |
+
# Consider if audio generation failure should also mark task as failed.
|
| 504 |
+
# Current logic counts success if text seems okay.
|
| 505 |
+
if response_text is not None and not response_text.startswith("[ERROR"):
|
| 506 |
+
tasks_processed_count += 1
|
| 507 |
+
|
| 508 |
+
# --- Periodic Saving ---
|
| 509 |
+
processed_count_in_loop = loop_idx + 1
|
| 510 |
+
if processed_count_in_loop % SAVE_EVERY_N_SAMPLES == 0 or processed_count_in_loop == total_tasks_to_process:
|
| 511 |
+
save_checkpoint(updated_data, original_features, OUTPUT_DATASET_DIR, fallback_save_dir)
|
| 512 |
+
|
| 513 |
+
# --- STEP 6: Final Summary and Save ---
|
| 514 |
+
end_total_time = time.time()
|
| 515 |
+
print("\n" + "="*30)
|
| 516 |
+
print(f"STEP 6: {KIMI_MODEL_NAME} Processing Complete - Summary")
|
| 517 |
+
print("="*30)
|
| 518 |
+
print(f"Total tasks identified for processing in this run: {total_tasks_to_process}")
|
| 519 |
+
print(f"Total tasks successfully processed (generated text): {tasks_processed_count}") # Update definition if needed
|
| 520 |
+
total_duration = end_total_time - start_total_time
|
| 521 |
+
print(f"Total processing time for this run: {format_time(total_duration)}")
|
| 522 |
+
if tasks_processed_count > 0:
|
| 523 |
+
avg_time = total_duration / tasks_processed_count
|
| 524 |
+
print(f"Average time per successfully processed task in this run: {avg_time:.2f} seconds")
|
| 525 |
+
else:
|
| 526 |
+
print("Average time per task: N/A (no tasks successfully processed in this run)")
|
| 527 |
+
|
| 528 |
+
# --- Final Save ---
|
| 529 |
+
print("\nPerforming final save of the dataset...")
|
| 530 |
+
save_checkpoint(updated_data, original_features, OUTPUT_DATASET_DIR, fallback_save_dir)
|
| 531 |
+
|
| 532 |
+
print("\nScript finished.")
|
r1-a/response_generation/minicpm.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
import re # For parsing history
|
| 5 |
+
import uuid # For generating unique filenames
|
| 6 |
+
import random # For random voice selection
|
| 7 |
+
import torch # For MiniCPM-o
|
| 8 |
+
import librosa # For audio loading
|
| 9 |
+
from transformers import AutoModel, AutoTokenizer # For MiniCPM-o
|
| 10 |
+
from datasets import load_from_disk, Dataset, Features, Audio, Value # Import necessary types
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
import datetime # For ETA formatting
|
| 13 |
+
from tqdm import tqdm # Import tqdm
|
| 14 |
+
import traceback # For detailed error printing
|
| 15 |
+
|
| 16 |
+
# --- Configuration ---
|
| 17 |
+
load_dotenv()
|
| 18 |
+
|
| 19 |
+
# 1. Model & Tokenizer Setup
|
| 20 |
+
MINICPMO_MODEL_NAME = "minicpm" # Name used in the dataset to identify tasks for this model
|
| 21 |
+
MINICPMO_HF_ID = 'openbmb/MiniCPM-o-2_6'
|
| 22 |
+
MINICPMO_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 23 |
+
MINICPMO_DTYPE = torch.bfloat16
|
| 24 |
+
MINICPMO_ATTN_IMPL = 'sdpa'
|
| 25 |
+
|
| 26 |
+
# 2. Dataset Paths
|
| 27 |
+
INPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_sampling_tasks" # Original source
|
| 28 |
+
OUTPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_tasks_with_minicpmo" # Where processed data is saved/resumed from
|
| 29 |
+
|
| 30 |
+
# 3. Output Audio Configuration
|
| 31 |
+
OUTPUT_AUDIO_ROOT_DIR = "/root/autodl-tmp/audio-r1/r1-a/generated_audio/minicpmo" # Where generated audio files are saved
|
| 32 |
+
OUTPUT_AUDIO_FORMAT = "wav"
|
| 33 |
+
OUTPUT_AUDIO_SAMPLERATE = 16000
|
| 34 |
+
|
| 35 |
+
# --- !!! IMPORTANT: Update these paths to your actual reference voice files !!! ---
|
| 36 |
+
REF_VOICE_PATHS = {
|
| 37 |
+
"female": "/root/autodl-tmp/audio-r1/r1-a/response_generation/minicpm/MiniCPM-o/assets/input_examples/assistant_female_voice.wav",
|
| 38 |
+
"male": "/root/autodl-tmp/audio-r1/r1-a/response_generation/minicpm/MiniCPM-o/assets/input_examples/assistant_male_voice.wav",
|
| 39 |
+
"default_female": "/root/autodl-tmp/audio-r1/r1-a/response_generation/minicpm/MiniCPM-o/assets/input_examples/assistant_default_female_voice.wav"
|
| 40 |
+
}
|
| 41 |
+
# --- End Reference Voice Paths ---
|
| 42 |
+
# Check voice paths exist early
|
| 43 |
+
for key, path in REF_VOICE_PATHS.items():
|
| 44 |
+
if not os.path.exists(path):
|
| 45 |
+
print(f"FATAL ERROR: Reference voice file not found for '{key}': {path}")
|
| 46 |
+
print("Please ensure the reference voice files exist at the specified paths in REF_VOICE_PATHS.")
|
| 47 |
+
exit(1) # Exit early if critical files are missing
|
| 48 |
+
|
| 49 |
+
AVAILABLE_MINICPMO_VOICES = list(REF_VOICE_PATHS.keys())
|
| 50 |
+
|
| 51 |
+
# 4. MiniCPM-o Call Settings
|
| 52 |
+
MODEL_MAX_NEW_TOKENS = 128
|
| 53 |
+
MODEL_TEMPERATURE = 0.3
|
| 54 |
+
MODEL_SAMPLING = True
|
| 55 |
+
|
| 56 |
+
# 5. Periodic Save Settings
|
| 57 |
+
SAVE_EVERY_N_SAMPLES = 50 # Save after processing this many samples
|
| 58 |
+
|
| 59 |
+
# --- Helper Functions ---
|
| 60 |
+
|
| 61 |
+
def format_time(seconds):
|
| 62 |
+
"""Formats seconds into a human-readable string H:MM:SS"""
|
| 63 |
+
if seconds < 0:
|
| 64 |
+
return "N/A"
|
| 65 |
+
return str(datetime.timedelta(seconds=int(seconds)))
|
| 66 |
+
|
| 67 |
+
def load_audio_minicpm(audio_path, target_sr=16000):
|
| 68 |
+
"""Loads audio using librosa, handling potential errors."""
|
| 69 |
+
if not audio_path or not os.path.exists(audio_path):
|
| 70 |
+
# print(f"Warning: Audio file not found or path is empty: {audio_path}") # Less verbose
|
| 71 |
+
return None
|
| 72 |
+
try:
|
| 73 |
+
audio_array, sr = librosa.load(audio_path, sr=None, mono=True)
|
| 74 |
+
if sr != target_sr:
|
| 75 |
+
# print(f" Resampling audio from {sr} Hz to {target_sr} Hz...") # Less verbose
|
| 76 |
+
audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=target_sr)
|
| 77 |
+
return audio_array
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"\nWarning: Error loading/processing audio file {audio_path}: {e}")
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
def parse_ultra_history(history_str):
|
| 83 |
+
"""Parses the specific history string format from ultra metadata."""
|
| 84 |
+
messages = []
|
| 85 |
+
# Relaxed pattern to capture content even if tags are slightly off or whitespace varies
|
| 86 |
+
pattern = re.compile(r"\[\s*(USER|ASSISTANT)\s*\]\s*([\s\S]*?)(?=\s*\[\s*(?:USER|ASSISTANT)\s*\]|$)")
|
| 87 |
+
matches = pattern.findall(history_str)
|
| 88 |
+
if not matches and history_str and history_str.strip():
|
| 89 |
+
# Simple fallback if standard pattern fails but there's content
|
| 90 |
+
if history_str.lower().startswith("user:") or history_str.lower().startswith("[user]"):
|
| 91 |
+
role = "user"
|
| 92 |
+
content = re.sub(r"^(user:|\[user\])\s*", "", history_str, flags=re.IGNORECASE).strip()
|
| 93 |
+
if content: messages.append({"role": role, "content": content})
|
| 94 |
+
elif history_str.lower().startswith("assistant:") or history_str.lower().startswith("[assistant]"):
|
| 95 |
+
role = "assistant"
|
| 96 |
+
content = re.sub(r"^(assistant:|\[assistant\])\s*", "", history_str, flags=re.IGNORECASE).strip()
|
| 97 |
+
if content: messages.append({"role": role, "content": content})
|
| 98 |
+
else:
|
| 99 |
+
print(f"Warning: Could not parse history string format: {history_str[:100]}...")
|
| 100 |
+
return messages # Return whatever was parsed, even if empty
|
| 101 |
+
|
| 102 |
+
for role_tag, content in matches:
|
| 103 |
+
role = role_tag.strip().lower()
|
| 104 |
+
cleaned_content = content.strip()
|
| 105 |
+
if cleaned_content:
|
| 106 |
+
messages.append({"role": role, "content": cleaned_content})
|
| 107 |
+
# else: # Removed warning for empty content for brevity
|
| 108 |
+
# print(f"Warning: Empty content found for role {role_tag} in history.")
|
| 109 |
+
return messages
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# --- MiniCPM-o Model Interaction Function ---
|
| 113 |
+
def call_minicpmo_model(model, tokenizer, history_messages, prompt_text, question_audio_path, output_audio_filepath):
|
| 114 |
+
"""Calls the local MiniCPM-o model, saves audio, returns text and audio path."""
|
| 115 |
+
try:
|
| 116 |
+
# 1. Select and Load Random Reference Voice
|
| 117 |
+
selected_voice_key = random.choice(AVAILABLE_MINICPMO_VOICES)
|
| 118 |
+
ref_voice_path = REF_VOICE_PATHS[selected_voice_key]
|
| 119 |
+
ref_audio_array = load_audio_minicpm(ref_voice_path, target_sr=OUTPUT_AUDIO_SAMPLERATE)
|
| 120 |
+
if ref_audio_array is None:
|
| 121 |
+
print(f" Error: Failed to load reference voice: {ref_voice_path}")
|
| 122 |
+
return None, None # Signal failure
|
| 123 |
+
|
| 124 |
+
# 2. Generate System Prompt
|
| 125 |
+
sys_prompt = model.get_sys_prompt(ref_audio=ref_audio_array, mode='audio_assistant', language='en')
|
| 126 |
+
|
| 127 |
+
# 3. Load User Question Audio
|
| 128 |
+
user_audio_array = load_audio_minicpm(question_audio_path, target_sr=OUTPUT_AUDIO_SAMPLERATE)
|
| 129 |
+
if user_audio_array is None:
|
| 130 |
+
print(f" Error: Failed to load user question audio: {question_audio_path}")
|
| 131 |
+
return None, None # Signal failure
|
| 132 |
+
|
| 133 |
+
# 4. Construct User Message
|
| 134 |
+
user_message_content = []
|
| 135 |
+
if prompt_text and prompt_text.strip():
|
| 136 |
+
user_message_content.append(prompt_text.strip())
|
| 137 |
+
# Ensure user_audio_array is added only if loaded successfully
|
| 138 |
+
if user_audio_array is not None:
|
| 139 |
+
user_message_content.append(user_audio_array) # Add audio array
|
| 140 |
+
else:
|
| 141 |
+
print(" Warning: Proceeding without user audio due to loading error.")
|
| 142 |
+
# Optionally decide if you want to proceed without user audio or return error
|
| 143 |
+
# return None, None # If user audio is essential
|
| 144 |
+
|
| 145 |
+
user_message = {'role': 'user', 'content': user_message_content}
|
| 146 |
+
|
| 147 |
+
# 5. Construct Full Message List
|
| 148 |
+
msgs = [sys_prompt] + history_messages + [user_message]
|
| 149 |
+
|
| 150 |
+
# 6. Ensure Output Directory Exists
|
| 151 |
+
os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True)
|
| 152 |
+
|
| 153 |
+
# 7. Call Model's Chat Function
|
| 154 |
+
response_obj = model.chat(
|
| 155 |
+
msgs=msgs,
|
| 156 |
+
tokenizer=tokenizer,
|
| 157 |
+
sampling=MODEL_SAMPLING,
|
| 158 |
+
max_new_tokens=MODEL_MAX_NEW_TOKENS,
|
| 159 |
+
use_tts_template=True,
|
| 160 |
+
generate_audio=True,
|
| 161 |
+
temperature=MODEL_TEMPERATURE,
|
| 162 |
+
output_audio_path=output_audio_filepath # Model saves the audio directly
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# --- Extract text from the response object ---
|
| 166 |
+
response_text = None
|
| 167 |
+
if hasattr(response_obj, 'text'):
|
| 168 |
+
response_text = response_obj.text
|
| 169 |
+
elif hasattr(response_obj, 'content'):
|
| 170 |
+
response_text = response_obj.content
|
| 171 |
+
elif isinstance(response_obj, str):
|
| 172 |
+
response_text = response_obj
|
| 173 |
+
else:
|
| 174 |
+
print(f" Warning: Could not automatically extract text from model response object of type {type(response_obj)}. Response object dir: {dir(response_obj)}")
|
| 175 |
+
response_text = "[ERROR: Could not extract text]"
|
| 176 |
+
|
| 177 |
+
# Ensure response_text is a string before stripping
|
| 178 |
+
response_text_cleaned = ""
|
| 179 |
+
if isinstance(response_text, str):
|
| 180 |
+
response_text_cleaned = response_text.strip()
|
| 181 |
+
elif response_text is not None:
|
| 182 |
+
response_text_cleaned = str(response_text).strip()
|
| 183 |
+
|
| 184 |
+
# 8. Check if audio file was actually created by the model
|
| 185 |
+
if os.path.exists(output_audio_filepath) and os.path.getsize(output_audio_filepath) > 0: # Check size too
|
| 186 |
+
# Success: Return text and the path where the model saved the audio
|
| 187 |
+
return response_text_cleaned, output_audio_filepath
|
| 188 |
+
else:
|
| 189 |
+
print(f" Error: Model finished but output audio file not found or empty at {output_audio_filepath}")
|
| 190 |
+
# Attempt cleanup if file exists but is empty
|
| 191 |
+
if os.path.exists(output_audio_filepath):
|
| 192 |
+
try:
|
| 193 |
+
os.remove(output_audio_filepath)
|
| 194 |
+
except OSError as rm_err:
|
| 195 |
+
print(f" Warning: Could not remove empty file {output_audio_filepath}: {rm_err}")
|
| 196 |
+
return response_text_cleaned, None # Return text (if any) but signal audio failure
|
| 197 |
+
|
| 198 |
+
except Exception as e:
|
| 199 |
+
print(f"\n --- Error during MiniCPM-o model call for {os.path.basename(question_audio_path)} ---")
|
| 200 |
+
print(f" Exception Type: {type(e).__name__}")
|
| 201 |
+
print(f" Error Details: {e}")
|
| 202 |
+
print(" Traceback:")
|
| 203 |
+
traceback.print_exc()
|
| 204 |
+
print(" --- End Error Details ---")
|
| 205 |
+
# Attempt cleanup of potentially incomplete output file
|
| 206 |
+
if 'output_audio_filepath' in locals() and os.path.exists(output_audio_filepath):
|
| 207 |
+
try:
|
| 208 |
+
os.remove(output_audio_filepath)
|
| 209 |
+
except OSError as rm_err:
|
| 210 |
+
print(f" Warning: Could not remove file {output_audio_filepath} after error: {rm_err}")
|
| 211 |
+
return None, None # Indicate failure
|
| 212 |
+
|
| 213 |
+
# --- Dataset Saving Function ---
|
| 214 |
+
def save_checkpoint(data_list, features, output_dir, fallback_dir=None):
|
| 215 |
+
"""Saves the current state of the data list as a Hugging Face Dataset."""
|
| 216 |
+
if not data_list:
|
| 217 |
+
print("\nSkipping checkpoint save: data list is empty.")
|
| 218 |
+
return
|
| 219 |
+
|
| 220 |
+
print(f"\nSaving checkpoint with {len(data_list)} rows to {output_dir}...")
|
| 221 |
+
try:
|
| 222 |
+
# Ensure the list contains dictionaries, not Dataset rows or other objects
|
| 223 |
+
data_to_save = [dict(item) for item in data_list]
|
| 224 |
+
# Create dataset from the current list of dictionaries using original features
|
| 225 |
+
updated_dataset = Dataset.from_list(data_to_save, features=features)
|
| 226 |
+
# Ensure output directory exists before saving
|
| 227 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 228 |
+
updated_dataset.save_to_disk(output_dir)
|
| 229 |
+
print("Checkpoint saved successfully.")
|
| 230 |
+
except Exception as e:
|
| 231 |
+
print(f"Error saving checkpoint dataset using save_to_disk: {e}")
|
| 232 |
+
traceback.print_exc()
|
| 233 |
+
if fallback_dir:
|
| 234 |
+
fallback_path = os.path.join(fallback_dir, f"updated_minicpmo_data_checkpoint_{int(time.time())}.jsonl")
|
| 235 |
+
print(f"Attempting to save data as JSON Lines fallback to: {fallback_path}")
|
| 236 |
+
try:
|
| 237 |
+
os.makedirs(fallback_dir, exist_ok=True)
|
| 238 |
+
with open(fallback_path, 'w', encoding='utf-8') as f:
|
| 239 |
+
for item in data_to_save:
|
| 240 |
+
# Ensure all values are serializable
|
| 241 |
+
serializable_item = {}
|
| 242 |
+
for k, v in item.items():
|
| 243 |
+
if isinstance(v, (datetime.datetime, datetime.date)):
|
| 244 |
+
serializable_item[k] = v.isoformat()
|
| 245 |
+
elif isinstance(v, bytes):
|
| 246 |
+
serializable_item[k] = v.decode('utf-8', errors='ignore')
|
| 247 |
+
# Add handling for specific non-serializable types if they appear
|
| 248 |
+
elif not isinstance(v, (str, int, float, bool, list, dict, type(None))):
|
| 249 |
+
print(f" Warning: Converting non-standard type {type(v)} for key '{k}' to string for JSON fallback.")
|
| 250 |
+
serializable_item[k] = str(v)
|
| 251 |
+
else:
|
| 252 |
+
serializable_item[k] = v
|
| 253 |
+
try:
|
| 254 |
+
f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n')
|
| 255 |
+
except TypeError as json_type_err:
|
| 256 |
+
print(f" Skipping row due to JSON serialization error: {json_type_err} in item part: {k}={v}")
|
| 257 |
+
print("Fallback JSON Lines checkpoint saved successfully.")
|
| 258 |
+
except Exception as json_e:
|
| 259 |
+
print(f"Error saving fallback JSON Lines checkpoint: {json_e}")
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
# =============================================
|
| 263 |
+
# --- Main Processing Logic ---
|
| 264 |
+
# =============================================
|
| 265 |
+
|
| 266 |
+
# --- Dataset Loading (Modified for Resumption) ---
|
| 267 |
+
print("="*30)
|
| 268 |
+
print("STEP 1: Loading Dataset")
|
| 269 |
+
print("="*30)
|
| 270 |
+
dataset = None
|
| 271 |
+
original_features = None # Initialize
|
| 272 |
+
|
| 273 |
+
if os.path.exists(OUTPUT_DATASET_DIR):
|
| 274 |
+
print(f"Found existing processed dataset directory at: {OUTPUT_DATASET_DIR}")
|
| 275 |
+
print("Attempting to load it to resume processing...")
|
| 276 |
+
try:
|
| 277 |
+
# Need write permissions check sometimes? If saving fails later.
|
| 278 |
+
dataset = load_from_disk(OUTPUT_DATASET_DIR)
|
| 279 |
+
original_features = dataset.features # Get features from the loaded dataset
|
| 280 |
+
print(f"Resumed dataset loaded successfully with {len(dataset)} rows.")
|
| 281 |
+
print(f"Features from resumed dataset: {original_features}")
|
| 282 |
+
except Exception as e:
|
| 283 |
+
print(f"Warning: Error loading existing dataset from {OUTPUT_DATASET_DIR}: {e}")
|
| 284 |
+
traceback.print_exc()
|
| 285 |
+
print("Will attempt to load the original input dataset instead.")
|
| 286 |
+
dataset = None # Reset dataset variable
|
| 287 |
+
else:
|
| 288 |
+
print(f"No existing processed dataset found at {OUTPUT_DATASET_DIR}.")
|
| 289 |
+
print("Will attempt to load the original input dataset.")
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# If dataset is still None (either output dir didn't exist or loading it failed), load from input
|
| 293 |
+
if dataset is None:
|
| 294 |
+
print(f"\nLoading original input dataset from: {INPUT_DATASET_DIR}")
|
| 295 |
+
if not os.path.exists(INPUT_DATASET_DIR):
|
| 296 |
+
print(f"FATAL: Original input dataset directory not found at {INPUT_DATASET_DIR}")
|
| 297 |
+
exit(1)
|
| 298 |
+
try:
|
| 299 |
+
dataset = load_from_disk(INPUT_DATASET_DIR)
|
| 300 |
+
original_features = dataset.features # Get features from the input dataset
|
| 301 |
+
print(f"Original input dataset loaded successfully with {len(dataset)} rows.")
|
| 302 |
+
print(f"Features from input dataset: {original_features}")
|
| 303 |
+
except Exception as e:
|
| 304 |
+
print(f"FATAL: Error loading original input dataset from {INPUT_DATASET_DIR}: {e}")
|
| 305 |
+
traceback.print_exc()
|
| 306 |
+
exit(1)
|
| 307 |
+
|
| 308 |
+
# --- Ensure dataset was loaded ---
|
| 309 |
+
if dataset is None or original_features is None:
|
| 310 |
+
print("FATAL: Failed to load any dataset. Exiting.")
|
| 311 |
+
exit(1)
|
| 312 |
+
# --- End Dataset Loading Modification ---
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
# --- Pre-computation Step: Identify and Prioritize Tasks ---
|
| 316 |
+
print("\n" + "="*30)
|
| 317 |
+
print("STEP 2: Identifying Tasks to Process")
|
| 318 |
+
print("="*30)
|
| 319 |
+
# NO CHANGES NEEDED HERE. This logic will now run on the dataset loaded above
|
| 320 |
+
# (which could be the original input or the partially processed output).
|
| 321 |
+
# It correctly identifies tasks where model is 'minicpm' and response_text is missing.
|
| 322 |
+
pkusafe_tasks_indices = []
|
| 323 |
+
other_tasks_indices = []
|
| 324 |
+
|
| 325 |
+
# Iterate through the loaded dataset structure
|
| 326 |
+
for idx, row in enumerate(dataset):
|
| 327 |
+
source_dataset = row.get('source_dataset')
|
| 328 |
+
processed_in_row = False # Flag to ensure we only pick one slot per row initially
|
| 329 |
+
for i in range(1, 4): # Check slots 1, 2, 3
|
| 330 |
+
model_key = f"model_{i}"
|
| 331 |
+
response_text_key = f"response_text_{i}"
|
| 332 |
+
# Check if the slot is assigned to minicpm and is NOT yet filled
|
| 333 |
+
is_minicpm_task = row.get(model_key) == MINICPMO_MODEL_NAME
|
| 334 |
+
# Crucially, check if the response text field is missing or empty in the loaded data
|
| 335 |
+
is_unfilled = not row.get(response_text_key) # True if None or empty string
|
| 336 |
+
|
| 337 |
+
if is_minicpm_task and is_unfilled and not processed_in_row:
|
| 338 |
+
task_info = (idx, i) # Store tuple of (original_row_index, slot_index)
|
| 339 |
+
if source_dataset == 'pkusafe':
|
| 340 |
+
pkusafe_tasks_indices.append(task_info)
|
| 341 |
+
else:
|
| 342 |
+
other_tasks_indices.append(task_info)
|
| 343 |
+
processed_in_row = True # Mark as processed for this row for task identification
|
| 344 |
+
|
| 345 |
+
# Combine lists, prioritizing pkusafe
|
| 346 |
+
tasks_to_process_indices = pkusafe_tasks_indices + other_tasks_indices
|
| 347 |
+
total_tasks_to_process = len(tasks_to_process_indices)
|
| 348 |
+
|
| 349 |
+
print(f"Found {len(pkusafe_tasks_indices)} 'pkusafe' tasks and {len(other_tasks_indices)} other tasks requiring '{MINICPMO_MODEL_NAME}' processing in the loaded dataset.")
|
| 350 |
+
print(f"Total tasks remaining to process: {total_tasks_to_process}")
|
| 351 |
+
|
| 352 |
+
if total_tasks_to_process == 0:
|
| 353 |
+
print("\nNo remaining tasks to process for MiniCPM-o based on the loaded dataset.")
|
| 354 |
+
# Optionally, perform a final save here if you want ensure the output dir reflects the 'completed' state
|
| 355 |
+
# print("Performing a final save to ensure consistency...")
|
| 356 |
+
# final_data_list = [dict(row) for row in dataset] # Convert dataset rows back to dicts
|
| 357 |
+
# fallback_save_dir_final = os.path.join(os.path.dirname(OUTPUT_DATASET_DIR), "minicpmo_checkpoints_fallback")
|
| 358 |
+
# save_checkpoint(final_data_list, original_features, OUTPUT_DATASET_DIR, fallback_save_dir_final)
|
| 359 |
+
print("Exiting.")
|
| 360 |
+
exit(0)
|
| 361 |
+
# --- End Pre-computation Step ---
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
# --- Load Model (Only if tasks exist) ---
|
| 365 |
+
print("\n" + "="*30)
|
| 366 |
+
print("STEP 3: Loading Model")
|
| 367 |
+
print("="*30)
|
| 368 |
+
print(f"Loading MiniCPM-o model ({MINICPMO_HF_ID}) and tokenizer...")
|
| 369 |
+
try:
|
| 370 |
+
model = AutoModel.from_pretrained(
|
| 371 |
+
MINICPMO_HF_ID,
|
| 372 |
+
trust_remote_code=True,
|
| 373 |
+
attn_implementation=MINICPMO_ATTN_IMPL,
|
| 374 |
+
torch_dtype=MINICPMO_DTYPE
|
| 375 |
+
)
|
| 376 |
+
model = model.eval().to(MINICPMO_DEVICE)
|
| 377 |
+
tokenizer = AutoTokenizer.from_pretrained(MINICPMO_HF_ID, trust_remote_code=True)
|
| 378 |
+
|
| 379 |
+
print("Initializing TTS...")
|
| 380 |
+
model.init_tts()
|
| 381 |
+
model.tts.float() # Use float32 for TTS stability
|
| 382 |
+
print(f"Model and TTS initialized successfully on {MINICPMO_DEVICE}.")
|
| 383 |
+
except Exception as e:
|
| 384 |
+
print(f"Error loading MiniCPM-o model or tokenizer: {e}")
|
| 385 |
+
traceback.print_exc()
|
| 386 |
+
exit(1)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
# --- Prepare for Processing ---
|
| 390 |
+
# Create output directory for MiniCPM-o audio if it doesn't exist
|
| 391 |
+
os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True)
|
| 392 |
+
# Ensure the main output dataset directory exists for saving checkpoints
|
| 393 |
+
os.makedirs(OUTPUT_DATASET_DIR, exist_ok=True)
|
| 394 |
+
# Define and create fallback directory
|
| 395 |
+
fallback_save_dir = os.path.join(os.path.dirname(OUTPUT_DATASET_DIR), "minicpmo_checkpoints_fallback")
|
| 396 |
+
os.makedirs(fallback_save_dir, exist_ok=True)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
# Create a mutable list of dictionaries from the loaded dataset for updates
|
| 400 |
+
# This is crucial as Hugging Face datasets are typically immutable
|
| 401 |
+
updated_data = [dict(row) for row in dataset] # Convert each row to a dictionary
|
| 402 |
+
|
| 403 |
+
tasks_processed_count = 0 # Count successful completions for average time calculation
|
| 404 |
+
start_total_time = time.time()
|
| 405 |
+
|
| 406 |
+
print("\n" + "="*30)
|
| 407 |
+
print(f"STEP 4: Starting MiniCPM-o Processing for {total_tasks_to_process} Tasks")
|
| 408 |
+
print("="*30)
|
| 409 |
+
# Use tqdm for the progress bar, iterating over the identified task indices
|
| 410 |
+
pbar = tqdm(enumerate(tasks_to_process_indices), total=total_tasks_to_process, desc="Processing MiniCPM-o Tasks")
|
| 411 |
+
for loop_idx, (row_idx, slot_i) in pbar:
|
| 412 |
+
# Get the row data *from our mutable list* using the original index
|
| 413 |
+
row = updated_data[row_idx] # This is already a dictionary
|
| 414 |
+
|
| 415 |
+
# Set description in tqdm dynamically
|
| 416 |
+
pbar.set_description(f"Processing Row {row_idx}, Slot {slot_i}")
|
| 417 |
+
|
| 418 |
+
prompt_text_key = f"prompt_text_{slot_i}"
|
| 419 |
+
response_text_key = f"response_text_{slot_i}"
|
| 420 |
+
response_audio_key = f"response_audio_path_{slot_i}"
|
| 421 |
+
model_key = f"model_{slot_i}" # Get model key for verification
|
| 422 |
+
|
| 423 |
+
# --- Sanity Check: Ensure this is still a valid MiniCPM-o task ---
|
| 424 |
+
# (This might be redundant if identification was perfect, but good for safety)
|
| 425 |
+
if row.get(model_key) != MINICPMO_MODEL_NAME:
|
| 426 |
+
tqdm.write(f" Skipping Row {row_idx}, Slot {slot_i}: Model is no longer '{MINICPMO_MODEL_NAME}'.")
|
| 427 |
+
continue
|
| 428 |
+
if row.get(response_text_key): # Check again if it got filled somehow concurrently (unlikely here)
|
| 429 |
+
tqdm.write(f" Skipping Row {row_idx}, Slot {slot_i}: Already has response text '{str(row.get(response_text_key))[:50]}...'.")
|
| 430 |
+
continue
|
| 431 |
+
|
| 432 |
+
# --- Prepare Model Inputs ---
|
| 433 |
+
prompt_text = row.get(prompt_text_key, "")
|
| 434 |
+
question_audio_path = row.get('question_audio')
|
| 435 |
+
metadata_str = row.get('metadata', "{}")
|
| 436 |
+
source_dataset = row.get('source_dataset') # Used for history parsing
|
| 437 |
+
|
| 438 |
+
# Basic check for essential input audio
|
| 439 |
+
if not question_audio_path or not os.path.exists(question_audio_path):
|
| 440 |
+
tqdm.write(f" Error: Input audio path missing or invalid for Row {row_idx}: '{question_audio_path}'. Skipping model call.")
|
| 441 |
+
# Update the specific row in the list (mark as failed/skipped)
|
| 442 |
+
updated_data[row_idx][response_text_key] = "[ERROR: Missing Input Audio]"
|
| 443 |
+
updated_data[row_idx][response_audio_key] = None
|
| 444 |
+
continue # Move to the next task
|
| 445 |
+
|
| 446 |
+
# Parse History
|
| 447 |
+
history_messages = []
|
| 448 |
+
if source_dataset == 'ultra' and metadata_str:
|
| 449 |
+
try:
|
| 450 |
+
metadata = json.loads(metadata_str)
|
| 451 |
+
history_str = metadata.get('history', '')
|
| 452 |
+
if history_str:
|
| 453 |
+
history_messages = parse_ultra_history(history_str)
|
| 454 |
+
except json.JSONDecodeError:
|
| 455 |
+
tqdm.write(f" Warning: Could not parse metadata JSON for row {row_idx}")
|
| 456 |
+
except Exception as hist_e:
|
| 457 |
+
tqdm.write(f" Warning: Error processing history for row {row_idx}: {hist_e}")
|
| 458 |
+
# Add elif blocks here if other datasets have different history formats in metadata
|
| 459 |
+
|
| 460 |
+
# Generate unique output audio filename
|
| 461 |
+
unique_id = str(uuid.uuid4())
|
| 462 |
+
output_audio_filename = f"minicpmo_row{row_idx}_slot{slot_i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}"
|
| 463 |
+
output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename)
|
| 464 |
+
|
| 465 |
+
# --- Call Model ---
|
| 466 |
+
# tqdm.write(f" Calling model for Row {row_idx}, Slot {slot_i} (Source: {source_dataset}). Output: {output_audio_filepath}") # More verbose
|
| 467 |
+
call_start_time = time.time()
|
| 468 |
+
response_text, saved_audio_path = call_minicpmo_model(
|
| 469 |
+
model,
|
| 470 |
+
tokenizer,
|
| 471 |
+
history_messages,
|
| 472 |
+
prompt_text,
|
| 473 |
+
question_audio_path,
|
| 474 |
+
output_audio_filepath
|
| 475 |
+
)
|
| 476 |
+
call_end_time = time.time()
|
| 477 |
+
tqdm.write(f" Row {row_idx}, Slot {slot_i}: Finished in {call_end_time - call_start_time:.2f}s. Text: '{str(response_text)[:50]}...', Audio: {os.path.basename(str(saved_audio_path))}")
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
# Store results directly into the list item (updated_data)
|
| 481 |
+
updated_data[row_idx][response_text_key] = response_text if response_text is not None else "[ERROR: Model Call Failed]"
|
| 482 |
+
updated_data[row_idx][response_audio_key] = saved_audio_path # Will be None if audio saving/generation failed
|
| 483 |
+
|
| 484 |
+
if response_text is not None and saved_audio_path is not None: # Count as successfully processed only if both text and audio are generated
|
| 485 |
+
tasks_processed_count += 1
|
| 486 |
+
|
| 487 |
+
# --- Periodic Saving ---
|
| 488 |
+
# Save after processing N samples (using loop_idx + 1 because index is 0-based)
|
| 489 |
+
# Also save on the very last iteration
|
| 490 |
+
processed_count_in_loop = loop_idx + 1
|
| 491 |
+
if processed_count_in_loop % SAVE_EVERY_N_SAMPLES == 0 or processed_count_in_loop == total_tasks_to_process:
|
| 492 |
+
save_checkpoint(updated_data, original_features, OUTPUT_DATASET_DIR, fallback_save_dir)
|
| 493 |
+
|
| 494 |
+
# Optional small delay if needed for hardware cooling, etc.
|
| 495 |
+
# time.sleep(0.1)
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
# --- Final Summary ---
|
| 499 |
+
end_total_time = time.time()
|
| 500 |
+
print("\n" + "="*30)
|
| 501 |
+
print("STEP 5: Processing Complete - Summary")
|
| 502 |
+
print("="*30)
|
| 503 |
+
print(f"Total tasks identified for processing in this run: {total_tasks_to_process}")
|
| 504 |
+
print(f"Total tasks successfully processed (generated text & audio) in this run: {tasks_processed_count}")
|
| 505 |
+
total_duration = end_total_time - start_total_time
|
| 506 |
+
print(f"Total processing time for this run: {format_time(total_duration)}")
|
| 507 |
+
if tasks_processed_count > 0:
|
| 508 |
+
avg_time = total_duration / tasks_processed_count
|
| 509 |
+
print(f"Average time per successfully processed task in this run: {avg_time:.2f} seconds")
|
| 510 |
+
else:
|
| 511 |
+
print("Average time per task: N/A (no tasks successfully processed in this run)")
|
| 512 |
+
|
| 513 |
+
# --- Final Save ---
|
| 514 |
+
# This ensures the very last state is saved, even if the last iteration didn't trigger the periodic save exactly.
|
| 515 |
+
# It might be redundant if SAVE_EVERY_N_SAMPLES aligns perfectly, but it's safe to include.
|
| 516 |
+
print("\nPerforming final save of the dataset...")
|
| 517 |
+
save_checkpoint(updated_data, original_features, OUTPUT_DATASET_DIR, fallback_save_dir)
|
| 518 |
+
|
| 519 |
+
print("\nScript finished.")
|
r1-a/response_generation/minicpm/MiniCPM-o/.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.bk
|
| 2 |
+
__pycache__
|
| 3 |
+
.DS_Store
|
r1-a/response_generation/minicpm/MiniCPM-o/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2024 OpenBMB
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
r1-a/response_generation/minicpm/MiniCPM-o/README.md
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
r1-a/response_generation/minicpm/MiniCPM-o/README_zh.md
ADDED
|
@@ -0,0 +1,2524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
<img src="./assets/MiniCPM-o.png" width="300em" ></img>
|
| 4 |
+
|
| 5 |
+
**端侧可用的 GPT-4o 级视觉、语音、多模态实时流式大模型**
|
| 6 |
+
|
| 7 |
+
<strong>中文 |
|
| 8 |
+
[English](./README.md)</strong>
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
<span style="display: inline-flex; align-items: center; margin-right: 2px;">
|
| 13 |
+
<a href="docs/wechat.md" target="_blank"> 微信社区</a> |
|
| 14 |
+
</span>
|
| 15 |
+
<span style="display: inline-flex; align-items: center; margin-left: 2px;">
|
| 16 |
+
MiniCPM-V <a href="docs/best_practice_summary_zh.md" target="_blank"> 📖 最佳实践</a>
|
| 17 |
+
</span>
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
<p align="center">
|
| 21 |
+
MiniCPM-o 2.6 <a href="https://huggingface.co/openbmb/MiniCPM-o-2_6">🤗</a> <a href="https://minicpm-omni-webdemo-us.modelbest.cn/"> 🤖</a> | MiniCPM-V 2.6 <a href="https://huggingface.co/openbmb/MiniCPM-V-2_6">🤗</a> <a href="http://120.92.209.146:8887/">🤖</a> |
|
| 22 |
+
📄 技术报告 [<a href="https://openbmb.notion.site/MiniCPM-o-2-6-GPT-4o-188ede1b7a558084b3aedd669cb80730">中文</a>/<a href="https://openbmb.notion.site/MiniCPM-o-2-6-A-GPT-4o-Level-MLLM-for-Vision-Speech-and-Multimodal-Live-Streaming-on-Your-Phone-185ede1b7a558042b5d5e45e6b237da9">English</a>]
|
| 23 |
+
</p>
|
| 24 |
+
|
| 25 |
+
</div>
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
**MiniCPM-o** 是从 MiniCPM-V 升级的最新端侧多模态大模型系列。该系列模型可以以端到端方式,接受图像、视频、文本、音频作为输入,并生成高质量文本和语音输出。自2024年2月以来,我们以实现高性能和高效部署为目标,发布了6个版本的模型。目前系列中最值得关注的模型包括:
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
- **MiniCPM-o 2.6**: 🔥🔥🔥 MiniCPM-o 系列的最新、性能最佳模型。总参数量 8B,**视觉、语音和多模态流式能力达到了 GPT-4o-202405 级别**,是开源社区中模态支持最丰富、性能最佳的模型之一。在新的语音模式中,MiniCPM-o 2.6 **支持可配置声音的中英双语语音对话,还具备情感/语速/风格控制、端到端声音克隆、角色扮演等进阶能力**。模型也进一步提升了 MiniCPM-V 2.6 的 **OCR、可信行为、多语言支持和视频理解等视觉能力**。基于其领先的视觉 token 密度,MiniCPM-V 2.6 成为了**首个支持在 iPad 等端侧设备上进行多模态实时流式交互**的多模态大模型。
|
| 32 |
+
|
| 33 |
+
- **MiniCPM-V 2.6**: MiniCPM-V 系列中性能最佳的模型。总参数量 8B,单图、多图和视频理解性能**超越了 GPT-4V**。它取得了优于 **GPT-4o mini、Gemini 1.5 Pro 和 Claude 3.5 Sonnet**等的单图理解表现,并成为了首个支持在 iPad 等端侧设备上进行实时视频理解的多模态大模型。
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
## 更新日志 <!-- omit in toc -->
|
| 37 |
+
|
| 38 |
+
#### 📌 置顶
|
| 39 |
+
|
| 40 |
+
* [2025.03.01] 🚀🚀🚀 MiniCPM-o 系列的对齐技术 RLAIF-V 被 CVPR 2025 接收了!其[代码](https://github.com/RLHF-V/RLAIF-V)、[数据](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset)、[论文](https://arxiv.org/abs/2405.17220)均已开源。
|
| 41 |
+
|
| 42 |
+
* [2025.01.24] 📢📢📢 MiniCPM-o 2.6 技术报告已发布! 欢迎点击[这里](https://openbmb.notion.site/MiniCPM-o-2-6-A-GPT-4o-Level-MLLM-for-Vision-Speech-and-Multimodal-Live-Streaming-on-Your-Phone-185ede1b7a558042b5d5e45e6b237da9)查看.
|
| 43 |
+
|
| 44 |
+
* [2025.01.23] 💡💡💡 MiniCPM-o 2.6 现在已被北大团队开发的 [Align-Anything](https://github.com/PKU-Alignment/align-anything),一个用于对齐全模态大模型的框架集成,支持 DPO 和 SFT 在视觉和音频模态上的微调。欢迎试用!
|
| 45 |
+
|
| 46 |
+
* [2025.01.19] 📢 **注意!** 我们正在努力将 MiniCPM-o 2.6 的支持合并到 llama.cpp、ollama、vLLM 的官方仓库,但还未完成。请大家暂时先使用我们提供的 fork 来进行部署:[llama.cpp](https://github.com/OpenBMB/llama.cpp/blob/minicpm-omni/examples/llava/README-minicpmo2.6.md)、[ollama](https://github.com/OpenBMB/ollama/blob/minicpm-v2.6/examples/minicpm-v2.6/README.md)、[vllm](https://github.com/OpenBMB/MiniCPM-o?tab=readme-ov-file#efficient-inference-with-llamacpp-ollama-vllm)。 **合并完成前,使用官方仓库可能会导致不可预期的问题**。
|
| 47 |
+
|
| 48 |
+
* [2025.01.19] ⭐️⭐️⭐️ MiniCPM-o 在 GitHub Trending 上登顶, Hugging Face Trending 上也达到了第二!
|
| 49 |
+
|
| 50 |
+
* [2025.01.17] 我们更新了 MiniCPM-o 2.6 int4 量化版本的使用方式,解决了模型初始化的问题,欢迎点击[这里](https://huggingface.co/openbmb/MiniCPM-o-2_6-int4)试用!
|
| 51 |
+
|
| 52 |
+
* [2025.01.13] 🔥🔥🔥 我们开源了 MiniCPM-o 2.6,该模型视觉、语音和多模态流式能力达到了 GPT-4o-202405 级别,进一步优化了 MiniCPM-V 2.6 的众多亮点能力,还支持了很多有趣的新功能。欢迎试用!
|
| 53 |
+
|
| 54 |
+
* [2024.08.17] 🚀🚀🚀 llama.cpp [官方仓库](https://github.com/ggerganov/llama.cpp)正式支持 MiniCPM-V 2.6 啦!点击[这里](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf)查看各种大小的 GGUF 版本。
|
| 55 |
+
|
| 56 |
+
* [2024.08.06] 🔥🔥🔥 我们开源了 MiniCPM-V 2.6,该模型在单图、多图和视频理解方面取得了优于 GPT-4V 的表现。我们还进一步���升了 MiniCPM-Llama3-V 2.5 的多项亮点能力,并首次支持了 iPad 上的实时视频理解。欢迎试用!
|
| 57 |
+
|
| 58 |
+
* [2024.08.03] MiniCPM-Llama3-V 2.5 技术报告已发布!欢迎点击[这里](https://arxiv.org/abs/2408.01800)查看。
|
| 59 |
+
|
| 60 |
+
* [2024.05.23] 🔥🔥🔥 MiniCPM-V 在 GitHub Trending 和 Hugging Face Trending 上登顶!MiniCPM-Llama3-V 2.5 Demo 被 Hugging Face 的 Gradio 官方账户推荐,欢迎点击[这里](https://huggingface.co/spaces/openbmb/MiniCPM-Llama3-V-2_5)体验!
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
<br>
|
| 64 |
+
|
| 65 |
+
<details>
|
| 66 |
+
<summary>点击查看完整更新日志。</summary>
|
| 67 |
+
|
| 68 |
+
* [2024.08.15] MiniCPM-V 2.6 现在支持多图像 SFT。有关更多详细信息,请参阅[微调文档](https://github.com/OpenBMB/MiniCPM-V/tree/main/finetune)
|
| 69 |
+
* [2024.08.14] MiniCPM-V 2.6 现在可以通过 SWIFT 框架 [微调](https://github.com/modelscope/ms-swift/issues/1613) 了!
|
| 70 |
+
* [2024.08.10] 🚀🚀🚀 llama.cpp [官方仓库](https://github.com/ggerganov/llama.cpp)正式支持 MiniCPM-Llama3-V 2.5 啦!点击[这里](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf/tree/main)查看各种大小的 GGUF 版本。
|
| 71 |
+
* [2024.07.19] MiniCPM-Llama3-V 2.5 现已支持[vLLM](#vllm-部署-) !
|
| 72 |
+
* [2024.06.03] 现在,你可以利用多张低显存显卡(12G/16G)进行GPU串行推理。详情请参见该[文档](https://github.com/OpenBMB/MiniCPM-V/blob/main/docs/inference_on_multiple_gpus.md)配置。
|
| 73 |
+
* [2024.05.28] 💫 我们现在支持 MiniCPM-Llama3-V 2.5 的 LoRA 微调,更多内存使用统计信息可以在[这里](https://github.com/OpenBMB/MiniCPM-V/tree/main/finetune#model-fine-tuning-memory-usage-statistics)找到。
|
| 74 |
+
* [2024.05.28] 💥 MiniCPM-Llama3-V 2.5 现在在 llama.cpp 和 ollama 中完全支持其功能!**请拉取我们最新的 fork 来使用**:[llama.cpp](https://github.com/OpenBMB/llama.cpp/blob/minicpm-v2.5/examples/minicpmv/README.md) & [ollama](https://github.com/OpenBMB/ollama/tree/minicpm-v2.5/examples/minicpm-v2.5)。我们还发布了各种大小的 GGUF 版本,请点击[这里](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf/tree/main)查看。请注意,**目前官方仓库尚未支持 MiniCPM-Llama3-V 2.5**,我们也正积极推进将这些功能合并到 llama.cpp & ollama 官方仓库,敬请关注!
|
| 75 |
+
* [2024.05.25] MiniCPM-Llama3-V 2.5 [支持流式输出和自定义系统提示词](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5#usage)了,欢迎试用!
|
| 76 |
+
* [2024.05.24] 我们开源了 MiniCPM-Llama3-V 2.5 [gguf](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf),支持 [llama.cpp](#llamacpp-部署) 推理!实现端侧 6-8 tokens/s 的流畅解码,欢迎试用!
|
| 77 |
+
* [2024.05.23] 🔍 我们添加了Phi-3-vision-128k-instruct 与 MiniCPM-Llama3-V 2.5的全面对比,包括基准测试评估、多语言能力和推理效率 🌟📊🌍🚀。点击[这里](./docs/compare_with_phi-3_vision.md)查看详细信息。
|
| 78 |
+
* [2024.05.20] 我们开源了 MiniCPM-Llama3-V 2.5,增强了 OCR 能力,支持 30 多种语言,并首次在端侧实现了 GPT-4V 级的多模态能力!我们提供了[高效推理](#手机端部署)和[简易微调](./finetune/readme.md)的支持,欢迎试用!
|
| 79 |
+
* [2024.04.23] 我们增加了MiniCPM-V 2.0对 [vLLM](#vllm-部署-) 的支持,欢迎体验!
|
| 80 |
+
* [2024.04.18] 我们在 HuggingFace Space 新增了 MiniCPM-V 2.0 的 [demo](https://huggingface.co/spaces/openbmb/MiniCPM-V-2),欢迎体验!
|
| 81 |
+
* [2024.04.17] MiniCPM-V 2.0 现在支持用户部署本地 [WebUI Demo](#本地webui-demo部署) 了,欢迎试用!
|
| 82 |
+
* [2024.04.15] MiniCPM-V 2.0 现在可以通过 SWIFT 框架 [微调](https://github.com/modelscope/swift/blob/main/docs/source/Multi-Modal/minicpm-v-2最佳实践.md) 了,支持流式输出!
|
| 83 |
+
* [2024.04.12] 我们开源了 MiniCPM-V 2.0,该模型刷新了 OCRBench 开源模型最佳成绩,在场景文字识别能力上比肩 Gemini Pro,同时还在综合了 11 个主流多模态大模型评测基准的 <a href="https://rank.opencompass.org.cn/leaderboard-multimodal">OpenCompass</a> 榜单上超过了 Qwen-VL-Chat 10B、CogVLM-Chat 17B 和 Yi-VL 34B 等更大参数规模的模型!点击<a href="https://openbmb.vercel.app/minicpm-v-2">这里</a>查看 MiniCPM-V 2.0 技术博客。
|
| 84 |
+
* [2024.03.14] MiniCPM-V 现在支持 SWIFT 框架下的[微调](https://github.com/modelscope/swift/blob/main/docs/source/Multi-Modal/minicpm-v最佳实践.md)了,感谢 [Jintao](https://github.com/Jintao-Huang) 的贡献!
|
| 85 |
+
* [2024.03.01] MiniCPM-V 现在支持在 Mac 电脑上进行部署!
|
| 86 |
+
* [2024.02.01] 我们开源了 MiniCPM-V 和 OmniLMM-12B,分别可以支持高效的端侧部署和同规模领先的多模态能力!
|
| 87 |
+
</details>
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
## 目录 <!-- omit in toc -->
|
| 91 |
+
|
| 92 |
+
- [MiniCPM-o 2.6](#minicpm-o-26)
|
| 93 |
+
- [MiniCPM-V 2.6](#minicpm-v-26)
|
| 94 |
+
- [Chat with Our Demo on Gradio 🤗](#chat-with-our-demo-on-gradio-)
|
| 95 |
+
- [推理](#推理)
|
| 96 |
+
- [模型库](#模型库)
|
| 97 |
+
- [多轮对话](#多轮对话)
|
| 98 |
+
- [多图对话](#多图对话)
|
| 99 |
+
- [少样本上下文对话](#少样本上下文对话)
|
| 100 |
+
- [视频��话](#视频对话)
|
| 101 |
+
- [语音对话](#语音对话)
|
| 102 |
+
- [Mimick](#mimick)
|
| 103 |
+
- [可配置声音的语音对话](#可配置声音的语音对话)
|
| 104 |
+
- [更多语音任务](#更多语音任务)
|
| 105 |
+
- [多模态流式交互](#多模态流式交互)
|
| 106 |
+
- [多卡推理](#多卡推理)
|
| 107 |
+
- [Mac 推理](#mac-推理)
|
| 108 |
+
- [基于 llama.cpp、ollama、vLLM 的高效推理](#基于-llamacppollamavllm-的高效推理)
|
| 109 |
+
- [微调](#微调)
|
| 110 |
+
- [FAQs](#faqs)
|
| 111 |
+
- [模型局限性](#模型局限性)
|
| 112 |
+
|
| 113 |
+
## MiniCPM-o 2.6
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
MiniCPM-o 2.6 是 MiniCPM-o 系列的最新、性能最佳模型。该模型基于 SigLip-400M、Whisper-medium-300M、ChatTTS-200M 和 Qwen2.5-7B 构建,共 8B 参数,通过端到端方式训练和推理。相比 MiniCPM-V 2.6,该模型在性能上有了显著提升,并支持了实时语音对话和多模态流式交互的新功能。MiniCPM-o 2.6 的主要特性包括:
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
- 🔥 **领先的视觉能力。**
|
| 120 |
+
MiniCPM-o 2.6 在 OpenCompass 榜单上(综合 8 个主流多模态评测基准)平均得分 70.2,**以 8B 量级的大小在单图理解方面超越了 GPT-4o-202405、Gemini 1.5 Pro 和 Claude 3.5 Sonnet 等主流商用闭源多模态大模型**。此外,它的多图和视频理解表现也**优于 GPT-4V 和 Claude 3.5 Sonnet**,并展现出了优秀的上下文学习能力。
|
| 121 |
+
|
| 122 |
+
- 🎙 **出色的语音能力。**
|
| 123 |
+
MiniCPM-o 2.6 **支持可配置声音的中英双语实时对话**。MiniCPM-o 2.6 在语音理解任务(如 ASR 和 STT 等)**优于 GPT-4o-realtime**,并在语音对话的语义和声学评估中展现了**开源模型中最高的语音生成性能**。它还支持情绪/语速/风格控制、语音克隆、角色扮演等进阶能力。
|
| 124 |
+
|
| 125 |
+
- 🎬 **强大的多模态流式交互能力。**
|
| 126 |
+
作为一项新功能,MiniCPM-o 2.6 能够**接受连续的视频和音频流,并和用户进行实时语音交互**。在针对实时视频理解、全模态视音频理解、多模态上下文理解的综合评测基准 StreamingBench 中,MiniCPM-o 2.6 取得开源社区最佳水平,并**超过了 GPT-4o-202408 和 Claude 3.5 Sonnet**。
|
| 127 |
+
|
| 128 |
+
- 💪 **强大的 OCR 能力及其他功能。**
|
| 129 |
+
MiniCPM-o 2.6 进一步优化了 MiniCPM-V 2.6 的众多视觉理解能力,其可以处理任意长宽比的图像,像素数可达 180 万(如 1344x1344)。在 OCRBench 上取得**25B 以下最佳水平,超过 GPT-4o-202405 等商用闭源模型**。基于最新的 [RLHF-V](https://rlhf-v.github.io/)、[RLAIF-V](https://github.com/RLHF-V/RLAIF-V/) 和 [VisCPM](https://github.com/OpenBMB/VisCPM) 技术,其具备了**可信的多模态行为**,在 MMHal-Bench 上超过了 GPT-4o 和 Claude 3.5,并支持英语、中文、德语、法语、意大利语、韩语等**30多种语言**。
|
| 130 |
+
|
| 131 |
+
- 🚀 **卓越的效率。**
|
| 132 |
+
除了对个人用户友好的模型大小,MiniCPM-o 2.6 还表现出**最先进的视觉 token 密度**(即每个视觉 token 编码的像素数量)。它**仅需 640 个 token 即可处理 180 万像素图像,比大多数模型少 75%**。这一特性优化了模型的推理速度、首 token 延迟、内存占用和功耗。因此,MiniCPM-o 2.6 可以支持 iPad 等终端设备上的高效**多模态实时流式交互**。
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
- 💫 **易于使用。**
|
| 136 |
+
MiniCPM-o 2.6 可以通过多种方式轻松使用:(1) [llama.cpp](https://github.com/OpenBMB/llama.cpp/blob/minicpm-omni/examples/llava/README-minicpmo2.6.md) 支持在本地设备上进行高效的 CPU 推理,(2) [int4](https://huggingface.co/openbmb/MiniCPM-V-2_6-int4) 和 [GGUF](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf) 格式的量化模型,有 16 种尺寸,(3) [vLLM](#基于-llamacppollamavllm-的高效推理) 支持高吞吐量和内存高效的推理,(4) 通过[LLaMA-Factory](./docs/llamafactory_train_and_infer.md)框架针对新领域和任务进行微调,(5) 使用 [Gradio](#本地-webui-demo-) 快速设置本地 WebUI 演示,(6) 部署于服务器的在线 [demo](https://minicpm-omni-webdemo-us.modelbest.cn/)。
|
| 137 |
+
|
| 138 |
+
**模型架构。**
|
| 139 |
+
|
| 140 |
+
- **端到端全模态架构。** 通过**端到端**的方式连接和训练不同模态的编/解码模块以充分利用丰富的多模态知识。模型完全使用 CE 损失端到端训练。
|
| 141 |
+
- **全模态流式机制。** (1) 我们将不同模态的离线编/解码器改造为适用于**流式输入/输出**的在线模块。 (2) 我们针对大语言模型基座设计了**时分复用的全模态流式信息处理机制**,将平行的不同模态的信息流拆分重组为周期性时间片序列。
|
| 142 |
+
- **可配置的声音方案。** 我们设计了新的多模态系统提示,包含传统文本系统提示词,和**用于指定模型声音的语音系统提示词**。模型可在推理时灵活地通过文字或语音样例控制声音风格,并支持端到端声音克隆和音色创建等高级能力。
|
| 143 |
+
|
| 144 |
+
<div align="center">
|
| 145 |
+
<img src="./assets/minicpm-o-26-framework-v2.png" , width=80%>
|
| 146 |
+
</div>
|
| 147 |
+
|
| 148 |
+
<br>
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
### 性能评估 <!-- omit in toc -->
|
| 153 |
+
|
| 154 |
+
<div align="center">
|
| 155 |
+
<img src="./assets/radar.jpg", width=80%>
|
| 156 |
+
</div>
|
| 157 |
+
|
| 158 |
+
<details>
|
| 159 |
+
<summary>点击查看视觉理解能力详细评测结果。</summary>
|
| 160 |
+
|
| 161 |
+
**图像理解能力**
|
| 162 |
+
|
| 163 |
+
<div align="center">
|
| 164 |
+
<table style="margin: 0px auto;">
|
| 165 |
+
<thead>
|
| 166 |
+
<tr>
|
| 167 |
+
<th align="left">Model</th>
|
| 168 |
+
<th>Size</th>
|
| 169 |
+
<th>Token Density<sup>+</sup></th>
|
| 170 |
+
<th>OpenCompass</th>
|
| 171 |
+
<th>OCRBench</th>
|
| 172 |
+
<th>MathVista mini</th>
|
| 173 |
+
<th>ChartQA</th>
|
| 174 |
+
<th>MMVet</th>
|
| 175 |
+
<th>MMStar</th>
|
| 176 |
+
<th>MME</th>
|
| 177 |
+
<th>MMB1.1 test</th>
|
| 178 |
+
<th>AI2D</th>
|
| 179 |
+
<th>MMMU val</th>
|
| 180 |
+
<th>HallusionBench</th>
|
| 181 |
+
<th>TextVQA val</th>
|
| 182 |
+
<th>DocVQA test</th>
|
| 183 |
+
<th>MathVerse mini</th>
|
| 184 |
+
<th>MathVision</th>
|
| 185 |
+
<th>MMHal Score</th>
|
| 186 |
+
</tr>
|
| 187 |
+
</thead>
|
| 188 |
+
<tbody align="center">
|
| 189 |
+
<tr>
|
| 190 |
+
<td colspan="19" align="left"><strong>Proprietary</strong></td>
|
| 191 |
+
</tr>
|
| 192 |
+
<tr>
|
| 193 |
+
<td nowrap="nowrap" align="left">GPT-4o-20240513</td>
|
| 194 |
+
<td>-</td>
|
| 195 |
+
<td>1088</td>
|
| 196 |
+
<td><u>69.9</u></td>
|
| 197 |
+
<td>736</td>
|
| 198 |
+
<td>61.3</td>
|
| 199 |
+
<td>85.7</td>
|
| 200 |
+
<td><strong>69.1</strong></td>
|
| 201 |
+
<td>63.9</td>
|
| 202 |
+
<td>2328.7</td>
|
| 203 |
+
<td>82.2</td>
|
| 204 |
+
<td>84.6</td>
|
| 205 |
+
<td><strong>69.2</strong></td>
|
| 206 |
+
<td><strong>55.0</strong></td>
|
| 207 |
+
<td>-</td>
|
| 208 |
+
<td>92.8</td>
|
| 209 |
+
<td><strong>50.2</strong></td>
|
| 210 |
+
<td><strong>30.4</strong></td>
|
| 211 |
+
<td><u>3.6</u></td>
|
| 212 |
+
</tr>
|
| 213 |
+
<tr>
|
| 214 |
+
<td nowrap="nowrap" align="left">Claude3.5-Sonnet</td>
|
| 215 |
+
<td>-</td>
|
| 216 |
+
<td>750</td>
|
| 217 |
+
<td>67.9</td>
|
| 218 |
+
<td>788</td>
|
| 219 |
+
<td>61.6</td>
|
| 220 |
+
<td><strong>90.8</strong></td>
|
| 221 |
+
<td>66.0</td>
|
| 222 |
+
<td>62.2</td>
|
| 223 |
+
<td>1920.0</td>
|
| 224 |
+
<td>78.5</td>
|
| 225 |
+
<td>80.2</td>
|
| 226 |
+
<td><u>65.9</u></td>
|
| 227 |
+
<td>49.9</td>
|
| 228 |
+
<td>-</td>
|
| 229 |
+
<td><strong>95.2</strong></td>
|
| 230 |
+
<td>-</td>
|
| 231 |
+
<td>-</td>
|
| 232 |
+
<td>3.4</td>
|
| 233 |
+
</tr>
|
| 234 |
+
<tr>
|
| 235 |
+
<td nowrap="nowrap" align="left">Gemini 1.5 Pro</td>
|
| 236 |
+
<td>-</td>
|
| 237 |
+
<td>-</td>
|
| 238 |
+
<td>64.4</td>
|
| 239 |
+
<td>754</td>
|
| 240 |
+
<td>57.7</td>
|
| 241 |
+
<td>81.3</td>
|
| 242 |
+
<td>64.0</td>
|
| 243 |
+
<td>59.1</td>
|
| 244 |
+
<td>2110.6</td>
|
| 245 |
+
<td>73.9</td>
|
| 246 |
+
<td>79.1</td>
|
| 247 |
+
<td>60.6</td>
|
| 248 |
+
<td>45.6</td>
|
| 249 |
+
<td>73.5</td>
|
| 250 |
+
<td>86.5</td>
|
| 251 |
+
<td>-</td>
|
| 252 |
+
<td>19.2</td>
|
| 253 |
+
<td>-</td>
|
| 254 |
+
</tr>
|
| 255 |
+
<tr>
|
| 256 |
+
<td nowrap="nowrap" align="left">GPT-4o-mini-20240718</td>
|
| 257 |
+
<td>-</td>
|
| 258 |
+
<td>1088</td>
|
| 259 |
+
<td>64.1</td>
|
| 260 |
+
<td>785</td>
|
| 261 |
+
<td>52.4</td>
|
| 262 |
+
<td>-</td>
|
| 263 |
+
<td>66.9</td>
|
| 264 |
+
<td>54.8</td>
|
| 265 |
+
<td>2003.4</td>
|
| 266 |
+
<td>76.0</td>
|
| 267 |
+
<td>77.8</td>
|
| 268 |
+
<td>60.0</td>
|
| 269 |
+
<td>46.1</td>
|
| 270 |
+
<td>-</td>
|
| 271 |
+
<td>-</td>
|
| 272 |
+
<td>-</td>
|
| 273 |
+
<td>-</td>
|
| 274 |
+
<td>3.3</td>
|
| 275 |
+
</tr>
|
| 276 |
+
<tr>
|
| 277 |
+
<td colspan="19" align="left"><strong>Open Source</strong></td>
|
| 278 |
+
</tr>
|
| 279 |
+
<tr>
|
| 280 |
+
<td nowrap="nowrap" align="left">Cambrian-34B</td>
|
| 281 |
+
<td>34B</td>
|
| 282 |
+
<td><u>1820</u></td>
|
| 283 |
+
<td>58.3</td>
|
| 284 |
+
<td>591</td>
|
| 285 |
+
<td>50.3</td>
|
| 286 |
+
<td>75.6</td>
|
| 287 |
+
<td>53.2</td>
|
| 288 |
+
<td>54.2</td>
|
| 289 |
+
<td>2049.9</td>
|
| 290 |
+
<td>77.8</td>
|
| 291 |
+
<td>79.5</td>
|
| 292 |
+
<td>50.4</td>
|
| 293 |
+
<td>41.6</td>
|
| 294 |
+
<td>76.7</td>
|
| 295 |
+
<td>75.5</td>
|
| 296 |
+
<td>-</td>
|
| 297 |
+
<td>-</td>
|
| 298 |
+
<td>-</td>
|
| 299 |
+
</tr>
|
| 300 |
+
<tr>
|
| 301 |
+
<td nowrap="nowrap" align="left">GLM-4V-9B</td>
|
| 302 |
+
<td>13B</td>
|
| 303 |
+
<td>784</td>
|
| 304 |
+
<td>59.1</td>
|
| 305 |
+
<td>776</td>
|
| 306 |
+
<td>51.1</td>
|
| 307 |
+
<td>-</td>
|
| 308 |
+
<td>58.0</td>
|
| 309 |
+
<td>54.8</td>
|
| 310 |
+
<td>2018.8</td>
|
| 311 |
+
<td>67.9</td>
|
| 312 |
+
<td>71.2</td>
|
| 313 |
+
<td>46.9</td>
|
| 314 |
+
<td>45.0</td>
|
| 315 |
+
<td>-</td>
|
| 316 |
+
<td>-</td>
|
| 317 |
+
<td>-</td>
|
| 318 |
+
<td>-</td>
|
| 319 |
+
<td>-</td>
|
| 320 |
+
</tr>
|
| 321 |
+
<tr>
|
| 322 |
+
<td nowrap="nowrap" align="left">Pixtral-12B</td>
|
| 323 |
+
<td>12B</td>
|
| 324 |
+
<td>256</td>
|
| 325 |
+
<td>61.0</td>
|
| 326 |
+
<td>685</td>
|
| 327 |
+
<td>56.9</td>
|
| 328 |
+
<td>81.8</td>
|
| 329 |
+
<td>58.5</td>
|
| 330 |
+
<td>54.5</td>
|
| 331 |
+
<td>-</td>
|
| 332 |
+
<td>72.7</td>
|
| 333 |
+
<td>79.0</td>
|
| 334 |
+
<td>51.1</td>
|
| 335 |
+
<td>47.0</td>
|
| 336 |
+
<td>75.7</td>
|
| 337 |
+
<td>90.7</td>
|
| 338 |
+
<td>-</td>
|
| 339 |
+
<td>-</td>
|
| 340 |
+
<td>-</td>
|
| 341 |
+
</tr>
|
| 342 |
+
<tr>
|
| 343 |
+
<td nowrap="nowrap" align="left">DeepSeek-VL2-27B (4B)</td>
|
| 344 |
+
<td>27B</td>
|
| 345 |
+
<td>672</td>
|
| 346 |
+
<td>66.4</td>
|
| 347 |
+
<td>809</td>
|
| 348 |
+
<td>63.9</td>
|
| 349 |
+
<td>86.0</td>
|
| 350 |
+
<td>60.0</td>
|
| 351 |
+
<td>61.9</td>
|
| 352 |
+
<td>2253.0</td>
|
| 353 |
+
<td>81.2</td>
|
| 354 |
+
<td>83.8</td>
|
| 355 |
+
<td>54.0</td>
|
| 356 |
+
<td>45.3</td>
|
| 357 |
+
<td><u>84.2</u></td>
|
| 358 |
+
<td>93.3</td>
|
| 359 |
+
<td>-</td>
|
| 360 |
+
<td>-</td>
|
| 361 |
+
<td>3.0</td>
|
| 362 |
+
</tr>
|
| 363 |
+
<tr>
|
| 364 |
+
<td nowrap="nowrap" align="left">Qwen2-VL-7B</td>
|
| 365 |
+
<td>8B</td>
|
| 366 |
+
<td>784</td>
|
| 367 |
+
<td>67.1</td>
|
| 368 |
+
<td><u>866</u></td>
|
| 369 |
+
<td>58.2</td>
|
| 370 |
+
<td>83.0</td>
|
| 371 |
+
<td>62.0</td>
|
| 372 |
+
<td>60.7</td>
|
| 373 |
+
<td>2326.0</td>
|
| 374 |
+
<td>81.8</td>
|
| 375 |
+
<td>83.0</td>
|
| 376 |
+
<td>54.1</td>
|
| 377 |
+
<td>50.6</td>
|
| 378 |
+
<td><strong>84.3</strong></td>
|
| 379 |
+
<td><u>94.5</u></td>
|
| 380 |
+
<td>31.9</td>
|
| 381 |
+
<td>16.3</td>
|
| 382 |
+
<td>3.2</td>
|
| 383 |
+
</tr>
|
| 384 |
+
<tr>
|
| 385 |
+
<td nowrap="nowrap" align="left">LLaVA-OneVision-72B</td>
|
| 386 |
+
<td>72B</td>
|
| 387 |
+
<td>182</td>
|
| 388 |
+
<td>68.1</td>
|
| 389 |
+
<td>741</td>
|
| 390 |
+
<td>67.5</td>
|
| 391 |
+
<td>83.7</td>
|
| 392 |
+
<td>60.6</td>
|
| 393 |
+
<td><strong>65.8</strong></td>
|
| 394 |
+
<td>2261.0</td>
|
| 395 |
+
<td><strong>85.0</strong></td>
|
| 396 |
+
<td><u>85.6</u></td>
|
| 397 |
+
<td>56.8</td>
|
| 398 |
+
<td>49.0</td>
|
| 399 |
+
<td>80.5</td>
|
| 400 |
+
<td>91.3</td>
|
| 401 |
+
<td>39.1</td>
|
| 402 |
+
<td>-</td>
|
| 403 |
+
<td>3.5</td>
|
| 404 |
+
</tr>
|
| 405 |
+
<tr>
|
| 406 |
+
<td nowrap="nowrap" align="left">InternVL2.5-8B</td>
|
| 407 |
+
<td>8B</td>
|
| 408 |
+
<td>706</td>
|
| 409 |
+
<td>68.3</td>
|
| 410 |
+
<td>822</td>
|
| 411 |
+
<td><u>64.4</u></td>
|
| 412 |
+
<td>84.8</td>
|
| 413 |
+
<td>62.8</td>
|
| 414 |
+
<td>62.8</td>
|
| 415 |
+
<td>2344.0</td>
|
| 416 |
+
<td><u>83.6</u></td>
|
| 417 |
+
<td>84.5</td>
|
| 418 |
+
<td>56.0</td>
|
| 419 |
+
<td>50.1</td>
|
| 420 |
+
<td>79.1</td>
|
| 421 |
+
<td>93.0</td>
|
| 422 |
+
<td>39.5</td>
|
| 423 |
+
<td>19.7</td>
|
| 424 |
+
<td>3.4</td>
|
| 425 |
+
</tr>
|
| 426 |
+
<tr>
|
| 427 |
+
<td nowrap="nowrap" align="left">MiniCPM-V 2.6</td>
|
| 428 |
+
<td>8B</td>
|
| 429 |
+
<td><strong>2822</strong></td>
|
| 430 |
+
<td>65.2</td>
|
| 431 |
+
<td>852*</td>
|
| 432 |
+
<td>60.6</td>
|
| 433 |
+
<td>79.4</td>
|
| 434 |
+
<td>60.0</td>
|
| 435 |
+
<td>57.5</td>
|
| 436 |
+
<td><u>2348.4*</u></td>
|
| 437 |
+
<td>78.0</td>
|
| 438 |
+
<td>82.1</td>
|
| 439 |
+
<td>49.8*</td>
|
| 440 |
+
<td>48.1*</td>
|
| 441 |
+
<td>80.1</td>
|
| 442 |
+
<td>90.8</td>
|
| 443 |
+
<td>25.7</td>
|
| 444 |
+
<td>18.3</td>
|
| 445 |
+
<td>3.6</td>
|
| 446 |
+
</tr>
|
| 447 |
+
<tr>
|
| 448 |
+
<td nowrap="nowrap" align="left">MiniCPM-o 2.6</td>
|
| 449 |
+
<td>8B</td>
|
| 450 |
+
<td><strong>2822</strong></td>
|
| 451 |
+
<td><strong>70.2</strong></td>
|
| 452 |
+
<td><strong>897*</strong></td>
|
| 453 |
+
<td><strong>71.9*</strong></td>
|
| 454 |
+
<td><u>86.9*</u></td>
|
| 455 |
+
<td><u>67.5</u></td>
|
| 456 |
+
<td><u>64.0</u></td>
|
| 457 |
+
<td><strong>2372.0*</strong></td>
|
| 458 |
+
<td>80.5</td>
|
| 459 |
+
<td><strong>85.8</strong></td>
|
| 460 |
+
<td>50.4*</td>
|
| 461 |
+
<td><u>51.9</u></td>
|
| 462 |
+
<td>82.0</td>
|
| 463 |
+
<td>93.5</td>
|
| 464 |
+
<td><u>41.4*</u></td>
|
| 465 |
+
<td><u>23.1*</u></td>
|
| 466 |
+
<td><strong>3.8</strong></td>
|
| 467 |
+
</tr>
|
| 468 |
+
</tbody>
|
| 469 |
+
</table>
|
| 470 |
+
</div>
|
| 471 |
+
* 我们使用思维链提示词来评估这些基准,对于 MME 我们只在 Cognition 任务上使用了思维链。
|
| 472 |
+
+ Token Density:每个视觉 token 在最大分辨率下编码的像素数,即最大分辨率下的像素数 / 视觉 token 数。
|
| 473 |
+
|
| 474 |
+
注意:闭源模型的 Token Density 由 API 收费方式估算得到。
|
| 475 |
+
|
| 476 |
+
**多图和视频理解能力**
|
| 477 |
+
|
| 478 |
+
<div align="center">
|
| 479 |
+
|
| 480 |
+
<table style="margin: 0px auto;">
|
| 481 |
+
<thead>
|
| 482 |
+
<tr>
|
| 483 |
+
<th align="left">Model</th>
|
| 484 |
+
<th>Size</th>
|
| 485 |
+
<th>BLINK val</th>
|
| 486 |
+
<th>Mantis Eval</th>
|
| 487 |
+
<th>MIRB</th>
|
| 488 |
+
<th>Video-MME (wo / w subs)</th>
|
| 489 |
+
</tr>
|
| 490 |
+
</thead>
|
| 491 |
+
<tbody align="center">
|
| 492 |
+
<tr>
|
| 493 |
+
<td colspan="6" align="left"><strong>Proprietary</strong></td>
|
| 494 |
+
</tr>
|
| 495 |
+
<tr>
|
| 496 |
+
<td nowrap="nowrap" align="left">GPT-4o-20240513</td>
|
| 497 |
+
<td>-</td>
|
| 498 |
+
<td><strong>68</strong></td>
|
| 499 |
+
<td>-</td>
|
| 500 |
+
<td>-</td>
|
| 501 |
+
<td><strong>71.9/77.2<strong></td>
|
| 502 |
+
</tr>
|
| 503 |
+
<tr>
|
| 504 |
+
<td nowrap="nowrap" align="left">GPT4V</td>
|
| 505 |
+
<td>-</td>
|
| 506 |
+
<td>54.6</td>
|
| 507 |
+
<td>62.7</td>
|
| 508 |
+
<td>53.1</td>
|
| 509 |
+
<td>59.9/63.3</td>
|
| 510 |
+
</tr>
|
| 511 |
+
<tr>
|
| 512 |
+
<td colspan="6" align="left"><strong>Open-source</strong></td>
|
| 513 |
+
</tr>
|
| 514 |
+
<tr>
|
| 515 |
+
<td nowrap="nowrap" align="left">LLaVA-NeXT-Interleave 14B</td>
|
| 516 |
+
<td>14B</td>
|
| 517 |
+
<td>52.6</td>
|
| 518 |
+
<td>66.4</td>
|
| 519 |
+
<td>30.2</td>
|
| 520 |
+
<td>-</td>
|
| 521 |
+
</tr>
|
| 522 |
+
<tr>
|
| 523 |
+
<td nowrap="nowrap" align="left">LLaVA-OneVision-72B</td>
|
| 524 |
+
<td>72B</td>
|
| 525 |
+
<td>55.4</td>
|
| 526 |
+
<td><strong>77.6</strong></td>
|
| 527 |
+
<td>-</td>
|
| 528 |
+
<td><u>66.2/69.5</u></td>
|
| 529 |
+
</tr>
|
| 530 |
+
<tr>
|
| 531 |
+
<td nowrap="nowrap" align="left">MANTIS 8B</td>
|
| 532 |
+
<td>8B</td>
|
| 533 |
+
<td>49.1</td>
|
| 534 |
+
<td>59.5</td>
|
| 535 |
+
<td>34.8</td>
|
| 536 |
+
<td>-</td>
|
| 537 |
+
</tr>
|
| 538 |
+
<tr>
|
| 539 |
+
<td nowrap="nowrap" align="left">Qwen2-VL-7B</td>
|
| 540 |
+
<td>8B</td>
|
| 541 |
+
<td>53.2</td>
|
| 542 |
+
<td>69.6*</td>
|
| 543 |
+
<td><strong>67.6*</strong></td>
|
| 544 |
+
<td>63.3/69.0</td>
|
| 545 |
+
</tr>
|
| 546 |
+
<tr>
|
| 547 |
+
<td nowrap="nowrap" align="left">InternVL2.5-8B</td>
|
| 548 |
+
<td>8B</td>
|
| 549 |
+
<td>54.8</td>
|
| 550 |
+
<td>67.7</td>
|
| 551 |
+
<td>52.5</td>
|
| 552 |
+
<td>64.2/66.9</td>
|
| 553 |
+
</tr>
|
| 554 |
+
<tr>
|
| 555 |
+
<td nowrap="nowrap" align="left">MiniCPM-V 2.6</td>
|
| 556 |
+
<td>8B</td>
|
| 557 |
+
<td>53</td>
|
| 558 |
+
<td>69.1</td>
|
| 559 |
+
<td>53.8</td>
|
| 560 |
+
<td>60.9/63.6</td>
|
| 561 |
+
</tr>
|
| 562 |
+
<tr>
|
| 563 |
+
<td nowrap="nowrap" align="left">MiniCPM-o 2.6</td>
|
| 564 |
+
<td>8B</td>
|
| 565 |
+
<td><u>56.7</u></td>
|
| 566 |
+
<td><u>71.9</u></td>
|
| 567 |
+
<td><u>58.6</u></td>
|
| 568 |
+
<td>63.9/67.9</td>
|
| 569 |
+
</tr>
|
| 570 |
+
</tbody>
|
| 571 |
+
</table>
|
| 572 |
+
|
| 573 |
+
</div>
|
| 574 |
+
* 正式开源模型权重的评测结果。
|
| 575 |
+
|
| 576 |
+
</details>
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
<details>
|
| 580 |
+
<summary>点击查看语音理解和生成能力的详细评测结果。</summary>
|
| 581 |
+
|
| 582 |
+
**语音理解能力**
|
| 583 |
+
|
| 584 |
+
<div align="center">
|
| 585 |
+
<table style="margin: 0px auto;">
|
| 586 |
+
<thead>
|
| 587 |
+
<tr>
|
| 588 |
+
<th align="left">Task</th>
|
| 589 |
+
<th>Size</th>
|
| 590 |
+
<th colspan="3">ASR (zh)</th>
|
| 591 |
+
<th colspan="3">ASR (en)</th>
|
| 592 |
+
<th colspan="2">AST</th>
|
| 593 |
+
<th>Emotion</th>
|
| 594 |
+
</tr>
|
| 595 |
+
<tr>
|
| 596 |
+
<th align="left">Metric</th>
|
| 597 |
+
<td></td>
|
| 598 |
+
<th colspan="3">CER↓</th>
|
| 599 |
+
<th colspan="3">WER↓</th>
|
| 600 |
+
<th colspan="2">BLEU↑</th>
|
| 601 |
+
<th>ACC↑</th>
|
| 602 |
+
</tr>
|
| 603 |
+
<tr>
|
| 604 |
+
<th align="left">Dataset</th>
|
| 605 |
+
<td></td>
|
| 606 |
+
<th>AISHELL-1</th>
|
| 607 |
+
<th>Fleurs zh</th>
|
| 608 |
+
<th>WenetSpeech test-net</th>
|
| 609 |
+
<th>LibriSpeech test-clean</th>
|
| 610 |
+
<th>GigaSpeech</th>
|
| 611 |
+
<th>TED-LIUM</th>
|
| 612 |
+
<th>CoVoST en2zh</th>
|
| 613 |
+
<th>CoVoST zh2en</th>
|
| 614 |
+
<th>MELD emotion</th>
|
| 615 |
+
</tr>
|
| 616 |
+
</thead>
|
| 617 |
+
<tbody align="center">
|
| 618 |
+
<tr>
|
| 619 |
+
<td colspan="11" align="left"><strong>Proprietary</strong></td>
|
| 620 |
+
</tr>
|
| 621 |
+
<tr>
|
| 622 |
+
<td nowrap="nowrap" align="left">GPT-4o-Realtime</td>
|
| 623 |
+
<td>-</td>
|
| 624 |
+
<td>7.3*</td>
|
| 625 |
+
<td><u>5.4*</u></td>
|
| 626 |
+
<td>28.9*</td>
|
| 627 |
+
<td>2.6*</td>
|
| 628 |
+
<td>12.9*</td>
|
| 629 |
+
<td>4.8*</td>
|
| 630 |
+
<td>37.1*</td>
|
| 631 |
+
<td>15.7*</td>
|
| 632 |
+
<td>33.2*</td>
|
| 633 |
+
</tr>
|
| 634 |
+
<tr>
|
| 635 |
+
<td nowrap="nowrap" align="left">Gemini 1.5 Pro</td>
|
| 636 |
+
<td>-</td>
|
| 637 |
+
<td>4.5*</td>
|
| 638 |
+
<td>5.9*</td>
|
| 639 |
+
<td>14.3*</td>
|
| 640 |
+
<td>2.9*</td>
|
| 641 |
+
<td>10.6*</td>
|
| 642 |
+
<td><strong>3.0*</strong></td>
|
| 643 |
+
<td><u>47.3*</u></td>
|
| 644 |
+
<td>22.6*</td>
|
| 645 |
+
<td>48.4*</td>
|
| 646 |
+
</tr>
|
| 647 |
+
<tr>
|
| 648 |
+
<td colspan="11" align="left"><strong>Open-Source</strong></td>
|
| 649 |
+
</tr>
|
| 650 |
+
<tr>
|
| 651 |
+
<td nowrap="nowrap" align="left">Qwen2-Audio-7B</td>
|
| 652 |
+
<td>8B</td>
|
| 653 |
+
<td>-</td>
|
| 654 |
+
<td>7.5</td>
|
| 655 |
+
<td>-</td>
|
| 656 |
+
<td><strong>1.6</strong></td>
|
| 657 |
+
<td>-</td>
|
| 658 |
+
<td>-</td>
|
| 659 |
+
<td>45.2</td>
|
| 660 |
+
<td><u>24.4</u></td>
|
| 661 |
+
<td><strong>55.3</strong></td>
|
| 662 |
+
</tr>
|
| 663 |
+
<tr>
|
| 664 |
+
<td nowrap="nowrap" align="left">Qwen2-Audio-7B-Instruct</td>
|
| 665 |
+
<td>8B</td>
|
| 666 |
+
<td>2.6*</td>
|
| 667 |
+
<td>6.9*</td>
|
| 668 |
+
<td><u>10.3*</u></td>
|
| 669 |
+
<td>3.1*</td>
|
| 670 |
+
<td><u>9.7</u>*</td>
|
| 671 |
+
<td>5.9*</td>
|
| 672 |
+
<td>39.5*</td>
|
| 673 |
+
<td>22.9*</td>
|
| 674 |
+
<td>17.4*</td>
|
| 675 |
+
</tr>
|
| 676 |
+
<tr>
|
| 677 |
+
<td nowrap="nowrap" align="left">GLM-4-Voice-Base</td>
|
| 678 |
+
<td>9B</td>
|
| 679 |
+
<td><u>2.5</u></td>
|
| 680 |
+
<td>-</td>
|
| 681 |
+
<td>-</td>
|
| 682 |
+
<td>2.8</td>
|
| 683 |
+
<td>-</td>
|
| 684 |
+
<td>-</td>
|
| 685 |
+
<td>-</td>
|
| 686 |
+
<td>-</td>
|
| 687 |
+
</tr>
|
| 688 |
+
<tr>
|
| 689 |
+
<td nowrap="nowrap" align="left">MiniCPM-o 2.6</td>
|
| 690 |
+
<td>8B</td>
|
| 691 |
+
<td><strong>1.6</strong></td>
|
| 692 |
+
<td><strong>4.4</strong></td>
|
| 693 |
+
<td><strong>6.9</strong></td>
|
| 694 |
+
<td><u>1.7</u></td>
|
| 695 |
+
<td><strong>8.7</strong></td>
|
| 696 |
+
<td><strong>3.0</strong></td>
|
| 697 |
+
<td><strong>48.2</strong></td>
|
| 698 |
+
<td><strong>27.2</strong></td>
|
| 699 |
+
<td><u>52.4</u></td>
|
| 700 |
+
</tr>
|
| 701 |
+
</tbody>
|
| 702 |
+
</table>
|
| 703 |
+
</div>
|
| 704 |
+
* 正式开源模型权重的评测结果。<br><br>
|
| 705 |
+
|
| 706 |
+
**语音生成能力。**
|
| 707 |
+
|
| 708 |
+
<div align="center">
|
| 709 |
+
<table style="margin: 0px auto;">
|
| 710 |
+
<thead>
|
| 711 |
+
<tr>
|
| 712 |
+
<th align="left">Task</th>
|
| 713 |
+
<th>Size</th>
|
| 714 |
+
<th colspan="9">SpeechQA</th>
|
| 715 |
+
</tr>
|
| 716 |
+
<tr>
|
| 717 |
+
<th align="left">Metric</th>
|
| 718 |
+
<th></th>
|
| 719 |
+
<th colspan="3">ACC↑</th>
|
| 720 |
+
<th>G-Eval (10 point)↑</th>
|
| 721 |
+
<th>Semantic ELO score↑</th>
|
| 722 |
+
<th>Acoustic ELO score↑</th>
|
| 723 |
+
<th>Overall ELO score↑</th>
|
| 724 |
+
<th>UTMOS↑</th>
|
| 725 |
+
<th>ASR-WER↓</th>
|
| 726 |
+
</tr>
|
| 727 |
+
<tr>
|
| 728 |
+
<th align="left">Dataset</th>
|
| 729 |
+
<th></th>
|
| 730 |
+
<th>Speech Llama Q.</th>
|
| 731 |
+
<th>Speech Web Q.</th>
|
| 732 |
+
<th>Speech Trivia QA</th>
|
| 733 |
+
<th>Speech AlpacaEval</th>
|
| 734 |
+
<th colspan="5">AudioArena</th>
|
| 735 |
+
</tr>
|
| 736 |
+
</thead>
|
| 737 |
+
<tbody align="center">
|
| 738 |
+
<tr>
|
| 739 |
+
<td colspan="11" align="left"><strong>Proprietary</strong></td>
|
| 740 |
+
</tr>
|
| 741 |
+
<tr>
|
| 742 |
+
<td nowrap="nowrap" align="left">GPT-4o-Realtime</td>
|
| 743 |
+
<td></td>
|
| 744 |
+
<td><strong>71.7</strong></td>
|
| 745 |
+
<td><strong>51.6</strong></td>
|
| 746 |
+
<td><strong>69.7</strong></td>
|
| 747 |
+
<td><strong>7.4</strong></td>
|
| 748 |
+
<td><strong>1157</strong></td>
|
| 749 |
+
<td><strong>1203</strong></td>
|
| 750 |
+
<td><strong>1200</strong></td>
|
| 751 |
+
<td><strong>4.2</strong></td>
|
| 752 |
+
<td><strong>2.3</strong></td>
|
| 753 |
+
</tr>
|
| 754 |
+
<tr>
|
| 755 |
+
<td colspan="11" align="left"><strong>Open-Source</strong></td>
|
| 756 |
+
</tr>
|
| 757 |
+
<tr>
|
| 758 |
+
<td nowrap="nowrap" align="left">GLM-4-Voice</td>
|
| 759 |
+
<td>9B</td>
|
| 760 |
+
<td>50.0</td>
|
| 761 |
+
<td>32.0</td>
|
| 762 |
+
<td>36.4</td>
|
| 763 |
+
<td><u>5.1</u></td>
|
| 764 |
+
<td>999</td>
|
| 765 |
+
<td>1147</td>
|
| 766 |
+
<td>1035</td>
|
| 767 |
+
<td><u>4.1</u></td>
|
| 768 |
+
<td><u>11.7</u></td>
|
| 769 |
+
</tr>
|
| 770 |
+
<tr>
|
| 771 |
+
<td nowrap="nowrap" align="left">Llama-Omni</td>
|
| 772 |
+
<td>8B</td>
|
| 773 |
+
<td>45.3</td>
|
| 774 |
+
<td>22.9</td>
|
| 775 |
+
<td>10.7</td>
|
| 776 |
+
<td>3.9</td>
|
| 777 |
+
<td>960</td>
|
| 778 |
+
<td>878</td>
|
| 779 |
+
<td>897</td>
|
| 780 |
+
<td>3.2</td>
|
| 781 |
+
<td>24.3</td>
|
| 782 |
+
</tr>
|
| 783 |
+
<tr>
|
| 784 |
+
<td nowrap="nowrap" align="left">VITA-1.5</td>
|
| 785 |
+
<td>8B</td>
|
| 786 |
+
<td>46.7</td>
|
| 787 |
+
<td>28.1</td>
|
| 788 |
+
<td>23.3</td>
|
| 789 |
+
<td>2.0</td>
|
| 790 |
+
<td>-</td>
|
| 791 |
+
<td>-</td>
|
| 792 |
+
<td>-</td>
|
| 793 |
+
<td>-</td>
|
| 794 |
+
<td>-</td>
|
| 795 |
+
</tr>
|
| 796 |
+
<tr>
|
| 797 |
+
<td nowrap="nowrap" align="left">Moshi</td>
|
| 798 |
+
<td>7B</td>
|
| 799 |
+
<td>43.7</td>
|
| 800 |
+
<td>23.8</td>
|
| 801 |
+
<td>16.7</td>
|
| 802 |
+
<td>2.4</td>
|
| 803 |
+
<td>871</td>
|
| 804 |
+
<td>808</td>
|
| 805 |
+
<td>875</td>
|
| 806 |
+
<td>2.8</td>
|
| 807 |
+
<td>8.2</td>
|
| 808 |
+
</tr>
|
| 809 |
+
<tr>
|
| 810 |
+
<td nowrap="nowrap" align="left">Mini-Omni</td>
|
| 811 |
+
<td>1B</td>
|
| 812 |
+
<td>22.0</td>
|
| 813 |
+
<td>12.8</td>
|
| 814 |
+
<td>6.9</td>
|
| 815 |
+
<td>2.5</td>
|
| 816 |
+
<td>926</td>
|
| 817 |
+
<td>803</td>
|
| 818 |
+
<td>865</td>
|
| 819 |
+
<td>3.4</td>
|
| 820 |
+
<td>10.0</td>
|
| 821 |
+
</tr>
|
| 822 |
+
<tr>
|
| 823 |
+
<td nowrap="nowrap" align="left">MiniCPM-o 2.6</td>
|
| 824 |
+
<td>8B</td>
|
| 825 |
+
<td><u>61.0</u></td>
|
| 826 |
+
<td><u>40.0</u></td>
|
| 827 |
+
<td><u>40.2</u></td>
|
| 828 |
+
<td><u>5.1</u></td>
|
| 829 |
+
<td><u>1088</u></td>
|
| 830 |
+
<td><u>1163</u></td>
|
| 831 |
+
<td><u>1131</u></td>
|
| 832 |
+
<td><strong>4.2</strong></td>
|
| 833 |
+
<td>9.8</td>
|
| 834 |
+
</tr>
|
| 835 |
+
</tbody>
|
| 836 |
+
</table>
|
| 837 |
+
</div>
|
| 838 |
+
所有的结果都基于 <a href="https://github.com/OpenBMB/UltraEval-Audio" target="_blank">AudioEvals</a>。<br><br>
|
| 839 |
+
|
| 840 |
+
**端到端声音克隆能力。**
|
| 841 |
+
|
| 842 |
+
<div align="center">
|
| 843 |
+
<table style="margin: 0px auto;">
|
| 844 |
+
<thead>
|
| 845 |
+
<tr>
|
| 846 |
+
<th align="left">Task</th>
|
| 847 |
+
<th colspan="2">TTS</th>
|
| 848 |
+
</tr>
|
| 849 |
+
<tr>
|
| 850 |
+
<th align="left">Metric</th>
|
| 851 |
+
<th>SIMO↑</th>
|
| 852 |
+
<th>SIMO↑</th>
|
| 853 |
+
</tr>
|
| 854 |
+
<tr>
|
| 855 |
+
<th align="left">Dataset</th>
|
| 856 |
+
<th>Seed-TTS test-zh</th>
|
| 857 |
+
<th>Seed-TTS test-en</th>
|
| 858 |
+
</tr>
|
| 859 |
+
</thead>
|
| 860 |
+
<tbody align="center">
|
| 861 |
+
<tr>
|
| 862 |
+
<td nowrap="nowrap" align="left">F5-TTS</td>
|
| 863 |
+
<td><strong>76</strong></td>
|
| 864 |
+
<td><strong>67</strong></td>
|
| 865 |
+
</tr>
|
| 866 |
+
<tr>
|
| 867 |
+
<td nowrap="nowrap" align="left">CosyVoice</td>
|
| 868 |
+
<td><u>75</u></td>
|
| 869 |
+
<td><u>64</u></td>
|
| 870 |
+
</tr>
|
| 871 |
+
<tr>
|
| 872 |
+
<td nowrap="nowrap" align="left">FireRedTTS</td>
|
| 873 |
+
<td>63</td>
|
| 874 |
+
<td>46</td>
|
| 875 |
+
</tr>
|
| 876 |
+
<tr>
|
| 877 |
+
<td nowrap="nowrap" align="left">MiniCPM-o 2.6</td>
|
| 878 |
+
<td>57</td>
|
| 879 |
+
<td>47</td>
|
| 880 |
+
</tr>
|
| 881 |
+
</tbody>
|
| 882 |
+
</table>
|
| 883 |
+
</div>
|
| 884 |
+
|
| 885 |
+
</details>
|
| 886 |
+
|
| 887 |
+
<details>
|
| 888 |
+
<summary>点击查看多模态流式交互能力评测详细结果。</summary>
|
| 889 |
+
|
| 890 |
+
**多模态流式交互能力**: StreamingBench 分数
|
| 891 |
+
|
| 892 |
+
<table style="margin: 0px auto;">
|
| 893 |
+
<thead>
|
| 894 |
+
<tr>
|
| 895 |
+
<th align="left">Model</th>
|
| 896 |
+
<th>Size</th>
|
| 897 |
+
<th>Real-Time Video Understanding</th>
|
| 898 |
+
<th>Omni-Source Understanding</th>
|
| 899 |
+
<th>Contextual Understanding</th>
|
| 900 |
+
<th>Overall</th>
|
| 901 |
+
</tr>
|
| 902 |
+
</thead>
|
| 903 |
+
<tbody align="center">
|
| 904 |
+
<tr>
|
| 905 |
+
<td colspan="7" align="left"><strong>Proprietary</strong></td>
|
| 906 |
+
</tr>
|
| 907 |
+
<tr>
|
| 908 |
+
<td nowrap="nowrap" align="left">Gemini 1.5 Pro</td>
|
| 909 |
+
<td>-</td>
|
| 910 |
+
<td><u>77.4</u></td>
|
| 911 |
+
<td><strong>67.8</strong></td>
|
| 912 |
+
<td><strong>51.1</strong></td>
|
| 913 |
+
<td><strong>70.3</strong></td>
|
| 914 |
+
</tr>
|
| 915 |
+
<tr>
|
| 916 |
+
<td nowrap="nowrap" align="left">GPT-4o-202408</td>
|
| 917 |
+
<td>-</td>
|
| 918 |
+
<td>74.5</td>
|
| 919 |
+
<td>51.0</td>
|
| 920 |
+
<td><u>48.0</u></td>
|
| 921 |
+
<td>64.1</td>
|
| 922 |
+
</tr>
|
| 923 |
+
<tr>
|
| 924 |
+
<td nowrap="nowrap" align="left">Claude-3.5-Sonnet</td>
|
| 925 |
+
<td>-</td>
|
| 926 |
+
<td>74.0</td>
|
| 927 |
+
<td>41.4</td>
|
| 928 |
+
<td>37.8</td>
|
| 929 |
+
<td>59.7</td>
|
| 930 |
+
</tr>
|
| 931 |
+
<tr>
|
| 932 |
+
<td colspan="9" align="left"><strong>Open-source</strong></td>
|
| 933 |
+
</tr>
|
| 934 |
+
<tr>
|
| 935 |
+
<td nowrap="nowrap" align="left">VILA-1.5</td>
|
| 936 |
+
<td>8B</td>
|
| 937 |
+
<td>61.5</td>
|
| 938 |
+
<td>37.5</td>
|
| 939 |
+
<td>26.7</td>
|
| 940 |
+
<td>49.5</td>
|
| 941 |
+
</tr>
|
| 942 |
+
<tr>
|
| 943 |
+
<td nowrap="nowrap" align="left">LongVA</td>
|
| 944 |
+
<td>7B</td>
|
| 945 |
+
<td>63.1</td>
|
| 946 |
+
<td>35.9</td>
|
| 947 |
+
<td>30.2</td>
|
| 948 |
+
<td>50.7</td>
|
| 949 |
+
</tr>
|
| 950 |
+
<tr>
|
| 951 |
+
<td nowrap="nowrap" align="left">LLaVA-Next-Video-34B</td>
|
| 952 |
+
<td>34B</td>
|
| 953 |
+
<td>69.8</td>
|
| 954 |
+
<td>41.7</td>
|
| 955 |
+
<td>34.3</td>
|
| 956 |
+
<td>56.7</td>
|
| 957 |
+
</tr>
|
| 958 |
+
<tr>
|
| 959 |
+
<td nowrap="nowrap" align="left">Qwen2-VL-7B</td>
|
| 960 |
+
<td>8B</td>
|
| 961 |
+
<td>71.2</td>
|
| 962 |
+
<td>40.7</td>
|
| 963 |
+
<td>33.1</td>
|
| 964 |
+
<td>57.0</td>
|
| 965 |
+
</tr>
|
| 966 |
+
<tr>
|
| 967 |
+
<td nowrap="nowrap" align="left">InternVL2-8B</td>
|
| 968 |
+
<td>8B</td>
|
| 969 |
+
<td>70.1</td>
|
| 970 |
+
<td>42.7</td>
|
| 971 |
+
<td>34.1</td>
|
| 972 |
+
<td>57.0</td>
|
| 973 |
+
</tr>
|
| 974 |
+
<tr>
|
| 975 |
+
<td nowrap="nowrap" align="left">VITA-1.5</td>
|
| 976 |
+
<td>8B</td>
|
| 977 |
+
<td>70.9</td>
|
| 978 |
+
<td>40.8</td>
|
| 979 |
+
<td>35.8</td>
|
| 980 |
+
<td>57.4</td>
|
| 981 |
+
</tr>
|
| 982 |
+
<tr>
|
| 983 |
+
<td nowrap="nowrap" align="left">LLaVA-OneVision-7B</td>
|
| 984 |
+
<td>8B</td>
|
| 985 |
+
<td>74.3</td>
|
| 986 |
+
<td>40.8</td>
|
| 987 |
+
<td>31.0</td>
|
| 988 |
+
<td>58.4</td>
|
| 989 |
+
</tr>
|
| 990 |
+
<tr>
|
| 991 |
+
<td nowrap="nowrap" align="left">InternLM-XC2.5-OL-7B</td>
|
| 992 |
+
<td>8B</td>
|
| 993 |
+
<td>75.4</td>
|
| 994 |
+
<td>46.2</td>
|
| 995 |
+
<td>33.6</td>
|
| 996 |
+
<td>60.8</td>
|
| 997 |
+
</tr>
|
| 998 |
+
<tr>
|
| 999 |
+
<td nowrap="nowrap" align="left">MiniCPM-V 2.6</td>
|
| 1000 |
+
<td>8B</td>
|
| 1001 |
+
<td>72.4</td>
|
| 1002 |
+
<td>40.2</td>
|
| 1003 |
+
<td>33.4</td>
|
| 1004 |
+
<td>57.7</td>
|
| 1005 |
+
</tr>
|
| 1006 |
+
<tr>
|
| 1007 |
+
<td nowrap="nowrap" align="left">MiniCPM-o 2.6</td>
|
| 1008 |
+
<td>8B</td>
|
| 1009 |
+
<td><strong>79.9</strong></td>
|
| 1010 |
+
<td><u>53.4</u></td>
|
| 1011 |
+
<td>38.5</td>
|
| 1012 |
+
<td><u>66.0</u></td>
|
| 1013 |
+
</tr>
|
| 1014 |
+
</tbody>
|
| 1015 |
+
</table>
|
| 1016 |
+
|
| 1017 |
+
</details>
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
### 典型示例 <!-- omit in toc -->
|
| 1021 |
+
|
| 1022 |
+
以下为 MiniCPM-o 2.6 的 iPad Pro 实机演示和 web demo 演示样例:
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
<div align="center">
|
| 1026 |
+
<a href="https://www.youtube.com/watch?v=vRIMbxJzStY&t=2s"><img src="./assets/minicpmo2_6/2dot6_o_demo_video_img.png", width=70%></a>
|
| 1027 |
+
</div>
|
| 1028 |
+
<br>
|
| 1029 |
+
|
| 1030 |
+
|
| 1031 |
+
|
| 1032 |
+
<div style="display: flex; flex-direction: column; align-items: center;">
|
| 1033 |
+
<img src="assets/minicpmo2_6/minicpmo2_6_math_intersect.png" alt="math" style="margin-bottom: 5px;">
|
| 1034 |
+
<img src="assets/minicpmo2_6/minicpmo2_6_diagram_train_NN.png" alt="diagram" style="margin-bottom: 5px;">
|
| 1035 |
+
<img src="assets/minicpmo2_6/minicpmo2_6_multi-image_bike.png" alt="bike" style="margin-bottom: 5px;">
|
| 1036 |
+
</div>
|
| 1037 |
+
|
| 1038 |
+
|
| 1039 |
+
<details>
|
| 1040 |
+
<summary>Click to view more details of MiniCPM-V 2.6</summary>
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
## MiniCPM-V 2.6
|
| 1044 |
+
|
| 1045 |
+
**MiniCPM-V 2.6** 是 MiniCPM-V 系列中最新、性能最佳的模型。该模型基于 SigLip-400M 和 Qwen2-7B 构建,共 8B 参数。与 MiniCPM-Llama3-V 2.5 相比,MiniCPM-V 2.6 性能提升显著,并引入了多图和视频理解的新功能。MiniCPM-V 2.6 的主要特点包括:
|
| 1046 |
+
|
| 1047 |
+
|
| 1048 |
+
- 🔥 **领先的性能。**
|
| 1049 |
+
MiniCPM-V 2.6 在最新版本 OpenCompass 榜单上(综合 8 个主流多模态评测基准)平均得分 65.2,**以8B量级的大小在单图理解方面超越了 GPT-4o mini、GPT-4V、Gemini 1.5 Pro 和 Claude 3.5 Sonnet 等主流商用闭源多模态大模型**。
|
| 1050 |
+
|
| 1051 |
+
- 🖼️ **多图理解和上下文学习。**
|
| 1052 |
+
MiniCPM-V 2.6 还支持**多图对话和推理**。它在 Mantis-Eval、BLINK、Mathverse mv 和 Sciverse mv 等主流多图评测基准中取得了**最佳水平**,并展现出了优秀的上下文学习能力。
|
| 1053 |
+
|
| 1054 |
+
- 🎬 **视频理解。**
|
| 1055 |
+
MiniCPM-V 2.6 还可以**接受视频输入**,进行对话和提供涵盖时序和空间信息的详细视频描述。模型在 有/无字幕 评测场景下的 Video-MME 表现均超过了 **GPT-4V、Claude 3.5 Sonnet 和 LLaVA-NeXT-Video-34B**等商用闭源模型。
|
| 1056 |
+
|
| 1057 |
+
- 💪 **强大的 OCR 能力及其他功能。**
|
| 1058 |
+
MiniCPM-V 2.6 可以处理任意长宽比的图像,像素数可达 180 万(如 1344x1344)。在 OCRBench 上取得**最佳水平,超过 GPT-4o、GPT-4V 和 Gemini 1.5 Pro 等商用闭源模型**。基于最新的 [RLAIF-V](https://github.com/RLHF-V/RLAIF-V/) 和 [VisCPM](https://github.com/OpenBMB/VisCPM) 技术,其具备了**可信的多模态行为**,在 Object HalBench 上的幻觉率显著低于 GPT-4o 和 GPT-4V,并支持英语、中文、德语、法语、意大利语、韩语等**多种语言**。
|
| 1059 |
+
|
| 1060 |
+
- 🚀 **卓越的效率。**
|
| 1061 |
+
除了对个人用户友好的模型大小,MiniCPM-V 2.6 还表现出**最先进的视觉 token 密度**(即每个视觉 token 编码的像素数量)。它**仅需 640 个 token 即可处理 180 万像素图像,比大多数模型少 75%**。这一特性优化了模型的推理速度、首 token 延迟、内存占用和功耗。因此,MiniCPM-V 2.6 可以支持 iPad 等终端设备上的高效**实时视频理解**。
|
| 1062 |
+
|
| 1063 |
+
- 💫 **易于使用。**
|
| 1064 |
+
MiniCPM-V 2.6 可以通过多种方式轻松使用:(1) [llama.cpp](https://github.com/OpenBMB/llama.cpp/blob/minicpmv-main/examples/llava/README-minicpmv2.6.md) 和 [ollama](https://github.com/OpenBMB/ollama/blob/minicpm-v2.6/examples/minicpm-v2.6/README.md) 支持在本地设备上进行高效的 CPU 推理,(2) [int4](https://huggingface.co/openbmb/MiniCPM-V-2_6-int4) 和 [GGUF](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf) 格式的量化模型,有 16 种尺寸,(3) [vLLM](#vllm-部署-) 支持高吞吐量和内存高效的推理,(4) 针对新领域和任务进行微调,(5) 使用 [Gradio](#本地-webui-demo-) 快速设置本地 WebUI 演示,(6) 在线[demo](http://120.92.209.146:8887/)即可体验。
|
| 1065 |
+
|
| 1066 |
+
### 性能评估 <!-- omit in toc -->
|
| 1067 |
+
<div align="center">
|
| 1068 |
+
<img src=assets/radar_final.png width=90% />
|
| 1069 |
+
</div>
|
| 1070 |
+
|
| 1071 |
+
<details>
|
| 1072 |
+
<summary>点击查看 OpenCompass, MME, MMVet, OCRBench, MMMU, MathVista, MMB, AI2D, TextVQA, DocVQA, HallusionBench, Object HalBench 上的单图评测结果详情。 </summary>
|
| 1073 |
+
<div align="center">
|
| 1074 |
+
|
| 1075 |
+
<table style="margin: 0px auto;">
|
| 1076 |
+
<thead>
|
| 1077 |
+
<tr>
|
| 1078 |
+
<th align="left">Model</th>
|
| 1079 |
+
<th>Size</th>
|
| 1080 |
+
<th>Token Density<sup>+</sup></th>
|
| 1081 |
+
<th>OpenCompass</th>
|
| 1082 |
+
<th>MME</th>
|
| 1083 |
+
<th>MMVet</th>
|
| 1084 |
+
<th>OCRBench</th>
|
| 1085 |
+
<th>MMMU val</th>
|
| 1086 |
+
<th>MathVista mini</th>
|
| 1087 |
+
<th>MMB1.1 test</th>
|
| 1088 |
+
<th>AI2D</th>
|
| 1089 |
+
<th>TextVQA val</th>
|
| 1090 |
+
<th>DocVQA test</th>
|
| 1091 |
+
<th>HallusionBench</th>
|
| 1092 |
+
<th>Object HalBench</th>
|
| 1093 |
+
</tr>
|
| 1094 |
+
</thead>
|
| 1095 |
+
<tbody align="center">
|
| 1096 |
+
<tr>
|
| 1097 |
+
<td colspan="15" align="left"><strong>Proprietary</strong></td>
|
| 1098 |
+
</tr>
|
| 1099 |
+
<tr>
|
| 1100 |
+
<td nowrap="nowrap" align="left">GPT-4o</td>
|
| 1101 |
+
<td>-</td>
|
| 1102 |
+
<td>1088</td>
|
| 1103 |
+
<td>69.9</td>
|
| 1104 |
+
<td>2328.7</td>
|
| 1105 |
+
<td>69.1</td>
|
| 1106 |
+
<td>736</td>
|
| 1107 |
+
<td>69.2</td>
|
| 1108 |
+
<td>61.3</td>
|
| 1109 |
+
<td>82.2</td>
|
| 1110 |
+
<td>84.6</td>
|
| 1111 |
+
<td>-</td>
|
| 1112 |
+
<td>92.8</td>
|
| 1113 |
+
<td>55.0</td>
|
| 1114 |
+
<td>17.6</td>
|
| 1115 |
+
</tr>
|
| 1116 |
+
<tr>
|
| 1117 |
+
<td nowrap="nowrap" align="left">Claude 3.5 Sonnet</td>
|
| 1118 |
+
<td>-</td>
|
| 1119 |
+
<td>750</td>
|
| 1120 |
+
<td>67.9</td>
|
| 1121 |
+
<td>1920.0</td>
|
| 1122 |
+
<td>66.0</td>
|
| 1123 |
+
<td>788</td>
|
| 1124 |
+
<td>65.9</td>
|
| 1125 |
+
<td>61.6</td>
|
| 1126 |
+
<td>78.5</td>
|
| 1127 |
+
<td>80.2</td>
|
| 1128 |
+
<td>-</td>
|
| 1129 |
+
<td>95.2</td>
|
| 1130 |
+
<td>49.9</td>
|
| 1131 |
+
<td>13.8</td>
|
| 1132 |
+
</tr>
|
| 1133 |
+
<tr>
|
| 1134 |
+
<td nowrap="nowrap" align="left">Gemini 1.5 Pro</td>
|
| 1135 |
+
<td>-</td>
|
| 1136 |
+
<td>-</td>
|
| 1137 |
+
<td>64.4</td>
|
| 1138 |
+
<td>2110.6</td>
|
| 1139 |
+
<td>64.0</td>
|
| 1140 |
+
<td>754</td>
|
| 1141 |
+
<td>60.6</td>
|
| 1142 |
+
<td>57.7</td>
|
| 1143 |
+
<td>73.9</td>
|
| 1144 |
+
<td>79.1</td>
|
| 1145 |
+
<td>73.5</td>
|
| 1146 |
+
<td>86.5</td>
|
| 1147 |
+
<td>45.6</td>
|
| 1148 |
+
<td>-</td>
|
| 1149 |
+
</tr>
|
| 1150 |
+
<tr>
|
| 1151 |
+
<td nowrap="nowrap" align="left">GPT-4o mini</td>
|
| 1152 |
+
<td>-</td>
|
| 1153 |
+
<td>1088</td>
|
| 1154 |
+
<td>64.1</td>
|
| 1155 |
+
<td>2003.4</td>
|
| 1156 |
+
<td>66.9</td>
|
| 1157 |
+
<td>785</td>
|
| 1158 |
+
<td>60.0</td>
|
| 1159 |
+
<td>52.4</td>
|
| 1160 |
+
<td>76.0</td>
|
| 1161 |
+
<td>77.8</td>
|
| 1162 |
+
<td>-</td>
|
| 1163 |
+
<td>-</td>
|
| 1164 |
+
<td>46.1</td>
|
| 1165 |
+
<td>12.4</td>
|
| 1166 |
+
</tr>
|
| 1167 |
+
<tr>
|
| 1168 |
+
<td nowrap="nowrap" align="left">GPT-4V</td>
|
| 1169 |
+
<td>-</td>
|
| 1170 |
+
<td>1088</td>
|
| 1171 |
+
<td>63.5</td>
|
| 1172 |
+
<td>2070.2</td>
|
| 1173 |
+
<td>67.5</td>
|
| 1174 |
+
<td>656</td>
|
| 1175 |
+
<td>61.7</td>
|
| 1176 |
+
<td>54.7</td>
|
| 1177 |
+
<td>79.8</td>
|
| 1178 |
+
<td>78.6</td>
|
| 1179 |
+
<td>78.0</td>
|
| 1180 |
+
<td>87.2</td>
|
| 1181 |
+
<td>43.9</td>
|
| 1182 |
+
<td>14.2</td>
|
| 1183 |
+
</tr>
|
| 1184 |
+
<tr>
|
| 1185 |
+
<td nowrap="nowrap" align="left">Step-1V</td>
|
| 1186 |
+
<td>-</td>
|
| 1187 |
+
<td>-</td>
|
| 1188 |
+
<td>59.5</td>
|
| 1189 |
+
<td>2206.4</td>
|
| 1190 |
+
<td>63.3</td>
|
| 1191 |
+
<td>625</td>
|
| 1192 |
+
<td>49.9</td>
|
| 1193 |
+
<td>44.8</td>
|
| 1194 |
+
<td>78.0</td>
|
| 1195 |
+
<td>79.2</td>
|
| 1196 |
+
<td>71.6</td>
|
| 1197 |
+
<td>-</td>
|
| 1198 |
+
<td>48.4</td>
|
| 1199 |
+
<td>-</td>
|
| 1200 |
+
</tr>
|
| 1201 |
+
<tr>
|
| 1202 |
+
<td nowrap="nowrap" align="left">Qwen-VL-Max</td>
|
| 1203 |
+
<td>-</td>
|
| 1204 |
+
<td>784</td>
|
| 1205 |
+
<td>58.3</td>
|
| 1206 |
+
<td>2281.7</td>
|
| 1207 |
+
<td>61.8</td>
|
| 1208 |
+
<td>684</td>
|
| 1209 |
+
<td>52.0</td>
|
| 1210 |
+
<td>43.4</td>
|
| 1211 |
+
<td>74.6</td>
|
| 1212 |
+
<td>75.7</td>
|
| 1213 |
+
<td>79.5</td>
|
| 1214 |
+
<td>93.1</td>
|
| 1215 |
+
<td>41.2</td>
|
| 1216 |
+
<td>13.4</td>
|
| 1217 |
+
</tr>
|
| 1218 |
+
<tr>
|
| 1219 |
+
<td colspan="15" align="left"><strong>Open-source</strong></td>
|
| 1220 |
+
</tr>
|
| 1221 |
+
<tr>
|
| 1222 |
+
<td nowrap="nowrap" align="left">LLaVA-NeXT-Yi-34B</td>
|
| 1223 |
+
<td>34B</td>
|
| 1224 |
+
<td>157</td>
|
| 1225 |
+
<td>55.0</td>
|
| 1226 |
+
<td>2006.5</td>
|
| 1227 |
+
<td>50.7</td>
|
| 1228 |
+
<td>574</td>
|
| 1229 |
+
<td>48.8</td>
|
| 1230 |
+
<td>40.4</td>
|
| 1231 |
+
<td>77.8</td>
|
| 1232 |
+
<td>78.9</td>
|
| 1233 |
+
<td>69.3</td>
|
| 1234 |
+
<td>-</td>
|
| 1235 |
+
<td>34.8</td>
|
| 1236 |
+
<td>12.6</td>
|
| 1237 |
+
</tr>
|
| 1238 |
+
<tr>
|
| 1239 |
+
<td nowrap="nowrap" align="left">Mini-Gemini-HD-34B</td>
|
| 1240 |
+
<td>34B</td>
|
| 1241 |
+
<td>157</td>
|
| 1242 |
+
<td>-</td>
|
| 1243 |
+
<td>2141</td>
|
| 1244 |
+
<td>59.3</td>
|
| 1245 |
+
<td>518</td>
|
| 1246 |
+
<td>48.0</td>
|
| 1247 |
+
<td>43.3</td>
|
| 1248 |
+
<td>-</td>
|
| 1249 |
+
<td>80.5</td>
|
| 1250 |
+
<td>74.1</td>
|
| 1251 |
+
<td>78.9</td>
|
| 1252 |
+
<td>-</td>
|
| 1253 |
+
<td>-</td>
|
| 1254 |
+
</tr>
|
| 1255 |
+
<tr>
|
| 1256 |
+
<td nowrap="nowrap" align="left">Cambrian-34B</td>
|
| 1257 |
+
<td>34B</td>
|
| 1258 |
+
<td>1820</td>
|
| 1259 |
+
<td>58.3</td>
|
| 1260 |
+
<td>2049.9</td>
|
| 1261 |
+
<td>53.2</td>
|
| 1262 |
+
<td>591</td>
|
| 1263 |
+
<td>50.4</td>
|
| 1264 |
+
<td>50.3</td>
|
| 1265 |
+
<td>77.8</td>
|
| 1266 |
+
<td>79.5</td>
|
| 1267 |
+
<td>76.7</td>
|
| 1268 |
+
<td>75.5</td>
|
| 1269 |
+
<td>41.6</td>
|
| 1270 |
+
<td>14.7</td>
|
| 1271 |
+
</tr>
|
| 1272 |
+
<tr>
|
| 1273 |
+
<td nowrap="nowrap" align="left">GLM-4V-9B</td>
|
| 1274 |
+
<td>13B</td>
|
| 1275 |
+
<td>784</td>
|
| 1276 |
+
<td>59.1</td>
|
| 1277 |
+
<td>2018.8</td>
|
| 1278 |
+
<td>58.0</td>
|
| 1279 |
+
<td>776</td>
|
| 1280 |
+
<td>46.9</td>
|
| 1281 |
+
<td>51.1</td>
|
| 1282 |
+
<td>67.9</td>
|
| 1283 |
+
<td>71.2</td>
|
| 1284 |
+
<td>-</td>
|
| 1285 |
+
<td>-</td>
|
| 1286 |
+
<td>45.0</td>
|
| 1287 |
+
<td>-</td>
|
| 1288 |
+
</tr>
|
| 1289 |
+
<tr>
|
| 1290 |
+
<td nowrap="nowrap" align="left">InternVL2-8B</td>
|
| 1291 |
+
<td>8B</td>
|
| 1292 |
+
<td>706</td>
|
| 1293 |
+
<td>64.1</td>
|
| 1294 |
+
<td>2215.1</td>
|
| 1295 |
+
<td>54.3</td>
|
| 1296 |
+
<td>794</td>
|
| 1297 |
+
<td><strong>51.2</strong></td>
|
| 1298 |
+
<td>58.3</td>
|
| 1299 |
+
<td><strong>79.4</strong></td>
|
| 1300 |
+
<td><strong>83.6</strong></td>
|
| 1301 |
+
<td>77.4</td>
|
| 1302 |
+
<td><strong>91.6</strong></td>
|
| 1303 |
+
<td>45.0</td>
|
| 1304 |
+
<td>21.3</td>
|
| 1305 |
+
</tr>
|
| 1306 |
+
<tr>
|
| 1307 |
+
<td nowrap="nowrap" align="left">MiniCPM-Llama-V 2.5</td>
|
| 1308 |
+
<td>8B</td>
|
| 1309 |
+
<td>1882</td>
|
| 1310 |
+
<td>58.8</td>
|
| 1311 |
+
<td>2024.6</td>
|
| 1312 |
+
<td>52.8</td>
|
| 1313 |
+
<td>725</td>
|
| 1314 |
+
<td>45.8</td>
|
| 1315 |
+
<td>54.3</td>
|
| 1316 |
+
<td>72.0</td>
|
| 1317 |
+
<td>78.4</td>
|
| 1318 |
+
<td>76.6</td>
|
| 1319 |
+
<td>84.8</td>
|
| 1320 |
+
<td>42.4</td>
|
| 1321 |
+
<td>10.3</td>
|
| 1322 |
+
</tr>
|
| 1323 |
+
<tr>
|
| 1324 |
+
<td nowrap="nowrap" align="left">MiniCPM-V 2.6</td>
|
| 1325 |
+
<td>8B</td>
|
| 1326 |
+
<td><strong>2822</strong></td>
|
| 1327 |
+
<td><strong>65.2</strong></td>
|
| 1328 |
+
<td><strong>2348.4</strong>*</td>
|
| 1329 |
+
<td><strong>60.0</strong></td>
|
| 1330 |
+
<td><strong>852</strong>*</td>
|
| 1331 |
+
<td>49.8*</td>
|
| 1332 |
+
<td><strong>60.6</strong></td>
|
| 1333 |
+
<td>78.0</td>
|
| 1334 |
+
<td>82.1</td>
|
| 1335 |
+
<td><strong>80.1<strong></td>
|
| 1336 |
+
<td>90.8</td>
|
| 1337 |
+
<td><strong>48.1</strong>*</td>
|
| 1338 |
+
<td><strong>8.2</strong></td>
|
| 1339 |
+
</tr>
|
| 1340 |
+
</tbody>
|
| 1341 |
+
</table>
|
| 1342 |
+
|
| 1343 |
+
</div>
|
| 1344 |
+
* 我们使用思维链提示词来评估这些基准。
|
| 1345 |
+
|
| 1346 |
+
<sup>+</sup> Token Density:每个视觉 token 在最大分辨率下编码的像素数,即最大分辨率下的像素数 / 视觉 token 数。
|
| 1347 |
+
|
| 1348 |
+
注意:闭源模型的 Token Density 由 API 收费方式估算得到。
|
| 1349 |
+
</details>
|
| 1350 |
+
|
| 1351 |
+
|
| 1352 |
+
<details>
|
| 1353 |
+
<summary>点击查看 Mantis Eval, BLINK, Mathverse mv, Sciverse mv, MIRB 上的多图评测结果详情。</summary>
|
| 1354 |
+
<div align="center">
|
| 1355 |
+
|
| 1356 |
+
<table style="margin: 0px auto;">
|
| 1357 |
+
<thead>
|
| 1358 |
+
<tr>
|
| 1359 |
+
<th align="left">Model</th>
|
| 1360 |
+
<th>Size</th>
|
| 1361 |
+
<th>Mantis Eval</th>
|
| 1362 |
+
<th>BLINK val</th>
|
| 1363 |
+
<th>Mathverse mv</th>
|
| 1364 |
+
<th>Sciverse mv</th>
|
| 1365 |
+
<th>MIRB</th>
|
| 1366 |
+
</tr>
|
| 1367 |
+
</thead>
|
| 1368 |
+
<tbody align="center">
|
| 1369 |
+
<tr>
|
| 1370 |
+
<td colspan="7" align="left"><strong>Proprietary</strong></td>
|
| 1371 |
+
</tr>
|
| 1372 |
+
<tr>
|
| 1373 |
+
<td nowrap="nowrap" align="left">GPT-4V</td>
|
| 1374 |
+
<td>-</td>
|
| 1375 |
+
<td>62.7</td>
|
| 1376 |
+
<td>54.6</td>
|
| 1377 |
+
<td>60.3</td>
|
| 1378 |
+
<td>66.9</td>
|
| 1379 |
+
<td>53.1</td>
|
| 1380 |
+
</tr>
|
| 1381 |
+
<tr>
|
| 1382 |
+
<td nowrap="nowrap" align="left">LLaVA-NeXT-Interleave-14B</td>
|
| 1383 |
+
<td>14B</td>
|
| 1384 |
+
<td>66.4</td>
|
| 1385 |
+
<td>52.6</td>
|
| 1386 |
+
<td>32.7</td>
|
| 1387 |
+
<td>30.2</td>
|
| 1388 |
+
<td>-</td>
|
| 1389 |
+
</tr>
|
| 1390 |
+
<tr>
|
| 1391 |
+
<td colspan="7" align="left"><strong>Open-source</strong></td>
|
| 1392 |
+
</tr>
|
| 1393 |
+
<tr>
|
| 1394 |
+
<td nowrap="nowrap" align="left">Emu2-Chat</td>
|
| 1395 |
+
<td>37B</td>
|
| 1396 |
+
<td>37.8</td>
|
| 1397 |
+
<td>36.2</td>
|
| 1398 |
+
<td>-</td>
|
| 1399 |
+
<td>27.2</td>
|
| 1400 |
+
<td>-</td>
|
| 1401 |
+
</tr>
|
| 1402 |
+
<tr>
|
| 1403 |
+
<td nowrap="nowrap" align="left">CogVLM</td>
|
| 1404 |
+
<td>17B</td>
|
| 1405 |
+
<td>45.2</td>
|
| 1406 |
+
<td>41.1</td>
|
| 1407 |
+
<td>-</td>
|
| 1408 |
+
<td>-</td>
|
| 1409 |
+
<td>-</td>
|
| 1410 |
+
</tr>
|
| 1411 |
+
<tr>
|
| 1412 |
+
<td nowrap="nowrap" align="left">VPG-C</td>
|
| 1413 |
+
<td>7B</td>
|
| 1414 |
+
<td>52.4</td>
|
| 1415 |
+
<td>43.1</td>
|
| 1416 |
+
<td>24.3</td>
|
| 1417 |
+
<td>23.1</td>
|
| 1418 |
+
<td>-</td>
|
| 1419 |
+
</tr>
|
| 1420 |
+
<tr>
|
| 1421 |
+
<td nowrap="nowrap" align="left">VILA 8B</td>
|
| 1422 |
+
<td>8B</td>
|
| 1423 |
+
<td>51.2</td>
|
| 1424 |
+
<td>39.3</td>
|
| 1425 |
+
<td>-</td>
|
| 1426 |
+
<td>36.5</td>
|
| 1427 |
+
<td>-</td>
|
| 1428 |
+
</tr>
|
| 1429 |
+
<tr>
|
| 1430 |
+
<td nowrap="nowrap" align="left">InternLM-XComposer-2.5</td>
|
| 1431 |
+
<td>8B</td>
|
| 1432 |
+
<td>53.1*</td>
|
| 1433 |
+
<td>48.9</td>
|
| 1434 |
+
<td>32.1*</td>
|
| 1435 |
+
<td>-</td>
|
| 1436 |
+
<td>42.5</td>
|
| 1437 |
+
</tr>
|
| 1438 |
+
<tr>
|
| 1439 |
+
<td nowrap="nowrap" align="left">InternVL2-8B</td>
|
| 1440 |
+
<td>8B</td>
|
| 1441 |
+
<td>59.0*</td>
|
| 1442 |
+
<td>50.9</td>
|
| 1443 |
+
<td>30.5*</td>
|
| 1444 |
+
<td>34.4*</td>
|
| 1445 |
+
<td><strong>56.9*</strong></td>
|
| 1446 |
+
</tr>
|
| 1447 |
+
<tr>
|
| 1448 |
+
<td nowrap="nowrap" align="left">MiniCPM-V 2.6</td>
|
| 1449 |
+
<td>8B</td>
|
| 1450 |
+
<td><strong>69.1</strong></td>
|
| 1451 |
+
<td><strong>53.0</strong></td>
|
| 1452 |
+
<td><strong>84.9</strong></td>
|
| 1453 |
+
<td><strong>74.9</strong></td>
|
| 1454 |
+
<td>53.8</td>
|
| 1455 |
+
</tr>
|
| 1456 |
+
</tbody>
|
| 1457 |
+
</table>
|
| 1458 |
+
|
| 1459 |
+
|
| 1460 |
+
</div>
|
| 1461 |
+
* 正式开源模型权重的评测结果。
|
| 1462 |
+
</details>
|
| 1463 |
+
|
| 1464 |
+
<details>
|
| 1465 |
+
<summary>点击查看 Video-MME 和 Video-ChatGPT 上的视频评测结果详情。</summary>
|
| 1466 |
+
<div align="center">
|
| 1467 |
+
|
| 1468 |
+
<table style="margin: 0px auto;">
|
| 1469 |
+
<thead>
|
| 1470 |
+
<tr>
|
| 1471 |
+
<th align="left">Model</th>
|
| 1472 |
+
<th>Size</th>
|
| 1473 |
+
<th colspan="2">Video-MME</th>
|
| 1474 |
+
<th colspan="5">Video-ChatGPT</th>
|
| 1475 |
+
</tr>
|
| 1476 |
+
<tr>
|
| 1477 |
+
<th align="left"></th>
|
| 1478 |
+
<th></th>
|
| 1479 |
+
<th>w/o subs</th>
|
| 1480 |
+
<th>w subs</th>
|
| 1481 |
+
<th>Correctness</th>
|
| 1482 |
+
<th>Detail</th>
|
| 1483 |
+
<th>Context</th>
|
| 1484 |
+
<th>Temporal</th>
|
| 1485 |
+
<th>Consistency</th>
|
| 1486 |
+
</tr>
|
| 1487 |
+
</thead>
|
| 1488 |
+
<tbody align="center">
|
| 1489 |
+
<tr>
|
| 1490 |
+
<td colspan="9" align="left"><strong>Proprietary</strong></td>
|
| 1491 |
+
</tr>
|
| 1492 |
+
<tr>
|
| 1493 |
+
<td nowrap="nowrap" align="left">Claude 3.5 Sonnet</td>
|
| 1494 |
+
<td>-</td>
|
| 1495 |
+
<td>60.0</td>
|
| 1496 |
+
<td>62.9</td>
|
| 1497 |
+
<td>-</td>
|
| 1498 |
+
<td>-</td>
|
| 1499 |
+
<td>-</td>
|
| 1500 |
+
<td>-</td>
|
| 1501 |
+
<td>-</td>
|
| 1502 |
+
</tr>
|
| 1503 |
+
<tr>
|
| 1504 |
+
<td nowrap="nowrap" align="left">GPT-4V</td>
|
| 1505 |
+
<td>-</td>
|
| 1506 |
+
<td>59.9</td>
|
| 1507 |
+
<td>63.3</td>
|
| 1508 |
+
<td>-</td>
|
| 1509 |
+
<td>-</td>
|
| 1510 |
+
<td>-</td>
|
| 1511 |
+
<td>-</td>
|
| 1512 |
+
<td>-</td>
|
| 1513 |
+
</tr>
|
| 1514 |
+
<tr>
|
| 1515 |
+
<td colspan="9" align="left"><strong>Open-source</strong></td>
|
| 1516 |
+
</tr>
|
| 1517 |
+
<tr>
|
| 1518 |
+
<td nowrap="nowrap" align="left">LLaVA-NeXT-7B</td>
|
| 1519 |
+
<td>7B</td>
|
| 1520 |
+
<td>-</td>
|
| 1521 |
+
<td>-</td>
|
| 1522 |
+
<td>3.39</td>
|
| 1523 |
+
<td>3.29</td>
|
| 1524 |
+
<td>3.92</td>
|
| 1525 |
+
<td>2.60</td>
|
| 1526 |
+
<td>3.12</td>
|
| 1527 |
+
</tr>
|
| 1528 |
+
<tr>
|
| 1529 |
+
<td nowrap="nowrap" align="left">LLaVA-NeXT-34B</td>
|
| 1530 |
+
<td>34B</td>
|
| 1531 |
+
<td>-</td>
|
| 1532 |
+
<td>-</td>
|
| 1533 |
+
<td>3.29</td>
|
| 1534 |
+
<td>3.23</td>
|
| 1535 |
+
<td>3.83</td>
|
| 1536 |
+
<td>2.51</td>
|
| 1537 |
+
<td>3.47</td>
|
| 1538 |
+
</tr>
|
| 1539 |
+
<tr>
|
| 1540 |
+
<td nowrap="nowrap" align="left">CogVLM2-Video</td>
|
| 1541 |
+
<td>12B</td>
|
| 1542 |
+
<td>-</td>
|
| 1543 |
+
<td>-</td>
|
| 1544 |
+
<td>3.49</td>
|
| 1545 |
+
<td><strong>3.46</strong></td>
|
| 1546 |
+
<td>3.23</td>
|
| 1547 |
+
<td><strong>2.98</strong></td>
|
| 1548 |
+
<td><strong>3.64</strong></td>
|
| 1549 |
+
</tr>
|
| 1550 |
+
<tr>
|
| 1551 |
+
<td nowrap="nowrap" align="left">LongVA</td>
|
| 1552 |
+
<td>7B</td>
|
| 1553 |
+
<td>52.4</td>
|
| 1554 |
+
<td>54.3</td>
|
| 1555 |
+
<td>3.05</td>
|
| 1556 |
+
<td>3.09</td>
|
| 1557 |
+
<td>3.77</td>
|
| 1558 |
+
<td>2.44</td>
|
| 1559 |
+
<td><strong>3.64</strong></td>
|
| 1560 |
+
</tr>
|
| 1561 |
+
<tr>
|
| 1562 |
+
<td nowrap="nowrap" align="left">InternVL2-8B</td>
|
| 1563 |
+
<td>8B</td>
|
| 1564 |
+
<td>54.0</td>
|
| 1565 |
+
<td>56.9</td>
|
| 1566 |
+
<td>-</td>
|
| 1567 |
+
<td>-</td>
|
| 1568 |
+
<td>-</td>
|
| 1569 |
+
<td>-</td>
|
| 1570 |
+
<td>-</td>
|
| 1571 |
+
</tr>
|
| 1572 |
+
<tr>
|
| 1573 |
+
<td nowrap="nowrap" align="left">InternLM-XComposer-2.5</td>
|
| 1574 |
+
<td>8B</td>
|
| 1575 |
+
<td>55.8</td>
|
| 1576 |
+
<td>-</td>
|
| 1577 |
+
<td>-</td>
|
| 1578 |
+
<td>-</td>
|
| 1579 |
+
<td>-</td>
|
| 1580 |
+
<td>-</td>
|
| 1581 |
+
<td>-</td>
|
| 1582 |
+
</tr>
|
| 1583 |
+
<tr>
|
| 1584 |
+
<td nowrap="nowrap" align="left">LLaVA-NeXT-Video</td>
|
| 1585 |
+
<td>32B</td>
|
| 1586 |
+
<td>60.2</td>
|
| 1587 |
+
<td>63.0</td>
|
| 1588 |
+
<td>3.48</td>
|
| 1589 |
+
<td>3.37</td>
|
| 1590 |
+
<td><strong>3.95</strong></td>
|
| 1591 |
+
<td>2.64</td>
|
| 1592 |
+
<td>3.28</td>
|
| 1593 |
+
</tr>
|
| 1594 |
+
<tr>
|
| 1595 |
+
<td nowrap="nowrap" align="left">MiniCPM-V 2.6</td>
|
| 1596 |
+
<td>8B</td>
|
| 1597 |
+
<td><strong>60.9</strong></td>
|
| 1598 |
+
<td><strong>63.6</strong></td>
|
| 1599 |
+
<td><strong>3.59</strong></td>
|
| 1600 |
+
<td>3.28</td>
|
| 1601 |
+
<td>3.93</td>
|
| 1602 |
+
<td>2.73</td>
|
| 1603 |
+
<td>3.62</td>
|
| 1604 |
+
</tr>
|
| 1605 |
+
</tbody>
|
| 1606 |
+
</table>
|
| 1607 |
+
</div>
|
| 1608 |
+
</details>
|
| 1609 |
+
|
| 1610 |
+
|
| 1611 |
+
<details>
|
| 1612 |
+
<summary>点击查看 TextVQA, VizWiz, VQAv2, OK-VQA上的少样本评测结果详情。</summary>
|
| 1613 |
+
<div align="center">
|
| 1614 |
+
|
| 1615 |
+
<table style="margin: 0px auto;">
|
| 1616 |
+
<thead>
|
| 1617 |
+
<tr>
|
| 1618 |
+
<th align="left">Model</th>
|
| 1619 |
+
<th>Size</th>
|
| 1620 |
+
<th>Shot</th>
|
| 1621 |
+
<th>TextVQA val</th>
|
| 1622 |
+
<th>VizWiz test-dev</th>
|
| 1623 |
+
<th>VQAv2 test-dev</th>
|
| 1624 |
+
<th>OK-VQA val</th>
|
| 1625 |
+
</tr>
|
| 1626 |
+
</thead>
|
| 1627 |
+
<tbody align="center">
|
| 1628 |
+
<tr>
|
| 1629 |
+
<td align="left" nowrap="nowrap" rowspan="3">Flamingo</td>
|
| 1630 |
+
<td rowspan="3">80B</td>
|
| 1631 |
+
<td>0*</td>
|
| 1632 |
+
<td>35.0</td>
|
| 1633 |
+
<td>31.6</td>
|
| 1634 |
+
<td>56.3</td>
|
| 1635 |
+
<td>40.6</td>
|
| 1636 |
+
</tr>
|
| 1637 |
+
<tr>
|
| 1638 |
+
<td>4</td>
|
| 1639 |
+
<td>36.5</td>
|
| 1640 |
+
<td>39.6</td>
|
| 1641 |
+
<td>63.1</td>
|
| 1642 |
+
<td><strong>57.4</strong></td>
|
| 1643 |
+
</tr>
|
| 1644 |
+
<tr>
|
| 1645 |
+
<td>8</td>
|
| 1646 |
+
<td>37.3</td>
|
| 1647 |
+
<td>44.8</td>
|
| 1648 |
+
<td>65.6</td>
|
| 1649 |
+
<td>57.5</td>
|
| 1650 |
+
</tr>
|
| 1651 |
+
<tr>
|
| 1652 |
+
<td align="left" nowrap="nowrap" rowspan="3">IDEFICS</td>
|
| 1653 |
+
<td rowspan="3">80B</td>
|
| 1654 |
+
<td>0*</td>
|
| 1655 |
+
<td>30.9</td>
|
| 1656 |
+
<td>36.0</td>
|
| 1657 |
+
<td>60.0</td>
|
| 1658 |
+
<td>45.2</td>
|
| 1659 |
+
</tr>
|
| 1660 |
+
<tr>
|
| 1661 |
+
<td>4</td>
|
| 1662 |
+
<td>34.3</td>
|
| 1663 |
+
<td>40.4</td>
|
| 1664 |
+
<td>63.6</td>
|
| 1665 |
+
<td>52.4</td>
|
| 1666 |
+
</tr>
|
| 1667 |
+
<tr>
|
| 1668 |
+
<td>8</td>
|
| 1669 |
+
<td>35.7</td>
|
| 1670 |
+
<td>46.1</td>
|
| 1671 |
+
<td>64.8</td>
|
| 1672 |
+
<td>55.1</td>
|
| 1673 |
+
</tr>
|
| 1674 |
+
<tr>
|
| 1675 |
+
<td align="left" nowrap="nowrap" rowspan="3">OmniCorpus</td>
|
| 1676 |
+
<td rowspan="3">7B</td>
|
| 1677 |
+
<td>0*</td>
|
| 1678 |
+
<td>43.0</td>
|
| 1679 |
+
<td>49.8</td>
|
| 1680 |
+
<td>63.2</td>
|
| 1681 |
+
<td>45.5</td>
|
| 1682 |
+
</tr>
|
| 1683 |
+
<tr>
|
| 1684 |
+
<td>4</td>
|
| 1685 |
+
<td>45.4</td>
|
| 1686 |
+
<td>51.3</td>
|
| 1687 |
+
<td>64.5</td>
|
| 1688 |
+
<td>46.5</td>
|
| 1689 |
+
</tr>
|
| 1690 |
+
<tr>
|
| 1691 |
+
<td>8</td>
|
| 1692 |
+
<td>45.6</td>
|
| 1693 |
+
<td>52.2</td>
|
| 1694 |
+
<td>64.7</td>
|
| 1695 |
+
<td>46.6</td>
|
| 1696 |
+
</tr>
|
| 1697 |
+
<tr>
|
| 1698 |
+
<td align="left" nowrap="nowrap" rowspan="3">Emu2</td>
|
| 1699 |
+
<td rowspan="3">37B</td>
|
| 1700 |
+
<td>0</td>
|
| 1701 |
+
<td>26.4</td>
|
| 1702 |
+
<td>40.4</td>
|
| 1703 |
+
<td>33.5</td>
|
| 1704 |
+
<td>26.7</td>
|
| 1705 |
+
</tr>
|
| 1706 |
+
<tr>
|
| 1707 |
+
<td>4</td>
|
| 1708 |
+
<td>48.2</td>
|
| 1709 |
+
<td>54.6</td>
|
| 1710 |
+
<td>67.0</td>
|
| 1711 |
+
<td>53.2</td>
|
| 1712 |
+
</tr>
|
| 1713 |
+
<tr>
|
| 1714 |
+
<td>8</td>
|
| 1715 |
+
<td>49.3</td>
|
| 1716 |
+
<td>54.7</td>
|
| 1717 |
+
<td>67.8</td>
|
| 1718 |
+
<td>54.1</td>
|
| 1719 |
+
</tr>
|
| 1720 |
+
<tr>
|
| 1721 |
+
<td align="left" nowrap="nowrap" rowspan="2">MM1</td>
|
| 1722 |
+
<td rowspan="2">30B</td>
|
| 1723 |
+
<td>0</td>
|
| 1724 |
+
<td>26.2</td>
|
| 1725 |
+
<td>40.4</td>
|
| 1726 |
+
<td>48.9</td>
|
| 1727 |
+
<td>26.7</td>
|
| 1728 |
+
</tr>
|
| 1729 |
+
<tr>
|
| 1730 |
+
<td>8</td>
|
| 1731 |
+
<td>49.3</td>
|
| 1732 |
+
<td>54.7</td>
|
| 1733 |
+
<td><strong>70.9</strong></td>
|
| 1734 |
+
<td>54.1</td>
|
| 1735 |
+
</tr>
|
| 1736 |
+
<tr>
|
| 1737 |
+
<td align="left" nowrap="nowrap" rowspan="3">MiniCPM-V 2.6<sup>+</sup></td>
|
| 1738 |
+
<td rowspan="3">8B</td>
|
| 1739 |
+
<td>0</td>
|
| 1740 |
+
<td>43.9</td>
|
| 1741 |
+
<td>33.8</td>
|
| 1742 |
+
<td>45.4</td>
|
| 1743 |
+
<td>23.9</td>
|
| 1744 |
+
</tr>
|
| 1745 |
+
<tr>
|
| 1746 |
+
<td>4</td>
|
| 1747 |
+
<td>63.6</td>
|
| 1748 |
+
<td>60.5</td>
|
| 1749 |
+
<td>65.5</td>
|
| 1750 |
+
<td>50.1</td>
|
| 1751 |
+
</tr>
|
| 1752 |
+
<tr>
|
| 1753 |
+
<td>8</td>
|
| 1754 |
+
<td><strong>64.6</strong></td>
|
| 1755 |
+
<td><strong>63.4</strong></td>
|
| 1756 |
+
<td>68.2</td>
|
| 1757 |
+
<td>51.4</td>
|
| 1758 |
+
</tr>
|
| 1759 |
+
</tbody>
|
| 1760 |
+
</table>
|
| 1761 |
+
|
| 1762 |
+
|
| 1763 |
+
</div>
|
| 1764 |
+
* 使用 Flamingo 方式 zero image shot 和 two additional text shots 评估零样本性能。
|
| 1765 |
+
|
| 1766 |
+
<sup>+</sup> 我们在没有进行监督微调 (SFT) 的情况下评估预训练的模型权重 (ckpt)。
|
| 1767 |
+
</details>
|
| 1768 |
+
|
| 1769 |
+
### 典型示例 <!-- omit in toc -->
|
| 1770 |
+
|
| 1771 |
+
<div style="display: flex; flex-direction: column; align-items: center;">
|
| 1772 |
+
<img src="assets/minicpmv2_6/multi_img-bike.png" alt="Bike" style="margin-bottom: 5px;">
|
| 1773 |
+
<img src="assets/minicpmv2_6/multi_img-menu.png" alt="Menu" style="margin-bottom: 5px;">
|
| 1774 |
+
<img src="assets/minicpmv2_6/multi_img-code.png" alt="Code" style="margin-bottom: 5px;">
|
| 1775 |
+
<img src="assets/minicpmv2_6/ICL-Mem.png" alt="Mem" style="margin-bottom: 5px;">
|
| 1776 |
+
<img src="assets/minicpmv2_6/multiling-medal.png" alt="medal" style="margin-bottom: 10px;">
|
| 1777 |
+
</div>
|
| 1778 |
+
<details>
|
| 1779 |
+
<summary>点击查看更多示例。</summary>
|
| 1780 |
+
<div style="display: flex; flex-direction: column; align-items: center;">
|
| 1781 |
+
<img src="assets/minicpmv2_6/ICL-elec.png" alt="elec" style="margin-bottom: 5px;">
|
| 1782 |
+
<img src="assets/minicpmv2_6/multiling-olympic.png" alt="Menu" style="margin-bottom: 10px;">
|
| 1783 |
+
</div>
|
| 1784 |
+
</details>
|
| 1785 |
+
|
| 1786 |
+
我们将 MiniCPM-V 2.6 部署在iPad Pro上,并录制了以下演示视频。
|
| 1787 |
+
|
| 1788 |
+
<table align="center">
|
| 1789 |
+
<p align="center">
|
| 1790 |
+
<img src="assets/gif_cases/ai.gif" width=32%/>
|
| 1791 |
+
|
| 1792 |
+
<img src="assets/gif_cases/beer.gif" width=32%/>
|
| 1793 |
+
</p>
|
| 1794 |
+
</table>
|
| 1795 |
+
|
| 1796 |
+
<table align="center">
|
| 1797 |
+
<p align="center">
|
| 1798 |
+
<video src="https://github.com/user-attachments/assets/21f4b818-ede1-4822-920e-91281725c830" width="360" /> </video>
|
| 1799 |
+
<!-- <video src="https://github.com/user-attachments/assets/c835f757-206b-4d9c-8e36-70d67b453628" width="360" /> </video> -->
|
| 1800 |
+
</p>
|
| 1801 |
+
</table>
|
| 1802 |
+
|
| 1803 |
+
</details>
|
| 1804 |
+
|
| 1805 |
+
## 历史版本模型 <!-- omit in toc -->
|
| 1806 |
+
|
| 1807 |
+
|
| 1808 |
+
| 模型 | 介绍信息和使用教程 |
|
| 1809 |
+
|:----------------------|:-------------------:|
|
| 1810 |
+
| MiniCPM-Llama3-V 2.5 | [文档](./docs/minicpm_llama3_v2dot5.md) |
|
| 1811 |
+
| MiniCPM-V 2.0 | [文档](./docs/minicpm_v2.md) |
|
| 1812 |
+
| MiniCPM-V 1.0 | [文档](./docs/minicpm_v1.md) |
|
| 1813 |
+
| OmniLMM-12B | [文档](./omnilmm.md) |
|
| 1814 |
+
|
| 1815 |
+
|
| 1816 |
+
## Chat with Our Demo on Gradio 🤗
|
| 1817 |
+
|
| 1818 |
+
我们提供由 Hugging Face Gradio <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a> 支持的在线和本地 Demo。Gradio 是目前最流行的模型部署框架,支持流式输出、进度条、process bars 和其他常用功能。
|
| 1819 |
+
|
| 1820 |
+
### Online Demo <!-- omit in toc -->
|
| 1821 |
+
|
| 1822 |
+
欢迎试用 Online Demo: [MiniCPM-V 2.6](http://120.92.209.146:8887/) | [MiniCPM-Llama3-V 2.5](https://huggingface.co/spaces/openbmb/MiniCPM-Llama3-V-2_5) | [MiniCPM-V 2.0](https://huggingface.co/spaces/openbmb/MiniCPM-V-2) 。
|
| 1823 |
+
|
| 1824 |
+
### 本地 WebUI Demo <!-- omit in toc -->
|
| 1825 |
+
|
| 1826 |
+
您可以使用以下命令轻松构建自己的本地 WebUI Demo。更详细的部署教程请参考[文档](https://modelbest.feishu.cn/wiki/RnjjwnUT7idMSdklQcacd2ktnyN)。
|
| 1827 |
+
|
| 1828 |
+
**实时流式视频/语音通话demo:**
|
| 1829 |
+
1. 启动model server:
|
| 1830 |
+
```shell
|
| 1831 |
+
pip install -r requirements_o2.6.txt
|
| 1832 |
+
|
| 1833 |
+
python web_demos/minicpm-o_2.6/model_server.py
|
| 1834 |
+
```
|
| 1835 |
+
请确保 `transformers==4.44.2`,其他版本目前可能会有兼容性问题,我们正在解决。
|
| 1836 |
+
如果你使用的低版本的 Pytorch,你可能会遇到这个错误`"weight_norm_fwd_first_dim_kernel" not implemented for 'BFloat16'`, 请在模型初始化的时候添加 `self.minicpmo_model.tts.float()`
|
| 1837 |
+
|
| 1838 |
+
2. 启动web server:
|
| 1839 |
+
```shell
|
| 1840 |
+
# Make sure Node and PNPM is installed.
|
| 1841 |
+
sudo apt-get update
|
| 1842 |
+
sudo apt-get install nodejs npm
|
| 1843 |
+
npm install -g pnpm
|
| 1844 |
+
|
| 1845 |
+
|
| 1846 |
+
cd web_demos/minicpm-o_2.6/web_server
|
| 1847 |
+
# 为https创建自签名证书, 要申请浏览器摄像头和麦克风权限须启动https.
|
| 1848 |
+
bash ./make_ssl_cert.sh # output key.pem and cert.pem
|
| 1849 |
+
|
| 1850 |
+
pnpm install # install requirements
|
| 1851 |
+
pnpm run dev # start server
|
| 1852 |
+
```
|
| 1853 |
+
浏览器打开`https://localhost:8088/`,开始体验实时流式视频/语音通话.
|
| 1854 |
+
|
| 1855 |
+
**Chatbot图文对话demo:**
|
| 1856 |
+
```shell
|
| 1857 |
+
pip install -r requirements_o2.6.txt
|
| 1858 |
+
|
| 1859 |
+
python web_demos/minicpm-o_2.6/chatbot_web_demo_o2.6.py
|
| 1860 |
+
```
|
| 1861 |
+
浏览器打开`http://localhost:8000/`,开始体验图文对话Chatbot.
|
| 1862 |
+
|
| 1863 |
+
|
| 1864 |
+
## 推理
|
| 1865 |
+
|
| 1866 |
+
### 模型库
|
| 1867 |
+
|
| 1868 |
+
| 模型 | 设备 | 资源 |          简介 | 下载链接 |
|
| 1869 |
+
|:--------------|:-:|:----------:|:-------------------|:---------------:|
|
| 1870 |
+
| MiniCPM-o 2.6| GPU | 18 GB | 最新版本,提供端侧 GPT-4o 级的视觉、语音、多模态流式交互能力。 | [🤗](https://huggingface.co/openbmb/MiniCPM-o-2_6) [<img src="./assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-o-2_6) |
|
| 1871 |
+
| MiniCPM-o 2.6 gguf | CPU | 8 GB | gguf 版本,更低的内存占用和更高的推理效率。 | [🤗](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) [<img src="./assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-o-2_6-gguf) |
|
| 1872 |
+
| MiniCPM-o 2.6 int4 | GPU | 9 GB | int4量化版,更低显存占用。 | [🤗](https://huggingface.co/openbmb/MiniCPM-o-2_6-int4) [<img src="./assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-o-2_6-int4) |
|
| 1873 |
+
| MiniCPM-V 2.6| GPU | 17 GB | 提供出色的端侧单图、多图、视频理解能力。 | [🤗](https://huggingface.co/openbmb/MiniCPM-V-2_6) [<img src="./assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2_6) |
|
| 1874 |
+
| MiniCPM-V 2.6 gguf | CPU | 6 GB | gguf 版本,更低的内存占用和更高的推理效率。 | [🤗](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf) [<img src="./assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2_6-gguf) |
|
| 1875 |
+
| MiniCPM-V 2.6 int4 | GPU | 7 GB | int4量化版,更低显存占用。 | [🤗](https://huggingface.co/openbmb/MiniCPM-V-2_6-int4) [<img src="./assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2_6-int4) |
|
| 1876 |
+
|
| 1877 |
+
更多[历史版本模型](#legacy-models)
|
| 1878 |
+
|
| 1879 |
+
|
| 1880 |
+
### 多轮对话
|
| 1881 |
+
请确保 `transformers==4.44.2`,其他版本目前可能会有兼容性问题
|
| 1882 |
+
|
| 1883 |
+
```shell
|
| 1884 |
+
pip install -r requirements_o2.6.txt
|
| 1885 |
+
```
|
| 1886 |
+
|
| 1887 |
+
<div align="center">
|
| 1888 |
+
<img src="assets/minicpmo2_6/show_demo.jpg" width="500px">
|
| 1889 |
+
</div>
|
| 1890 |
+
|
| 1891 |
+
|
| 1892 |
+
```python
|
| 1893 |
+
import torch
|
| 1894 |
+
from PIL import Image
|
| 1895 |
+
from transformers import AutoModel, AutoTokenizer
|
| 1896 |
+
|
| 1897 |
+
torch.manual_seed(100)
|
| 1898 |
+
|
| 1899 |
+
model = AutoModel.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True,
|
| 1900 |
+
attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
|
| 1901 |
+
model = model.eval().cuda()
|
| 1902 |
+
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True)
|
| 1903 |
+
|
| 1904 |
+
image = Image.open('./assets/minicpmo2_6/show_demo.jpg').convert('RGB')
|
| 1905 |
+
|
| 1906 |
+
# First round chat
|
| 1907 |
+
question = "What is the landform in the picture?"
|
| 1908 |
+
msgs = [{'role': 'user', 'content': [image, question]}]
|
| 1909 |
+
|
| 1910 |
+
answer = model.chat(
|
| 1911 |
+
msgs=msgs,
|
| 1912 |
+
tokenizer=tokenizer
|
| 1913 |
+
)
|
| 1914 |
+
print(answer)
|
| 1915 |
+
|
| 1916 |
+
# Second round chat, pass history context of multi-turn conversation
|
| 1917 |
+
msgs.append({"role": "assistant", "content": [answer]})
|
| 1918 |
+
msgs.append({"role": "user", "content": ["What should I pay attention to when traveling here?"]})
|
| 1919 |
+
|
| 1920 |
+
answer = model.chat(
|
| 1921 |
+
msgs=msgs,
|
| 1922 |
+
tokenizer=tokenizer
|
| 1923 |
+
)
|
| 1924 |
+
print(answer)
|
| 1925 |
+
```
|
| 1926 |
+
|
| 1927 |
+
你可以得到如下推理结果:
|
| 1928 |
+
|
| 1929 |
+
```
|
| 1930 |
+
"The landform in the picture is a mountain range. The mountains appear to be karst formations, characterized by their steep, rugged peaks and smooth, rounded shapes. These types of mountains are often found in regions with limestone bedrock and are shaped by processes such as erosion and weathering. The reflection of the mountains in the water adds to the scenic beauty of the landscape."
|
| 1931 |
+
|
| 1932 |
+
"When traveling to this scenic location, it's important to pay attention to the weather conditions, as the area appears to be prone to fog and mist, especially during sunrise or sunset. Additionally, ensure you have proper footwear for navigating the potentially slippery terrain around the water. Lastly, respect the natural environment by not disturbing the local flora and fauna."
|
| 1933 |
+
```
|
| 1934 |
+
|
| 1935 |
+
#### 多图对话
|
| 1936 |
+
<details>
|
| 1937 |
+
<summary> 点击查看 MiniCPM-o 2.6 多图输入的 Python 代码。 </summary>
|
| 1938 |
+
|
| 1939 |
+
```python
|
| 1940 |
+
import torch
|
| 1941 |
+
from PIL import Image
|
| 1942 |
+
from transformers import AutoModel, AutoTokenizer
|
| 1943 |
+
|
| 1944 |
+
model = AutoModel.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True,
|
| 1945 |
+
attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
|
| 1946 |
+
model = model.eval().cuda()
|
| 1947 |
+
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True)
|
| 1948 |
+
|
| 1949 |
+
image1 = Image.open('image1.jpg').convert('RGB')
|
| 1950 |
+
image2 = Image.open('image2.jpg').convert('RGB')
|
| 1951 |
+
question = 'Compare image 1 and image 2, tell me about the differences between image 1 and image 2.'
|
| 1952 |
+
|
| 1953 |
+
msgs = [{'role': 'user', 'content': [image1, image2, question]}]
|
| 1954 |
+
|
| 1955 |
+
answer = model.chat(
|
| 1956 |
+
msgs=msgs,
|
| 1957 |
+
tokenizer=tokenizer
|
| 1958 |
+
)
|
| 1959 |
+
print(answer)
|
| 1960 |
+
```
|
| 1961 |
+
</details>
|
| 1962 |
+
|
| 1963 |
+
#### 少样本上下文对话
|
| 1964 |
+
<details>
|
| 1965 |
+
<summary> 点击查看 MiniCPM-o 2.6 少样本上下文对话的 Python 代码。 </summary>
|
| 1966 |
+
|
| 1967 |
+
```python
|
| 1968 |
+
import torch
|
| 1969 |
+
from PIL import Image
|
| 1970 |
+
from transformers import AutoModel, AutoTokenizer
|
| 1971 |
+
|
| 1972 |
+
model = AutoModel.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True,
|
| 1973 |
+
attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
|
| 1974 |
+
model = model.eval().cuda()
|
| 1975 |
+
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True)
|
| 1976 |
+
|
| 1977 |
+
question = "production date"
|
| 1978 |
+
image1 = Image.open('example1.jpg').convert('RGB')
|
| 1979 |
+
answer1 = "2023.08.04"
|
| 1980 |
+
image2 = Image.open('example2.jpg').convert('RGB')
|
| 1981 |
+
answer2 = "2007.04.24"
|
| 1982 |
+
image_test = Image.open('test.jpg').convert('RGB')
|
| 1983 |
+
|
| 1984 |
+
msgs = [
|
| 1985 |
+
{'role': 'user', 'content': [image1, question]}, {'role': 'assistant', 'content': [answer1]},
|
| 1986 |
+
{'role': 'user', 'content': [image2, question]}, {'role': 'assistant', 'content': [answer2]},
|
| 1987 |
+
{'role': 'user', 'content': [image_test, question]}
|
| 1988 |
+
]
|
| 1989 |
+
|
| 1990 |
+
answer = model.chat(
|
| 1991 |
+
msgs=msgs,
|
| 1992 |
+
tokenizer=tokenizer
|
| 1993 |
+
)
|
| 1994 |
+
print(answer)
|
| 1995 |
+
```
|
| 1996 |
+
</details>
|
| 1997 |
+
|
| 1998 |
+
#### 视频对话
|
| 1999 |
+
<details>
|
| 2000 |
+
<summary> 点击查看 MiniCPM-o 2.6 视频输入的 Python 代码。 </summary>
|
| 2001 |
+
|
| 2002 |
+
```python
|
| 2003 |
+
import torch
|
| 2004 |
+
from PIL import Image
|
| 2005 |
+
from transformers import AutoModel, AutoTokenizer
|
| 2006 |
+
from decord import VideoReader, cpu # pip install decord
|
| 2007 |
+
|
| 2008 |
+
model = AutoModel.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True,
|
| 2009 |
+
attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
|
| 2010 |
+
model = model.eval().cuda()
|
| 2011 |
+
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True)
|
| 2012 |
+
|
| 2013 |
+
MAX_NUM_FRAMES=64 # if cuda OOM set a smaller number
|
| 2014 |
+
|
| 2015 |
+
def encode_video(video_path):
|
| 2016 |
+
def uniform_sample(l, n):
|
| 2017 |
+
gap = len(l) / n
|
| 2018 |
+
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
| 2019 |
+
return [l[i] for i in idxs]
|
| 2020 |
+
|
| 2021 |
+
vr = VideoReader(video_path, ctx=cpu(0))
|
| 2022 |
+
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
| 2023 |
+
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
| 2024 |
+
if len(frame_idx) > MAX_NUM_FRAMES:
|
| 2025 |
+
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
|
| 2026 |
+
frames = vr.get_batch(frame_idx).asnumpy()
|
| 2027 |
+
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
| 2028 |
+
print('num frames:', len(frames))
|
| 2029 |
+
return frames
|
| 2030 |
+
|
| 2031 |
+
video_path="video_test.mp4"
|
| 2032 |
+
frames = encode_video(video_path)
|
| 2033 |
+
question = "Describe the video"
|
| 2034 |
+
msgs = [
|
| 2035 |
+
{'role': 'user', 'content': frames + [question]},
|
| 2036 |
+
]
|
| 2037 |
+
|
| 2038 |
+
# Set decode params for video
|
| 2039 |
+
params = {}
|
| 2040 |
+
params["use_image_id"] = False
|
| 2041 |
+
params["max_slice_nums"] = 2 # use 1 if cuda OOM and video resolution > 448*448
|
| 2042 |
+
|
| 2043 |
+
answer = model.chat(
|
| 2044 |
+
msgs=msgs,
|
| 2045 |
+
tokenizer=tokenizer,
|
| 2046 |
+
**params
|
| 2047 |
+
)
|
| 2048 |
+
print(answer)
|
| 2049 |
+
```
|
| 2050 |
+
</details>
|
| 2051 |
+
|
| 2052 |
+
|
| 2053 |
+
#### 语音对话
|
| 2054 |
+
<details> <summary> 初始化模型 </summary>
|
| 2055 |
+
|
| 2056 |
+
```python
|
| 2057 |
+
import torch
|
| 2058 |
+
import librosa
|
| 2059 |
+
from transformers import AutoModel, AutoTokenizer
|
| 2060 |
+
|
| 2061 |
+
model = AutoModel.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True,
|
| 2062 |
+
attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
|
| 2063 |
+
model = model.eval().cuda()
|
| 2064 |
+
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True)
|
| 2065 |
+
|
| 2066 |
+
model.init_tts()
|
| 2067 |
+
model.tts.float()
|
| 2068 |
+
```
|
| 2069 |
+
|
| 2070 |
+
</details>
|
| 2071 |
+
|
| 2072 |
+
##### Mimick
|
| 2073 |
+
|
| 2074 |
+
<details> <summary> 点击查看 MiniCPM-o 2.6 端到端语音理解生成的 Python 代码。 </summary>
|
| 2075 |
+
|
| 2076 |
+
- `Mimick` 任务反映了模型的端到端语音建模能力。模型接受音频输入,输出语音识别(ASR)转录结果,并随后以高相似度重建原始音频。重建的音频相似度和原始音频越高,表明模型有越高的���音端到端建模基础能力。
|
| 2077 |
+
```python
|
| 2078 |
+
mimick_prompt = "Please repeat each user's speech, including voice style and speech content."
|
| 2079 |
+
audio_input, _ = librosa.load('xxx.wav', sr=16000, mono=True)
|
| 2080 |
+
msgs = [{'role': 'user', 'content': [mimick_prompt,audio_input]}]
|
| 2081 |
+
res = model.chat(
|
| 2082 |
+
msgs=msgs,
|
| 2083 |
+
tokenizer=tokenizer,
|
| 2084 |
+
sampling=True,
|
| 2085 |
+
max_new_tokens=128,
|
| 2086 |
+
use_tts_template=True,
|
| 2087 |
+
temperature=0.3,
|
| 2088 |
+
generate_audio=True,
|
| 2089 |
+
output_audio_path='output.wav', # save the tts result to output_audio_path
|
| 2090 |
+
)
|
| 2091 |
+
```
|
| 2092 |
+
|
| 2093 |
+
</details>
|
| 2094 |
+
|
| 2095 |
+
##### 可配置声音的语音对话
|
| 2096 |
+
<details> <summary> 点击查看个性化配置 MiniCPM-o 2.6 对话声音的 Python 代码。</summary>
|
| 2097 |
+
|
| 2098 |
+
```python
|
| 2099 |
+
ref_audio, _ = librosa.load('./assets/voice_01.wav', sr=16000, mono=True) # load the reference audio
|
| 2100 |
+
|
| 2101 |
+
# Audio RolePlay: # With this mode, model will role-play the character based on the audio prompt.
|
| 2102 |
+
sys_prompt = model.get_sys_prompt(ref_audio=ref_audio, mode='audio_roleplay', language='en')
|
| 2103 |
+
user_question = {'role': 'user', 'content': [librosa.load('xxx.wav', sr=16000, mono=True)[0]]}
|
| 2104 |
+
|
| 2105 |
+
# Audio Assistant: # With this mode, model will speak with the voice in ref_audio as a AI assistant.
|
| 2106 |
+
# sys_prompt = model.get_sys_prompt(ref_audio=ref_audio, mode='audio_assistant', language='en')
|
| 2107 |
+
# user_question = {'role': 'user', 'content': [librosa.load('xxx.wav', sr=16000, mono=True)[0]]} # Try to ask something!
|
| 2108 |
+
```
|
| 2109 |
+
```python
|
| 2110 |
+
msgs = [sys_prompt, user_question]
|
| 2111 |
+
res = model.chat(
|
| 2112 |
+
msgs=msgs,
|
| 2113 |
+
tokenizer=tokenizer,
|
| 2114 |
+
sampling=True,
|
| 2115 |
+
max_new_tokens=128,
|
| 2116 |
+
use_tts_template=True,
|
| 2117 |
+
generate_audio=True,
|
| 2118 |
+
temperature=0.3,
|
| 2119 |
+
output_audio_path='result.wav',
|
| 2120 |
+
)
|
| 2121 |
+
|
| 2122 |
+
# round two
|
| 2123 |
+
history = msgs.append({'role': 'assistant', 'content': res})
|
| 2124 |
+
user_question = {'role': 'user', 'content': [librosa.load('xxx.wav', sr=16000, mono=True)[0]]}
|
| 2125 |
+
msgs = history.append(user_question)
|
| 2126 |
+
res = model.chat(
|
| 2127 |
+
msgs=msgs,
|
| 2128 |
+
tokenizer=tokenizer,
|
| 2129 |
+
sampling=True,
|
| 2130 |
+
max_new_tokens=128,
|
| 2131 |
+
use_tts_template=True,
|
| 2132 |
+
generate_audio=True,
|
| 2133 |
+
temperature=0.3,
|
| 2134 |
+
output_audio_path='result_round_2.wav',
|
| 2135 |
+
)
|
| 2136 |
+
print(res)
|
| 2137 |
+
```
|
| 2138 |
+
|
| 2139 |
+
</details>
|
| 2140 |
+
|
| 2141 |
+
##### 更多语音任务
|
| 2142 |
+
<details>
|
| 2143 |
+
<summary> 点击查看 MiniCPM-o 2.6 完成更多语音任务的 Python 代码。 </summary>
|
| 2144 |
+
|
| 2145 |
+
```python
|
| 2146 |
+
'''
|
| 2147 |
+
Audio Understanding Task Prompt:
|
| 2148 |
+
Speech:
|
| 2149 |
+
ASR with ZH(same as AST en2zh): 请仔细听这段音频片段,并将其内容逐字记录。
|
| 2150 |
+
ASR with EN(same as AST zh2en): Please listen to the audio snippet carefully and transcribe the content.
|
| 2151 |
+
Speaker Analysis: Based on the speaker's content, speculate on their gender, condition, age range, and health status.
|
| 2152 |
+
General Audio:
|
| 2153 |
+
Audio Caption: Summarize the main content of the audio.
|
| 2154 |
+
Sound Scene Tagging: Utilize one keyword to convey the audio's content or the associated scene.
|
| 2155 |
+
'''
|
| 2156 |
+
task_prompt = "\n"
|
| 2157 |
+
audio_input, _ = librosa.load('xxx.wav', sr=16000, mono=True)
|
| 2158 |
+
|
| 2159 |
+
msgs = [{'role': 'user', 'content': [task_prompt,audio_input]}]
|
| 2160 |
+
|
| 2161 |
+
res = model.chat(
|
| 2162 |
+
msgs=msgs,
|
| 2163 |
+
tokenizer=tokenizer,
|
| 2164 |
+
sampling=True,
|
| 2165 |
+
max_new_tokens=128,
|
| 2166 |
+
use_tts_template=True,
|
| 2167 |
+
generate_audio=True,
|
| 2168 |
+
temperature=0.3,
|
| 2169 |
+
output_audio_path='result.wav',
|
| 2170 |
+
)
|
| 2171 |
+
print(res)
|
| 2172 |
+
```
|
| 2173 |
+
```python
|
| 2174 |
+
'''
|
| 2175 |
+
Speech Generation Task Prompt:
|
| 2176 |
+
Human Instruction-to-Speech: see https://voxinstruct.github.io/VoxInstruct/
|
| 2177 |
+
Example:
|
| 2178 |
+
# 在新闻中,一个年轻男性兴致勃勃地说:“祝福亲爱的祖国母亲美丽富强!”他用低音调和低音量,慢慢地说出了这句话。
|
| 2179 |
+
# Delighting in a surprised tone, an adult male with low pitch and low volume comments:"One even gave my little dog a biscuit" This dialogue takes place at a leisurely pace, delivering a sense of excitement and surprise in the context.
|
| 2180 |
+
|
| 2181 |
+
Voice Cloning or Voice Creation: With this mode, model will act like a TTS model.
|
| 2182 |
+
'''
|
| 2183 |
+
# Human Instruction-to-Speech:
|
| 2184 |
+
task_prompt = '' #Try to make some Human Instruction-to-Speech prompt
|
| 2185 |
+
msgs = [{'role': 'user', 'content': [task_prompt]}] # you can try to use the same audio question
|
| 2186 |
+
|
| 2187 |
+
# Voice Cloning mode: With this mode, model will act like a TTS model.
|
| 2188 |
+
# sys_prompt = model.get_sys_prompt(ref_audio=ref_audio, mode='voice_cloning', language='en')
|
| 2189 |
+
# text_prompt = f"Please read the text below."
|
| 2190 |
+
# user_question = {'role': 'user', 'content': [text_prompt, "content that you want to read"]} # using same voice in sys_prompt to read the text. (Voice Cloning)
|
| 2191 |
+
# user_question = {'role': 'user', 'content': [text_prompt, librosa.load('xxx.wav', sr=16000, mono=True)[0]]} # using same voice in sys_prompt to read 'xxx.wav'. (Voice Creation)
|
| 2192 |
+
|
| 2193 |
+
msgs = [sys_prompt, user_question]
|
| 2194 |
+
res = model.chat(
|
| 2195 |
+
msgs=msgs,
|
| 2196 |
+
tokenizer=tokenizer,
|
| 2197 |
+
sampling=True,
|
| 2198 |
+
max_new_tokens=128,
|
| 2199 |
+
use_tts_template=True,
|
| 2200 |
+
generate_audio=True,
|
| 2201 |
+
temperature=0.3,
|
| 2202 |
+
output_audio_path='result.wav',
|
| 2203 |
+
)
|
| 2204 |
+
|
| 2205 |
+
|
| 2206 |
+
```
|
| 2207 |
+
|
| 2208 |
+
</details>
|
| 2209 |
+
|
| 2210 |
+
#### 多模态流式交互
|
| 2211 |
+
<details>
|
| 2212 |
+
<summary> 点击查看 MiniCPM-o 2.6 多模态流式交互的 Python 代码。 </summary>
|
| 2213 |
+
|
| 2214 |
+
```python
|
| 2215 |
+
import math
|
| 2216 |
+
import numpy as np
|
| 2217 |
+
from PIL import Image
|
| 2218 |
+
from moviepy.editor import VideoFileClip
|
| 2219 |
+
import tempfile
|
| 2220 |
+
import librosa
|
| 2221 |
+
import soundfile as sf
|
| 2222 |
+
import torch
|
| 2223 |
+
from transformers import AutoModel, AutoTokenizer
|
| 2224 |
+
|
| 2225 |
+
def get_video_chunk_content(video_path, flatten=True):
|
| 2226 |
+
video = VideoFileClip(video_path)
|
| 2227 |
+
print('video_duration:', video.duration)
|
| 2228 |
+
|
| 2229 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_audio_file:
|
| 2230 |
+
temp_audio_file_path = temp_audio_file.name
|
| 2231 |
+
video.audio.write_audiofile(temp_audio_file_path, codec="pcm_s16le", fps=16000)
|
| 2232 |
+
audio_np, sr = librosa.load(temp_audio_file_path, sr=16000, mono=True)
|
| 2233 |
+
num_units = math.ceil(video.duration)
|
| 2234 |
+
|
| 2235 |
+
# 1 frame + 1s audio chunk
|
| 2236 |
+
contents= []
|
| 2237 |
+
for i in range(num_units):
|
| 2238 |
+
frame = video.get_frame(i+1)
|
| 2239 |
+
image = Image.fromarray((frame).astype(np.uint8))
|
| 2240 |
+
audio = audio_np[sr*i:sr*(i+1)]
|
| 2241 |
+
if flatten:
|
| 2242 |
+
contents.extend(["<unit>", image, audio])
|
| 2243 |
+
else:
|
| 2244 |
+
contents.append(["<unit>", image, audio])
|
| 2245 |
+
|
| 2246 |
+
return contents
|
| 2247 |
+
|
| 2248 |
+
|
| 2249 |
+
model = AutoModel.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True,
|
| 2250 |
+
attn_implementation='sdpa', torch_dtype=torch.bfloat16)
|
| 2251 |
+
model = model.eval().cuda()
|
| 2252 |
+
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True)
|
| 2253 |
+
|
| 2254 |
+
model.init_tts()
|
| 2255 |
+
|
| 2256 |
+
# If you are using an older version of PyTorch, you might encounter this issue "weight_norm_fwd_first_dim_kernel" not implemented for 'BFloat16', Please convert the TTS to float32 type.
|
| 2257 |
+
# model.tts.float()
|
| 2258 |
+
|
| 2259 |
+
# https://huggingface.co/openbmb/MiniCPM-o-2_6/blob/main/assets/Skiing.mp4
|
| 2260 |
+
video_path="assets/Skiing.mp4"
|
| 2261 |
+
sys_msg = model.get_sys_prompt(mode='omni', language='en')
|
| 2262 |
+
# if use voice clone prompt, please set ref_audio
|
| 2263 |
+
# ref_audio_path = '/path/to/ref_audio'
|
| 2264 |
+
# ref_audio, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
|
| 2265 |
+
# sys_msg = model.get_sys_prompt(ref_audio=ref_audio, mode='omni', language='en')
|
| 2266 |
+
|
| 2267 |
+
contents = get_video_chunk_content(video_path)
|
| 2268 |
+
msg = {"role":"user", "content": contents}
|
| 2269 |
+
msgs = [sys_msg, msg]
|
| 2270 |
+
|
| 2271 |
+
# please set generate_audio=True and output_audio_path to save the tts result
|
| 2272 |
+
generate_audio = True
|
| 2273 |
+
output_audio_path = 'output.wav'
|
| 2274 |
+
|
| 2275 |
+
res = model.chat(
|
| 2276 |
+
msgs=msgs,
|
| 2277 |
+
tokenizer=tokenizer,
|
| 2278 |
+
sampling=True,
|
| 2279 |
+
temperature=0.5,
|
| 2280 |
+
max_new_tokens=4096,
|
| 2281 |
+
omni_input=True, # please set omni_input=True when omni inference
|
| 2282 |
+
use_tts_template=True,
|
| 2283 |
+
generate_audio=generate_audio,
|
| 2284 |
+
output_audio_path=output_audio_path,
|
| 2285 |
+
max_slice_nums=1,
|
| 2286 |
+
use_image_id=False,
|
| 2287 |
+
return_dict=True
|
| 2288 |
+
)
|
| 2289 |
+
print(res)
|
| 2290 |
+
```
|
| 2291 |
+
</details>
|
| 2292 |
+
|
| 2293 |
+
<details>
|
| 2294 |
+
<summary> 点击查看多模态流式推理设置。 </summary>
|
| 2295 |
+
|
| 2296 |
+
注意:流式推理存在轻微的性能下降,因为音频编码并非全局的。
|
| 2297 |
+
```python
|
| 2298 |
+
# a new conversation need reset session first, it will reset the kv-cache
|
| 2299 |
+
model.reset_session()
|
| 2300 |
+
|
| 2301 |
+
contents = get_video_chunk_content(video_path, flatten=False)
|
| 2302 |
+
session_id = '123'
|
| 2303 |
+
generate_audio = True
|
| 2304 |
+
|
| 2305 |
+
# 1. prefill system prompt
|
| 2306 |
+
res = model.streaming_prefill(
|
| 2307 |
+
session_id=session_id,
|
| 2308 |
+
msgs=[sys_msg],
|
| 2309 |
+
tokenizer=tokenizer
|
| 2310 |
+
)
|
| 2311 |
+
|
| 2312 |
+
# 2. prefill video/audio chunks
|
| 2313 |
+
for content in contents:
|
| 2314 |
+
msgs = [{"role":"user", "content": content}]
|
| 2315 |
+
res = model.streaming_prefill(
|
| 2316 |
+
session_id=session_id,
|
| 2317 |
+
msgs=msgs,
|
| 2318 |
+
tokenizer=tokenizer
|
| 2319 |
+
)
|
| 2320 |
+
|
| 2321 |
+
# 3. generate
|
| 2322 |
+
res = model.streaming_generate(
|
| 2323 |
+
session_id=session_id,
|
| 2324 |
+
tokenizer=tokenizer,
|
| 2325 |
+
temperature=0.5,
|
| 2326 |
+
generate_audio=generate_audio
|
| 2327 |
+
)
|
| 2328 |
+
|
| 2329 |
+
audios = []
|
| 2330 |
+
text = ""
|
| 2331 |
+
|
| 2332 |
+
if generate_audio:
|
| 2333 |
+
for r in res:
|
| 2334 |
+
audio_wav = r.audio_wav
|
| 2335 |
+
sampling_rate = r.sampling_rate
|
| 2336 |
+
txt = r.text
|
| 2337 |
+
|
| 2338 |
+
audios.append(audio_wav)
|
| 2339 |
+
text += txt
|
| 2340 |
+
|
| 2341 |
+
res = np.concatenate(audios)
|
| 2342 |
+
sf.write("output.wav", res, samplerate=sampling_rate)
|
| 2343 |
+
print("text:", text)
|
| 2344 |
+
print("audio saved to output.wav")
|
| 2345 |
+
else:
|
| 2346 |
+
for r in res:
|
| 2347 |
+
text += r['text']
|
| 2348 |
+
print("text:", text)
|
| 2349 |
+
```
|
| 2350 |
+
|
| 2351 |
+
</details>
|
| 2352 |
+
|
| 2353 |
+
|
| 2354 |
+
### 多卡推理
|
| 2355 |
+
您可以通过将模型的层分布在多个低显存显卡(12 GB 或 16 GB)上,运行 MiniCPM-Llama3-V 2.5。请查看该[教程](https://github.com/OpenBMB/MiniCPM-V/blob/main/docs/inference_on_multiple_gpus.md),详细了解如何使用多张低显存显卡载入模型并进行推理。
|
| 2356 |
+
|
| 2357 |
+
|
| 2358 |
+
### Mac 推理
|
| 2359 |
+
<details>
|
| 2360 |
+
<summary>点击查看 MiniCPM-Llama3-V 2.5 / MiniCPM-V 2.0 基于Mac MPS运行 (Apple silicon 或 AMD GPUs)的示例。 </summary>
|
| 2361 |
+
|
| 2362 |
+
```python
|
| 2363 |
+
# test.py Need more than 16GB memory to run.
|
| 2364 |
+
import torch
|
| 2365 |
+
from PIL import Image
|
| 2366 |
+
from transformers import AutoModel, AutoTokenizer
|
| 2367 |
+
|
| 2368 |
+
model = AutoModel.from_pretrained('openbmb/MiniCPM-Llama3-V-2_5', trust_remote_code=True, low_cpu_mem_usage=True)
|
| 2369 |
+
model = model.to(device='mps')
|
| 2370 |
+
|
| 2371 |
+
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-Llama3-V-2_5', trust_remote_code=True)
|
| 2372 |
+
model.eval()
|
| 2373 |
+
|
| 2374 |
+
image = Image.open('./assets/hk_OCR.jpg').convert('RGB')
|
| 2375 |
+
question = 'Where is this photo taken?'
|
| 2376 |
+
msgs = [{'role': 'user', 'content': question}]
|
| 2377 |
+
|
| 2378 |
+
answer, context, _ = model.chat(
|
| 2379 |
+
image=image,
|
| 2380 |
+
msgs=msgs,
|
| 2381 |
+
context=None,
|
| 2382 |
+
tokenizer=tokenizer,
|
| 2383 |
+
sampling=True
|
| 2384 |
+
)
|
| 2385 |
+
print(answer)
|
| 2386 |
+
```
|
| 2387 |
+
运行:
|
| 2388 |
+
```shell
|
| 2389 |
+
PYTORCH_ENABLE_MPS_FALLBACK=1 python test.py
|
| 2390 |
+
```
|
| 2391 |
+
</details>
|
| 2392 |
+
|
| 2393 |
+
|
| 2394 |
+
### 基于 llama.cpp、ollama、vLLM 的高效推理
|
| 2395 |
+
|
| 2396 |
+
llama.cpp 用法请参考[我们的fork llama.cpp](https://github.com/OpenBMB/llama.cpp/tree/minicpmv-main/examples/llava/README-minicpmv2.6.md), 在iPad上可以支持 16~18 token/s 的流畅推理(测试环境:iPad Pro + M4)。
|
| 2397 |
+
|
| 2398 |
+
ollama 用法请参考[我们的fork ollama](https://github.com/OpenBMB/ollama/blob/minicpm-v2.6/examples/minicpm-v2.6/README.md), 在iPad上可以支持 16~18 token/s 的流畅推理(测试环境:iPad Pro + M4)。
|
| 2399 |
+
|
| 2400 |
+
<details>
|
| 2401 |
+
<summary>点击查看, vLLM 现已官方支持MiniCPM-o 2.6、MiniCPM-V 2.6、MiniCPM-Llama3-V 2.5 和 MiniCPM-V 2.0。 </summary>
|
| 2402 |
+
1. 安装 vLLM(>=0.7.1):
|
| 2403 |
+
|
| 2404 |
+
```shell
|
| 2405 |
+
pip install vllm
|
| 2406 |
+
```
|
| 2407 |
+
|
| 2408 |
+
2. 运行示例代码:(注意:如果使用本地路径的模型,请确保模型代码已更新到Hugging Face上的最新版)
|
| 2409 |
+
|
| 2410 |
+
* [图文示例](https://docs.vllm.ai/en/latest/getting_started/examples/vision_language.html)
|
| 2411 |
+
* [音频示例](https://docs.vllm.ai/en/latest/getting_started/examples/audio_language.html)
|
| 2412 |
+
|
| 2413 |
+
</details>
|
| 2414 |
+
|
| 2415 |
+
|
| 2416 |
+
## 微调
|
| 2417 |
+
|
| 2418 |
+
### 简易微调 <!-- omit in toc -->
|
| 2419 |
+
|
| 2420 |
+
我们支持使用 Huggingface Transformers 库简易地微调 MiniCPM-o 2.6、MiniCPM-V 2.6、MiniCPM-Llama3-V 2.5 和 MiniCPM-V 2.0 模型。
|
| 2421 |
+
|
| 2422 |
+
[参考文档](./finetune/readme.md)
|
| 2423 |
+
|
| 2424 |
+
|
| 2425 |
+
### 使用 Align-Anything <!-- omit in toc -->
|
| 2426 |
+
|
| 2427 |
+
我们支持使用北大团队开发的 [Align-Anything](https://github.com/PKU-Alignment/align-anything) 框架微调 MiniCPM-o 系列模型,同时支持 DPO 和 SFT 在视觉和音频模态上的微调。Align-Anything 是一个用于对齐全模态大模型的高度可扩展框架,开源了[数据集、模型和评测](https://huggingface.co/datasets/PKU-Alignment/align-anything)。它支持了 30+ 开源基准,40+ 模型,以及包含SFT、SimPO、RLHF在内的多种算法,并提供了 30+ 直接可运行的脚本,适合初学者快速上手。
|
| 2428 |
+
|
| 2429 |
+
最佳实践: [MiniCPM-o 2.6](https://github.com/PKU-Alignment/align-anything/tree/main/scripts).
|
| 2430 |
+
|
| 2431 |
+
|
| 2432 |
+
### 使用 LLaMA-Factory <!-- omit in toc -->
|
| 2433 |
+
|
| 2434 |
+
我们支持使用 LLaMA-Factory 微调 MiniCPM-o 2.6 和 MiniCPM-V 2.6。LLaMA-Factory 提供了一种灵活定制 200 多个大型语言模型(LLM)微调(Lora/Full/Qlora)解决方案,无需编写代码,通过内置的 Web 用户界面 LLaMABoard 即可实现训练/推理/评估。它支持多种训练方法,如 sft/ppo/dpo/kto,并且还支持如 Galore/BAdam/LLaMA-Pro/Pissa/LongLoRA 等高级算法。
|
| 2435 |
+
|
| 2436 |
+
最佳实践: [MiniCPM-o 2.6 | MiniCPM-V 2.6](./docs/llamafactory_train_and_infer.md).
|
| 2437 |
+
|
| 2438 |
+
|
| 2439 |
+
### 使用 SWIFT 框架 <!-- omit in toc -->
|
| 2440 |
+
|
| 2441 |
+
我们支持使用 SWIFT 框架微调 MiniCPM-V 系列模型。SWIFT 支持近 200 种大语言模型和多模态大模型的训练、推理、评测和部署。支持 PEFT 提供的轻量训练方案和完整的 Adapters 库支持的最新训练技术如 NEFTune、LoRA+、LLaMA-PRO 等。
|
| 2442 |
+
|
| 2443 |
+
参考文档:[MiniCPM-V 1.0](https://github.com/modelscope/swift/blob/main/docs/source/Multi-Modal/minicpm-v最佳实践.md),[MiniCPM-V 2.0](https://github.com/modelscope/swift/blob/main/docs/source/Multi-Modal/minicpm-v-2最佳实践.md) [MiniCPM-V 2.6](https://github.com/modelscope/ms-swift/issues/1613).
|
| 2444 |
+
|
| 2445 |
+
## FAQs
|
| 2446 |
+
点击查看 [FAQs](./docs/faqs.md)
|
| 2447 |
+
|
| 2448 |
+
|
| 2449 |
+
## 模型局限性
|
| 2450 |
+
|
| 2451 |
+
我们实验发现 MiniCPM-o 2.6 存在一些显著的局限性,需要进一步研究和改进:
|
| 2452 |
+
- **不稳定的语音输出。** 语音生成可能会受到背景噪音和无意义声音的影响,表现不稳定。
|
| 2453 |
+
- **重复响应。** 当遇到连续相似的用户请求时,模型往往会重复相同的回答。
|
| 2454 |
+
- **Web Demo 延迟较高。** 用户在使用远程服务器上部署的 web demo 时可能会产生较高延迟。我们推荐用户在本地部署来获得更低延迟的体验。
|
| 2455 |
+
|
| 2456 |
+
|
| 2457 |
+
## 模型协议 <!-- omit in toc -->
|
| 2458 |
+
|
| 2459 |
+
* 本仓库中代码依照 [Apache-2.0](https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE) 协议开源
|
| 2460 |
+
* MiniCPM-o/V 模型权重的使用则需要遵循 [“MiniCPM模型商用许可协议.md”](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%E6%A8%A1%E5%9E%8B%E5%95%86%E7%94%A8%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.md)。
|
| 2461 |
+
* MiniCPM 模型权重对学术研究完全开放,在填写[“问卷”](https://modelbest.feishu.cn/share/base/form/shrcnpV5ZT9EJ6xYjh3Kx0J6v8g)进行登记后亦允许免费商业使用。
|
| 2462 |
+
|
| 2463 |
+
## 声明 <!-- omit in toc -->
|
| 2464 |
+
|
| 2465 |
+
作为多模态大模型,MiniCPM-o/V 系列模型(包括 OmniLMM)通过学习大量的多模态数据来生成内容,但它无法理解、表达个人观点或价值判断,它所输出的任何内容都不代表模型开发者的观点和立场。
|
| 2466 |
+
|
| 2467 |
+
因此用户在使用本项目的系列模型生成的内容时,应自行负责对其进行评估和验证。如果由于使用本项目的系列开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。
|
| 2468 |
+
|
| 2469 |
+
|
| 2470 |
+
## 机构 <!-- omit in toc -->
|
| 2471 |
+
|
| 2472 |
+
本项目由以下机构共同开发:
|
| 2473 |
+
|
| 2474 |
+
- <img src="assets/thunlp.png" width="28px"> [清华大学自然语言处理实验室](https://nlp.csai.tsinghua.edu.cn/)
|
| 2475 |
+
- <img src="assets/modelbest.png" width="28px"> [面壁智能](https://modelbest.cn/)
|
| 2476 |
+
|
| 2477 |
+
## 🌟 Star History <!-- omit in toc -->
|
| 2478 |
+
|
| 2479 |
+
|
| 2480 |
+
<!-- <table align="center">
|
| 2481 |
+
<p align="center">
|
| 2482 |
+
<img src="assets/star_history.svg"/>
|
| 2483 |
+
</p>
|
| 2484 |
+
</table> -->
|
| 2485 |
+
|
| 2486 |
+
<picture>
|
| 2487 |
+
<source
|
| 2488 |
+
media="(prefers-color-scheme: dark)"
|
| 2489 |
+
srcset="
|
| 2490 |
+
https://api.star-history.com/svg?repos=OpenBMB/MiniCPM-o&type=Date&theme=dark
|
| 2491 |
+
"
|
| 2492 |
+
/>
|
| 2493 |
+
<source
|
| 2494 |
+
media="(prefers-color-scheme: light)"
|
| 2495 |
+
srcset="
|
| 2496 |
+
https://api.star-history.com/svg?repos=OpenBMB/MiniCPM-o&type=Date
|
| 2497 |
+
"
|
| 2498 |
+
/>
|
| 2499 |
+
<img
|
| 2500 |
+
alt="Star History Chart"
|
| 2501 |
+
src="https://api.star-history.com/svg?repos=OpenBMB/MiniCPM-o&type=Date"
|
| 2502 |
+
/>
|
| 2503 |
+
</picture>
|
| 2504 |
+
|
| 2505 |
+
## 支持技术和其他多模态项目 <!-- omit in toc -->
|
| 2506 |
+
|
| 2507 |
+
👏 欢迎了解 MiniCPM-o/V 背后的支持技术和更多我们的多模态项目!
|
| 2508 |
+
|
| 2509 |
+
[VisCPM](https://github.com/OpenBMB/VisCPM/tree/main) | [RLHF-V](https://github.com/RLHF-V/RLHF-V) | [LLaVA-UHD](https://github.com/thunlp/LLaVA-UHD) | [RLAIF-V](https://github.com/RLHF-V/RLAIF-V)
|
| 2510 |
+
|
| 2511 |
+
|
| 2512 |
+
|
| 2513 |
+
## 引用 <!-- omit in toc -->
|
| 2514 |
+
|
| 2515 |
+
如果您觉得我们模型/代码/论文有帮助,请给我们 ⭐ 和 引用 📝,感谢!
|
| 2516 |
+
|
| 2517 |
+
```bib
|
| 2518 |
+
@article{yao2024minicpm,
|
| 2519 |
+
title={MiniCPM-V: A GPT-4V Level MLLM on Your Phone},
|
| 2520 |
+
author={Yao, Yuan and Yu, Tianyu and Zhang, Ao and Wang, Chongyi and Cui, Junbo and Zhu, Hongji and Cai, Tianchi and Li, Haoyu and Zhao, Weilin and He, Zhihui and others},
|
| 2521 |
+
journal={arXiv preprint arXiv:2408.01800},
|
| 2522 |
+
year={2024}
|
| 2523 |
+
}
|
| 2524 |
+
```
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/dataset/cgbench.py
ADDED
|
@@ -0,0 +1,1760 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import snapshot_download
|
| 2 |
+
from ..smp import *
|
| 3 |
+
from .video_base import VideoBaseDataset
|
| 4 |
+
from .utils import build_judge, DEBUG_MESSAGE
|
| 5 |
+
from .utils.cgbench import *
|
| 6 |
+
from ..utils import track_progress_rich
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CGBench_MCQ_Grounding_Mini(VideoBaseDataset):
|
| 10 |
+
|
| 11 |
+
dataset = "CG-Bench_MCQ_Grounding_Mini"
|
| 12 |
+
|
| 13 |
+
TYPE = "Video-MCQ-Grounding"
|
| 14 |
+
|
| 15 |
+
MD5 = "54ed3e90a51a6fb375c92b319a715f72"
|
| 16 |
+
|
| 17 |
+
SYS = {
|
| 18 |
+
"long_acc": (
|
| 19 |
+
"You will be provided with sampled frames from a video, along with a "
|
| 20 |
+
"multiple-choice question that includes a question and several answer options.\n"
|
| 21 |
+
"Your task is to analyze the provided frames, infer the most plausible "
|
| 22 |
+
"answer based on the visual information.\n"
|
| 23 |
+
"If the video does not provide enough information, infer the answer based "
|
| 24 |
+
"on the options available and still provide a result. "
|
| 25 |
+
"Therefore, In all cases, an answer must be given.\n"
|
| 26 |
+
"Only output the answer in the following format:\n\n"
|
| 27 |
+
'```json\n{"result": "option"}\n```\n\n'
|
| 28 |
+
'The "option" is the uppercase letter corresponding to your answer.\n\n'
|
| 29 |
+
),
|
| 30 |
+
"clue_acc": (
|
| 31 |
+
"You will be provided with sampled frames from a video, along with a "
|
| 32 |
+
"multiple-choice question that includes a question and several answer options.\n"
|
| 33 |
+
"Your task is to analyze the provided frames, infer the most plausible "
|
| 34 |
+
"answer based on the visual information.\n"
|
| 35 |
+
"If the video does not provide enough information, infer the answer based "
|
| 36 |
+
"on the options available and still provide a result. "
|
| 37 |
+
"Therefore, In all cases, an answer must be given.\n"
|
| 38 |
+
"Only output the answer in the following format:\n\n"
|
| 39 |
+
'```json\n{"result": "option"}\n```\n\n'
|
| 40 |
+
"The 'option' is the uppercase letter corresponding to your answer.\n\n"
|
| 41 |
+
),
|
| 42 |
+
"miou": (
|
| 43 |
+
"You will be provided with uniformly sampled frames from a video and their "
|
| 44 |
+
"timestamps, along with a multiple-choice question that includes a question "
|
| 45 |
+
"and several answer options.\n"
|
| 46 |
+
"Your task is to determine in which intervals the 'clue intervals' exist "
|
| 47 |
+
"that contain visual information needed to answer the question.\n"
|
| 48 |
+
"Only output the answer in the following format:\n\n"
|
| 49 |
+
'```json\n{"result": [[start1, end1], [start2, end2], ...]}\n```\n\n'
|
| 50 |
+
"In this output format, each 'start' and 'end' represents the beginning and "
|
| 51 |
+
"end of an interval in seconds where relevant clues can be found.\n"
|
| 52 |
+
"You must provide at least one interval and at most five intervals. "
|
| 53 |
+
"Intervals exceeding five will NOT be considered valid.\n"
|
| 54 |
+
),
|
| 55 |
+
"miou_wo_frame_time": (
|
| 56 |
+
"You will be provided with uniformly sampled frames from a video, along "
|
| 57 |
+
"with a multiple-choice question that includes a question and several "
|
| 58 |
+
"answer options.\n"
|
| 59 |
+
"Your task is to determine in which intervals the 'clue intervals' exist "
|
| 60 |
+
"that contain visual information needed to answer the question.\n"
|
| 61 |
+
"Only output the answer in the following format:\n\n"
|
| 62 |
+
'```json\n{"result": [[start1, end1], [start2, end2], ...]}\n```\n\n'
|
| 63 |
+
'In this output format, each "start" and "end" represents the start and '
|
| 64 |
+
"end of the video where the relevant clue can be found in the form of a "
|
| 65 |
+
"floating point number between 0 and 1, where 0 represents the start time "
|
| 66 |
+
"of the video and 1 represents the end time of the video.\n"
|
| 67 |
+
"You must provide at least one interval and at most five intervals. "
|
| 68 |
+
"Intervals exceeding five will NOT be considered valid.\n"
|
| 69 |
+
),
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
dataset="CG-Bench_MCQ_Grounding_Mini",
|
| 75 |
+
use_subtitle=False,
|
| 76 |
+
use_subtitle_time=False,
|
| 77 |
+
use_frame_time=False,
|
| 78 |
+
nframe=0,
|
| 79 |
+
fps=-1,
|
| 80 |
+
):
|
| 81 |
+
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
|
| 82 |
+
self.use_subtitle = use_subtitle
|
| 83 |
+
self.use_subtitle_time = use_subtitle_time
|
| 84 |
+
self.use_frame_time = use_frame_time
|
| 85 |
+
self.dataset_name = dataset
|
| 86 |
+
lmu_root = LMUDataRoot()
|
| 87 |
+
self.clue_frame_root = osp.join(lmu_root, "clue_images", dataset)
|
| 88 |
+
|
| 89 |
+
@classmethod
|
| 90 |
+
def supported_datasets(cls):
|
| 91 |
+
return ["CG-Bench_MCQ_Grounding_Mini"]
|
| 92 |
+
|
| 93 |
+
def clue_frame_paths(self, qid, num_frames=8):
|
| 94 |
+
frame_root = osp.join(self.clue_frame_root, qid)
|
| 95 |
+
os.makedirs(frame_root, exist_ok=True)
|
| 96 |
+
return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
|
| 97 |
+
|
| 98 |
+
def clue_frame_paths_fps(self, qid, num_frames=8, fps=-1):
|
| 99 |
+
frame_root = osp.join(self.clue_frame_root, qid)
|
| 100 |
+
os.makedirs(frame_root, exist_ok=True)
|
| 101 |
+
return [osp.join(frame_root, self.frame_tmpl_fps.format(i, num_frames, fps)) for i in range(1, num_frames + 1)]
|
| 102 |
+
|
| 103 |
+
def get_subtitles(self, subtitle_path, frame_indices=None, fps=None, sub_time=False):
|
| 104 |
+
|
| 105 |
+
subtitles = []
|
| 106 |
+
|
| 107 |
+
srt_path = osp.join(self.data_root, subtitle_path)
|
| 108 |
+
assert osp.exists(srt_path)
|
| 109 |
+
import pysubs2
|
| 110 |
+
|
| 111 |
+
subs = pysubs2.load(srt_path, encoding="utf-8")
|
| 112 |
+
if not frame_indices:
|
| 113 |
+
for sub in subs:
|
| 114 |
+
sub_text = sub.text.replace("\\N", " ")
|
| 115 |
+
if sub_time:
|
| 116 |
+
start_time = milliseconds_to_seconds(sub.start)
|
| 117 |
+
end_time = milliseconds_to_seconds(sub.end)
|
| 118 |
+
sub_text = f"[{start_time}, {end_time}] {sub_text}"
|
| 119 |
+
if sub_text.strip() and sub_text not in subtitles:
|
| 120 |
+
subtitles.append(sub_text)
|
| 121 |
+
else:
|
| 122 |
+
for selected_frame_id in frame_indices:
|
| 123 |
+
cur_time = pysubs2.make_time(fps=fps, frames=selected_frame_id)
|
| 124 |
+
for sub in subs:
|
| 125 |
+
if sub.start < cur_time and sub.end > cur_time:
|
| 126 |
+
sub_text = sub.text.replace("\\N", " ")
|
| 127 |
+
if sub_time:
|
| 128 |
+
start_time = milliseconds_to_seconds(sub.start)
|
| 129 |
+
end_time = milliseconds_to_seconds(sub.end)
|
| 130 |
+
sub_text = f"[{start_time}, {end_time}] {sub_text}"
|
| 131 |
+
if sub_text.strip() and sub_text not in subtitles:
|
| 132 |
+
subtitles.append(sub_text)
|
| 133 |
+
|
| 134 |
+
if subtitles:
|
| 135 |
+
subtitles_str = '\n'.join(subtitles)
|
| 136 |
+
return f"The subtitles of the video are as follows:\n\n{subtitles_str}\n\n"
|
| 137 |
+
else:
|
| 138 |
+
return ""
|
| 139 |
+
|
| 140 |
+
def prepare_dataset(self, dataset_name="CG-Bench_MCQ_Grounding_Mini", repo_id="CG-Bench/CG-Bench"):
|
| 141 |
+
|
| 142 |
+
def check_integrity(pth):
|
| 143 |
+
data_file = osp.join(pth, f"{dataset_name}.tsv")
|
| 144 |
+
|
| 145 |
+
if not os.path.exists(data_file):
|
| 146 |
+
return False
|
| 147 |
+
|
| 148 |
+
if md5(data_file) != self.MD5:
|
| 149 |
+
return False
|
| 150 |
+
data = load(data_file)
|
| 151 |
+
for video_pth in data["video"]:
|
| 152 |
+
if not osp.exists(osp.join(pth, video_pth)):
|
| 153 |
+
return False
|
| 154 |
+
|
| 155 |
+
return True
|
| 156 |
+
|
| 157 |
+
cache_path = get_cache_path(repo_id)
|
| 158 |
+
|
| 159 |
+
if cache_path is not None and check_integrity(cache_path):
|
| 160 |
+
dataset_path = cache_path
|
| 161 |
+
else:
|
| 162 |
+
|
| 163 |
+
def generate_tsv(pth):
|
| 164 |
+
|
| 165 |
+
tsv_file = osp.join(pth, f"{dataset_name}.tsv")
|
| 166 |
+
|
| 167 |
+
task_modes = ["long_acc", "clue_acc", "miou"]
|
| 168 |
+
all_data = []
|
| 169 |
+
for task_mode in task_modes:
|
| 170 |
+
with open(osp.join(pth, "cgbench_mini.json"), "r") as f:
|
| 171 |
+
data_file = pd.DataFrame(json.load(f))
|
| 172 |
+
|
| 173 |
+
data_file = data_file.assign(index=range(len(data_file)))
|
| 174 |
+
data_file["video"] = data_file["video_uid"].apply(lambda x: f"cg_videos_720p/{x}.mp4")
|
| 175 |
+
data_file["subtitle_path"] = data_file["video_uid"].apply(
|
| 176 |
+
lambda x: (
|
| 177 |
+
f"cg_subtitles/{x}.srt"
|
| 178 |
+
if osp.exists(osp.join(dataset_path, f"cg_subtitles/{x}.srt"))
|
| 179 |
+
else ""
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
data_file["clue_video_path"] = ""
|
| 184 |
+
|
| 185 |
+
if task_mode in ["clue_acc"]:
|
| 186 |
+
data_file["clue_video_path"] = data_file["clue_video_path"] = data_file.apply(
|
| 187 |
+
lambda row: f"cg_clue_videos/{row['qid']}.mp4", axis=1
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
data_file["task_mode"] = task_mode
|
| 191 |
+
|
| 192 |
+
if task_mode in ["clue_acc", "long_acc"]:
|
| 193 |
+
data_file["answer"] = data_file["right_answer"]
|
| 194 |
+
|
| 195 |
+
if task_mode == "miou":
|
| 196 |
+
data_file["answer"] = data_file["clue_intervals"]
|
| 197 |
+
|
| 198 |
+
if task_mode in ["long_acc", "miou"]:
|
| 199 |
+
data_file["clue_intervals"] = ""
|
| 200 |
+
|
| 201 |
+
data_file = data_file[
|
| 202 |
+
[
|
| 203 |
+
"index",
|
| 204 |
+
"video_uid",
|
| 205 |
+
"video",
|
| 206 |
+
"duration",
|
| 207 |
+
"domain",
|
| 208 |
+
"choices",
|
| 209 |
+
"sub_category",
|
| 210 |
+
"subtitle_path",
|
| 211 |
+
"question",
|
| 212 |
+
"answer",
|
| 213 |
+
"task_mode",
|
| 214 |
+
"clue_intervals",
|
| 215 |
+
"qid",
|
| 216 |
+
"clue_video_path",
|
| 217 |
+
]
|
| 218 |
+
]
|
| 219 |
+
|
| 220 |
+
all_data.append(data_file)
|
| 221 |
+
|
| 222 |
+
final_data = pd.concat(all_data, ignore_index=True)
|
| 223 |
+
final_data["index"] = range(len(final_data))
|
| 224 |
+
final_data.to_csv(tsv_file, sep="\t", index=False)
|
| 225 |
+
|
| 226 |
+
if modelscope_flag_set():
|
| 227 |
+
from modelscope import dataset_snapshot_download
|
| 228 |
+
|
| 229 |
+
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
|
| 230 |
+
else:
|
| 231 |
+
dataset_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
|
| 232 |
+
|
| 233 |
+
unzip_hf_zip(dataset_path)
|
| 234 |
+
generate_tsv(dataset_path)
|
| 235 |
+
|
| 236 |
+
tsv_file = osp.join(dataset_path, f"{dataset_name}.tsv")
|
| 237 |
+
|
| 238 |
+
return dict(data_file=tsv_file, root=dataset_path)
|
| 239 |
+
|
| 240 |
+
def build_prompt(self, line, video_llm):
|
| 241 |
+
|
| 242 |
+
if isinstance(line, int):
|
| 243 |
+
assert line < len(self)
|
| 244 |
+
line = self.data.iloc[line]
|
| 245 |
+
|
| 246 |
+
task_mode = line["task_mode"]
|
| 247 |
+
|
| 248 |
+
message = []
|
| 249 |
+
|
| 250 |
+
origin_use_subtitle_time = self.use_subtitle_time
|
| 251 |
+
|
| 252 |
+
try:
|
| 253 |
+
if task_mode in ["long_acc", "clue_acc"]:
|
| 254 |
+
system_prompt = self.SYS[task_mode]
|
| 255 |
+
elif task_mode == "miou":
|
| 256 |
+
if self.use_frame_time and not video_llm:
|
| 257 |
+
system_prompt = self.SYS[task_mode]
|
| 258 |
+
else:
|
| 259 |
+
system_prompt = self.SYS["miou_wo_frame_time"]
|
| 260 |
+
if self.use_subtitle_time is True:
|
| 261 |
+
self.use_subtitle_time = False
|
| 262 |
+
|
| 263 |
+
user_prompt = ""
|
| 264 |
+
|
| 265 |
+
if task_mode in ["long_acc", "miou"]:
|
| 266 |
+
video_path = line["video"]
|
| 267 |
+
|
| 268 |
+
if video_llm:
|
| 269 |
+
message.append(dict(type="video", value=osp.join(self.data_root, video_path)))
|
| 270 |
+
|
| 271 |
+
if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
|
| 272 |
+
if self.nframe:
|
| 273 |
+
image_paths, frame_indices, vid_fps = self.save_video_frames(
|
| 274 |
+
video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
|
| 275 |
+
)
|
| 276 |
+
user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
|
| 277 |
+
fps=vid_fps, sub_time=self.use_subtitle_time)
|
| 278 |
+
else:
|
| 279 |
+
user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
|
| 280 |
+
else:
|
| 281 |
+
image_paths, frame_indices, vid_fps = self.save_video_frames(
|
| 282 |
+
video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
|
| 283 |
+
)
|
| 284 |
+
message.extend(dict(type="image", value=im) for im in image_paths)
|
| 285 |
+
|
| 286 |
+
if self.use_frame_time:
|
| 287 |
+
user_prompt += get_timestampes(frame_indices, vid_fps)
|
| 288 |
+
|
| 289 |
+
if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
|
| 290 |
+
user_prompt += self.get_subtitles(
|
| 291 |
+
line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
|
| 292 |
+
sub_time=self.use_subtitle_time
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
elif task_mode == "clue_acc":
|
| 296 |
+
clue_video_path = line["clue_video_path"]
|
| 297 |
+
video_path = line["video"]
|
| 298 |
+
|
| 299 |
+
if video_llm:
|
| 300 |
+
message.append(dict(type="video", value=osp.join(self.data_root, clue_video_path)))
|
| 301 |
+
print(message)
|
| 302 |
+
|
| 303 |
+
if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
|
| 304 |
+
if self.nframe:
|
| 305 |
+
image_paths, frame_indices, vid_fps = self.save_video_frames(
|
| 306 |
+
video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
|
| 307 |
+
)
|
| 308 |
+
user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
|
| 309 |
+
fps=vid_fps, sub_time=self.use_subtitle_time)
|
| 310 |
+
else:
|
| 311 |
+
user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
|
| 312 |
+
else:
|
| 313 |
+
if self.nframe > 32:
|
| 314 |
+
self.nframe = 32
|
| 315 |
+
print("The maximum number of frames is 32 when evaluating clue-based mcq in CG-Bench !")
|
| 316 |
+
|
| 317 |
+
clue_intervals = eval(line["clue_intervals"])
|
| 318 |
+
|
| 319 |
+
image_paths, frame_indices, vid_fps = self.save_video_frames(
|
| 320 |
+
video_path, uid=line["qid"], clue_intervals=clue_intervals, num_frames=self.nframe, fps=self.fps
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
message.extend(dict(type="image", value=im) for im in image_paths)
|
| 324 |
+
|
| 325 |
+
if self.use_frame_time:
|
| 326 |
+
user_prompt += get_timestampes(frame_indices, vid_fps)
|
| 327 |
+
|
| 328 |
+
if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
|
| 329 |
+
user_prompt += self.get_subtitles(
|
| 330 |
+
line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
|
| 331 |
+
sub_time=self.use_subtitle_time
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
question = line["question"]
|
| 335 |
+
user_prompt += f"Question: {question}\n\n"
|
| 336 |
+
|
| 337 |
+
choices = eval(line["choices"])
|
| 338 |
+
labels = [chr(ord("A") + i) for i in range(len(choices))]
|
| 339 |
+
user_prompt += "\n".join([f"{label}:{value}" for label, value in zip(labels, choices)]) + "\n\n"
|
| 340 |
+
|
| 341 |
+
message.append(dict(type="text", value=system_prompt + user_prompt))
|
| 342 |
+
|
| 343 |
+
return message
|
| 344 |
+
|
| 345 |
+
finally:
|
| 346 |
+
# Ensure that `use_subtitle_time` is always restored to its original value
|
| 347 |
+
self.use_subtitle_time = origin_use_subtitle_time
|
| 348 |
+
|
| 349 |
+
def save_video_frames(self, video, uid, clue_intervals=None, num_frames=8, fps=-1):
|
| 350 |
+
|
| 351 |
+
if type(uid) is not str:
|
| 352 |
+
uid = str(uid)
|
| 353 |
+
|
| 354 |
+
vid_path = osp.join(self.data_root, video)
|
| 355 |
+
vid = decord.VideoReader(vid_path)
|
| 356 |
+
vid_fps = vid.get_avg_fps()
|
| 357 |
+
n_frames = len(vid)
|
| 358 |
+
|
| 359 |
+
if clue_intervals is not None:
|
| 360 |
+
merged_intervals = merge_intervals(clue_intervals)
|
| 361 |
+
|
| 362 |
+
if num_frames > 0 and fps < 0:
|
| 363 |
+
indices = sample_frames_clue_average(merged_intervals, num_frames, vid_fps)
|
| 364 |
+
frame_paths = self.clue_frame_paths(uid, len(indices))
|
| 365 |
+
|
| 366 |
+
elif fps > 0:
|
| 367 |
+
frame_indices = []
|
| 368 |
+
for start, end in merged_intervals:
|
| 369 |
+
start_frame = int(start * vid_fps)
|
| 370 |
+
end_frame = int(end * vid_fps)
|
| 371 |
+
step = vid_fps / fps
|
| 372 |
+
interval_indices = [
|
| 373 |
+
int(start_frame + i * step) for i in range(int((end_frame - start_frame) / step))
|
| 374 |
+
]
|
| 375 |
+
frame_indices.extend(interval_indices)
|
| 376 |
+
|
| 377 |
+
if len(frame_indices) < 32:
|
| 378 |
+
indices = sample_frames_clue_average(merged_intervals, 32, vid_fps)
|
| 379 |
+
else:
|
| 380 |
+
indices = frame_indices
|
| 381 |
+
frame_paths = self.clue_frame_paths_fps(uid, len(indices), fps)
|
| 382 |
+
|
| 383 |
+
else:
|
| 384 |
+
if num_frames > 0 and fps < 0:
|
| 385 |
+
step_size = len(vid) / (num_frames + 1)
|
| 386 |
+
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
|
| 387 |
+
|
| 388 |
+
frame_paths = self.frame_paths(uid)
|
| 389 |
+
elif fps > 0:
|
| 390 |
+
total_duration = n_frames / vid_fps
|
| 391 |
+
required_frames = int(total_duration * fps)
|
| 392 |
+
step_size = vid_fps / fps
|
| 393 |
+
indices = [int(i * step_size) for i in range(required_frames)]
|
| 394 |
+
frame_paths = self.frame_paths_fps(uid, len(indices))
|
| 395 |
+
|
| 396 |
+
# Save and validate frames
|
| 397 |
+
valid_paths = []
|
| 398 |
+
valid_indices = []
|
| 399 |
+
|
| 400 |
+
if not np.all([osp.exists(p) for p in frame_paths]):
|
| 401 |
+
images = [vid[i].asnumpy() for i in indices]
|
| 402 |
+
for i, (img_array, path) in enumerate(zip(images, frame_paths)):
|
| 403 |
+
if osp.exists(path):
|
| 404 |
+
try:
|
| 405 |
+
with Image.open(path) as img:
|
| 406 |
+
img.verify()
|
| 407 |
+
valid_paths.append(path)
|
| 408 |
+
valid_indices.append(indices[i])
|
| 409 |
+
except Exception:
|
| 410 |
+
continue
|
| 411 |
+
else:
|
| 412 |
+
try:
|
| 413 |
+
img = Image.fromarray(img_array)
|
| 414 |
+
img.save(path)
|
| 415 |
+
img.verify()
|
| 416 |
+
valid_paths.append(path)
|
| 417 |
+
valid_indices.append(indices[i])
|
| 418 |
+
except Exception:
|
| 419 |
+
continue
|
| 420 |
+
else:
|
| 421 |
+
for i, path in enumerate(frame_paths):
|
| 422 |
+
try:
|
| 423 |
+
with Image.open(path) as img:
|
| 424 |
+
img.verify()
|
| 425 |
+
valid_paths.append(path)
|
| 426 |
+
valid_indices.append(indices[i])
|
| 427 |
+
except Exception:
|
| 428 |
+
continue
|
| 429 |
+
|
| 430 |
+
return valid_paths, valid_indices, vid_fps
|
| 431 |
+
|
| 432 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
| 433 |
+
|
| 434 |
+
assert eval_file.endswith(".xlsx"), "data file should be an xlsx file"
|
| 435 |
+
|
| 436 |
+
tgt_file = eval_file.replace(".xlsx", "_rating.json")
|
| 437 |
+
score_file = eval_file.replace(".xlsx", "_score.xlsx")
|
| 438 |
+
|
| 439 |
+
data = load(eval_file)
|
| 440 |
+
|
| 441 |
+
data_un = data[~pd.isna(data["prediction"])]
|
| 442 |
+
data_pred_na = data[pd.isna(data["prediction"])]
|
| 443 |
+
|
| 444 |
+
data_pred_na["score"] = -1
|
| 445 |
+
|
| 446 |
+
data_un["score"] = data_un.apply(
|
| 447 |
+
lambda row: post_process(
|
| 448 |
+
response=row["prediction"],
|
| 449 |
+
right_answer=row["answer"],
|
| 450 |
+
task_mode=row["task_mode"],
|
| 451 |
+
duration=row["duration"],
|
| 452 |
+
),
|
| 453 |
+
axis=1,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
data = pd.concat([data_pred_na, data_un])
|
| 457 |
+
|
| 458 |
+
rejected_count = (data["score"] == -1).sum()
|
| 459 |
+
|
| 460 |
+
print(
|
| 461 |
+
f"Among {len(data)} questions, "
|
| 462 |
+
f"failed to obtain prediction for {len(data_pred_na)} questions, "
|
| 463 |
+
f"failed to obtain the score for {rejected_count - len(data_pred_na)} questions. "
|
| 464 |
+
f"Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating."
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
dump(data, score_file)
|
| 468 |
+
|
| 469 |
+
rating = get_dimention_rating_mcq_grouding(score_file)
|
| 470 |
+
|
| 471 |
+
dump(rating, tgt_file)
|
| 472 |
+
|
| 473 |
+
return rating
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
# 评估时,step_2 评估时,给出 [prompt] + image_paths 就行
|
| 477 |
+
class CGBench_OpenEnded_Mini(VideoBaseDataset):
|
| 478 |
+
|
| 479 |
+
TYPE = "Video-OpenEnded"
|
| 480 |
+
|
| 481 |
+
dataset = "CG-Bench_OpenEnded_Mini"
|
| 482 |
+
|
| 483 |
+
MD5 = "9175791b11afdfa305fdb3e525b7a4ee"
|
| 484 |
+
|
| 485 |
+
SYS = (
|
| 486 |
+
"You will be provided with sampled frames from a video, along with a "
|
| 487 |
+
"question.\n"
|
| 488 |
+
"Your task is to analyze the provided frames and infer the most plausible "
|
| 489 |
+
"answer based on the visual information.\n"
|
| 490 |
+
"If the visual information is ambiguous or insufficient, use the available "
|
| 491 |
+
"context to reason your answer.\n"
|
| 492 |
+
"Only output the answer in the following format:\n\n"
|
| 493 |
+
'```json\n{"result": "answer"}\n```\n\n'
|
| 494 |
+
'The "answer" can be a word, phrase, or sentence that directly responds to '
|
| 495 |
+
"the question.\n\n"
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
def __init__(
|
| 499 |
+
self,
|
| 500 |
+
dataset="CG-Bench_OpenEnded_Mini",
|
| 501 |
+
use_subtitle=False,
|
| 502 |
+
use_subtitle_time=False,
|
| 503 |
+
use_frame_time=False,
|
| 504 |
+
nframe=0,
|
| 505 |
+
fps=-1,
|
| 506 |
+
):
|
| 507 |
+
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
|
| 508 |
+
self.use_subtitle = use_subtitle
|
| 509 |
+
self.use_subtitle_time = use_subtitle_time
|
| 510 |
+
self.use_frame_time = use_frame_time
|
| 511 |
+
self.dataset_name = dataset
|
| 512 |
+
lmu_root = LMUDataRoot()
|
| 513 |
+
self.clue_frame_root = osp.join(lmu_root, "clue_images", dataset)
|
| 514 |
+
|
| 515 |
+
@classmethod
|
| 516 |
+
def supported_datasets(cls):
|
| 517 |
+
return ["CG-Bench_OpenEnded_Mini"]
|
| 518 |
+
|
| 519 |
+
def get_subtitles(self, subtitle_path, frame_indices=None, fps=None, sub_time=False):
|
| 520 |
+
|
| 521 |
+
subtitles = []
|
| 522 |
+
|
| 523 |
+
srt_path = osp.join(self.data_root, subtitle_path)
|
| 524 |
+
assert osp.exists(srt_path)
|
| 525 |
+
import pysubs2
|
| 526 |
+
|
| 527 |
+
subs = pysubs2.load(srt_path, encoding="utf-8")
|
| 528 |
+
if not frame_indices:
|
| 529 |
+
for sub in subs:
|
| 530 |
+
sub_text = sub.text.replace("\\N", " ")
|
| 531 |
+
if sub_time:
|
| 532 |
+
start_time = milliseconds_to_seconds(sub.start)
|
| 533 |
+
end_time = milliseconds_to_seconds(sub.end)
|
| 534 |
+
sub_text = f"[{start_time}, {end_time}] {sub_text}"
|
| 535 |
+
if sub_text.strip() and sub_text not in subtitles:
|
| 536 |
+
subtitles.append(sub_text)
|
| 537 |
+
else:
|
| 538 |
+
for selected_frame_id in frame_indices:
|
| 539 |
+
cur_time = pysubs2.make_time(fps=fps, frames=selected_frame_id)
|
| 540 |
+
for sub in subs:
|
| 541 |
+
if sub.start < cur_time and sub.end > cur_time:
|
| 542 |
+
sub_text = sub.text.replace("\\N", " ")
|
| 543 |
+
if sub_time:
|
| 544 |
+
start_time = milliseconds_to_seconds(sub.start)
|
| 545 |
+
end_time = milliseconds_to_seconds(sub.end)
|
| 546 |
+
sub_text = f"[{start_time}, {end_time}] {sub_text}"
|
| 547 |
+
if sub_text.strip() and sub_text not in subtitles:
|
| 548 |
+
subtitles.append(sub_text)
|
| 549 |
+
|
| 550 |
+
if subtitles:
|
| 551 |
+
subtitles_str = '\n'.join(subtitles)
|
| 552 |
+
return f"The subtitles of the video are as follows:\n\n{subtitles_str}\n\n"
|
| 553 |
+
else:
|
| 554 |
+
return ""
|
| 555 |
+
|
| 556 |
+
def prepare_dataset(self, dataset_name="CG-Bench_OpenEnded_Mini", repo_id="CG-Bench/CG-Bench"):
|
| 557 |
+
|
| 558 |
+
def check_integrity(pth):
|
| 559 |
+
data_file = osp.join(pth, f"{dataset_name}.tsv")
|
| 560 |
+
|
| 561 |
+
if not os.path.exists(data_file):
|
| 562 |
+
return False
|
| 563 |
+
|
| 564 |
+
if md5(data_file) != self.MD5:
|
| 565 |
+
return False
|
| 566 |
+
data = load(data_file)
|
| 567 |
+
for video_pth in data["video"]:
|
| 568 |
+
if not osp.exists(osp.join(pth, video_pth)):
|
| 569 |
+
return False
|
| 570 |
+
|
| 571 |
+
return True
|
| 572 |
+
|
| 573 |
+
cache_path = get_cache_path(repo_id)
|
| 574 |
+
|
| 575 |
+
if cache_path is not None and check_integrity(cache_path):
|
| 576 |
+
dataset_path = cache_path
|
| 577 |
+
else:
|
| 578 |
+
|
| 579 |
+
def generate_tsv(pth):
|
| 580 |
+
|
| 581 |
+
tsv_file = osp.join(pth, f"{dataset_name}.tsv")
|
| 582 |
+
|
| 583 |
+
with open(osp.join(pth, "cgbench_mini.json"), "r") as f:
|
| 584 |
+
data_file = pd.DataFrame(json.load(f))
|
| 585 |
+
|
| 586 |
+
data_file = data_file.assign(index=range(len(data_file)))
|
| 587 |
+
data_file["video"] = data_file["video_uid"].apply(lambda x: f"cg_videos_720p/{x}.mp4")
|
| 588 |
+
data_file["subtitle_path"] = data_file["video_uid"].apply(
|
| 589 |
+
lambda x: f"cg_subtitles/{x}.srt" if osp.exists(osp.join(pth, f"cg_subtitles/{x}.srt")) else ""
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
data_file = data_file[
|
| 593 |
+
[
|
| 594 |
+
"index",
|
| 595 |
+
"video_uid",
|
| 596 |
+
"video",
|
| 597 |
+
"duration",
|
| 598 |
+
"domain",
|
| 599 |
+
"sub_category",
|
| 600 |
+
"subtitle_path",
|
| 601 |
+
"question",
|
| 602 |
+
"answer",
|
| 603 |
+
"clue_intervals",
|
| 604 |
+
"qid",
|
| 605 |
+
]
|
| 606 |
+
]
|
| 607 |
+
|
| 608 |
+
data_file.to_csv(tsv_file, sep="\t", index=False)
|
| 609 |
+
|
| 610 |
+
if modelscope_flag_set():
|
| 611 |
+
from modelscope import dataset_snapshot_download
|
| 612 |
+
|
| 613 |
+
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
|
| 614 |
+
else:
|
| 615 |
+
dataset_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
|
| 616 |
+
|
| 617 |
+
unzip_hf_zip(dataset_path)
|
| 618 |
+
generate_tsv(dataset_path)
|
| 619 |
+
|
| 620 |
+
tsv_file = osp.join(dataset_path, f"{dataset_name}.tsv")
|
| 621 |
+
|
| 622 |
+
return dict(data_file=tsv_file, root=dataset_path)
|
| 623 |
+
|
| 624 |
+
def build_prompt(self, line, video_llm):
|
| 625 |
+
|
| 626 |
+
if isinstance(line, int):
|
| 627 |
+
assert line < len(self)
|
| 628 |
+
line = self.data.iloc[line]
|
| 629 |
+
|
| 630 |
+
message = []
|
| 631 |
+
|
| 632 |
+
sys_prompt = self.SYS
|
| 633 |
+
|
| 634 |
+
user_prompt = ""
|
| 635 |
+
|
| 636 |
+
video_path = line["video"]
|
| 637 |
+
|
| 638 |
+
if video_llm:
|
| 639 |
+
message.append(dict(type="video", value=osp.join(self.data_root, video_path)))
|
| 640 |
+
if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
|
| 641 |
+
if self.nframe:
|
| 642 |
+
image_paths, frame_indices, vid_fps = self.save_video_frames(
|
| 643 |
+
video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
|
| 644 |
+
)
|
| 645 |
+
user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
|
| 646 |
+
fps=vid_fps, sub_time=self.use_subtitle_time)
|
| 647 |
+
else:
|
| 648 |
+
user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
|
| 649 |
+
else:
|
| 650 |
+
image_paths, frame_indices, vid_fps = self.save_video_frames(
|
| 651 |
+
video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
|
| 652 |
+
)
|
| 653 |
+
message.extend(dict(type="image", value=im) for im in image_paths)
|
| 654 |
+
|
| 655 |
+
if self.use_frame_time:
|
| 656 |
+
user_prompt += get_timestampes(frame_indices, vid_fps)
|
| 657 |
+
|
| 658 |
+
if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
|
| 659 |
+
user_prompt += self.get_subtitles(
|
| 660 |
+
line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
|
| 661 |
+
sub_time=self.use_subtitle_time
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
question = line["question"]
|
| 665 |
+
user_prompt += f"Question: {question}\n\n"
|
| 666 |
+
|
| 667 |
+
message.append(dict(type="text", value=sys_prompt + user_prompt))
|
| 668 |
+
|
| 669 |
+
return message
|
| 670 |
+
|
| 671 |
+
def clue_frame_paths(self, qid, num_frames=8):
|
| 672 |
+
frame_root = osp.join(self.clue_frame_root, qid)
|
| 673 |
+
os.makedirs(frame_root, exist_ok=True)
|
| 674 |
+
return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
|
| 675 |
+
|
| 676 |
+
def save_video_frames(self, video, uid, clue_intervals=None, num_frames=8, fps=-1):
|
| 677 |
+
|
| 678 |
+
if type(uid) is not str:
|
| 679 |
+
uid = str(uid)
|
| 680 |
+
|
| 681 |
+
vid_path = osp.join(self.data_root, video)
|
| 682 |
+
vid = decord.VideoReader(vid_path)
|
| 683 |
+
vid_fps = vid.get_avg_fps()
|
| 684 |
+
n_frames = len(vid)
|
| 685 |
+
|
| 686 |
+
if clue_intervals is not None:
|
| 687 |
+
merged_intervals = merge_intervals(clue_intervals)
|
| 688 |
+
|
| 689 |
+
if num_frames > 0 and fps < 0:
|
| 690 |
+
indices = sample_frames_clue_average(merged_intervals, num_frames, vid_fps)
|
| 691 |
+
frame_paths = self.clue_frame_paths(uid, len(indices))
|
| 692 |
+
|
| 693 |
+
elif fps > 0:
|
| 694 |
+
frame_indices = []
|
| 695 |
+
for start, end in merged_intervals:
|
| 696 |
+
start_frame = int(start * vid_fps)
|
| 697 |
+
end_frame = int(end * vid_fps)
|
| 698 |
+
step = vid_fps / fps
|
| 699 |
+
interval_indices = [
|
| 700 |
+
int(start_frame + i * step) for i in range(int((end_frame - start_frame) / step))
|
| 701 |
+
]
|
| 702 |
+
frame_indices.extend(interval_indices)
|
| 703 |
+
|
| 704 |
+
if len(frame_indices) < 32:
|
| 705 |
+
indices = sample_frames_clue_average(merged_intervals, 32, vid_fps)
|
| 706 |
+
else:
|
| 707 |
+
indices = frame_indices
|
| 708 |
+
frame_paths = self.clue_frame_paths_fps(uid, len(indices), fps)
|
| 709 |
+
|
| 710 |
+
else:
|
| 711 |
+
if num_frames > 0 and fps < 0:
|
| 712 |
+
step_size = len(vid) / (num_frames + 1)
|
| 713 |
+
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
|
| 714 |
+
frame_paths = self.frame_paths(uid)
|
| 715 |
+
elif fps > 0:
|
| 716 |
+
total_duration = n_frames / vid_fps
|
| 717 |
+
required_frames = int(total_duration * fps)
|
| 718 |
+
step_size = vid_fps / fps
|
| 719 |
+
indices = [int(i * step_size) for i in range(required_frames)]
|
| 720 |
+
frame_paths = self.frame_paths_fps(uid, len(indices))
|
| 721 |
+
|
| 722 |
+
valid_paths = []
|
| 723 |
+
valid_indices = []
|
| 724 |
+
|
| 725 |
+
if not np.all([osp.exists(p) for p in frame_paths]):
|
| 726 |
+
images = [vid[i].asnumpy() for i in indices]
|
| 727 |
+
for i, (img_array, path) in enumerate(zip(images, frame_paths)):
|
| 728 |
+
if osp.exists(path):
|
| 729 |
+
try:
|
| 730 |
+
with Image.open(path) as img:
|
| 731 |
+
img.verify()
|
| 732 |
+
valid_paths.append(path)
|
| 733 |
+
valid_indices.append(indices[i])
|
| 734 |
+
except Exception:
|
| 735 |
+
continue
|
| 736 |
+
else:
|
| 737 |
+
try:
|
| 738 |
+
img = Image.fromarray(img_array)
|
| 739 |
+
img.save(path)
|
| 740 |
+
img.verify()
|
| 741 |
+
valid_paths.append(path)
|
| 742 |
+
valid_indices.append(indices[i])
|
| 743 |
+
except Exception:
|
| 744 |
+
continue
|
| 745 |
+
else:
|
| 746 |
+
for i, path in enumerate(frame_paths):
|
| 747 |
+
try:
|
| 748 |
+
with Image.open(path) as img:
|
| 749 |
+
img.verify()
|
| 750 |
+
valid_paths.append(path)
|
| 751 |
+
valid_indices.append(indices[i])
|
| 752 |
+
except Exception:
|
| 753 |
+
continue
|
| 754 |
+
|
| 755 |
+
return valid_paths, valid_indices, vid_fps
|
| 756 |
+
|
| 757 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
| 758 |
+
|
| 759 |
+
from .utils.cgbench import get_dimention_rating_open_ended, post_process_open
|
| 760 |
+
|
| 761 |
+
assert eval_file.endswith(".xlsx"), "data file should be an xlsx file"
|
| 762 |
+
|
| 763 |
+
tgt_file = eval_file.replace(".xlsx", "_rating.json")
|
| 764 |
+
score_file = eval_file.replace(".xlsx", "_score.xlsx")
|
| 765 |
+
step_1_tmp_file = eval_file.replace(".xlsx", "_step_1.pkl")
|
| 766 |
+
step_2_tmp_file = eval_file.replace(".xlsx", "_step_2.pkl")
|
| 767 |
+
|
| 768 |
+
data = load(eval_file)
|
| 769 |
+
|
| 770 |
+
data_pred_no_na = data[~pd.isna(data["prediction"])]
|
| 771 |
+
data_pred_na = data[pd.isna(data["prediction"])]
|
| 772 |
+
|
| 773 |
+
data_pred_na["model_result"] = -1
|
| 774 |
+
data_pred_na["step_1_result"] = -1
|
| 775 |
+
data_pred_na["step_2_result"] = -1
|
| 776 |
+
data_pred_na["score"] = -1
|
| 777 |
+
|
| 778 |
+
data_pred_no_na["model_result"] = data_pred_no_na.apply(
|
| 779 |
+
lambda row: post_process_open(
|
| 780 |
+
response=row["prediction"],
|
| 781 |
+
),
|
| 782 |
+
axis=1,
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
data_no_model_result = data_pred_no_na[data_pred_no_na["model_result"] == -1]
|
| 786 |
+
data_step_1 = data_pred_no_na[data_pred_no_na["model_result"] != -1]
|
| 787 |
+
|
| 788 |
+
if judge_kwargs.get("model", None) != "gpt-4o-0806":
|
| 789 |
+
judge_kwargs["model"] = "gpt-4o-0806"
|
| 790 |
+
print("The judge model in cg-bench is gpt-4o-0806!")
|
| 791 |
+
|
| 792 |
+
model_step_1 = build_judge(system_prompt=sys_prompt_open_eval_step_1, **judge_kwargs)
|
| 793 |
+
nproc = judge_kwargs.pop("nproc", 32)
|
| 794 |
+
|
| 795 |
+
lines_step_1 = data_step_1.to_dict("records")
|
| 796 |
+
tups_step_1 = [(model_step_1, line) for line in lines_step_1]
|
| 797 |
+
|
| 798 |
+
keys_step_1 = {line["qid"] for line in lines_step_1}
|
| 799 |
+
|
| 800 |
+
ans = {}
|
| 801 |
+
if osp.exists(step_1_tmp_file):
|
| 802 |
+
ans = load(step_1_tmp_file)
|
| 803 |
+
tups_step_1 = [x for x, i in zip(tups_step_1, keys_step_1) if i not in ans]
|
| 804 |
+
keys_step_1 = [i for i in keys_step_1 if i not in ans]
|
| 805 |
+
|
| 806 |
+
_ = track_progress_rich(
|
| 807 |
+
eval_open_first,
|
| 808 |
+
tups_step_1,
|
| 809 |
+
nproc=nproc,
|
| 810 |
+
keys=keys_step_1,
|
| 811 |
+
save=step_1_tmp_file,
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
step_1_results = load(step_1_tmp_file)
|
| 815 |
+
data_step_1 = save_step_1_steps(data_step_1, step_1_results) # -1, 0, 1, 2
|
| 816 |
+
|
| 817 |
+
data_no_step_1_results = data_step_1[data_step_1["step_1_result"] == -1]
|
| 818 |
+
data_step_1_over = data_step_1[data_step_1["step_1_result"].isin([0, 1])]
|
| 819 |
+
data_step_2 = data_step_1[data_step_1["step_1_result"] == 2]
|
| 820 |
+
|
| 821 |
+
print(judge_kwargs)
|
| 822 |
+
|
| 823 |
+
model_step_2 = build_judge(system_prompt=sys_prompt_open_eval_step_2, **judge_kwargs)
|
| 824 |
+
|
| 825 |
+
lines_step_2 = data_step_2.to_dict("records")
|
| 826 |
+
|
| 827 |
+
tups_step_2 = []
|
| 828 |
+
|
| 829 |
+
for line in tqdm(lines_step_2):
|
| 830 |
+
clue_intervals = eval(line["clue_intervals"])
|
| 831 |
+
lmu_root = LMUDataRoot()
|
| 832 |
+
clue_frame_root = osp.join(lmu_root, "clue_images", self.dataset)
|
| 833 |
+
data_root = self.data_root
|
| 834 |
+
frame_paths, _, _ = save_clue_video_frames(
|
| 835 |
+
data_root,
|
| 836 |
+
clue_frame_root,
|
| 837 |
+
video=line["video"],
|
| 838 |
+
uid=line["qid"],
|
| 839 |
+
clue_intervals=clue_intervals,
|
| 840 |
+
num_frames=32,
|
| 841 |
+
)
|
| 842 |
+
tups_step_2.append((model_step_2, line, frame_paths))
|
| 843 |
+
|
| 844 |
+
keys_step_2 = {line["qid"] for line in lines_step_2}
|
| 845 |
+
|
| 846 |
+
ans = {}
|
| 847 |
+
if osp.exists(step_2_tmp_file):
|
| 848 |
+
ans = load(step_2_tmp_file)
|
| 849 |
+
tups_step_2 = [x for x, i in zip(tups_step_2, keys_step_2) if i not in ans]
|
| 850 |
+
keys_step_2 = [i for i in keys_step_2 if i not in ans]
|
| 851 |
+
|
| 852 |
+
_ = track_progress_rich(
|
| 853 |
+
eval_open_second,
|
| 854 |
+
tups_step_2,
|
| 855 |
+
nproc=nproc,
|
| 856 |
+
keys=keys_step_2,
|
| 857 |
+
save=step_2_tmp_file,
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
+
step_2_results = load(step_2_tmp_file)
|
| 861 |
+
data_step_2 = save_step_2_steps(data_step_2, step_2_results)
|
| 862 |
+
|
| 863 |
+
data_no_step_2_results = data_step_2[data_step_2["score"] == -1]
|
| 864 |
+
data_step_2_over = data_step_2[data_step_2["score"].isin([0, 1])]
|
| 865 |
+
|
| 866 |
+
data = pd.concat(
|
| 867 |
+
[
|
| 868 |
+
data_pred_na,
|
| 869 |
+
data_no_model_result,
|
| 870 |
+
data_no_step_1_results,
|
| 871 |
+
data_step_1_over,
|
| 872 |
+
data_no_step_2_results,
|
| 873 |
+
data_step_2_over,
|
| 874 |
+
]
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
dump(data, score_file)
|
| 878 |
+
|
| 879 |
+
rating = get_dimention_rating_open_ended(score_file)
|
| 880 |
+
|
| 881 |
+
dump(rating, tgt_file)
|
| 882 |
+
|
| 883 |
+
return rating
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
class CGBench_MCQ_Grounding(VideoBaseDataset):
|
| 887 |
+
|
| 888 |
+
TYPE = "Video-MCQ-Grounding"
|
| 889 |
+
|
| 890 |
+
MD5 = "eaead3d978a689269fefce4ae29c86df"
|
| 891 |
+
|
| 892 |
+
SYS = {
|
| 893 |
+
"long_acc": (
|
| 894 |
+
"You will be provided with sampled frames from a video, along with a "
|
| 895 |
+
"multiple-choice question that includes a question and several answer options.\n"
|
| 896 |
+
"Your task is to analyze the provided frames, infer the most plausible "
|
| 897 |
+
"answer based on the visual information.\n"
|
| 898 |
+
"If the video does not provide enough information, infer the answer based "
|
| 899 |
+
"on the options available and still provide a result. "
|
| 900 |
+
"Therefore, In all cases, an answer must be given.\n"
|
| 901 |
+
"Only output the answer in the following format:\n\n"
|
| 902 |
+
'```json\n{"result": "option"}\n```\n\n'
|
| 903 |
+
'The "option" is the uppercase letter corresponding to your answer.\n\n'
|
| 904 |
+
),
|
| 905 |
+
"clue_acc": (
|
| 906 |
+
"You will be provided with sampled frames from a video, along with a "
|
| 907 |
+
"multiple-choice question that includes a question and several answer options.\n"
|
| 908 |
+
"Your task is to analyze the provided frames, infer the most plausible "
|
| 909 |
+
"answer based on the visual information.\n"
|
| 910 |
+
"If the video does not provide enough information, infer the answer based "
|
| 911 |
+
"on the options available and still provide a result. "
|
| 912 |
+
"Therefore, In all cases, an answer must be given.\n"
|
| 913 |
+
"Only output the answer in the following format:\n\n"
|
| 914 |
+
'```json\n{"result": "option"}\n```\n\n'
|
| 915 |
+
"The 'option' is the uppercase letter corresponding to your answer.\n\n"
|
| 916 |
+
),
|
| 917 |
+
"miou": (
|
| 918 |
+
"You will be provided with uniformly sampled frames from a video and their "
|
| 919 |
+
"timestamps, along with a multiple-choice question that includes a question "
|
| 920 |
+
"and several answer options.\n"
|
| 921 |
+
"Your task is to determine in which intervals the 'clue intervals' exist "
|
| 922 |
+
"that contain visual information needed to answer the question.\n"
|
| 923 |
+
"Only output the answer in the following format:\n\n"
|
| 924 |
+
'```json\n{"result": [[start1, end1], [start2, end2], ...]}\n```\n\n'
|
| 925 |
+
"In this output format, each 'start' and 'end' represents the beginning and "
|
| 926 |
+
"end of an interval in seconds where relevant clues can be found.\n"
|
| 927 |
+
"You must provide at least one interval and at most five intervals. "
|
| 928 |
+
"Intervals exceeding five will NOT be considered valid.\n"
|
| 929 |
+
),
|
| 930 |
+
"miou_wo_frame_time": (
|
| 931 |
+
"You will be provided with uniformly sampled frames from a video, along "
|
| 932 |
+
"with a multiple-choice question that includes a question and several "
|
| 933 |
+
"answer options.\n"
|
| 934 |
+
"Your task is to determine in which intervals the 'clue intervals' exist "
|
| 935 |
+
"that contain visual information needed to answer the question.\n"
|
| 936 |
+
"Only output the answer in the following format:\n\n"
|
| 937 |
+
'```json\n{"result": [[start1, end1], [start2, end2], ...]}\n```\n\n'
|
| 938 |
+
'In this output format, each "start" and "end" represents the start and '
|
| 939 |
+
"end of the video where the relevant clue can be found in the form of a "
|
| 940 |
+
"floating point number between 0 and 1, where 0 represents the start time "
|
| 941 |
+
"of the video and 1 represents the end time of the video.\n"
|
| 942 |
+
"You must provide at least one interval and at most five intervals. "
|
| 943 |
+
"Intervals exceeding five will NOT be considered valid.\n"
|
| 944 |
+
),
|
| 945 |
+
}
|
| 946 |
+
|
| 947 |
+
def __init__(
|
| 948 |
+
self,
|
| 949 |
+
dataset="CG-Bench_MCQ_Grounding",
|
| 950 |
+
use_subtitle=False,
|
| 951 |
+
use_subtitle_time=False,
|
| 952 |
+
use_frame_time=False,
|
| 953 |
+
nframe=0,
|
| 954 |
+
fps=-1,
|
| 955 |
+
):
|
| 956 |
+
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
|
| 957 |
+
self.use_subtitle = use_subtitle
|
| 958 |
+
self.use_subtitle_time = use_subtitle_time
|
| 959 |
+
self.use_frame_time = use_frame_time
|
| 960 |
+
self.dataset_name = dataset
|
| 961 |
+
lmu_root = LMUDataRoot()
|
| 962 |
+
self.clue_frame_root = osp.join(lmu_root, "clue_images", dataset)
|
| 963 |
+
|
| 964 |
+
@classmethod
|
| 965 |
+
def supported_datasets(cls):
|
| 966 |
+
return ["CG-Bench_MCQ_Grounding"]
|
| 967 |
+
|
| 968 |
+
def clue_frame_paths(self, qid, num_frames=8):
|
| 969 |
+
frame_root = osp.join(self.clue_frame_root, qid)
|
| 970 |
+
os.makedirs(frame_root, exist_ok=True)
|
| 971 |
+
return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
|
| 972 |
+
|
| 973 |
+
def clue_frame_paths_fps(self, qid, num_frames=8, fps=-1):
|
| 974 |
+
frame_root = osp.join(self.clue_frame_root, qid)
|
| 975 |
+
os.makedirs(frame_root, exist_ok=True)
|
| 976 |
+
return [osp.join(frame_root, self.frame_tmpl_fps.format(i, num_frames, fps)) for i in range(1, num_frames + 1)]
|
| 977 |
+
|
| 978 |
+
def get_subtitles(self, subtitle_path, frame_indices=None, fps=None, sub_time=False):
|
| 979 |
+
|
| 980 |
+
subtitles = []
|
| 981 |
+
|
| 982 |
+
srt_path = osp.join(self.data_root, subtitle_path)
|
| 983 |
+
assert osp.exists(srt_path)
|
| 984 |
+
import pysubs2
|
| 985 |
+
|
| 986 |
+
subs = pysubs2.load(srt_path, encoding="utf-8")
|
| 987 |
+
if not frame_indices:
|
| 988 |
+
for sub in subs:
|
| 989 |
+
sub_text = sub.text.replace("\\N", " ")
|
| 990 |
+
if sub_time:
|
| 991 |
+
start_time = milliseconds_to_seconds(sub.start)
|
| 992 |
+
end_time = milliseconds_to_seconds(sub.end)
|
| 993 |
+
sub_text = f"[{start_time}, {end_time}] {sub_text}"
|
| 994 |
+
if sub_text.strip() and sub_text not in subtitles:
|
| 995 |
+
subtitles.append(sub_text)
|
| 996 |
+
else:
|
| 997 |
+
for selected_frame_id in frame_indices:
|
| 998 |
+
cur_time = pysubs2.make_time(fps=fps, frames=selected_frame_id)
|
| 999 |
+
for sub in subs:
|
| 1000 |
+
if sub.start < cur_time and sub.end > cur_time:
|
| 1001 |
+
sub_text = sub.text.replace("\\N", " ")
|
| 1002 |
+
if sub_time:
|
| 1003 |
+
start_time = milliseconds_to_seconds(sub.start)
|
| 1004 |
+
end_time = milliseconds_to_seconds(sub.end)
|
| 1005 |
+
sub_text = f"[{start_time}, {end_time}] {sub_text}"
|
| 1006 |
+
if sub_text.strip() and sub_text not in subtitles:
|
| 1007 |
+
subtitles.append(sub_text)
|
| 1008 |
+
|
| 1009 |
+
if subtitles:
|
| 1010 |
+
subtitles_str = '\n'.join(subtitles)
|
| 1011 |
+
return f"The subtitles of the video are as follows:\n\n{subtitles_str}\n\n"
|
| 1012 |
+
else:
|
| 1013 |
+
return ""
|
| 1014 |
+
|
| 1015 |
+
def prepare_dataset(self, dataset_name="CG-Bench_MCQ_Grounding", repo_id="CG-Bench/CG-Bench"):
|
| 1016 |
+
|
| 1017 |
+
def check_integrity(pth):
|
| 1018 |
+
data_file = osp.join(pth, f"{dataset_name}.tsv")
|
| 1019 |
+
|
| 1020 |
+
if not os.path.exists(data_file):
|
| 1021 |
+
return False
|
| 1022 |
+
|
| 1023 |
+
if md5(data_file) != self.MD5:
|
| 1024 |
+
return False
|
| 1025 |
+
data = load(data_file)
|
| 1026 |
+
for video_pth in data["video"]:
|
| 1027 |
+
if not osp.exists(osp.join(pth, video_pth)):
|
| 1028 |
+
return False
|
| 1029 |
+
|
| 1030 |
+
for clue_video_pth in data["clue_video_path"]:
|
| 1031 |
+
if clue_video_pth and not (isinstance(clue_video_pth, float) and np.isnan(clue_video_pth)):
|
| 1032 |
+
if not osp.exists(osp.join(pth, clue_video_pth)):
|
| 1033 |
+
return False
|
| 1034 |
+
|
| 1035 |
+
return True
|
| 1036 |
+
|
| 1037 |
+
cache_path = get_cache_path(repo_id)
|
| 1038 |
+
|
| 1039 |
+
if cache_path is not None and check_integrity(cache_path):
|
| 1040 |
+
dataset_path = cache_path
|
| 1041 |
+
else:
|
| 1042 |
+
|
| 1043 |
+
def generate_tsv(pth):
|
| 1044 |
+
|
| 1045 |
+
tsv_file = osp.join(pth, f"{dataset_name}.tsv")
|
| 1046 |
+
|
| 1047 |
+
task_modes = ["long_acc", "clue_acc", "miou"]
|
| 1048 |
+
all_data = []
|
| 1049 |
+
for task_mode in task_modes:
|
| 1050 |
+
with open(osp.join(pth, "cgbench.json"), "r") as f:
|
| 1051 |
+
data_file = pd.DataFrame(json.load(f))
|
| 1052 |
+
|
| 1053 |
+
data_file = data_file.assign(index=range(len(data_file)))
|
| 1054 |
+
data_file["video"] = data_file["video_uid"].apply(lambda x: f"cg_videos_720p/{x}.mp4")
|
| 1055 |
+
data_file["subtitle_path"] = data_file["video_uid"].apply(
|
| 1056 |
+
lambda x: (
|
| 1057 |
+
f"cg_subtitles/{x}.srt"
|
| 1058 |
+
if osp.exists(osp.join(dataset_path, f"cg_subtitles/{x}.srt"))
|
| 1059 |
+
else ""
|
| 1060 |
+
)
|
| 1061 |
+
)
|
| 1062 |
+
|
| 1063 |
+
data_file["clue_video_path"] = ""
|
| 1064 |
+
|
| 1065 |
+
if task_mode in ["clue_acc"]:
|
| 1066 |
+
data_file["clue_video_path"] = data_file["clue_video_path"] = data_file.apply(
|
| 1067 |
+
lambda row: f"cg_clue_videos/{row['qid']}.mp4", axis=1
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
data_file["task_mode"] = task_mode
|
| 1071 |
+
|
| 1072 |
+
if task_mode in ["clue_acc", "long_acc"]:
|
| 1073 |
+
data_file["answer"] = data_file["right_answer"]
|
| 1074 |
+
|
| 1075 |
+
if task_mode == "miou":
|
| 1076 |
+
data_file["answer"] = data_file["clue_intervals"]
|
| 1077 |
+
|
| 1078 |
+
if task_mode in ["long_acc", "miou"]:
|
| 1079 |
+
data_file["clue_intervals"] = ""
|
| 1080 |
+
|
| 1081 |
+
data_file = data_file[
|
| 1082 |
+
[
|
| 1083 |
+
"index",
|
| 1084 |
+
"video_uid",
|
| 1085 |
+
"video",
|
| 1086 |
+
"duration",
|
| 1087 |
+
"domain",
|
| 1088 |
+
"choices",
|
| 1089 |
+
"sub_category",
|
| 1090 |
+
"subtitle_path",
|
| 1091 |
+
"question",
|
| 1092 |
+
"answer",
|
| 1093 |
+
"task_mode",
|
| 1094 |
+
"clue_intervals",
|
| 1095 |
+
"qid",
|
| 1096 |
+
"clue_video_path",
|
| 1097 |
+
]
|
| 1098 |
+
]
|
| 1099 |
+
|
| 1100 |
+
all_data.append(data_file)
|
| 1101 |
+
|
| 1102 |
+
final_data = pd.concat(all_data, ignore_index=True)
|
| 1103 |
+
final_data["index"] = range(len(final_data))
|
| 1104 |
+
final_data.to_csv(tsv_file, sep="\t", index=False)
|
| 1105 |
+
|
| 1106 |
+
if modelscope_flag_set():
|
| 1107 |
+
from modelscope import dataset_snapshot_download
|
| 1108 |
+
|
| 1109 |
+
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
|
| 1110 |
+
else:
|
| 1111 |
+
dataset_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
|
| 1112 |
+
|
| 1113 |
+
unzip_hf_zip(dataset_path)
|
| 1114 |
+
generate_tsv(dataset_path)
|
| 1115 |
+
|
| 1116 |
+
tsv_file = osp.join(dataset_path, f"{dataset_name}.tsv")
|
| 1117 |
+
|
| 1118 |
+
return dict(data_file=tsv_file, root=dataset_path)
|
| 1119 |
+
|
| 1120 |
+
def build_prompt(self, line, video_llm):
|
| 1121 |
+
|
| 1122 |
+
if isinstance(line, int):
|
| 1123 |
+
assert line < len(self)
|
| 1124 |
+
line = self.data.iloc[line]
|
| 1125 |
+
|
| 1126 |
+
task_mode = line["task_mode"]
|
| 1127 |
+
|
| 1128 |
+
message = []
|
| 1129 |
+
|
| 1130 |
+
origin_use_subtitle_time = self.use_subtitle_time
|
| 1131 |
+
|
| 1132 |
+
try:
|
| 1133 |
+
if task_mode in ["long_acc", "clue_acc"]:
|
| 1134 |
+
system_prompt = self.SYS[task_mode]
|
| 1135 |
+
elif task_mode == "miou":
|
| 1136 |
+
if self.use_frame_time and not video_llm:
|
| 1137 |
+
system_prompt = self.SYS[task_mode]
|
| 1138 |
+
else:
|
| 1139 |
+
system_prompt = self.SYS["miou_wo_frame_time"]
|
| 1140 |
+
if self.use_subtitle_time is True:
|
| 1141 |
+
self.use_subtitle_time = False
|
| 1142 |
+
|
| 1143 |
+
user_prompt = ""
|
| 1144 |
+
|
| 1145 |
+
if task_mode in ["long_acc", "miou"]:
|
| 1146 |
+
video_path = line["video"]
|
| 1147 |
+
|
| 1148 |
+
if video_llm:
|
| 1149 |
+
message.append(dict(type="video", value=osp.join(self.data_root, video_path)))
|
| 1150 |
+
|
| 1151 |
+
if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
|
| 1152 |
+
if self.nframe:
|
| 1153 |
+
image_paths, frame_indices, vid_fps = self.save_video_frames(
|
| 1154 |
+
video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
|
| 1155 |
+
)
|
| 1156 |
+
user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
|
| 1157 |
+
fps=vid_fps, sub_time=self.use_subtitle_time)
|
| 1158 |
+
else:
|
| 1159 |
+
user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
|
| 1160 |
+
else:
|
| 1161 |
+
image_paths, frame_indices, vid_fps = self.save_video_frames(
|
| 1162 |
+
video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
|
| 1163 |
+
)
|
| 1164 |
+
message.extend(dict(type="image", value=im) for im in image_paths)
|
| 1165 |
+
|
| 1166 |
+
if self.use_frame_time:
|
| 1167 |
+
user_prompt += get_timestampes(frame_indices, vid_fps)
|
| 1168 |
+
|
| 1169 |
+
if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
|
| 1170 |
+
user_prompt += self.get_subtitles(
|
| 1171 |
+
line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
|
| 1172 |
+
sub_time=self.use_subtitle_time
|
| 1173 |
+
)
|
| 1174 |
+
|
| 1175 |
+
elif task_mode == "clue_acc":
|
| 1176 |
+
clue_video_path = line["clue_video_path"]
|
| 1177 |
+
video_path = line["video"]
|
| 1178 |
+
|
| 1179 |
+
if video_llm:
|
| 1180 |
+
message.append(dict(type="video", value=osp.join(self.data_root, clue_video_path)))
|
| 1181 |
+
print(message)
|
| 1182 |
+
|
| 1183 |
+
if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
|
| 1184 |
+
if self.nframe:
|
| 1185 |
+
image_paths, frame_indices, vid_fps = self.save_video_frames(
|
| 1186 |
+
video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
|
| 1187 |
+
)
|
| 1188 |
+
user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
|
| 1189 |
+
fps=vid_fps, sub_time=self.use_subtitle_time)
|
| 1190 |
+
else:
|
| 1191 |
+
user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
|
| 1192 |
+
else:
|
| 1193 |
+
if self.nframe > 32:
|
| 1194 |
+
self.nframe = 32
|
| 1195 |
+
print("The maximum number of frames is 32 when evaluating clue-based mcq in CG-Bench !")
|
| 1196 |
+
|
| 1197 |
+
clue_intervals = eval(line["clue_intervals"])
|
| 1198 |
+
|
| 1199 |
+
image_paths, frame_indices, vid_fps = self.save_video_frames(
|
| 1200 |
+
video_path, uid=line["qid"], clue_intervals=clue_intervals, num_frames=self.nframe, fps=self.fps
|
| 1201 |
+
)
|
| 1202 |
+
|
| 1203 |
+
message.extend(dict(type="image", value=im) for im in image_paths)
|
| 1204 |
+
|
| 1205 |
+
if self.use_frame_time:
|
| 1206 |
+
user_prompt += get_timestampes(frame_indices, vid_fps)
|
| 1207 |
+
|
| 1208 |
+
if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
|
| 1209 |
+
user_prompt += self.get_subtitles(
|
| 1210 |
+
line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
|
| 1211 |
+
sub_time=self.use_subtitle_time
|
| 1212 |
+
)
|
| 1213 |
+
|
| 1214 |
+
question = line["question"]
|
| 1215 |
+
user_prompt += f"Question: {question}\n\n"
|
| 1216 |
+
|
| 1217 |
+
choices = eval(line["choices"])
|
| 1218 |
+
labels = [chr(ord("A") + i) for i in range(len(choices))]
|
| 1219 |
+
user_prompt += "\n".join([f"{label}:{value}" for label, value in zip(labels, choices)]) + "\n\n"
|
| 1220 |
+
|
| 1221 |
+
message.append(dict(type="text", value=system_prompt + user_prompt))
|
| 1222 |
+
|
| 1223 |
+
return message
|
| 1224 |
+
|
| 1225 |
+
finally:
|
| 1226 |
+
# Ensure that `use_subtitle_time` is always restored to its original value
|
| 1227 |
+
self.use_subtitle_time = origin_use_subtitle_time
|
| 1228 |
+
|
| 1229 |
+
def save_video_frames(self, video, uid, clue_intervals=None, num_frames=8, fps=-1):
|
| 1230 |
+
|
| 1231 |
+
if type(uid) is not str:
|
| 1232 |
+
uid = str(uid)
|
| 1233 |
+
|
| 1234 |
+
vid_path = osp.join(self.data_root, video)
|
| 1235 |
+
vid = decord.VideoReader(vid_path)
|
| 1236 |
+
vid_fps = vid.get_avg_fps()
|
| 1237 |
+
n_frames = len(vid)
|
| 1238 |
+
|
| 1239 |
+
if clue_intervals is not None:
|
| 1240 |
+
merged_intervals = merge_intervals(clue_intervals)
|
| 1241 |
+
|
| 1242 |
+
if num_frames > 0 and fps < 0:
|
| 1243 |
+
indices = sample_frames_clue_average(merged_intervals, num_frames, vid_fps)
|
| 1244 |
+
frame_paths = self.clue_frame_paths(uid, len(indices))
|
| 1245 |
+
|
| 1246 |
+
elif fps > 0:
|
| 1247 |
+
frame_indices = []
|
| 1248 |
+
for start, end in merged_intervals:
|
| 1249 |
+
start_frame = int(start * vid_fps)
|
| 1250 |
+
end_frame = int(end * vid_fps)
|
| 1251 |
+
step = vid_fps / fps
|
| 1252 |
+
interval_indices = [
|
| 1253 |
+
int(start_frame + i * step) for i in range(int((end_frame - start_frame) / step))
|
| 1254 |
+
]
|
| 1255 |
+
frame_indices.extend(interval_indices)
|
| 1256 |
+
|
| 1257 |
+
if len(frame_indices) < 32:
|
| 1258 |
+
indices = sample_frames_clue_average(merged_intervals, 32, vid_fps)
|
| 1259 |
+
else:
|
| 1260 |
+
indices = frame_indices
|
| 1261 |
+
frame_paths = self.clue_frame_paths_fps(uid, len(indices), fps)
|
| 1262 |
+
|
| 1263 |
+
else:
|
| 1264 |
+
if num_frames > 0 and fps < 0:
|
| 1265 |
+
step_size = len(vid) / (num_frames + 1)
|
| 1266 |
+
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
|
| 1267 |
+
|
| 1268 |
+
frame_paths = self.frame_paths(uid)
|
| 1269 |
+
elif fps > 0:
|
| 1270 |
+
total_duration = n_frames / vid_fps
|
| 1271 |
+
required_frames = int(total_duration * fps)
|
| 1272 |
+
step_size = vid_fps / fps
|
| 1273 |
+
indices = [int(i * step_size) for i in range(required_frames)]
|
| 1274 |
+
frame_paths = self.frame_paths_fps(uid, len(indices))
|
| 1275 |
+
|
| 1276 |
+
# Save and validate frames
|
| 1277 |
+
valid_paths = []
|
| 1278 |
+
valid_indices = []
|
| 1279 |
+
|
| 1280 |
+
if not np.all([osp.exists(p) for p in frame_paths]):
|
| 1281 |
+
images = [vid[i].asnumpy() for i in indices]
|
| 1282 |
+
for i, (img_array, path) in enumerate(zip(images, frame_paths)):
|
| 1283 |
+
if osp.exists(path):
|
| 1284 |
+
try:
|
| 1285 |
+
with Image.open(path) as img:
|
| 1286 |
+
img.verify()
|
| 1287 |
+
valid_paths.append(path)
|
| 1288 |
+
valid_indices.append(indices[i])
|
| 1289 |
+
except Exception:
|
| 1290 |
+
continue
|
| 1291 |
+
else:
|
| 1292 |
+
try:
|
| 1293 |
+
img = Image.fromarray(img_array)
|
| 1294 |
+
img.save(path)
|
| 1295 |
+
img.verify()
|
| 1296 |
+
valid_paths.append(path)
|
| 1297 |
+
valid_indices.append(indices[i])
|
| 1298 |
+
except Exception:
|
| 1299 |
+
continue
|
| 1300 |
+
else:
|
| 1301 |
+
for i, path in enumerate(frame_paths):
|
| 1302 |
+
try:
|
| 1303 |
+
with Image.open(path) as img:
|
| 1304 |
+
img.verify()
|
| 1305 |
+
valid_paths.append(path)
|
| 1306 |
+
valid_indices.append(indices[i])
|
| 1307 |
+
except Exception:
|
| 1308 |
+
continue
|
| 1309 |
+
|
| 1310 |
+
return valid_paths, valid_indices, vid_fps
|
| 1311 |
+
|
| 1312 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
| 1313 |
+
|
| 1314 |
+
assert eval_file.endswith(".xlsx"), "data file should be an xlsx file"
|
| 1315 |
+
|
| 1316 |
+
tgt_file = eval_file.replace(".xlsx", "_rating.json")
|
| 1317 |
+
score_file = eval_file.replace(".xlsx", "_score.xlsx")
|
| 1318 |
+
|
| 1319 |
+
data = load(eval_file)
|
| 1320 |
+
|
| 1321 |
+
data_un = data[~pd.isna(data["prediction"])]
|
| 1322 |
+
data_pred_na = data[pd.isna(data["prediction"])]
|
| 1323 |
+
|
| 1324 |
+
data_pred_na["score"] = -1
|
| 1325 |
+
|
| 1326 |
+
data_un["score"] = data_un.apply(
|
| 1327 |
+
lambda row: post_process(
|
| 1328 |
+
response=row["prediction"],
|
| 1329 |
+
right_answer=row["answer"],
|
| 1330 |
+
task_mode=row["task_mode"],
|
| 1331 |
+
duration=row["duration"],
|
| 1332 |
+
),
|
| 1333 |
+
axis=1,
|
| 1334 |
+
)
|
| 1335 |
+
|
| 1336 |
+
data = pd.concat([data_pred_na, data_un])
|
| 1337 |
+
|
| 1338 |
+
rejected_count = (data["score"] == -1).sum()
|
| 1339 |
+
|
| 1340 |
+
print(
|
| 1341 |
+
f"Among {len(data)} questions, "
|
| 1342 |
+
f"failed to obtain prediction for {len(data_pred_na)} questions, "
|
| 1343 |
+
f"failed to obtain the score for {rejected_count - len(data_pred_na)} questions. "
|
| 1344 |
+
f"Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating."
|
| 1345 |
+
)
|
| 1346 |
+
|
| 1347 |
+
dump(data, score_file)
|
| 1348 |
+
|
| 1349 |
+
rating = get_dimention_rating_mcq_grouding(score_file)
|
| 1350 |
+
|
| 1351 |
+
dump(rating, tgt_file)
|
| 1352 |
+
|
| 1353 |
+
return rating
|
| 1354 |
+
|
| 1355 |
+
|
| 1356 |
+
# 评估时,step_2 评估时,给出 [prompt] + image_paths 就行
|
| 1357 |
+
class CGBench_OpenEnded(VideoBaseDataset):
|
| 1358 |
+
|
| 1359 |
+
TYPE = "Video-OpenEnded"
|
| 1360 |
+
|
| 1361 |
+
dataset = "CG-Bench_OpenEnded"
|
| 1362 |
+
|
| 1363 |
+
MD5 = "796035eda0b1e916c517cdc1bc145cfc"
|
| 1364 |
+
|
| 1365 |
+
SYS = (
|
| 1366 |
+
"You will be provided with sampled frames from a video, along with a "
|
| 1367 |
+
"question.\n"
|
| 1368 |
+
"Your task is to analyze the provided frames and infer the most plausible "
|
| 1369 |
+
"answer based on the visual information.\n"
|
| 1370 |
+
"If the visual information is ambiguous or insufficient, use the available "
|
| 1371 |
+
"context to reason your answer.\n"
|
| 1372 |
+
"Only output the answer in the following format:\n\n"
|
| 1373 |
+
'```json\n{"result": "answer"}\n```\n\n'
|
| 1374 |
+
'The "answer" can be a word, phrase, or sentence that directly responds to '
|
| 1375 |
+
"the question.\n\n"
|
| 1376 |
+
)
|
| 1377 |
+
|
| 1378 |
+
def __init__(
|
| 1379 |
+
self,
|
| 1380 |
+
dataset="CG-Bench_OpenEnded",
|
| 1381 |
+
use_subtitle=False,
|
| 1382 |
+
use_subtitle_time=False,
|
| 1383 |
+
use_frame_time=False,
|
| 1384 |
+
nframe=0,
|
| 1385 |
+
fps=-1,
|
| 1386 |
+
):
|
| 1387 |
+
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
|
| 1388 |
+
self.use_subtitle = use_subtitle
|
| 1389 |
+
self.use_subtitle_time = use_subtitle_time
|
| 1390 |
+
self.use_frame_time = use_frame_time
|
| 1391 |
+
self.dataset_name = dataset
|
| 1392 |
+
lmu_root = LMUDataRoot()
|
| 1393 |
+
self.clue_frame_root = osp.join(lmu_root, "clue_images", dataset)
|
| 1394 |
+
|
| 1395 |
+
@classmethod
|
| 1396 |
+
def supported_datasets(cls):
|
| 1397 |
+
return ["CG-Bench_OpenEnded"]
|
| 1398 |
+
|
| 1399 |
+
def get_subtitles(self, subtitle_path, frame_indices=None, fps=None, sub_time=False):
|
| 1400 |
+
|
| 1401 |
+
subtitles = []
|
| 1402 |
+
|
| 1403 |
+
srt_path = osp.join(self.data_root, subtitle_path)
|
| 1404 |
+
assert osp.exists(srt_path)
|
| 1405 |
+
import pysubs2
|
| 1406 |
+
|
| 1407 |
+
subs = pysubs2.load(srt_path, encoding="utf-8")
|
| 1408 |
+
if not frame_indices:
|
| 1409 |
+
for sub in subs:
|
| 1410 |
+
sub_text = sub.text.replace("\\N", " ")
|
| 1411 |
+
if sub_time:
|
| 1412 |
+
start_time = milliseconds_to_seconds(sub.start)
|
| 1413 |
+
end_time = milliseconds_to_seconds(sub.end)
|
| 1414 |
+
sub_text = f"[{start_time}, {end_time}] {sub_text}"
|
| 1415 |
+
if sub_text.strip() and sub_text not in subtitles:
|
| 1416 |
+
subtitles.append(sub_text)
|
| 1417 |
+
else:
|
| 1418 |
+
for selected_frame_id in frame_indices:
|
| 1419 |
+
cur_time = pysubs2.make_time(fps=fps, frames=selected_frame_id)
|
| 1420 |
+
for sub in subs:
|
| 1421 |
+
if sub.start < cur_time and sub.end > cur_time:
|
| 1422 |
+
sub_text = sub.text.replace("\\N", " ")
|
| 1423 |
+
if sub_time:
|
| 1424 |
+
start_time = milliseconds_to_seconds(sub.start)
|
| 1425 |
+
end_time = milliseconds_to_seconds(sub.end)
|
| 1426 |
+
sub_text = f"[{start_time}, {end_time}] {sub_text}"
|
| 1427 |
+
if sub_text.strip() and sub_text not in subtitles:
|
| 1428 |
+
subtitles.append(sub_text)
|
| 1429 |
+
|
| 1430 |
+
if subtitles:
|
| 1431 |
+
subtitles_str = '\n'.join(subtitles)
|
| 1432 |
+
return f"The subtitles of the video are as follows:\n\n{subtitles_str}\n\n"
|
| 1433 |
+
else:
|
| 1434 |
+
return ""
|
| 1435 |
+
|
| 1436 |
+
def prepare_dataset(self, dataset_name="CG-Bench_OpenEnded", repo_id="CG-Bench/CG-Bench"):
|
| 1437 |
+
|
| 1438 |
+
def check_integrity(pth):
|
| 1439 |
+
data_file = osp.join(pth, f"{dataset_name}.tsv")
|
| 1440 |
+
|
| 1441 |
+
if not os.path.exists(data_file):
|
| 1442 |
+
return False
|
| 1443 |
+
|
| 1444 |
+
if md5(data_file) != self.MD5:
|
| 1445 |
+
return False
|
| 1446 |
+
data = load(data_file)
|
| 1447 |
+
for video_pth in data["video"]:
|
| 1448 |
+
if not osp.exists(osp.join(pth, video_pth)):
|
| 1449 |
+
return False
|
| 1450 |
+
|
| 1451 |
+
return True
|
| 1452 |
+
|
| 1453 |
+
cache_path = get_cache_path(repo_id)
|
| 1454 |
+
|
| 1455 |
+
if cache_path is not None and check_integrity(cache_path):
|
| 1456 |
+
dataset_path = cache_path
|
| 1457 |
+
else:
|
| 1458 |
+
|
| 1459 |
+
def generate_tsv(pth):
|
| 1460 |
+
|
| 1461 |
+
tsv_file = osp.join(pth, f"{dataset_name}.tsv")
|
| 1462 |
+
|
| 1463 |
+
with open(osp.join(pth, "cgbench.json"), "r") as f:
|
| 1464 |
+
data_file = pd.DataFrame(json.load(f))
|
| 1465 |
+
|
| 1466 |
+
data_file = data_file.assign(index=range(len(data_file)))
|
| 1467 |
+
data_file["video"] = data_file["video_uid"].apply(lambda x: f"cg_videos_720p/{x}.mp4")
|
| 1468 |
+
data_file["subtitle_path"] = data_file["video_uid"].apply(
|
| 1469 |
+
lambda x: f"cg_subtitles/{x}.srt" if osp.exists(osp.join(pth, f"cg_subtitles/{x}.srt")) else ""
|
| 1470 |
+
)
|
| 1471 |
+
|
| 1472 |
+
data_file = data_file[
|
| 1473 |
+
[
|
| 1474 |
+
"index",
|
| 1475 |
+
"video_uid",
|
| 1476 |
+
"video",
|
| 1477 |
+
"duration",
|
| 1478 |
+
"domain",
|
| 1479 |
+
"sub_category",
|
| 1480 |
+
"subtitle_path",
|
| 1481 |
+
"question",
|
| 1482 |
+
"answer",
|
| 1483 |
+
"clue_intervals",
|
| 1484 |
+
"qid",
|
| 1485 |
+
]
|
| 1486 |
+
]
|
| 1487 |
+
|
| 1488 |
+
data_file.to_csv(tsv_file, sep="\t", index=False)
|
| 1489 |
+
|
| 1490 |
+
if modelscope_flag_set():
|
| 1491 |
+
from modelscope import dataset_snapshot_download
|
| 1492 |
+
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
|
| 1493 |
+
else:
|
| 1494 |
+
dataset_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
|
| 1495 |
+
|
| 1496 |
+
unzip_hf_zip(dataset_path)
|
| 1497 |
+
generate_tsv(dataset_path)
|
| 1498 |
+
|
| 1499 |
+
tsv_file = osp.join(dataset_path, f"{dataset_name}.tsv")
|
| 1500 |
+
|
| 1501 |
+
return dict(data_file=tsv_file, root=dataset_path)
|
| 1502 |
+
|
| 1503 |
+
def build_prompt(self, line, video_llm):
|
| 1504 |
+
|
| 1505 |
+
if isinstance(line, int):
|
| 1506 |
+
assert line < len(self)
|
| 1507 |
+
line = self.data.iloc[line]
|
| 1508 |
+
|
| 1509 |
+
message = []
|
| 1510 |
+
|
| 1511 |
+
sys_prompt = self.SYS
|
| 1512 |
+
|
| 1513 |
+
user_prompt = ""
|
| 1514 |
+
|
| 1515 |
+
video_path = line["video"]
|
| 1516 |
+
|
| 1517 |
+
if video_llm:
|
| 1518 |
+
message.append(dict(type="video", value=osp.join(self.data_root, video_path)))
|
| 1519 |
+
if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
|
| 1520 |
+
if self.nframe:
|
| 1521 |
+
image_paths, frame_indices, vid_fps = self.save_video_frames(
|
| 1522 |
+
video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
|
| 1523 |
+
)
|
| 1524 |
+
user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
|
| 1525 |
+
fps=vid_fps, sub_time=self.use_subtitle_time)
|
| 1526 |
+
else:
|
| 1527 |
+
user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
|
| 1528 |
+
else:
|
| 1529 |
+
image_paths, frame_indices, vid_fps = self.save_video_frames(
|
| 1530 |
+
video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
|
| 1531 |
+
)
|
| 1532 |
+
message.extend(dict(type="image", value=im) for im in image_paths)
|
| 1533 |
+
|
| 1534 |
+
if self.use_frame_time:
|
| 1535 |
+
user_prompt += get_timestampes(frame_indices, vid_fps)
|
| 1536 |
+
|
| 1537 |
+
if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
|
| 1538 |
+
user_prompt += self.get_subtitles(
|
| 1539 |
+
line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
|
| 1540 |
+
sub_time=self.use_subtitle_time
|
| 1541 |
+
)
|
| 1542 |
+
|
| 1543 |
+
question = line["question"]
|
| 1544 |
+
user_prompt += f"Question: {question}\n\n"
|
| 1545 |
+
|
| 1546 |
+
message.append(dict(type="text", value=sys_prompt + user_prompt))
|
| 1547 |
+
|
| 1548 |
+
return message
|
| 1549 |
+
|
| 1550 |
+
def clue_frame_paths(self, qid, num_frames=8):
|
| 1551 |
+
frame_root = osp.join(self.clue_frame_root, qid)
|
| 1552 |
+
os.makedirs(frame_root, exist_ok=True)
|
| 1553 |
+
return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
|
| 1554 |
+
|
| 1555 |
+
def save_video_frames(self, video, uid, clue_intervals=None, num_frames=8, fps=-1):
|
| 1556 |
+
|
| 1557 |
+
if type(uid) is not str:
|
| 1558 |
+
uid = str(uid)
|
| 1559 |
+
|
| 1560 |
+
vid_path = osp.join(self.data_root, video)
|
| 1561 |
+
vid = decord.VideoReader(vid_path)
|
| 1562 |
+
vid_fps = vid.get_avg_fps()
|
| 1563 |
+
n_frames = len(vid)
|
| 1564 |
+
|
| 1565 |
+
if clue_intervals is not None:
|
| 1566 |
+
merged_intervals = merge_intervals(clue_intervals)
|
| 1567 |
+
|
| 1568 |
+
if num_frames > 0 and fps < 0:
|
| 1569 |
+
indices = sample_frames_clue_average(merged_intervals, num_frames, vid_fps)
|
| 1570 |
+
frame_paths = self.clue_frame_paths(uid, len(indices))
|
| 1571 |
+
|
| 1572 |
+
elif fps > 0:
|
| 1573 |
+
frame_indices = []
|
| 1574 |
+
for start, end in merged_intervals:
|
| 1575 |
+
start_frame = int(start * vid_fps)
|
| 1576 |
+
end_frame = int(end * vid_fps)
|
| 1577 |
+
step = vid_fps / fps
|
| 1578 |
+
interval_indices = [
|
| 1579 |
+
int(start_frame + i * step) for i in range(int((end_frame - start_frame) / step))
|
| 1580 |
+
]
|
| 1581 |
+
frame_indices.extend(interval_indices)
|
| 1582 |
+
|
| 1583 |
+
if len(frame_indices) < 32:
|
| 1584 |
+
indices = sample_frames_clue_average(merged_intervals, 32, vid_fps)
|
| 1585 |
+
else:
|
| 1586 |
+
indices = frame_indices
|
| 1587 |
+
frame_paths = self.clue_frame_paths_fps(uid, len(indices), fps)
|
| 1588 |
+
|
| 1589 |
+
else:
|
| 1590 |
+
if num_frames > 0 and fps < 0:
|
| 1591 |
+
step_size = len(vid) / (num_frames + 1)
|
| 1592 |
+
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
|
| 1593 |
+
frame_paths = self.frame_paths(uid)
|
| 1594 |
+
elif fps > 0:
|
| 1595 |
+
total_duration = n_frames / vid_fps
|
| 1596 |
+
required_frames = int(total_duration * fps)
|
| 1597 |
+
step_size = vid_fps / fps
|
| 1598 |
+
indices = [int(i * step_size) for i in range(required_frames)]
|
| 1599 |
+
frame_paths = self.frame_paths_fps(uid, len(indices))
|
| 1600 |
+
|
| 1601 |
+
valid_paths = []
|
| 1602 |
+
valid_indices = []
|
| 1603 |
+
|
| 1604 |
+
if not np.all([osp.exists(p) for p in frame_paths]):
|
| 1605 |
+
images = [vid[i].asnumpy() for i in indices]
|
| 1606 |
+
for i, (img_array, path) in enumerate(zip(images, frame_paths)):
|
| 1607 |
+
if osp.exists(path):
|
| 1608 |
+
try:
|
| 1609 |
+
with Image.open(path) as img:
|
| 1610 |
+
img.verify()
|
| 1611 |
+
valid_paths.append(path)
|
| 1612 |
+
valid_indices.append(indices[i])
|
| 1613 |
+
except Exception:
|
| 1614 |
+
continue
|
| 1615 |
+
else:
|
| 1616 |
+
try:
|
| 1617 |
+
img = Image.fromarray(img_array)
|
| 1618 |
+
img.save(path)
|
| 1619 |
+
img.verify()
|
| 1620 |
+
valid_paths.append(path)
|
| 1621 |
+
valid_indices.append(indices[i])
|
| 1622 |
+
except Exception:
|
| 1623 |
+
continue
|
| 1624 |
+
else:
|
| 1625 |
+
for i, path in enumerate(frame_paths):
|
| 1626 |
+
try:
|
| 1627 |
+
with Image.open(path) as img:
|
| 1628 |
+
img.verify()
|
| 1629 |
+
valid_paths.append(path)
|
| 1630 |
+
valid_indices.append(indices[i])
|
| 1631 |
+
except Exception:
|
| 1632 |
+
continue
|
| 1633 |
+
|
| 1634 |
+
return valid_paths, valid_indices, vid_fps
|
| 1635 |
+
|
| 1636 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
| 1637 |
+
|
| 1638 |
+
from .utils.cgbench import get_dimention_rating_open_ended, post_process_open
|
| 1639 |
+
|
| 1640 |
+
assert eval_file.endswith(".xlsx"), "data file should be an xlsx file"
|
| 1641 |
+
|
| 1642 |
+
tgt_file = eval_file.replace(".xlsx", "_rating.json")
|
| 1643 |
+
score_file = eval_file.replace(".xlsx", "_score.xlsx")
|
| 1644 |
+
step_1_tmp_file = eval_file.replace(".xlsx", "_step_1.pkl")
|
| 1645 |
+
step_2_tmp_file = eval_file.replace(".xlsx", "_step_2.pkl")
|
| 1646 |
+
|
| 1647 |
+
data = load(eval_file)
|
| 1648 |
+
|
| 1649 |
+
data_pred_no_na = data[~pd.isna(data["prediction"])]
|
| 1650 |
+
data_pred_na = data[pd.isna(data["prediction"])]
|
| 1651 |
+
|
| 1652 |
+
data_pred_na["model_result"] = -1
|
| 1653 |
+
data_pred_na["step_1_result"] = -1
|
| 1654 |
+
data_pred_na["step_2_result"] = -1
|
| 1655 |
+
data_pred_na["score"] = -1
|
| 1656 |
+
|
| 1657 |
+
data_pred_no_na["model_result"] = data_pred_no_na.apply(
|
| 1658 |
+
lambda row: post_process_open(
|
| 1659 |
+
response=row["prediction"],
|
| 1660 |
+
),
|
| 1661 |
+
axis=1,
|
| 1662 |
+
)
|
| 1663 |
+
|
| 1664 |
+
if judge_kwargs.get("model", None) != "gpt-4o-0806":
|
| 1665 |
+
judge_kwargs["model"] = "gpt-4o-0806"
|
| 1666 |
+
print("The judge model in cg-bench is gpt-4o-0806!")
|
| 1667 |
+
|
| 1668 |
+
data_no_model_result = data_pred_no_na[data_pred_no_na["model_result"] == -1]
|
| 1669 |
+
data_step_1 = data_pred_no_na[data_pred_no_na["model_result"] != -1]
|
| 1670 |
+
|
| 1671 |
+
model_step_1 = build_judge(system_prompt=sys_prompt_open_eval_step_1, **judge_kwargs)
|
| 1672 |
+
nproc = judge_kwargs.pop('nproc', 32)
|
| 1673 |
+
|
| 1674 |
+
lines_step_1 = data_step_1.to_dict("records")
|
| 1675 |
+
tups_step_1 = [(model_step_1, line) for line in lines_step_1]
|
| 1676 |
+
|
| 1677 |
+
keys_step_1 = {line["qid"] for line in lines_step_1}
|
| 1678 |
+
|
| 1679 |
+
ans = {}
|
| 1680 |
+
if osp.exists(step_1_tmp_file):
|
| 1681 |
+
ans = load(step_1_tmp_file)
|
| 1682 |
+
tups_step_1 = [x for x, i in zip(tups_step_1, keys_step_1) if i not in ans]
|
| 1683 |
+
keys_step_1 = [i for i in keys_step_1 if i not in ans]
|
| 1684 |
+
|
| 1685 |
+
_ = track_progress_rich(
|
| 1686 |
+
eval_open_first,
|
| 1687 |
+
tups_step_1,
|
| 1688 |
+
nproc=nproc,
|
| 1689 |
+
keys=keys_step_1,
|
| 1690 |
+
save=step_1_tmp_file,
|
| 1691 |
+
)
|
| 1692 |
+
|
| 1693 |
+
step_1_results = load(step_1_tmp_file)
|
| 1694 |
+
data_step_1 = save_step_1_steps(data_step_1, step_1_results) # -1, 0, 1, 2
|
| 1695 |
+
|
| 1696 |
+
data_no_step_1_results = data_step_1[data_step_1["step_1_result"] == -1]
|
| 1697 |
+
data_step_1_over = data_step_1[data_step_1["step_1_result"].isin([0, 1])]
|
| 1698 |
+
data_step_2 = data_step_1[data_step_1["step_1_result"] == 2]
|
| 1699 |
+
|
| 1700 |
+
model_step_2 = build_judge(system_prompt=sys_prompt_open_eval_step_2, **judge_kwargs)
|
| 1701 |
+
|
| 1702 |
+
lines_step_2 = data_step_2.to_dict("records")
|
| 1703 |
+
|
| 1704 |
+
tups_step_2 = []
|
| 1705 |
+
|
| 1706 |
+
for line in tqdm(lines_step_2):
|
| 1707 |
+
clue_intervals = eval(line["clue_intervals"])
|
| 1708 |
+
lmu_root = LMUDataRoot()
|
| 1709 |
+
clue_frame_root = osp.join(lmu_root, "clue_images", self.dataset)
|
| 1710 |
+
data_root = self.data_root
|
| 1711 |
+
frame_paths, _, _ = save_clue_video_frames(
|
| 1712 |
+
data_root,
|
| 1713 |
+
clue_frame_root,
|
| 1714 |
+
video=line["video"],
|
| 1715 |
+
uid=line["qid"],
|
| 1716 |
+
clue_intervals=clue_intervals,
|
| 1717 |
+
num_frames=32,
|
| 1718 |
+
)
|
| 1719 |
+
tups_step_2.append((model_step_2, line, frame_paths))
|
| 1720 |
+
|
| 1721 |
+
keys_step_2 = {line["qid"] for line in lines_step_2}
|
| 1722 |
+
|
| 1723 |
+
ans = {}
|
| 1724 |
+
if osp.exists(step_2_tmp_file):
|
| 1725 |
+
ans = load(step_2_tmp_file)
|
| 1726 |
+
tups_step_2 = [x for x, i in zip(tups_step_2, keys_step_2) if i not in ans]
|
| 1727 |
+
keys_step_2 = [i for i in keys_step_2 if i not in ans]
|
| 1728 |
+
|
| 1729 |
+
_ = track_progress_rich(
|
| 1730 |
+
eval_open_second,
|
| 1731 |
+
tups_step_2,
|
| 1732 |
+
nproc=nproc,
|
| 1733 |
+
keys=keys_step_2,
|
| 1734 |
+
save=step_2_tmp_file,
|
| 1735 |
+
)
|
| 1736 |
+
|
| 1737 |
+
step_2_results = load(step_2_tmp_file)
|
| 1738 |
+
data_step_2 = save_step_2_steps(data_step_2, step_2_results)
|
| 1739 |
+
|
| 1740 |
+
data_no_step_2_results = data_step_2[data_step_2["score"] == -1]
|
| 1741 |
+
data_step_2_over = data_step_2[data_step_2["score"].isin([0, 1])]
|
| 1742 |
+
|
| 1743 |
+
data = pd.concat(
|
| 1744 |
+
[
|
| 1745 |
+
data_pred_na,
|
| 1746 |
+
data_no_model_result,
|
| 1747 |
+
data_no_step_1_results,
|
| 1748 |
+
data_step_1_over,
|
| 1749 |
+
data_no_step_2_results,
|
| 1750 |
+
data_step_2_over,
|
| 1751 |
+
]
|
| 1752 |
+
)
|
| 1753 |
+
|
| 1754 |
+
dump(data, score_file)
|
| 1755 |
+
|
| 1756 |
+
rating = get_dimention_rating_open_ended(score_file)
|
| 1757 |
+
|
| 1758 |
+
dump(rating, tgt_file)
|
| 1759 |
+
|
| 1760 |
+
return rating
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/dataset/cmmmu.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .image_base import ImageBaseDataset
|
| 2 |
+
import random
|
| 3 |
+
from collections import Counter
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import tempfile
|
| 7 |
+
from ..smp import *
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_multi_choice_prediction(response, all_choices, index2ans):
|
| 11 |
+
for char in [',', '.', '!', '?', ';', ':', "'"]:
|
| 12 |
+
response = response.strip(char)
|
| 13 |
+
response = " " + response + " " # add space to avoid partial match
|
| 14 |
+
|
| 15 |
+
candidates = []
|
| 16 |
+
|
| 17 |
+
for choice in all_choices: # (A) (B) (C) (D)
|
| 18 |
+
# Add the choice to candidates each time it appears in the response
|
| 19 |
+
candidates.extend([choice for _ in range(response.count(f'({choice})'))])
|
| 20 |
+
|
| 21 |
+
if len(candidates) == 0:
|
| 22 |
+
for choice in all_choices: # A B C D
|
| 23 |
+
# Similarly, add the choice for each occurrence
|
| 24 |
+
candidates.extend([choice for _ in range(response.count(f'{choice}'))])
|
| 25 |
+
|
| 26 |
+
if len(candidates) == 0 and len(response.split()) >= 1:
|
| 27 |
+
for index, ans in index2ans.items():
|
| 28 |
+
# Add index for each occurrence of ans in response
|
| 29 |
+
candidates.extend([index for _ in range(response.count(ans))])
|
| 30 |
+
|
| 31 |
+
# if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
|
| 32 |
+
if len(candidates) == 0 and len(response.split()) >= 1:
|
| 33 |
+
for index, ans in index2ans.items():
|
| 34 |
+
if ans in response:
|
| 35 |
+
candidates.append(index)
|
| 36 |
+
# index_ans = False # it's content ans.
|
| 37 |
+
|
| 38 |
+
if len(candidates) == 0: # still not get answer, randomly choose one.
|
| 39 |
+
return random.choice(all_choices)
|
| 40 |
+
# return ''
|
| 41 |
+
else:
|
| 42 |
+
# Count the occurrence of each candidate
|
| 43 |
+
candidate_counts = Counter(candidates)
|
| 44 |
+
|
| 45 |
+
# Select the most frequent candidates
|
| 46 |
+
max_count = max(candidate_counts.values())
|
| 47 |
+
most_frequent_candidates = [c for c in all_choices if candidate_counts.get(c, 0) == max_count]
|
| 48 |
+
|
| 49 |
+
# Combine the most frequent candidates in ABCD order
|
| 50 |
+
return ''.join(most_frequent_candidates)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def extract_numbers(string):
|
| 54 |
+
# Pattern for numbers with Chinese commas
|
| 55 |
+
pattern_commas = r'-?\d{1,3}(?:,\d{3})+'
|
| 56 |
+
# Pattern for scientific notation
|
| 57 |
+
pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+'
|
| 58 |
+
# Pattern for simple numbers without Chinese commas
|
| 59 |
+
pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+)(?![eE][+-]?\d+)(?!,\d)'
|
| 60 |
+
|
| 61 |
+
# Extract numbers with Chinese commas
|
| 62 |
+
numbers_with_commas = re.findall(pattern_commas, string)
|
| 63 |
+
# Extract numbers in scientific notation
|
| 64 |
+
numbers_scientific = re.findall(pattern_scientific, string)
|
| 65 |
+
# Extract simple numbers without Chinese commas
|
| 66 |
+
numbers_simple = re.findall(pattern_simple, string)
|
| 67 |
+
|
| 68 |
+
# Combine all extracted numbers
|
| 69 |
+
all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
|
| 70 |
+
return all_numbers
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def check_is_number(string):
|
| 74 |
+
try:
|
| 75 |
+
float(string.replace(',', ''))
|
| 76 |
+
return True
|
| 77 |
+
except ValueError:
|
| 78 |
+
# check if there's comma inside
|
| 79 |
+
return False
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def count_letters(string):
|
| 83 |
+
return sum(c.isalpha() and 'a' <= c <= 'z' or 'A' <= c <= 'Z' for c in string)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def normalize_str(string, answer):
|
| 87 |
+
# check if characters in the string
|
| 88 |
+
|
| 89 |
+
# if number, numerize it.
|
| 90 |
+
if string is None:
|
| 91 |
+
return [string]
|
| 92 |
+
string = string.strip()
|
| 93 |
+
|
| 94 |
+
is_number = check_is_number(string)
|
| 95 |
+
|
| 96 |
+
if is_number:
|
| 97 |
+
string = string.replace(',', '')
|
| 98 |
+
string = float(string)
|
| 99 |
+
# leave 2 decimal
|
| 100 |
+
string = round(string, 2)
|
| 101 |
+
return [string]
|
| 102 |
+
else: # it's likely to be a string
|
| 103 |
+
if len(string) > len(answer) + 20 or count_letters(string) > count_letters(answer) + 2:
|
| 104 |
+
return []
|
| 105 |
+
return [string]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_fill_blank_prediction(response, answer):
|
| 109 |
+
"""get the prediction from the generated response,
|
| 110 |
+
return a list of predicted strings or numbers"""
|
| 111 |
+
|
| 112 |
+
def get_key_subresponses(response):
|
| 113 |
+
response = response.strip("。").strip()
|
| 114 |
+
sub_responses = re.split(r'。|\n', response)
|
| 115 |
+
indicators_of_keys = ['是', '为', '所以', '等于', '方案', '选择',
|
| 116 |
+
'正确答案', '因此', '最后', '答案', '结果']
|
| 117 |
+
key_responses = []
|
| 118 |
+
for index, resp in enumerate(sub_responses):
|
| 119 |
+
# if last one, accept it's an equation (the entire response can be just one sentence with equation)
|
| 120 |
+
if index == len(sub_responses) - 1:
|
| 121 |
+
indicators_of_keys.extend(['='])
|
| 122 |
+
shortest_key_response = None
|
| 123 |
+
# the shortest response that may contain the answer (tail part of the response)
|
| 124 |
+
for indicator in indicators_of_keys:
|
| 125 |
+
if indicator in resp:
|
| 126 |
+
if not shortest_key_response:
|
| 127 |
+
shortest_key_response = resp.split(indicator)[-1].strip()
|
| 128 |
+
else:
|
| 129 |
+
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
|
| 130 |
+
shortest_key_response = resp.split(indicator)[-1].strip()
|
| 131 |
+
|
| 132 |
+
if shortest_key_response:
|
| 133 |
+
# and it's not trivial
|
| 134 |
+
if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
|
| 135 |
+
key_responses.append(shortest_key_response)
|
| 136 |
+
if len(key_responses) == 0: # did not found any
|
| 137 |
+
return [response]
|
| 138 |
+
return key_responses
|
| 139 |
+
|
| 140 |
+
key_responses = get_key_subresponses(response)
|
| 141 |
+
|
| 142 |
+
pred_list = key_responses.copy() # keep the original string response
|
| 143 |
+
for resp in key_responses:
|
| 144 |
+
pred_list.extend(extract_numbers(resp))
|
| 145 |
+
|
| 146 |
+
tmp_pred_list = []
|
| 147 |
+
for i in range(len(pred_list)):
|
| 148 |
+
tmp_pred_list.extend(normalize_str(pred_list[i], answer))
|
| 149 |
+
pred_list = tmp_pred_list
|
| 150 |
+
|
| 151 |
+
# remove duplicates
|
| 152 |
+
pred_list = list(set(pred_list))
|
| 153 |
+
|
| 154 |
+
return pred_list
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def get_TF_prediction(response):
|
| 158 |
+
"""get the prediction from the generated response,
|
| 159 |
+
return a list of predicted strings or numbers"""
|
| 160 |
+
|
| 161 |
+
def get_key_subresponses(response):
|
| 162 |
+
response = response.strip("。").strip()
|
| 163 |
+
sub_responses = re.split(r'。|\n', response)
|
| 164 |
+
indicators_of_keys = ['是', '为', '所以', '判断',
|
| 165 |
+
'陈述', '说法', '表达', '答案', '结果']
|
| 166 |
+
key_responses = []
|
| 167 |
+
for index, resp in enumerate(sub_responses):
|
| 168 |
+
shortest_key_response = None
|
| 169 |
+
# the shortest response that may contain the answer (tail part of the response)
|
| 170 |
+
for indicator in indicators_of_keys:
|
| 171 |
+
if indicator in resp:
|
| 172 |
+
if not shortest_key_response:
|
| 173 |
+
shortest_key_response = resp.split(indicator)[-1].strip()
|
| 174 |
+
else:
|
| 175 |
+
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
|
| 176 |
+
shortest_key_response = resp.split(indicator)[-1].strip()
|
| 177 |
+
|
| 178 |
+
if shortest_key_response:
|
| 179 |
+
# and it's not trivial
|
| 180 |
+
if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
|
| 181 |
+
key_responses.append(shortest_key_response)
|
| 182 |
+
if len(key_responses) == 0: # did not found any
|
| 183 |
+
return [response]
|
| 184 |
+
return key_responses
|
| 185 |
+
|
| 186 |
+
key_responses = get_key_subresponses(response)
|
| 187 |
+
|
| 188 |
+
pred_list = key_responses.copy() # keep the original string response
|
| 189 |
+
# remove duplicates
|
| 190 |
+
pred_list = list(set(pred_list))
|
| 191 |
+
|
| 192 |
+
return pred_list
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class CMMMU(ImageBaseDataset):
|
| 196 |
+
TYPE = 'VQA'
|
| 197 |
+
|
| 198 |
+
DATASET_URL = {
|
| 199 |
+
'CMMMU_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/CMMMU_VAL.tsv'
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
DATASET_MD5 = {
|
| 203 |
+
'CMMMU_VAL': 'b4727e2fce2415bf646379e60c11a726'
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
def dump_image(self, line):
|
| 207 |
+
os.makedirs(self.img_root, exist_ok=True)
|
| 208 |
+
|
| 209 |
+
tgt_path_z = []
|
| 210 |
+
if isinstance(line['image'], list):
|
| 211 |
+
for i in range(len(line['image'])):
|
| 212 |
+
tgt_path = osp.join(self.img_root, f"{line['index']}--{i + 1}.jpg")
|
| 213 |
+
if not read_ok(tgt_path):
|
| 214 |
+
decode_base64_to_image_file(line['image'][i], tgt_path)
|
| 215 |
+
tgt_path_z.append(tgt_path)
|
| 216 |
+
else:
|
| 217 |
+
tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
|
| 218 |
+
if not read_ok(tgt_path):
|
| 219 |
+
decode_base64_to_image_file(line['image'], tgt_path)
|
| 220 |
+
tgt_path_z.append(tgt_path)
|
| 221 |
+
return tgt_path_z
|
| 222 |
+
|
| 223 |
+
@classmethod
|
| 224 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
| 225 |
+
|
| 226 |
+
suffix = eval_file.split('.')[-1]
|
| 227 |
+
result_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
| 228 |
+
|
| 229 |
+
if not osp.exists(result_file):
|
| 230 |
+
data = load(eval_file)
|
| 231 |
+
assert 'answer' in data and 'prediction' in data
|
| 232 |
+
data['prediction'] = [str(x) for x in data['prediction']]
|
| 233 |
+
data['answer'] = [str(x) for x in data['answer']]
|
| 234 |
+
|
| 235 |
+
correct_count = 0
|
| 236 |
+
correct_category = {
|
| 237 |
+
'技术与工程': [0, 0],
|
| 238 |
+
'科学': [0, 0],
|
| 239 |
+
'健康与医学': [0, 0],
|
| 240 |
+
'商业': [0, 0],
|
| 241 |
+
'艺术与设计': [0, 0],
|
| 242 |
+
'人文社会科学': [0, 0],
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
for i in tqdm(data.iterrows()):
|
| 246 |
+
line = i[1]
|
| 247 |
+
correct_category[line['category']][0] += 1
|
| 248 |
+
|
| 249 |
+
# Options
|
| 250 |
+
if line['type'] == '选择':
|
| 251 |
+
index2ans = {
|
| 252 |
+
'A': line['option1'],
|
| 253 |
+
'B': line['option2'],
|
| 254 |
+
'C': line['option3'],
|
| 255 |
+
'D': line['option4']
|
| 256 |
+
}
|
| 257 |
+
fact_option = get_multi_choice_prediction(line['prediction'], ['A', 'B', 'C', 'D'], index2ans)
|
| 258 |
+
if fact_option == line['answer']:
|
| 259 |
+
correct_count += 1
|
| 260 |
+
correct_category[line['category']][1] += 1
|
| 261 |
+
|
| 262 |
+
# Binary
|
| 263 |
+
elif line['type'] == '判断':
|
| 264 |
+
positive_keywords = ['正确', '对', '准确', '肯定', '对的']
|
| 265 |
+
negative_keywords = ['不对', '错误', '不正确', '不准确', '不合适', '否定', '错的', '错']
|
| 266 |
+
ambiguous_keywords = ['对错', '是否正确', '否正确', '或者', '是否', '正确性', '对不']
|
| 267 |
+
|
| 268 |
+
def judge_similarity(pred_list, positive_keywords, negative_keywords):
|
| 269 |
+
positive_count = 0
|
| 270 |
+
negative_count = 0
|
| 271 |
+
|
| 272 |
+
for pred in pred_list:
|
| 273 |
+
if any(pos_word in pred for pos_word in positive_keywords):
|
| 274 |
+
positive_count += 1
|
| 275 |
+
elif any(neg_word in pred for neg_word in negative_keywords):
|
| 276 |
+
negative_count += 1
|
| 277 |
+
|
| 278 |
+
if positive_count > negative_count:
|
| 279 |
+
return "对"
|
| 280 |
+
elif negative_count > positive_count:
|
| 281 |
+
return "错"
|
| 282 |
+
else:
|
| 283 |
+
return random.choice(['对', '错'])
|
| 284 |
+
|
| 285 |
+
answer = get_TF_prediction(line['prediction'])
|
| 286 |
+
answer = [word for word in answer if not any(ambiguous in word for ambiguous in ambiguous_keywords)]
|
| 287 |
+
fact_answer = judge_similarity(answer, positive_keywords, negative_keywords)
|
| 288 |
+
if fact_answer == line['answer']:
|
| 289 |
+
correct_count += 1
|
| 290 |
+
correct_category[line['category']][1] += 1
|
| 291 |
+
|
| 292 |
+
# Fill the Blank
|
| 293 |
+
else:
|
| 294 |
+
norm_answers = normalize_str(line['answer'], line['answer'])
|
| 295 |
+
predicted_answer = get_fill_blank_prediction(line['prediction'], line['answer'])
|
| 296 |
+
|
| 297 |
+
for pred in predicted_answer:
|
| 298 |
+
# already normalized
|
| 299 |
+
if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
|
| 300 |
+
for norm_ans in norm_answers:
|
| 301 |
+
# only see if the string answer in the string pred
|
| 302 |
+
# print(norm_ans, pred)
|
| 303 |
+
if isinstance(norm_ans, str) and norm_ans in pred:
|
| 304 |
+
correct_count += 1
|
| 305 |
+
correct_category[line['category']][1] += 1
|
| 306 |
+
else: # it's a number
|
| 307 |
+
if pred in norm_answers:
|
| 308 |
+
correct_count += 1
|
| 309 |
+
correct_category[line['category']][1] += 1
|
| 310 |
+
|
| 311 |
+
accuracyz = {}
|
| 312 |
+
accuracyz['总准确率'] = correct_count / len(data)
|
| 313 |
+
for i in correct_category.keys():
|
| 314 |
+
accuracyz[i] = correct_category[i][1] / correct_category[i][0]
|
| 315 |
+
|
| 316 |
+
accuracyz = d2df(accuracyz)
|
| 317 |
+
accuracyz.round(10)
|
| 318 |
+
dump(accuracyz, result_file)
|
| 319 |
+
|
| 320 |
+
result = pd.read_csv(result_file)
|
| 321 |
+
return result
|
| 322 |
+
|
| 323 |
+
def build_prompt(self, line):
|
| 324 |
+
if line['type'] == '选择':
|
| 325 |
+
tgt_path = self.dump_image(line)
|
| 326 |
+
question = line['question']
|
| 327 |
+
options_prompt = 'Options:\n'
|
| 328 |
+
|
| 329 |
+
for i in [['A', '1'], ['B', '2'], ['C', '3'], ['D', '4']]:
|
| 330 |
+
options_prompt += i[0] + '. ' + line['option' + i[1]] + '\n'
|
| 331 |
+
|
| 332 |
+
prompt = (f'问题: {question}\n' + options_prompt
|
| 333 |
+
+ '请回答上述多项选择题,并选出正确选项。这些题目可能包括单选和多选题型。如果所提供的信息不足以确定一个明确的答案,那么请根据可用的数据和你的判断来选择最可能正确的选项。')
|
| 334 |
+
|
| 335 |
+
msgs = []
|
| 336 |
+
if isinstance(tgt_path, list):
|
| 337 |
+
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
| 338 |
+
else:
|
| 339 |
+
msgs = [dict(type='image', value=tgt_path)]
|
| 340 |
+
msgs.append(dict(type='text', value=prompt))
|
| 341 |
+
|
| 342 |
+
return msgs
|
| 343 |
+
|
| 344 |
+
elif line['type'] == '判断':
|
| 345 |
+
msgs = super().build_prompt(line)
|
| 346 |
+
assert msgs[-1]['type'] == 'text'
|
| 347 |
+
msgs[-1]['value'] += '\n请回答上述判断题,并根据题目描述和所给的信息来判断问题中陈述的对错。如果信息不完整或不足以作出绝对判断,请运用你的逻辑推理和现有信息来做出最可能的判断。'
|
| 348 |
+
return msgs
|
| 349 |
+
|
| 350 |
+
else:
|
| 351 |
+
msgs = super().build_prompt(line)
|
| 352 |
+
assert msgs[-1]['type'] == 'text'
|
| 353 |
+
msgs[-1]['value'] += '\n请回答上述填空题,并根据题目的要求和所提供的信息来给出最恰当的答案。如果信息不足以确切回答,那么请依据现有的数据和你的推理能力来填写最合理的答案。'
|
| 354 |
+
return msgs
|
r1-a/response_generation/qwenomni.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import base64
|
| 4 |
+
import uuid # For generating unique filenames
|
| 5 |
+
import time
|
| 6 |
+
import re # For parsing history
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
import random
|
| 9 |
+
import concurrent.futures # <-- For ThreadPoolExecutor
|
| 10 |
+
from tqdm import tqdm # <-- For progress bar
|
| 11 |
+
import threading # <-- For potential thread-local data or locks if needed later
|
| 12 |
+
import traceback # <-- For detailed error printing
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import soundfile as sf
|
| 16 |
+
from openai import OpenAI
|
| 17 |
+
from datasets import load_from_disk, Dataset, Features, Value # Ensure Features is imported
|
| 18 |
+
from dotenv import load_dotenv
|
| 19 |
+
|
| 20 |
+
# --- Configuration ---
|
| 21 |
+
load_dotenv()
|
| 22 |
+
|
| 23 |
+
# 1. API Client Setup & Model Rotation Setup
|
| 24 |
+
QWEN_MODEL_LIST = [
|
| 25 |
+
"qwen-omni-turbo",
|
| 26 |
+
"qwen-omni-turbo-latest",
|
| 27 |
+
"qwen-omni-turbo-2025-03-26",
|
| 28 |
+
"qwen-omni-turbo-2025-01-19",
|
| 29 |
+
]
|
| 30 |
+
NUM_MODELS = len(QWEN_MODEL_LIST)
|
| 31 |
+
print(f"Using Qwen models in rotation: {QWEN_MODEL_LIST}")
|
| 32 |
+
|
| 33 |
+
client = OpenAI(
|
| 34 |
+
# api_key=os.getenv("DASHSCOPE_API_KEY"),
|
| 35 |
+
api_key="sk-368bc96f5be74b9bbc880cc6161ab64b", # Replace with your actual key or os.getenv("DASHSCOPE_API_KEY")
|
| 36 |
+
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# 2. Dataset Paths
|
| 40 |
+
INPUT_DATASET_DIR = "/home/chenyifu/audio-r1/r1-a/dataset/preference_tasks_fully_merged_with_audio/train/final_dataset"
|
| 41 |
+
OUTPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_tasks_with_qwen_rotated" # <-- Adjusted name
|
| 42 |
+
|
| 43 |
+
# 3. Output Audio Configuration
|
| 44 |
+
OUTPUT_AUDIO_ROOT_DIR = "/root/autodl-tmp/audio-r1/r1-a/generated_audio/qwen_omni_rotated" # <-- Adjusted name
|
| 45 |
+
OUTPUT_AUDIO_FORMAT = "wav"
|
| 46 |
+
AVAILABLE_QWEN_VOICES = ["Cherry", "Serena", "Ethan", "Chelsie"]
|
| 47 |
+
OUTPUT_AUDIO_SAMPLERATE = 24000
|
| 48 |
+
|
| 49 |
+
# 4. API Call Settings
|
| 50 |
+
API_RETRY_DELAY = 5
|
| 51 |
+
API_MAX_RETRIES = 3
|
| 52 |
+
MAX_WORKERS = 10 # <-- Set desired number of threads (Be mindful of rate limits!)
|
| 53 |
+
|
| 54 |
+
# 5. Checkpoint Saving Configuration
|
| 55 |
+
CHECKPOINT_INTERVAL = 50 # Save every 500 completed tasks
|
| 56 |
+
|
| 57 |
+
# --- Helper Functions ---
|
| 58 |
+
|
| 59 |
+
def encode_audio_base64(audio_path):
|
| 60 |
+
if not audio_path or not os.path.exists(audio_path):
|
| 61 |
+
print(f"Warning: Input audio file not found or path is empty: {audio_path}")
|
| 62 |
+
return None
|
| 63 |
+
try:
|
| 64 |
+
with open(audio_path, "rb") as audio_file:
|
| 65 |
+
return base64.b64encode(audio_file.read()).decode("utf-8")
|
| 66 |
+
except Exception as e:
|
| 67 |
+
print(f"Error encoding audio file {audio_path}: {e}")
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
def parse_ultra_history(history_str):
|
| 71 |
+
messages = []
|
| 72 |
+
pattern = re.compile(r"\[(USER|ASSISTANT)\]\s*([\s\S]*?)(?=\s*\[(?:USER|ASSISTANT)\]|$)")
|
| 73 |
+
matches = pattern.findall(history_str)
|
| 74 |
+
if not matches and history_str and history_str.strip():
|
| 75 |
+
if history_str.lower().startswith("user:") or history_str.lower().startswith("[user]"):
|
| 76 |
+
role = "user"
|
| 77 |
+
content = re.sub(r"^(user:|\[user\])\s*", "", history_str, flags=re.IGNORECASE).strip()
|
| 78 |
+
if content: messages.append({"role": role, "content": content})
|
| 79 |
+
elif history_str.lower().startswith("assistant:") or history_str.lower().startswith("[assistant]"):
|
| 80 |
+
role = "assistant"
|
| 81 |
+
content = re.sub(r"^(assistant:|\[assistant\])\s*", "", history_str, flags=re.IGNORECASE).strip()
|
| 82 |
+
if content: messages.append({"role": role, "content": content})
|
| 83 |
+
else:
|
| 84 |
+
return []
|
| 85 |
+
else:
|
| 86 |
+
for role_tag, content in matches:
|
| 87 |
+
role = role_tag.lower()
|
| 88 |
+
cleaned_content = content.strip()
|
| 89 |
+
if cleaned_content:
|
| 90 |
+
messages.append({"role": role, "content": cleaned_content})
|
| 91 |
+
return messages
|
| 92 |
+
|
| 93 |
+
# --- API Call Worker Function (Takes model_name) ---
|
| 94 |
+
def call_qwen_omni_api_worker(task_info):
|
| 95 |
+
"""
|
| 96 |
+
Worker function to call Qwen API for a single task using a specific model.
|
| 97 |
+
Returns results including the model used.
|
| 98 |
+
"""
|
| 99 |
+
row_idx = task_info["row_idx"]
|
| 100 |
+
slot_idx = task_info["slot_idx"]
|
| 101 |
+
model_to_use = task_info["model_name"]
|
| 102 |
+
history_messages = task_info["history_messages"]
|
| 103 |
+
prompt_text = task_info["prompt_text"]
|
| 104 |
+
question_audio_path = task_info["question_audio_path"]
|
| 105 |
+
output_audio_filepath = task_info["output_audio_filepath"]
|
| 106 |
+
|
| 107 |
+
retries = 0
|
| 108 |
+
selected_voice = random.choice(AVAILABLE_QWEN_VOICES)
|
| 109 |
+
|
| 110 |
+
while retries < API_MAX_RETRIES:
|
| 111 |
+
try:
|
| 112 |
+
base64_audio_data = encode_audio_base64(question_audio_path)
|
| 113 |
+
if not base64_audio_data:
|
| 114 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Skipping API call due to missing input audio: {question_audio_path}")
|
| 115 |
+
# Return the model name even on error for potential logging
|
| 116 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: Missing input audio]", "saved_audio_path": None, "model_used": model_to_use}
|
| 117 |
+
|
| 118 |
+
input_audio_format = os.path.splitext(question_audio_path)[1].lstrip('.') or 'wav'
|
| 119 |
+
|
| 120 |
+
user_content = []
|
| 121 |
+
user_content.append({
|
| 122 |
+
"type": "input_audio",
|
| 123 |
+
"input_audio": {
|
| 124 |
+
"data": f"data:audio/{input_audio_format};base64,{base64_audio_data}",
|
| 125 |
+
"format": input_audio_format,
|
| 126 |
+
},
|
| 127 |
+
})
|
| 128 |
+
user_content.append({"type": "text", "text": prompt_text})
|
| 129 |
+
messages = history_messages + [{"role": "user", "content": user_content}]
|
| 130 |
+
|
| 131 |
+
completion = client.chat.completions.create(
|
| 132 |
+
model=model_to_use,
|
| 133 |
+
messages=messages,
|
| 134 |
+
modalities=["text", "audio"],
|
| 135 |
+
audio={"voice": selected_voice, "format": OUTPUT_AUDIO_FORMAT},
|
| 136 |
+
stream=True,
|
| 137 |
+
stream_options={"include_usage": True},
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
collected_text = ""
|
| 141 |
+
audio_base64_string = ""
|
| 142 |
+
usage_info = None
|
| 143 |
+
|
| 144 |
+
for chunk in completion:
|
| 145 |
+
if chunk.choices and len(chunk.choices) > 0:
|
| 146 |
+
delta = chunk.choices[0].delta
|
| 147 |
+
if hasattr(delta, 'content') and delta.content:
|
| 148 |
+
collected_text += delta.content
|
| 149 |
+
if hasattr(delta, "audio") and delta.audio:
|
| 150 |
+
if "data" in delta.audio and delta.audio["data"]:
|
| 151 |
+
audio_base64_string += delta.audio["data"]
|
| 152 |
+
if "transcript" in delta.audio and delta.audio["transcript"]:
|
| 153 |
+
collected_text += delta.audio["transcript"]
|
| 154 |
+
elif hasattr(chunk, "usage") and chunk.usage:
|
| 155 |
+
usage_info = chunk.usage
|
| 156 |
+
|
| 157 |
+
if audio_base64_string:
|
| 158 |
+
try:
|
| 159 |
+
wav_bytes = base64.b64decode(audio_base64_string)
|
| 160 |
+
if len(wav_bytes) == 0:
|
| 161 |
+
print(f"Warning (Row {row_idx}, Slot {slot_idx}, Model: {model_to_use}): Decoded audio bytes are empty.")
|
| 162 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None, "model_used": model_to_use}
|
| 163 |
+
if len(wav_bytes) % 2 != 0:
|
| 164 |
+
wav_bytes = wav_bytes[:-1] # Truncate for int16
|
| 165 |
+
if len(wav_bytes) == 0:
|
| 166 |
+
print(f"Warning (Row {row_idx}, Slot {slot_idx}, Model: {model_to_use}): Audio bytes became empty after truncation.")
|
| 167 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None, "model_used": model_to_use}
|
| 168 |
+
|
| 169 |
+
audio_np = np.frombuffer(wav_bytes, dtype=np.int16)
|
| 170 |
+
os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True)
|
| 171 |
+
sf.write(output_audio_filepath, audio_np, OUTPUT_AUDIO_SAMPLERATE, format=OUTPUT_AUDIO_FORMAT.upper())
|
| 172 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": output_audio_filepath, "model_used": model_to_use}
|
| 173 |
+
|
| 174 |
+
except base64.binascii.Error as b64_e:
|
| 175 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}, Model: {model_to_use}): Decoding base64 failed: {b64_e}")
|
| 176 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None, "model_used": model_to_use}
|
| 177 |
+
except ValueError as val_e:
|
| 178 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}, Model: {model_to_use}): Interpreting buffer as int16 failed: {val_e} (Bytes: {len(wav_bytes)})")
|
| 179 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None, "model_used": model_to_use}
|
| 180 |
+
except Exception as e:
|
| 181 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}, Model: {model_to_use}): Processing/saving audio bytes failed: {e}")
|
| 182 |
+
traceback.print_exc()
|
| 183 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None, "model_used": model_to_use}
|
| 184 |
+
else:
|
| 185 |
+
print(f"Warning (Row {row_idx}, Slot {slot_idx}, Model: {model_to_use}): No audio data received in the stream.")
|
| 186 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None, "model_used": model_to_use}
|
| 187 |
+
|
| 188 |
+
except Exception as e:
|
| 189 |
+
retries += 1
|
| 190 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}, Model: {model_to_use}): API Call Attempt {retries}/{API_MAX_RETRIES} failed: {e}")
|
| 191 |
+
if "rate limit" in str(e).lower() or "too many requests" in str(e).lower():
|
| 192 |
+
print("Rate limit likely hit. Consider reducing MAX_WORKERS or increasing delays.")
|
| 193 |
+
time.sleep(API_RETRY_DELAY * 2)
|
| 194 |
+
elif retries < API_MAX_RETRIES:
|
| 195 |
+
time.sleep(API_RETRY_DELAY)
|
| 196 |
+
else:
|
| 197 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}, Model: {model_to_use}): Max retries reached. Giving up.")
|
| 198 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": f"[API ERROR: Max retries on {model_to_use}]", "saved_audio_path": None, "model_used": model_to_use}
|
| 199 |
+
|
| 200 |
+
return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": f"[UNEXPECTED ERROR on {model_to_use}]", "saved_audio_path": None, "model_used": model_to_use}
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# --- Checkpoint Saving Function (Strictly using original_features) --- # <-- MODIFIED
|
| 204 |
+
def save_checkpoint(data_to_save, output_dir, dataset_features):
|
| 205 |
+
"""Saves the current state of the data list as a Hugging Face Dataset,
|
| 206 |
+
strictly adhering to the provided dataset_features."""
|
| 207 |
+
if not data_to_save:
|
| 208 |
+
print("Checkpoint: No data available to save.")
|
| 209 |
+
return
|
| 210 |
+
|
| 211 |
+
print(f"\nCheckpoint: Saving {len(data_to_save)} rows to {output_dir}...")
|
| 212 |
+
try:
|
| 213 |
+
# --- REMOVED logic to add 'model_used' feature ---
|
| 214 |
+
|
| 215 |
+
# Create dataset using the original features passed to the function
|
| 216 |
+
# This will raise an error if data_to_save contains keys not in dataset_features
|
| 217 |
+
# or if data types are incompatible after processing.
|
| 218 |
+
# Ensure data_to_save only contains keys present in dataset_features.
|
| 219 |
+
# Filter data_to_save to only include keys present in the original features
|
| 220 |
+
feature_keys = set(dataset_features.keys())
|
| 221 |
+
filtered_data_to_save = []
|
| 222 |
+
for item in data_to_save:
|
| 223 |
+
filtered_item = {k: v for k, v in item.items() if k in feature_keys}
|
| 224 |
+
# Optional: Fill missing keys with None if required by schema, though from_list handles this.
|
| 225 |
+
# for key in feature_keys:
|
| 226 |
+
# if key not in filtered_item:
|
| 227 |
+
# filtered_item[key] = None
|
| 228 |
+
filtered_data_to_save.append(filtered_item)
|
| 229 |
+
|
| 230 |
+
checkpoint_dataset = Dataset.from_list(filtered_data_to_save, features=dataset_features)
|
| 231 |
+
|
| 232 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 233 |
+
checkpoint_dataset.save_to_disk(output_dir)
|
| 234 |
+
print(f"Checkpoint: Saved successfully to {output_dir}")
|
| 235 |
+
|
| 236 |
+
except Exception as ckpt_save_e:
|
| 237 |
+
print(f"Error saving checkpoint dataset using datasets lib: {ckpt_save_e}")
|
| 238 |
+
print("Detailed error:", traceback.format_exc()) # Print full traceback for save errors
|
| 239 |
+
# Fallback to JSON Lines (does not strictly enforce schema)
|
| 240 |
+
output_jsonl_path = output_dir + "_checkpoint.jsonl"
|
| 241 |
+
print(f"Attempting to save checkpoint as JSON lines to {output_jsonl_path}...")
|
| 242 |
+
try:
|
| 243 |
+
# Save the original unfiltered data to JSONL for debugging if needed
|
| 244 |
+
with open(output_jsonl_path, 'w', encoding='utf-8') as f:
|
| 245 |
+
for item in data_to_save: # Use original data for JSON fallback
|
| 246 |
+
serializable_item = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in item.items()}
|
| 247 |
+
f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n')
|
| 248 |
+
print(f"Checkpoint: Fallback save successful to {output_jsonl_path}")
|
| 249 |
+
except Exception as json_save_e:
|
| 250 |
+
print(f"Error saving checkpoint as JSON lines: {json_save_e}")
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
# --- Main Processing Logic ---
|
| 254 |
+
|
| 255 |
+
print("Checking for existing checkpoint/output dataset...")
|
| 256 |
+
dataset = None
|
| 257 |
+
original_features = None
|
| 258 |
+
|
| 259 |
+
try:
|
| 260 |
+
potential_checkpoint_info = os.path.join(OUTPUT_DATASET_DIR, "dataset_info.json")
|
| 261 |
+
potential_checkpoint_state = os.path.join(OUTPUT_DATASET_DIR, "state.json")
|
| 262 |
+
|
| 263 |
+
if os.path.exists(OUTPUT_DATASET_DIR) and \
|
| 264 |
+
(os.path.exists(potential_checkpoint_info) or os.path.exists(potential_checkpoint_state)):
|
| 265 |
+
print(f"Attempting to load existing data from output directory: {OUTPUT_DATASET_DIR}")
|
| 266 |
+
try:
|
| 267 |
+
dataset = load_from_disk(OUTPUT_DATASET_DIR)
|
| 268 |
+
original_features = dataset.features
|
| 269 |
+
print(f"Successfully resumed from {OUTPUT_DATASET_DIR}. Loaded {len(dataset)} rows.")
|
| 270 |
+
print(f"Resumed features: {original_features}") # Log the features
|
| 271 |
+
except Exception as load_ckpt_e:
|
| 272 |
+
print(f"Warning: Failed to load from {OUTPUT_DATASET_DIR}: {load_ckpt_e}")
|
| 273 |
+
dataset = None
|
| 274 |
+
else:
|
| 275 |
+
print(f"No valid existing data found in {OUTPUT_DATASET_DIR}.")
|
| 276 |
+
|
| 277 |
+
if dataset is None:
|
| 278 |
+
print(f"Loading original dataset from {INPUT_DATASET_DIR}...")
|
| 279 |
+
if not os.path.exists(INPUT_DATASET_DIR):
|
| 280 |
+
print(f"FATAL: Original input dataset directory not found at {INPUT_DATASET_DIR}")
|
| 281 |
+
exit(1)
|
| 282 |
+
dataset = load_from_disk(INPUT_DATASET_DIR)
|
| 283 |
+
original_features = dataset.features
|
| 284 |
+
print(f"Original dataset loaded successfully with {len(dataset)} rows.")
|
| 285 |
+
print(f"Original features: {original_features}") # Log the features
|
| 286 |
+
|
| 287 |
+
except Exception as initial_load_e:
|
| 288 |
+
print(f"FATAL: Error during initial dataset loading: {initial_load_e}")
|
| 289 |
+
traceback.print_exc()
|
| 290 |
+
exit(1)
|
| 291 |
+
breakpoint()
|
| 292 |
+
# Ensure original_features is loaded
|
| 293 |
+
if original_features is None:
|
| 294 |
+
print("FATAL: Failed to load dataset features. Exiting.")
|
| 295 |
+
exit(1)
|
| 296 |
+
|
| 297 |
+
os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True)
|
| 298 |
+
|
| 299 |
+
# --- Pre-calculation Step (Assign Models Round-Robin) ---
|
| 300 |
+
print("Pre-calculating tasks and assigning models...")
|
| 301 |
+
tasks_to_process = []
|
| 302 |
+
updated_data = list(dataset) # Use mutable list of dicts
|
| 303 |
+
task_creation_counter = 0
|
| 304 |
+
|
| 305 |
+
for idx, row in enumerate(tqdm(updated_data, desc="Scanning dataset")):
|
| 306 |
+
needs_processing_in_row = False
|
| 307 |
+
qwen_tasks_in_row = []
|
| 308 |
+
for i in range(1, 4):
|
| 309 |
+
model_key = f"model_{i}"
|
| 310 |
+
response_text_key = f"response_text_{i}"
|
| 311 |
+
model_assigned = row.get(model_key)
|
| 312 |
+
response_text_exists = row.get(response_text_key) is not None
|
| 313 |
+
if model_assigned == "qwen_omni" and not response_text_exists:
|
| 314 |
+
needs_processing_in_row = True
|
| 315 |
+
qwen_tasks_in_row.append(i)
|
| 316 |
+
|
| 317 |
+
if needs_processing_in_row:
|
| 318 |
+
slot_to_process = qwen_tasks_in_row[0]
|
| 319 |
+
i = slot_to_process
|
| 320 |
+
prompt_text_key = f"prompt_text_{i}"
|
| 321 |
+
response_audio_key = f"response_audio_path_{i}" # Define key for clarity
|
| 322 |
+
|
| 323 |
+
question_audio_path = row.get('question_audio')
|
| 324 |
+
if not question_audio_path or not os.path.exists(question_audio_path):
|
| 325 |
+
print(f"Warning (Row {idx}, Slot {i}): Skipping task creation - Missing or non-existent 'question_audio': {question_audio_path}")
|
| 326 |
+
# Ensure error state is marked in updated_data if skipping task creation
|
| 327 |
+
response_text_key_for_error = f"response_text_{i}"
|
| 328 |
+
response_audio_key_for_error = f"response_audio_path_{i}"
|
| 329 |
+
if 0 <= idx < len(updated_data):
|
| 330 |
+
updated_data[idx][response_text_key_for_error] = "[SKIPPED: Missing input audio]"
|
| 331 |
+
updated_data[idx][response_audio_key_for_error] = None
|
| 332 |
+
continue
|
| 333 |
+
|
| 334 |
+
metadata_str = row.get('metadata', "{}")
|
| 335 |
+
source_dataset = row.get('source_dataset')
|
| 336 |
+
metadata = {}
|
| 337 |
+
try:
|
| 338 |
+
if metadata_str and isinstance(metadata_str, str): metadata = json.loads(metadata_str)
|
| 339 |
+
elif isinstance(metadata_str, dict): metadata = metadata_str
|
| 340 |
+
except (json.JSONDecodeError, TypeError): pass
|
| 341 |
+
|
| 342 |
+
history_messages = []
|
| 343 |
+
if source_dataset == 'ultra':
|
| 344 |
+
history_str = metadata.get('history', '')
|
| 345 |
+
if history_str: history_messages = parse_ultra_history(history_str)
|
| 346 |
+
|
| 347 |
+
model_to_use_for_this_task = QWEN_MODEL_LIST[task_creation_counter % NUM_MODELS]
|
| 348 |
+
task_creation_counter += 1
|
| 349 |
+
|
| 350 |
+
unique_id = str(uuid.uuid4()).replace("-", "")
|
| 351 |
+
output_audio_filename = f"qwen_r{idx}_s{i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}"
|
| 352 |
+
output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename)
|
| 353 |
+
|
| 354 |
+
task_info = {
|
| 355 |
+
"row_idx": idx,
|
| 356 |
+
"slot_idx": i,
|
| 357 |
+
"model_name": model_to_use_for_this_task,
|
| 358 |
+
"history_messages": history_messages,
|
| 359 |
+
"prompt_text": row.get(prompt_text_key, ""),
|
| 360 |
+
"question_audio_path": question_audio_path,
|
| 361 |
+
"output_audio_filepath": output_audio_filepath,
|
| 362 |
+
}
|
| 363 |
+
tasks_to_process.append(task_info)
|
| 364 |
+
|
| 365 |
+
total_tasks = len(tasks_to_process)
|
| 366 |
+
if total_tasks == 0:
|
| 367 |
+
print("No Qwen tasks found needing processing in the loaded dataset.")
|
| 368 |
+
exit(0)
|
| 369 |
+
|
| 370 |
+
print(f"Found {total_tasks} Qwen tasks to process.")
|
| 371 |
+
model_counts = {model: 0 for model in QWEN_MODEL_LIST}
|
| 372 |
+
for task in tasks_to_process: model_counts[task['model_name']] += 1
|
| 373 |
+
print("Task distribution per model:", model_counts)
|
| 374 |
+
|
| 375 |
+
# --- Threaded Execution with Checkpointing ---
|
| 376 |
+
print(f"Starting processing with up to {MAX_WORKERS} worker threads...")
|
| 377 |
+
start_total_time = time.time()
|
| 378 |
+
tasks_completed = 0
|
| 379 |
+
tasks_failed = 0
|
| 380 |
+
completed_since_last_save = 0
|
| 381 |
+
|
| 382 |
+
# --- REMOVED code block that updated original_features ---
|
| 383 |
+
|
| 384 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
| 385 |
+
future_to_task = {executor.submit(call_qwen_omni_api_worker, task): task for task in tasks_to_process}
|
| 386 |
+
|
| 387 |
+
for future in tqdm(concurrent.futures.as_completed(future_to_task), total=total_tasks, desc="Processing tasks"):
|
| 388 |
+
task_info = future_to_task[future]
|
| 389 |
+
row_idx = task_info["row_idx"]
|
| 390 |
+
slot_idx = task_info["slot_idx"]
|
| 391 |
+
result = None
|
| 392 |
+
|
| 393 |
+
try:
|
| 394 |
+
result = future.result()
|
| 395 |
+
response_text_key = f"response_text_{slot_idx}"
|
| 396 |
+
response_audio_key = f"response_audio_path_{slot_idx}"
|
| 397 |
+
# --- REMOVED model_used_key and assignment ---
|
| 398 |
+
|
| 399 |
+
if 0 <= row_idx < len(updated_data):
|
| 400 |
+
updated_data[row_idx][response_text_key] = result["response_text"]
|
| 401 |
+
updated_data[row_idx][response_audio_key] = result["saved_audio_path"]
|
| 402 |
+
# --- REMOVED assignment to updated_data[row_idx][model_used_key] ---
|
| 403 |
+
|
| 404 |
+
if result["saved_audio_path"] is None or "ERROR" in result["response_text"]:
|
| 405 |
+
tasks_failed += 1
|
| 406 |
+
else:
|
| 407 |
+
print(f"Warning: Invalid row index {row_idx} encountered during result merge. Skipping update.")
|
| 408 |
+
tasks_failed += 1
|
| 409 |
+
|
| 410 |
+
tasks_completed += 1
|
| 411 |
+
completed_since_last_save += 1
|
| 412 |
+
|
| 413 |
+
if completed_since_last_save >= CHECKPOINT_INTERVAL:
|
| 414 |
+
# Pass the unmodified original_features
|
| 415 |
+
save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
|
| 416 |
+
completed_since_last_save = 0
|
| 417 |
+
|
| 418 |
+
except Exception as exc:
|
| 419 |
+
print(f"Error (Row {row_idx}, Slot {slot_idx}): Task generated an exception: {exc}")
|
| 420 |
+
traceback.print_exc()
|
| 421 |
+
response_text_key = f"response_text_{slot_idx}"
|
| 422 |
+
response_audio_key = f"response_audio_path_{slot_idx}"
|
| 423 |
+
# --- REMOVED model_used_key ---
|
| 424 |
+
|
| 425 |
+
if 0 <= row_idx < len(updated_data):
|
| 426 |
+
updated_data[row_idx][response_text_key] = f"[ERROR: Task Exception {type(exc).__name__}]"
|
| 427 |
+
updated_data[row_idx][response_audio_key] = None
|
| 428 |
+
# --- REMOVED assignment to updated_data[row_idx][model_used_key] ---
|
| 429 |
+
|
| 430 |
+
tasks_failed += 1
|
| 431 |
+
tasks_completed += 1
|
| 432 |
+
completed_since_last_save += 1
|
| 433 |
+
|
| 434 |
+
if completed_since_last_save >= CHECKPOINT_INTERVAL:
|
| 435 |
+
# Pass the unmodified original_features
|
| 436 |
+
save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
|
| 437 |
+
completed_since_last_save = 0
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
end_total_time = time.time()
|
| 441 |
+
print("\n--- Processing Complete ---")
|
| 442 |
+
print(f"Total tasks submitted: {total_tasks}")
|
| 443 |
+
print(f"Total tasks processed (returned): {tasks_completed} (Succeeded-ish: {tasks_completed - tasks_failed}, Failed: {tasks_failed})")
|
| 444 |
+
print(f"Total processing time: {(end_total_time - start_total_time)/60:.2f} minutes")
|
| 445 |
+
|
| 446 |
+
# --- Final Save ---
|
| 447 |
+
print("\nPerforming final save...")
|
| 448 |
+
# Pass the unmodified original_features
|
| 449 |
+
save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
|
| 450 |
+
|
| 451 |
+
print("\nScript finished.")
|