1f commited on
Commit
ff53362
·
verified ·
1 Parent(s): 9cdf7a2

Add files using upload-large-folder tool

Browse files
r1-a/final_dataset/preference_relative/dataset_info.json ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "question_text": {
6
+ "dtype": "string",
7
+ "_type": "Value"
8
+ },
9
+ "question_audio": {
10
+ "dtype": "string",
11
+ "_type": "Value"
12
+ },
13
+ "source_dataset": {
14
+ "dtype": "string",
15
+ "_type": "Value"
16
+ },
17
+ "metadata": {
18
+ "dtype": "string",
19
+ "_type": "Value"
20
+ },
21
+ "model_1": {
22
+ "dtype": "string",
23
+ "_type": "Value"
24
+ },
25
+ "prompt_name_1": {
26
+ "dtype": "string",
27
+ "_type": "Value"
28
+ },
29
+ "prompt_text_1": {
30
+ "dtype": "string",
31
+ "_type": "Value"
32
+ },
33
+ "response_text_1": {
34
+ "dtype": "string",
35
+ "_type": "Value"
36
+ },
37
+ "response_audio_path_1": {
38
+ "dtype": "string",
39
+ "_type": "Value"
40
+ },
41
+ "model_2": {
42
+ "dtype": "string",
43
+ "_type": "Value"
44
+ },
45
+ "prompt_name_2": {
46
+ "dtype": "string",
47
+ "_type": "Value"
48
+ },
49
+ "prompt_text_2": {
50
+ "dtype": "string",
51
+ "_type": "Value"
52
+ },
53
+ "response_text_2": {
54
+ "dtype": "string",
55
+ "_type": "Value"
56
+ },
57
+ "response_audio_path_2": {
58
+ "dtype": "string",
59
+ "_type": "Value"
60
+ },
61
+ "model_3": {
62
+ "dtype": "string",
63
+ "_type": "Value"
64
+ },
65
+ "prompt_name_3": {
66
+ "dtype": "string",
67
+ "_type": "Value"
68
+ },
69
+ "prompt_text_3": {
70
+ "dtype": "string",
71
+ "_type": "Value"
72
+ },
73
+ "response_text_3": {
74
+ "dtype": "string",
75
+ "_type": "Value"
76
+ },
77
+ "response_audio_path_3": {
78
+ "dtype": "string",
79
+ "_type": "Value"
80
+ }
81
+ },
82
+ "homepage": "",
83
+ "license": ""
84
+ }
r1-a/final_dataset/preference_relative/state.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "6ada4a1b690526f4",
8
+ "_format_columns": [
9
+ "question_text",
10
+ "question_audio",
11
+ "source_dataset",
12
+ "metadata",
13
+ "model_1",
14
+ "prompt_name_1",
15
+ "prompt_text_1",
16
+ "response_text_1",
17
+ "response_audio_path_1",
18
+ "model_2",
19
+ "prompt_name_2",
20
+ "prompt_text_2",
21
+ "response_text_2",
22
+ "response_audio_path_2",
23
+ "model_3",
24
+ "prompt_name_3",
25
+ "prompt_text_3",
26
+ "response_text_3",
27
+ "response_audio_path_3"
28
+ ],
29
+ "_format_kwargs": {},
30
+ "_format_type": null,
31
+ "_output_all_columns": false,
32
+ "_split": null
33
+ }
r1-a/final_dataset/preference_relative_processed_shards/logs/shard_0_gpu_0.log ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-05-04 14:41:37,640 - INFO - [Shard 0] - Process started for Shard 0 on GPU 0 (logical device cuda:0)
2
+ 2025-05-04 14:41:37,640 - INFO - [Shard 0] - Arguments: Namespace(shard_index=0, gpu_id=0, wer_threshold=0.4, pipeline_batch_size=16, map_batch_size=16, num_check_workers=4)
3
+ 2025-05-04 14:41:37,640 - INFO - [Shard 0] - Loading dataset from /home/chenyifu/audio-r1/r1-a/final_dataset/preference_relative
4
+ 2025-05-04 16:30:18,197 - ERROR - [Shard 0] - Failed to load dataset:
5
+ Traceback (most recent call last):
6
+ File "/home/chenyifu/audio-r1/r1-a/dataset/retts.py", line 380, in main
7
+ logger.info(f"Full dataset loaded with {full_ds.num_rows} rows.")
8
+ File "/home/chenyifu/audio-r1/r1-a/dataset/retts.py", line 380, in main
9
+ logger.info(f"Full dataset loaded with {full_ds.num_rows} rows.")
10
+ File "/home/chenyifu/miniconda3/envs/cosyvoice/lib/python3.10/bdb.py", line 90, in trace_dispatch
11
+ return self.dispatch_line(frame)
12
+ File "/home/chenyifu/miniconda3/envs/cosyvoice/lib/python3.10/bdb.py", line 115, in dispatch_line
13
+ if self.quitting: raise BdbQuit
14
+ bdb.BdbQuit
15
+ 2025-05-04 16:30:18,199 - WARNING - [Shard 0] - Processing did not complete or failed early. No statistics to log.
16
+ 2025-05-04 16:30:18,199 - INFO - [Shard 0] - Process for Shard 0 on GPU 0 finished.
r1-a/final_dataset/preference_relative_processed_shards/shard_0/dataset_info.json ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "question_text": {
6
+ "dtype": "string",
7
+ "_type": "Value"
8
+ },
9
+ "question_audio": {
10
+ "dtype": "string",
11
+ "_type": "Value"
12
+ },
13
+ "source_dataset": {
14
+ "dtype": "string",
15
+ "_type": "Value"
16
+ },
17
+ "metadata": {
18
+ "dtype": "string",
19
+ "_type": "Value"
20
+ },
21
+ "model_1": {
22
+ "dtype": "string",
23
+ "_type": "Value"
24
+ },
25
+ "prompt_name_1": {
26
+ "dtype": "string",
27
+ "_type": "Value"
28
+ },
29
+ "prompt_text_1": {
30
+ "dtype": "string",
31
+ "_type": "Value"
32
+ },
33
+ "response_text_1": {
34
+ "dtype": "string",
35
+ "_type": "Value"
36
+ },
37
+ "response_audio_path_1": {
38
+ "dtype": "string",
39
+ "_type": "Value"
40
+ },
41
+ "model_2": {
42
+ "dtype": "string",
43
+ "_type": "Value"
44
+ },
45
+ "prompt_name_2": {
46
+ "dtype": "string",
47
+ "_type": "Value"
48
+ },
49
+ "prompt_text_2": {
50
+ "dtype": "string",
51
+ "_type": "Value"
52
+ },
53
+ "response_text_2": {
54
+ "dtype": "string",
55
+ "_type": "Value"
56
+ },
57
+ "response_audio_path_2": {
58
+ "dtype": "string",
59
+ "_type": "Value"
60
+ },
61
+ "model_3": {
62
+ "dtype": "string",
63
+ "_type": "Value"
64
+ },
65
+ "prompt_name_3": {
66
+ "dtype": "string",
67
+ "_type": "Value"
68
+ },
69
+ "prompt_text_3": {
70
+ "dtype": "string",
71
+ "_type": "Value"
72
+ },
73
+ "response_text_3": {
74
+ "dtype": "string",
75
+ "_type": "Value"
76
+ },
77
+ "response_audio_path_3": {
78
+ "dtype": "string",
79
+ "_type": "Value"
80
+ },
81
+ "asr_transcription": {
82
+ "dtype": "string",
83
+ "_type": "Value"
84
+ },
85
+ "wer": {
86
+ "dtype": "float32",
87
+ "_type": "Value"
88
+ },
89
+ "is_bad_tts": {
90
+ "dtype": "bool",
91
+ "_type": "Value"
92
+ },
93
+ "error_message": {
94
+ "dtype": "string",
95
+ "_type": "Value"
96
+ }
97
+ },
98
+ "homepage": "",
99
+ "license": ""
100
+ }
r1-a/final_dataset/preference_relative_processed_shards/shard_0/state.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "00775948802a271b",
8
+ "_format_columns": [
9
+ "asr_transcription",
10
+ "error_message",
11
+ "is_bad_tts",
12
+ "metadata",
13
+ "model_1",
14
+ "model_2",
15
+ "model_3",
16
+ "prompt_name_1",
17
+ "prompt_name_2",
18
+ "prompt_name_3",
19
+ "prompt_text_1",
20
+ "prompt_text_2",
21
+ "prompt_text_3",
22
+ "question_audio",
23
+ "question_text",
24
+ "response_audio_path_1",
25
+ "response_audio_path_2",
26
+ "response_audio_path_3",
27
+ "response_text_1",
28
+ "response_text_2",
29
+ "response_text_3",
30
+ "source_dataset",
31
+ "wer"
32
+ ],
33
+ "_format_kwargs": {},
34
+ "_format_type": null,
35
+ "_output_all_columns": false,
36
+ "_split": null
37
+ }
r1-a/final_dataset/prompt_only_relative_paths/dataset_info.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "source_dataset": {
6
+ "dtype": "string",
7
+ "_type": "Value"
8
+ },
9
+ "question_text": {
10
+ "dtype": "string",
11
+ "_type": "Value"
12
+ },
13
+ "question_audio": {
14
+ "dtype": "string",
15
+ "_type": "Value"
16
+ },
17
+ "metadata": {
18
+ "dtype": "string",
19
+ "_type": "Value"
20
+ }
21
+ },
22
+ "homepage": "",
23
+ "license": ""
24
+ }
r1-a/final_dataset/prompt_only_relative_paths/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "419ded2384418f0a",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
r1-a/response_generation/glm4voice.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import uuid
5
+ import time
6
+ import re
7
+ from io import BytesIO
8
+ import concurrent.futures
9
+ from tqdm import tqdm
10
+ import threading
11
+ import itertools
12
+ import traceback # For detailed error logging
13
+
14
+ import numpy as np
15
+ import soundfile as sf
16
+ from zhipuai import ZhipuAI
17
+ # Import specific error type if available and helpful
18
+ # Attempt to import specific error, handle if it doesn't exist
19
+ try:
20
+ from zhipuai.core._errors import APIStatusError
21
+ except ImportError:
22
+ # Define a dummy class if the specific error isn't available
23
+ # This allows the except block to still catch general exceptions
24
+ # that might represent API status issues if the SDK changes.
25
+ print("Warning: zhipuai.core._errors.APIStatusError not found. Using generic Exception for status errors.")
26
+ class APIStatusError(Exception):
27
+ def __init__(self, message, status_code=None, body=None):
28
+ super().__init__(message)
29
+ self.status_code = status_code
30
+ self.body = body
31
+ self.message = message # Add message attribute for consistency
32
+
33
+ from datasets import load_from_disk, Dataset
34
+ from dotenv import load_dotenv
35
+
36
+ # --- Configuration (User's Original Settings) ---
37
+ load_dotenv()
38
+
39
+ # 1. API Client Setup
40
+ GLM_MODEL_NAME = "glm-4-voice" # <<< User's original model name
41
+
42
+ # --- API Key Rotation Setup (User's Original Keys & Logic) ---
43
+ ZHIPUAI_API_KEYS = [
44
+ "14a67189b8bc4ee489e83b6247c36d0e.AIPUNrII50wREvsh",
45
+ "72120787822c4123a9654965ff90e4e6.JS1nuey9MncQscPa",
46
+ "d41b3b5bb49f4c8680b3836e7fc49bbf.u0jGxYc5sYPeRr5p",
47
+ "bc9bccd6ddd145fc844a014521c26868.JwsZXHzA3l32dDwz",
48
+ "0e5a05d709794737923ebd122e07d491.sL67ALh6BiLYaaGW", # New key
49
+ "db87c1fda8af4eb8b505f36e791d700d.w5M0Q3ZssT55tvlW", # New key
50
+ "1594ac60fbca4973809f4da425238e0c.ZMMfchqbok992Dmu", # New key
51
+ "469c0fa3b14e4913b1d14bc5d6f0c858.0KdQjFqdi66VPMnb",
52
+ "b9b538bb0e134438bacaf922b023d1fd.sogFUUp57UJ8YSd6",
53
+ "50bb382993a345cfa35833fc89caaa52.oR921jSW8iwzCV22",
54
+ "44512bbede5940f7964db7694bfc04df.yhDEQyPOXQCqh1Mn",
55
+ "99aba409b55c432696b9d5f1ff565d30.GmfRNngBOo8qDUbf"
56
+ ] # <<< User's original keys
57
+
58
+ if not ZHIPUAI_API_KEYS:
59
+ print("FATAL: No ZHIPUAI_API_KEYS provided in the list.")
60
+ exit(1)
61
+
62
+ # Make sure keys are unique if duplicates were accidental
63
+ unique_keys = list(dict.fromkeys(ZHIPUAI_API_KEYS))
64
+ if len(unique_keys) != len(ZHIPUAI_API_KEYS):
65
+ print(f"Warning: Duplicate API keys found and removed. Using {len(unique_keys)} unique keys.")
66
+ ZHIPUAI_API_KEYS = unique_keys
67
+
68
+ key_cycler = itertools.cycle(ZHIPUAI_API_KEYS)
69
+ key_lock = threading.Lock()
70
+ disabled_keys = set() # Shared set to store disabled keys
71
+
72
+ class AllKeysDisabledError(Exception):
73
+ """Custom exception raised when all API keys are disabled."""
74
+ pass
75
+
76
+ def get_next_active_key():
77
+ """
78
+ Thread-safely gets the next API key from the cycle, skipping disabled keys.
79
+ Raises AllKeysDisabledError if all keys are disabled.
80
+ (User's Original Logic)
81
+ """
82
+ with key_lock:
83
+ initial_key_count = len(ZHIPUAI_API_KEYS)
84
+ checked_count = 0
85
+ while checked_count < initial_key_count:
86
+ potential_key = next(key_cycler)
87
+ if potential_key not in disabled_keys:
88
+ return potential_key
89
+ checked_count += 1
90
+ # Prevent infinite loop if somehow cycle changes mid-operation (shouldn't happen)
91
+ if checked_count > initial_key_count * 2:
92
+ print("Warning: Potential issue in get_next_active_key cycle detection.")
93
+ break
94
+ # If we exit the loop, all keys have been checked and are disabled
95
+ if len(disabled_keys) == initial_key_count:
96
+ raise AllKeysDisabledError("All API keys have been disabled.")
97
+ else:
98
+ # This case should ideally not be reached if logic is sound
99
+ # but indicates a potential problem finding an active key
100
+ print(f"Warning: Could not find an active key after checking {checked_count}. Disabled: {len(disabled_keys)}/{initial_key_count}")
101
+ raise RuntimeError("Failed to find an active API key.")
102
+ # --- End API Key Rotation Setup ---
103
+
104
+ # 2. Dataset Paths (User's Original Paths)
105
+ INPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_sampling_tasks" # <<< User's original path
106
+ OUTPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_tasks_with_glm" # <<< User's original path
107
+
108
+ # 3. Output Audio Configuration (User's Original Settings)
109
+ OUTPUT_AUDIO_ROOT_DIR = "/root/autodl-tmp/audio-r1/r1-a/generated_audio/glm_voice" # <<< User's original path
110
+ OUTPUT_AUDIO_FORMAT = "wav" # <<< User's original setting
111
+ OUTPUT_AUDIO_SAMPLERATE = 44100 # <<< User's original setting
112
+
113
+ # 4. API Call Settings (User's Original Settings)
114
+ API_RETRY_DELAY = 5 # <<< User's original setting
115
+ API_MAX_RETRIES = 3 # <<< User's original setting
116
+ MAX_WORKERS = 10 # <<< User's original setting
117
+
118
+ # --- Helper Functions (User's Original Functions) ---
119
+ def encode_audio_base64(audio_path):
120
+ # ... (implementation unchanged from user's script) ...
121
+ if not audio_path or not os.path.exists(audio_path):
122
+ print(f"Warning: Input audio file not found or path is empty: {audio_path}")
123
+ return None
124
+ try:
125
+ with open(audio_path, "rb") as audio_file:
126
+ return base64.b64encode(audio_file.read()).decode("utf-8")
127
+ except Exception as e:
128
+ print(f"Error encoding audio file {audio_path}: {e}")
129
+ return None
130
+
131
+ def parse_ultra_history(history_str):
132
+ # ... (implementation unchanged from user's script) ...
133
+ messages = []
134
+ pattern = re.compile(r"\[(USER|ASSISTANT)\]\s*([\s\S]*?)(?=\s*\[(?:USER|ASSISTANT)\]|$)")
135
+ matches = pattern.findall(history_str)
136
+ if not matches:
137
+ return [] # Return empty list if no matches, as per user's original code
138
+ for role_tag, content in matches:
139
+ role = role_tag.lower()
140
+ cleaned_content = content.strip()
141
+ if cleaned_content:
142
+ messages.append({"role": role, "content": cleaned_content})
143
+ return messages
144
+
145
+ # --- Modified API Call Worker Function (Handles Key Disabling & History Flattening) ---
146
+ def call_glm_voice_api_worker(task_info):
147
+ """
148
+ Worker function to call GLM Voice API, handling key disabling for error 1113,
149
+ and flattening history into the user prompt with clear markers.
150
+ (Incorporates Method 2 flattening into user's worker structure)
151
+ """
152
+ row_idx = task_info["row_idx"]
153
+ slot_idx = task_info["slot_idx"]
154
+ current_api_key = task_info["api_key"]
155
+ history_messages = task_info["history_messages"] # Original parsed history
156
+ prompt_text = task_info["prompt_text"] # The user's current text request
157
+ question_audio_path = task_info["question_audio_path"]
158
+ output_audio_filepath = task_info["output_audio_filepath"]
159
+
160
+ retries = 0
161
+ local_glm_client = None
162
+
163
+ while retries < API_MAX_RETRIES:
164
+ # --- Initialize or Re-initialize client (User's Original Logic) ---
165
+ if local_glm_client is None or getattr(local_glm_client, 'api_key', None) != current_api_key:
166
+ try:
167
+ with key_lock:
168
+ if current_api_key in disabled_keys:
169
+ print(f"Info (Row {row_idx}, Slot {slot_idx}): Assigned key ...{current_api_key[-6:]} was disabled before use, getting new key.")
170
+ current_api_key = get_next_active_key()
171
+ task_info["api_key"] = current_api_key # Update task_info potentially for logging?
172
+ print(f" [Thread-{threading.get_ident()}] Initializing client for Row {row_idx}, Slot {slot_idx} (Key: ...{current_api_key[-6:]})")
173
+ local_glm_client = ZhipuAI(api_key=current_api_key)
174
+ except AllKeysDisabledError:
175
+ print(f"FATAL (Row {row_idx}, Slot {slot_idx}): All API keys are disabled. Cannot proceed with task.")
176
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: All Keys Disabled]", "saved_audio_path": None}
177
+ except Exception as client_init_e:
178
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Failed to initialize ZhipuAI client with key ...{current_api_key[-6:]}: {client_init_e}")
179
+ retries += 1
180
+ time.sleep(API_RETRY_DELAY)
181
+ continue
182
+
183
+ # --- Attempt API Call ---
184
+ try:
185
+ # 1. Prepare Input Audio (User's Original Logic)
186
+ base64_audio_data = encode_audio_base64(question_audio_path)
187
+ if not base64_audio_data:
188
+ # This is a data error, not an API error, fail the task immediately (User's Original Logic)
189
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Skipping GLM API call - missing input audio.")
190
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: Missing input audio]", "saved_audio_path": None}
191
+ input_audio_format = os.path.splitext(question_audio_path)[1].lstrip('.') or 'wav'
192
+
193
+
194
+ # 2. *** Flatten History and Construct Combined User Text Prompt (Method 2 Implementation) ***
195
+ text_parts = []
196
+ if history_messages:
197
+ print(f" (Row {row_idx}, Slot {slot_idx}) Flattening history ({len(history_messages)} turns) into prompt.")
198
+ text_parts.append("--- Start of Conversation History ---")
199
+ for msg in history_messages:
200
+ role_tag = "[User]" if msg['role'] == 'user' else "[Assistant]"
201
+ # Ensure content is string, handle potential non-string data defensively
202
+ content_str = str(msg.get('content', '')).strip()
203
+ if content_str: # Avoid adding empty messages
204
+ text_parts.append(f"{role_tag}: {content_str}")
205
+ text_parts.append("--- End of Conversation History ---")
206
+ text_parts.append("\n--- Current Task ---") # Clear separator
207
+ # Explicit instruction referencing history and audio
208
+ text_parts.append("Based on the conversation history above and the accompanying audio input, please respond to the following request:")
209
+ else:
210
+ # No history, just provide the current prompt directly
211
+ print(f" (Row {row_idx}, Slot {slot_idx}) No history found. Using prompt directly.")
212
+ text_parts.append("--- Current Task ---")
213
+ text_parts.append("Please respond to the following request based on the accompanying audio input:")
214
+
215
+ # Add the user's actual current request text
216
+ if prompt_text: # Only add if not empty
217
+ text_parts.append(prompt_text.strip())
218
+
219
+ combined_user_text = "\n".join(text_parts)
220
+ # --- End Flattening Logic ---
221
+
222
+
223
+ # 3. Construct User Message Content List (Text + Audio)
224
+ user_content_list = [
225
+ {"type": "text", "text": combined_user_text}, # Use the combined text
226
+ {"type": "input_audio", "input_audio": {"data": base64_audio_data, "format": input_audio_format}}
227
+ ]
228
+
229
+ # 4. Construct Final Messages List (Only the single combined user message)
230
+ # This replaces the user's original 'messages = history_messages + [{"role": "user", "content": user_content_list}]'
231
+ messages = [{"role": "user", "content": user_content_list}]
232
+
233
+
234
+ # 5. Make API Call (User's Original Logic)
235
+ # Optional: print(f"Debug (Row {row_idx}, Slot {slot_idx}): Sending messages structure:\n{json.dumps(messages, indent=2, ensure_ascii=False)}")
236
+ response = local_glm_client.chat.completions.create(
237
+ model=GLM_MODEL_NAME,
238
+ messages=messages, # Send the single, combined user message
239
+ stream=False
240
+ # Add other parameters like temperature if the user had them originally (they didn't)
241
+ )
242
+
243
+ # 6. Process SUCCESSFUL Response (User's Original Logic -unchanged-)
244
+ if response and response.choices:
245
+ message = response.choices[0].message
246
+ collected_text = message.content
247
+ audio_info = getattr(message, 'audio', None) # Use getattr for safety as per user's original code
248
+ if audio_info and 'data' in audio_info:
249
+ audio_base64_string = audio_info['data']
250
+ try:
251
+ decoded_data = base64.b64decode(audio_base64_string)
252
+ if len(decoded_data) == 0: # Check after decode (User's Original Check)
253
+ print(f"Warning (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): GLM returned empty audio data.")
254
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None}
255
+
256
+ os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True)
257
+ # Soundfile saving logic (User's Original Logic -unchanged-)
258
+ with BytesIO(decoded_data) as bio:
259
+ try:
260
+ audio_data, samplerate = sf.read(bio, dtype='int16')
261
+ except Exception:
262
+ bio.seek(0) # Rewind buffer before trying float
263
+ try:
264
+ audio_data_float, samplerate = sf.read(bio, dtype='float32')
265
+ # Convert float to int16
266
+ audio_data = (audio_data_float * 32767).astype(np.int16)
267
+ except Exception as sf_read_err_float:
268
+ print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Soundfile failed to read audio data: {sf_read_err_float}")
269
+ # Return text, audio failed
270
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None}
271
+
272
+ # Use detected samplerate, fallback to configured rate if detection failed
273
+ write_samplerate = samplerate if samplerate > 0 else OUTPUT_AUDIO_SAMPLERATE
274
+ sf.write(output_audio_filepath, audio_data, write_samplerate)
275
+
276
+ # TASK SUCCEEDED!
277
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": output_audio_filepath}
278
+
279
+ except base64.binascii.Error as b64_e:
280
+ print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): GLM b64 decode failed: {b64_e}")
281
+ # Return text, audio failed
282
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None}
283
+ except Exception as e:
284
+ print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Saving GLM audio failed: {e}")
285
+ # Return text, audio failed
286
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None}
287
+ else: # No audio in successful text response (User's Original Logic)
288
+ print(f"Warning (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): No audio data in GLM response.")
289
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None}
290
+ else: # Invalid/empty successful response (User's Original Logic)
291
+ print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Invalid/empty GLM API response. Response: {response}")
292
+ # Treat as a retryable error for the task
293
+ retries += 1
294
+ time.sleep(API_RETRY_DELAY)
295
+ continue # Go to next iteration of while loop
296
+
297
+ # --- Handle API Errors (User's Original Logic -unchanged-) ---
298
+ except APIStatusError as e:
299
+ # --- Log the error details ---
300
+ print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): APIStatusError Encountered")
301
+ print(f" Status Code: {getattr(e, 'status_code', 'N/A')}") # Use getattr for safety
302
+ error_details = getattr(e, 'body', getattr(e, 'message', str(e)))
303
+ print(f" Error Details: {error_details}")
304
+ # --- End Logging ---
305
+
306
+ # Check for the specific "account overdue" error (User's Original Logic)
307
+ is_overdue_error = False
308
+ status_code = getattr(e, 'status_code', None)
309
+ # Adjust check to handle both 429 and potential 400 errors with code 1113 in body
310
+ if status_code == 429 or (status_code == 400 and '1113' in str(error_details)):
311
+ try:
312
+ error_body = {}
313
+ # Try parsing if details look like JSON
314
+ if isinstance(error_details, (str, bytes)) and error_details.strip().startswith('{'):
315
+ error_body = json.loads(error_details)
316
+ elif isinstance(error_details, dict):
317
+ error_body = error_details # If body is already a dict
318
+
319
+ if isinstance(error_body, dict) and str(error_body.get("error", {}).get("code", "")) == "1113":
320
+ is_overdue_error = True
321
+ except (json.JSONDecodeError, AttributeError):
322
+ # Can't parse body or access attributes, assume not the specific error for safety
323
+ pass
324
+ except Exception as parse_err:
325
+ print(f"Warning: Error parsing API error body: {parse_err}")
326
+
327
+ if is_overdue_error:
328
+ key_to_disable = current_api_key
329
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Account overdue (1113) for Key ...{key_to_disable[-6:]}. Disabling key.")
330
+ with key_lock:
331
+ disabled_keys.add(key_to_disable)
332
+ print(f" Disabled keys count: {len(disabled_keys)}/{len(ZHIPUAI_API_KEYS)}")
333
+
334
+ # Don't increment retries here, try getting a new key immediately
335
+ try:
336
+ current_api_key = get_next_active_key() # Get a new key
337
+ print(f" (Row {row_idx}, Slot {slot_idx}) Switched to new key ...{current_api_key[-6:]} for next attempt.")
338
+ local_glm_client = None # Force re-initialization with new key
339
+ continue # Go immediately to the next iteration of the while loop with the new key
340
+ except AllKeysDisabledError:
341
+ print(f"FATAL (Row {row_idx}, Slot {slot_idx}): All API keys are disabled after key ...{key_to_disable[-6:]} failed. Cannot retry task.")
342
+ # Return failure for this task as no keys are left
343
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: All Keys Disabled]", "saved_audio_path": None}
344
+
345
+ else:
346
+ # Other APIStatusError (rate limit, server error, etc.) - treat as retryable
347
+ retries += 1
348
+ print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): GLM API Call Attempt {retries}/{API_MAX_RETRIES} failed: HTTP {status_code}, {error_details}")
349
+ if retries < API_MAX_RETRIES:
350
+ time.sleep(API_RETRY_DELAY)
351
+ # Continue loop to retry with the *same* key (unless it was just disabled above)
352
+ continue
353
+ else:
354
+ print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Max retries reached after API error.")
355
+ # Return failure for the task
356
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Max retries after status error]", "saved_audio_path": None}
357
+
358
+ except Exception as e:
359
+ # Handle other unexpected errors during API call or processing (User's Original Logic)
360
+ retries += 1
361
+ print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Unexpected Error Attempt {retries}/{API_MAX_RETRIES}: {type(e).__name__} - {e}")
362
+ print(traceback.format_exc()) # Print traceback for unexpected errors
363
+ if retries < API_MAX_RETRIES:
364
+ time.sleep(API_RETRY_DELAY)
365
+ continue # Continue loop to retry
366
+ else:
367
+ print(f"Error (Row {row_idx}, Slot {slot_idx}, Key ...{current_api_key[-6:]}): Max retries reached after unexpected error.")
368
+ # Return failure for the task
369
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Max retries after unexpected error]", "saved_audio_path": None}
370
+
371
+ # If loop finishes without returning, max retries were hit (User's Original Logic)
372
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Task failed after {API_MAX_RETRIES} attempts (may include key switches).")
373
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Max retries reached]", "saved_audio_path": None}
374
+
375
+
376
+ # --- Main Processing Logic (User's Original Logic -unchanged-) ---
377
+
378
+ print("Loading dataset...")
379
+ try:
380
+ dataset = load_from_disk(INPUT_DATASET_DIR)
381
+ print(f"Dataset loaded successfully with {len(dataset)} rows from {INPUT_DATASET_DIR}.")
382
+ except Exception as e:
383
+ print(f"FATAL: Error loading dataset from {INPUT_DATASET_DIR}: {e}")
384
+ exit(1)
385
+
386
+ os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True)
387
+
388
+ # --- Pre-calculation Step for GLM (User's Original Logic -unchanged-) ---
389
+ print("Pre-calculating GLM tasks and assigning initial API keys...")
390
+ tasks_to_process = []
391
+ original_data = list(dataset) # Convert to list for easier updates later
392
+ initial_keys_available = True
393
+
394
+ for idx, row in enumerate(tqdm(original_data, desc="Scanning dataset for GLM tasks")):
395
+ if not initial_keys_available:
396
+ # Stop scanning if we know no keys are left
397
+ print("Stopping task scanning as no active keys are available.")
398
+ break
399
+
400
+ for i in range(1, 4):
401
+ model_key = f"model_{i}"
402
+ response_text_key = f"response_text_{i}"
403
+ prompt_text_key = f"prompt_text_{i}"
404
+ model_assigned = row.get(model_key)
405
+ # Check if response exists and is not empty string (User's original check was just existence)
406
+ response_text_exists = row.get(response_text_key) is not None and str(row.get(response_text_key)).strip() != ""
407
+
408
+
409
+ if model_assigned == "glm_voice" and not response_text_exists: # Check using configured model name
410
+ question_audio_path = row.get('question_audio')
411
+ # Add check if audio path exists on disk
412
+ if not question_audio_path or not os.path.exists(question_audio_path):
413
+ print(f"Warning (Row {idx}, Slot {i}): Skipping GLM task - Missing or invalid 'question_audio' path: {question_audio_path}")
414
+ continue # Skip this slot if audio is missing
415
+
416
+ # --- Get initial active API key (User's Original Logic) ---
417
+ try:
418
+ assigned_key = get_next_active_key()
419
+ except AllKeysDisabledError:
420
+ print("FATAL: All API keys are disabled during initial task scanning. Cannot proceed.")
421
+ initial_keys_available = False
422
+ break # Stop processing this row
423
+ except Exception as key_err:
424
+ print(f"FATAL: Error getting initial API key: {key_err}. Stopping.")
425
+ initial_keys_available = False
426
+ break
427
+ # ---
428
+
429
+ metadata_str = row.get('metadata', "{}")
430
+ source_dataset = row.get('source_dataset')
431
+ metadata = {}
432
+ try:
433
+ # Handle case where metadata might already be a dict or is a JSON string
434
+ if metadata_str and isinstance(metadata_str, str): metadata = json.loads(metadata_str)
435
+ elif isinstance(metadata_str, dict): metadata = metadata_str
436
+ except json.JSONDecodeError:
437
+ print(f"Warning (Row {idx}): Could not parse metadata string: {metadata_str[:100]}...")
438
+ pass # Continue with empty metadata
439
+
440
+ # Parse history here - it will be flattened later in the worker
441
+ history_messages = []
442
+ if source_dataset == 'ultra':
443
+ history_str = metadata.get('history', '')
444
+ if history_str: history_messages = parse_ultra_history(history_str)
445
+
446
+ unique_id = str(uuid.uuid4()).replace("-", "")
447
+ output_audio_filename = f"glm_r{idx}_s{i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}"
448
+ output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename)
449
+
450
+ task_info = {
451
+ "row_idx": idx,
452
+ "slot_idx": i,
453
+ "api_key": assigned_key, # Initial key
454
+ "history_messages": history_messages, # Pass the original parsed history
455
+ "prompt_text": row.get(prompt_text_key, ""),
456
+ "question_audio_path": question_audio_path,
457
+ "output_audio_filepath": output_audio_filepath,
458
+ }
459
+ tasks_to_process.append(task_info)
460
+ # Process only the first unfilled GLM slot found per row (User's Implicit Logic)
461
+ break # Stop checking slots for this row
462
+
463
+ if not initial_keys_available: break # Exit outer loop too
464
+
465
+ total_tasks = len(tasks_to_process)
466
+ if total_tasks == 0:
467
+ if not initial_keys_available:
468
+ print("No tasks processed because all initial keys were disabled.")
469
+ else:
470
+ print("No GLM Voice tasks found needing processing.")
471
+ exit(0)
472
+
473
+ print(f"Found {total_tasks} GLM Voice tasks to process using initially {len(ZHIPUAI_API_KEYS)} API keys.")
474
+ if len(disabled_keys) > 0: # Should be 0 here, but for safety
475
+ print(f"Note: {len(disabled_keys)} keys already marked as disabled (should not happen at this stage).")
476
+
477
+
478
+ # --- Threaded Execution for GLM (User's Original Logic -unchanged-) ---
479
+ print(f"Starting GLM processing with up to {MAX_WORKERS} worker threads...")
480
+ start_total_time = time.time()
481
+ results = {}
482
+ tasks_completed = 0
483
+ tasks_failed = 0
484
+ executor_shutdown = False # Flag to stop submitting new tasks if all keys die
485
+
486
+ # Use context manager for ThreadPoolExecutor
487
+ with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
488
+ # Create futures mapping back to task info for easier result merging
489
+ future_to_task = {executor.submit(call_glm_voice_api_worker, task): task for task in tasks_to_process}
490
+
491
+ for future in tqdm(concurrent.futures.as_completed(future_to_task), total=total_tasks, desc="Processing GLM tasks"):
492
+ task = future_to_task[future] # Get the original task info associated with this future
493
+ row_idx = task["row_idx"]
494
+ slot_idx = task["slot_idx"]
495
+ try:
496
+ result = future.result() # Get the result from the worker
497
+ results[(row_idx, slot_idx)] = result # Store result using (row, slot) tuple as key
498
+
499
+ # Check if the task failed because all keys got disabled during its execution
500
+ if result["response_text"] == "[ERROR: All Keys Disabled]" and not executor_shutdown:
501
+ print("\n--- CRITICAL: All Keys Disabled detected during execution. Stopping submission of new tasks. ---")
502
+ # Potentially cancel remaining futures if possible/desired
503
+ # Note: Standard ThreadPoolExecutor doesn't easily support cancelling submitted tasks
504
+ # We will just let running tasks finish but won't submit new ones if we had that logic.
505
+ # For now, just set flag and log.
506
+ executor_shutdown = True # Prevent theoretical resubmission logic
507
+ tasks_failed += 1 # Count this task as failed
508
+ # Check for other errors in the result text or missing audio path
509
+ elif result["saved_audio_path"] is None or "[ERROR" in result["response_text"]:
510
+ tasks_failed += 1
511
+ tasks_completed += 1
512
+
513
+ except Exception as exc: # Catch exceptions raised *by* the future (e.g., if worker itself crashes)
514
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): GLM Task generated an unhandled exception: {exc}")
515
+ print(traceback.format_exc())
516
+ # Store an error result
517
+ results[(row_idx, slot_idx)] = {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": f"[ERROR: Worker Crash - {type(exc).__name__}]", "saved_audio_path": None}
518
+ tasks_failed += 1
519
+ tasks_completed += 1
520
+ # No finally block needed here unless cleaning up future_to_task is desired
521
+
522
+
523
+ end_total_time = time.time()
524
+ print("\n--- GLM Processing Complete ---")
525
+ print(f"Total GLM tasks attempted: {tasks_completed} (Succeeded: {tasks_completed - tasks_failed}, Failed: {tasks_failed})")
526
+ print(f"Final disabled key count: {len(disabled_keys)}/{len(ZHIPUAI_API_KEYS)}")
527
+ print(f"Total GLM processing time: {(end_total_time - start_total_time)/60:.2f} minutes")
528
+
529
+
530
+ # --- Merge Results back into the dataset structure (User's Original Logic -unchanged-) ---
531
+ print("Merging GLM results...")
532
+ updated_data = original_data # Use the list created earlier
533
+ for (row_idx, slot_idx), result in tqdm(results.items(), desc="Merging GLM results"):
534
+ response_text_key = f"response_text_{slot_idx}"
535
+ response_audio_key = f"response_audio_path_{slot_idx}"
536
+ # Check index validity before updating
537
+ if 0 <= row_idx < len(updated_data):
538
+ # Ensure the item at the index is a dictionary (it should be if loaded from dataset)
539
+ if isinstance(updated_data[row_idx], dict):
540
+ updated_data[row_idx][response_text_key] = result["response_text"]
541
+ updated_data[row_idx][response_audio_key] = result["saved_audio_path"]
542
+ else:
543
+ print(f"Warning: Item at index {row_idx} is not a dictionary. Skipping merge for Slot {slot_idx}.")
544
+ else:
545
+ print(f"Warning: Invalid row index {row_idx} encountered during GLM result merge.")
546
+
547
+ # --- Save the final updated dataset (User's Original Logic -unchanged, including fallback) ---
548
+ if updated_data:
549
+ print(f"\nSaving updated dataset with GLM results to {OUTPUT_DATASET_DIR}...")
550
+ try:
551
+ # Use the features from the original loaded dataset if available
552
+ updated_dataset = Dataset.from_list(updated_data, features=dataset.features if dataset else None)
553
+ updated_dataset.save_to_disk(OUTPUT_DATASET_DIR)
554
+ print("Updated dataset saved successfully.")
555
+ except Exception as final_save_e:
556
+ print(f"Error saving final dataset using datasets lib: {final_save_e}")
557
+ print(f"Final disabled key count at save: {len(disabled_keys)}/{len(ZHIPUAI_API_KEYS)}")
558
+ print("Attempting to save as JSON lines as fallback...")
559
+ # Fallback to JSON Lines (User's original fallback logic)
560
+ output_jsonl_path = OUTPUT_DATASET_DIR.rstrip('/') + ".jsonl" # Ensure no trailing slash before adding extension
561
+ try:
562
+ with open(output_jsonl_path, 'w', encoding='utf-8') as f:
563
+ for item in updated_data:
564
+ # Attempt to make item JSON serializable
565
+ serializable_item = {}
566
+ for k, v in item.items():
567
+ if isinstance(v, (str, int, float, bool, list, dict)) or v is None:
568
+ serializable_item[k] = v
569
+ elif isinstance(v, np.ndarray):
570
+ serializable_item[k] = v.tolist() # Convert numpy arrays
571
+ else:
572
+ serializable_item[k] = str(v) # Convert other types to string as fallback
573
+ f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n')
574
+ print(f"Fallback save successful to {output_jsonl_path}")
575
+ except Exception as json_save_e:
576
+ print(f"Error saving as JSON lines: {json_save_e}")
577
+
578
+ else:
579
+ print("No data was available to save (potentially all keys disabled early or no tasks processed).")
r1-a/response_generation/gpt4o.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import uuid
5
+ import time
6
+ import re
7
+ import random
8
+ import concurrent.futures
9
+ from tqdm import tqdm
10
+ import threading
11
+ import traceback # For detailed error logging
12
+
13
+ import requests # Use requests library for HTTP calls
14
+ # Make sure numpy is imported if needed for potential fallback serialization
15
+ import numpy as np
16
+ from datasets import load_from_disk, Dataset
17
+ from dotenv import load_dotenv
18
+
19
+ # --- Configuration ---
20
+ load_dotenv()
21
+
22
+ # 1. API Client Setup
23
+ GPT4O_MODEL_NAME = "gpt4o" # How it's identified in your dataset's model columns
24
+ API_MODEL_NAME = "gpt-4o-audio-preview" # Actual model name for the API call
25
+ API_ENDPOINT = "https://api.vansai.cn/v1/chat/completions"
26
+ try:
27
+ # Assuming a single key for this service based on the original script
28
+ API_TOKEN = "sk-uOJ27X9jNsYh1PDx1e665b0f92434bEc9bD53bE6D3BaD29a"
29
+ if not API_TOKEN:
30
+ raise ValueError("AIGCBEST_API_KEY environment variable not set.")
31
+ print("AIGCBEST API Key loaded.")
32
+ except Exception as e:
33
+ print(f"FATAL: Error getting API Key: {e}")
34
+ exit(1)
35
+
36
+ # 2. Dataset Paths
37
+ INPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_sampling_tasks"
38
+ OUTPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_tasks_with_gpt4o"
39
+
40
+ # 3. Output Audio Configuration
41
+ OUTPUT_AUDIO_ROOT_DIR = "/root/autodl-tmp/audio-r1/r1-a/generated_audio/gpt4o_2"
42
+ OUTPUT_AUDIO_FORMAT = "wav" # API will be requested to return wav
43
+ AVAILABLE_VOICES = ['alloy', 'ash', 'ballad', 'coral', 'echo', 'sage', 'shimmer', 'verse']
44
+
45
+ # 4. API Call Settings
46
+ API_TIMEOUT = 120
47
+ API_RETRY_DELAY = 5
48
+ API_MAX_RETRIES = 3 # Max attempts *for the task*
49
+ MAX_WORKERS = 8 # Adjust based on API rate limits and system resources
50
+
51
+ # 5. Checkpoint Saving Configuration # <-- NEW
52
+ CHECKPOINT_INTERVAL = 500 # Save every 500 completed tasks
53
+
54
+ # --- Helper Functions (encode_audio_base64 and parse_ultra_history remain the same) ---
55
+
56
+ def encode_audio_base64(audio_path):
57
+ if not audio_path or not os.path.exists(audio_path):
58
+ print(f"Warning: Input audio file not found or path is empty: {audio_path}")
59
+ return None
60
+ try:
61
+ with open(audio_path, "rb") as audio_file:
62
+ return base64.b64encode(audio_file.read()).decode("utf-8")
63
+ except Exception as e:
64
+ print(f"Error encoding audio file {audio_path}: {e}")
65
+ return None
66
+
67
+ def parse_ultra_history(history_str):
68
+ messages = []
69
+ pattern = re.compile(r"\[(USER|ASSISTANT)\]\s*([\s\S]*?)(?=\s*\[(?:USER|ASSISTANT)\]|$)")
70
+ matches = pattern.findall(history_str)
71
+ if not matches:
72
+ return []
73
+ for role_tag, content in matches:
74
+ role = role_tag.lower()
75
+ cleaned_content = content.strip()
76
+ if cleaned_content:
77
+ messages.append({"role": role, "content": cleaned_content})
78
+ return messages
79
+
80
+ # --- Modified API Call Worker Function for GPT-4o (Reduced Prints) ---
81
+ def call_gpt4o_api_worker(task_info):
82
+ """
83
+ Worker function to call the custom GPT-4o API for a single task.
84
+ """
85
+ row_idx = task_info["row_idx"]
86
+ slot_idx = task_info["slot_idx"]
87
+ history_messages = task_info["history_messages"]
88
+ prompt_text = task_info["prompt_text"]
89
+ question_text = task_info["question_text"]
90
+ question_audio_path = task_info["question_audio_path"]
91
+ output_audio_filepath = task_info["output_audio_filepath"]
92
+
93
+ retries = 0
94
+ headers = {
95
+ 'Accept': 'application/json',
96
+ 'Authorization': f'Bearer {API_TOKEN}', # Use the single loaded token
97
+ 'Content-Type': 'application/json'
98
+ }
99
+ selected_voice = random.choice(AVAILABLE_VOICES)
100
+ # print(f" [Thread-{threading.get_ident()}] Processing Row {row_idx}, Slot {slot_idx} (GPT4o Voice: {selected_voice})") # Optional log
101
+
102
+ while retries < API_MAX_RETRIES:
103
+ try:
104
+ # 1. Prepare Input Audio
105
+ base64_audio_data = encode_audio_base64(question_audio_path)
106
+ if not base64_audio_data:
107
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Skipping GPT4o API call - missing input audio.")
108
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: Missing input audio]", "saved_audio_path": None}
109
+
110
+ input_audio_format = os.path.splitext(question_audio_path)[1].lstrip('.') or 'wav'
111
+
112
+ # 2. Construct User Message Content
113
+ combined_text = f"{prompt_text}"
114
+ user_content_list = [
115
+ {"type": "text", "text": combined_text},
116
+ {"type": "input_audio", "input_audio": {"data": base64_audio_data, "format": input_audio_format}}
117
+ ]
118
+ messages = history_messages + [{"role": "user", "content": user_content_list}]
119
+
120
+ # 4. Construct Payload
121
+ payload = {
122
+ "model": API_MODEL_NAME,
123
+ "modalities": ["text", "audio"],
124
+ "audio": {"voice": selected_voice, "format": OUTPUT_AUDIO_FORMAT},
125
+ "messages": messages
126
+ }
127
+
128
+ # 5. Make API Call
129
+ response = requests.post(
130
+ API_ENDPOINT,
131
+ headers=headers,
132
+ json=payload,
133
+ timeout=API_TIMEOUT
134
+ )
135
+
136
+ # 6. Process Response
137
+ if response.status_code == 200:
138
+ try:
139
+ response_data = response.json()
140
+ # Make parsing more robust
141
+ choices = response_data.get('choices')
142
+ if not choices or not isinstance(choices, list) or len(choices) == 0:
143
+ raise ValueError("Invalid or empty 'choices' field in response.")
144
+
145
+ message_content = choices[0].get('message', {})
146
+ if not message_content:
147
+ raise ValueError("Missing 'message' field in the first choice.")
148
+
149
+ audio_info = message_content.get('audio', {})
150
+ if not isinstance(audio_info, dict): audio_info = {} # Handle case where audio might be null or not a dict
151
+
152
+ audio_base64_string = audio_info.get('data', '')
153
+ # Try getting text from 'content' if 'transcript' is missing/empty in 'audio'
154
+ collected_text = audio_info.get('transcript', '').strip()
155
+ if not collected_text:
156
+ text_content_list = message_content.get('content', [])
157
+ if isinstance(text_content_list, list):
158
+ for item in text_content_list:
159
+ if isinstance(item, dict) and item.get("type") == "text":
160
+ collected_text = item.get("text", "").strip()
161
+ break # Take the first text part found
162
+ # Still no text? Try the top-level message content directly if it's a string
163
+ elif isinstance(message_content.get('content'), str):
164
+ collected_text = message_content['content'].strip()
165
+
166
+ if not collected_text: print(f"Warning (Row {row_idx}, Slot {slot_idx}): No text content found after checking multiple fields.")
167
+ if not audio_base64_string: print(f"Warning (Row {row_idx}, Slot {slot_idx}): No audio data found.")
168
+
169
+ saved_audio_path = None
170
+ if audio_base64_string:
171
+ try:
172
+ wav_bytes = base64.b64decode(audio_base64_string)
173
+ if len(wav_bytes) == 0:
174
+ print(f"Warning (Row {row_idx}, Slot {slot_idx}): Decoded audio bytes are empty.")
175
+ else:
176
+ os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True)
177
+ with open(output_audio_filepath, "wb") as f:
178
+ f.write(wav_bytes)
179
+ saved_audio_path = output_audio_filepath
180
+ # print(f" Audio saved to: {output_audio_filepath}") # Less verbose log
181
+ except base64.binascii.Error as b64_err:
182
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Decoding base64 audio data failed: {b64_err}")
183
+ except Exception as e:
184
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Saving audio file failed: {e}")
185
+
186
+ # TASK SUCCEEDED (even if audio saving failed, text might be valid)
187
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text, "saved_audio_path": saved_audio_path}
188
+
189
+ except (json.JSONDecodeError, IndexError, KeyError, TypeError, ValueError) as e:
190
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Parsing successful API response failed: {type(e).__name__} - {e}")
191
+ print(f" Response Text (start): {response.text[:500]}...")
192
+ retries += 1
193
+ print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
194
+ time.sleep(API_RETRY_DELAY)
195
+ continue
196
+ except Exception as e: # Catch-all for unexpected errors during processing
197
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Unexpected error processing response: {e}")
198
+ print(traceback.format_exc())
199
+ retries += 1
200
+ print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
201
+ time.sleep(API_RETRY_DELAY)
202
+ continue
203
+
204
+ else: # Handle non-200 status codes
205
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): API returned status {response.status_code}. Response: {response.text[:500]}...")
206
+ retries += 1
207
+ if retries < API_MAX_RETRIES:
208
+ print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
209
+ time.sleep(API_RETRY_DELAY)
210
+ continue # Go to next iteration of while loop
211
+ else:
212
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after status {response.status_code}.")
213
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": f"[API ERROR: Status {response.status_code}]", "saved_audio_path": None}
214
+
215
+
216
+ except requests.exceptions.Timeout:
217
+ retries += 1
218
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): API Call Attempt {retries}/{API_MAX_RETRIES} timed out after {API_TIMEOUT}s.")
219
+ if retries < API_MAX_RETRIES:
220
+ time.sleep(API_RETRY_DELAY)
221
+ continue
222
+ else:
223
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after timeout.")
224
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Timeout]", "saved_audio_path": None}
225
+
226
+ except requests.exceptions.RequestException as e:
227
+ retries += 1
228
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Network/Request Error Attempt {retries}/{API_MAX_RETRIES}: {e}")
229
+ if retries < API_MAX_RETRIES:
230
+ time.sleep(API_RETRY_DELAY)
231
+ continue
232
+ else:
233
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after network error.")
234
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Network Error]", "saved_audio_path": None}
235
+
236
+ except Exception as e:
237
+ retries += 1
238
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Unexpected Error in Worker Loop Attempt {retries}/{API_MAX_RETRIES}: {type(e).__name__} - {e}")
239
+ print(traceback.format_exc())
240
+ if retries < API_MAX_RETRIES:
241
+ time.sleep(API_RETRY_DELAY)
242
+ continue
243
+ else:
244
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after unexpected error.")
245
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Unexpected Worker Error]", "saved_audio_path": None}
246
+
247
+ # If loop finishes without returning, max retries were hit
248
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Task failed after {API_MAX_RETRIES} attempts.")
249
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Max retries reached]", "saved_audio_path": None}
250
+
251
+ # --- Checkpoint Saving Function --- # <-- NEW (Copied from previous response)
252
+ def save_checkpoint(data_to_save, output_dir, dataset_features):
253
+ """Saves the current state of the data to disk."""
254
+ if not data_to_save:
255
+ print("Checkpoint: No data available to save.")
256
+ return
257
+
258
+ # Ensure output directory exists before saving
259
+ os.makedirs(output_dir, exist_ok=True)
260
+
261
+ print(f"\nCheckpoint: Saving {len(data_to_save)} rows to {output_dir}...")
262
+ try:
263
+ # Convert list of dicts back to Dataset object
264
+ checkpoint_dataset = Dataset.from_list(data_to_save, features=dataset_features)
265
+ checkpoint_dataset.save_to_disk(output_dir)
266
+ print(f"Checkpoint: Saved successfully to {output_dir}")
267
+ except Exception as ckpt_save_e:
268
+ print(f"Error saving checkpoint dataset using datasets lib: {ckpt_save_e}")
269
+ # Fallback to JSON Lines (optional, but good practice)
270
+ output_jsonl_path = os.path.join(output_dir, "checkpoint_data.jsonl") # Save inside the dir
271
+ print(f"Attempting to save checkpoint as JSON lines to {output_jsonl_path}...")
272
+ try:
273
+ with open(output_jsonl_path, 'w', encoding='utf-8') as f:
274
+ for item in data_to_save:
275
+ # Basic serialization handling for common types like numpy arrays
276
+ serializable_item = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in item.items()}
277
+ f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n')
278
+ print(f"Checkpoint: Fallback save successful to {output_jsonl_path}")
279
+ except Exception as json_save_e:
280
+ print(f"Error saving checkpoint as JSON lines: {json_save_e}")
281
+
282
+
283
+ # --- Main Processing Logic ---
284
+
285
+ print("Checking for existing checkpoint/output dataset...")
286
+ dataset = None
287
+ original_features = None # Initialize
288
+
289
+ try:
290
+ # 检查输出目录是否存在,并且看起来像一个 Hugging Face datasets 目录
291
+ # (dataset_info.json 或 state.json 是常见的指示文件)
292
+ potential_checkpoint_info = os.path.join(OUTPUT_DATASET_DIR, "dataset_info.json")
293
+ potential_checkpoint_state = os.path.join(OUTPUT_DATASET_DIR, "state.json")
294
+
295
+ if os.path.exists(OUTPUT_DATASET_DIR) and \
296
+ (os.path.exists(potential_checkpoint_info) or os.path.exists(potential_checkpoint_state)):
297
+
298
+ print(f"Attempting to load existing data from output directory: {OUTPUT_DATASET_DIR}")
299
+ try:
300
+ dataset = load_from_disk(OUTPUT_DATASET_DIR)
301
+ original_features = dataset.features # 获取已保存数据集的特征
302
+ print(f"Successfully resumed from {OUTPUT_DATASET_DIR}. Loaded {len(dataset)} rows.")
303
+ except Exception as load_ckpt_e:
304
+ print(f"Warning: Failed to load from {OUTPUT_DATASET_DIR}: {load_ckpt_e}")
305
+ print("Falling back to loading original input dataset.")
306
+ dataset = None # Ensure we proceed to load original if checkpoint load failed
307
+ else:
308
+ print(f"No valid existing data found in {OUTPUT_DATASET_DIR}.")
309
+ # If no checkpoint, ensure dataset is None so original loading happens
310
+
311
+ # 如果 dataset 仍然是 None (因为没有找到 checkpoint 或加载失败)
312
+ if dataset is None:
313
+ print(f"Loading original dataset from {INPUT_DATASET_DIR}...")
314
+ dataset = load_from_disk(INPUT_DATASET_DIR)
315
+ original_features = dataset.features
316
+ print(f"Original dataset loaded successfully with {len(dataset)} rows.")
317
+
318
+ except Exception as initial_load_e:
319
+ print(f"FATAL: Error during initial dataset loading (original or checkpoint): {initial_load_e}")
320
+ print(traceback.format_exc()) # 打印详细错误
321
+ exit(1)
322
+
323
+ os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True)
324
+
325
+ # --- Pre-calculation Step for GPT-4o ---
326
+ print("Pre-calculating GPT-4o tasks...")
327
+ tasks_to_process = []
328
+ # Use a list of dictionaries, which is mutable and easier for direct updates
329
+ updated_data = list(dataset)
330
+
331
+ for idx, row in enumerate(tqdm(updated_data, desc="Scanning dataset for GPT-4o tasks")):
332
+ for i in range(1, 4):
333
+ model_key = f"model_{i}"
334
+ response_text_key = f"response_text_{i}"
335
+ prompt_text_key = f"prompt_text_{i}"
336
+ response_audio_key = f"response_audio_path_{i}" # Key for storing the *new* audio path
337
+
338
+ model_assigned = row.get(model_key)
339
+ response_text_exists = row.get(response_text_key) is not None
340
+
341
+ # Check for the specific model name used in the dataset
342
+ if model_assigned == GPT4O_MODEL_NAME and not response_text_exists:
343
+ question_audio_path = row.get('question_audio')
344
+ if not question_audio_path or not os.path.exists(question_audio_path): # Check path validity here
345
+ print(f"Warning (Row {idx}, Slot {i}): Skipping GPT-4o task - Missing or invalid 'question_audio' path: {question_audio_path}")
346
+ # Pre-fill error? Let's just skip task creation for now.
347
+ # If needed: updated_data[idx][response_text_key] = "[ERROR: Missing input audio]"
348
+ # If needed: updated_data[idx][response_audio_key] = None
349
+ continue # Skip this task
350
+
351
+ metadata_str = row.get('metadata', "{}")
352
+ source_dataset = row.get('source_dataset')
353
+ metadata = {}
354
+ try:
355
+ if metadata_str and isinstance(metadata_str, str): metadata = json.loads(metadata_str)
356
+ elif isinstance(metadata_str, dict): metadata = metadata_str
357
+ except json.JSONDecodeError: pass
358
+
359
+ history_messages = []
360
+ if source_dataset == 'ultra':
361
+ history_str = metadata.get('history', '')
362
+ if history_str: history_messages = parse_ultra_history(history_str)
363
+
364
+ unique_id = str(uuid.uuid4()).replace("-", "")
365
+ output_audio_filename = f"gpt4o_r{idx}_s{i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}"
366
+ output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename)
367
+
368
+ task_info = {
369
+ "row_idx": idx,
370
+ "slot_idx": i,
371
+ # No API key needed here as it's global/single
372
+ "history_messages": history_messages,
373
+ "prompt_text": row.get(prompt_text_key, ""),
374
+ "question_text": row.get('question_text', ""), # Pass question text
375
+ "question_audio_path": question_audio_path,
376
+ "output_audio_filepath": output_audio_filepath,
377
+ }
378
+ tasks_to_process.append(task_info)
379
+ # Decide if you process all slots or just the first unfilled one
380
+ # break # Uncomment this line if you only want the *first* unfilled gpt4o slot per row processed
381
+
382
+ total_tasks = len(tasks_to_process)
383
+ if total_tasks == 0:
384
+ print("No GPT-4o tasks found needing processing.")
385
+ exit(0)
386
+
387
+ print(f"Found {total_tasks} GPT-4o tasks to process.")
388
+
389
+ # --- Threaded Execution with Checkpointing for GPT-4o --- # <-- MODIFIED SECTION
390
+ print(f"Starting GPT-4o processing with up to {MAX_WORKERS} worker threads...")
391
+ start_total_time = time.time()
392
+ # results = {} # No longer needed
393
+ tasks_completed = 0
394
+ tasks_failed = 0
395
+ completed_since_last_save = 0 # <-- Counter for checkpointing
396
+
397
+ # Use context manager for ThreadPoolExecutor
398
+ with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
399
+ future_to_task = {executor.submit(call_gpt4o_api_worker, task): task for task in tasks_to_process}
400
+
401
+ for future in tqdm(concurrent.futures.as_completed(future_to_task), total=total_tasks, desc="Processing GPT-4o tasks"):
402
+ task_info = future_to_task[future] # Get original task info
403
+ row_idx = task_info["row_idx"]
404
+ slot_idx = task_info["slot_idx"]
405
+ result = None # Define result scope
406
+
407
+ try:
408
+ result = future.result()
409
+ # --- Direct Update and Checkpointing Logic ---
410
+ response_text_key = f"response_text_{slot_idx}"
411
+ response_audio_key = f"response_audio_path_{slot_idx}"
412
+
413
+ if 0 <= row_idx < len(updated_data):
414
+ updated_data[row_idx][response_text_key] = result["response_text"]
415
+ updated_data[row_idx][response_audio_key] = result["saved_audio_path"]
416
+ if result["saved_audio_path"] is None or "[ERROR" in result["response_text"]: # Check for error marker
417
+ tasks_failed += 1
418
+ else:
419
+ print(f"Warning: Invalid row index {row_idx} encountered during result merge. Skipping update.")
420
+ tasks_failed += 1 # Count as failed if index is bad
421
+
422
+ tasks_completed += 1
423
+ completed_since_last_save += 1 # Increment checkpoint counter
424
+
425
+ # Check if it's time to save a checkpoint
426
+ if completed_since_last_save >= CHECKPOINT_INTERVAL:
427
+ save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
428
+ completed_since_last_save = 0 # Reset counter
429
+
430
+ except Exception as exc: # Catch exceptions raised *by* the future/worker if not handled inside
431
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): GPT-4o Task generated an unhandled exception: {exc}")
432
+ print(traceback.format_exc())
433
+ # Attempt to record error in the main data structure
434
+ response_text_key = f"response_text_{slot_idx}"
435
+ response_audio_key = f"response_audio_path_{slot_idx}"
436
+ if 0 <= row_idx < len(updated_data):
437
+ updated_data[row_idx][response_text_key] = f"[ERROR: Worker Crash - {exc}]"
438
+ updated_data[row_idx][response_audio_key] = None
439
+ else:
440
+ print(f"Warning: Invalid row index {row_idx} encountered during exception handling merge.")
441
+
442
+ tasks_failed += 1
443
+ tasks_completed += 1 # Count as completed (though failed)
444
+ completed_since_last_save += 1 # Also increment for checkpointing
445
+
446
+ # Check if it's time to save a checkpoint even after an error
447
+ if completed_since_last_save >= CHECKPOINT_INTERVAL:
448
+ save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
449
+ completed_since_last_save = 0 # Reset counter
450
+
451
+ end_total_time = time.time()
452
+ print("\n--- GPT-4o Processing Complete ---")
453
+ print(f"Total GPT-4o tasks processed: {tasks_completed} (Succeeded: {tasks_completed - tasks_failed}, Failed: {tasks_failed})")
454
+ print(f"Total GPT-4o processing time: {(end_total_time - start_total_time)/60:.2f} minutes")
455
+
456
+
457
+ # --- Final Save ---
458
+ # Save one last time to ensure any remaining processed items (< CHECKPOINT_INTERVAL) are saved
459
+ print("\nPerforming final save...")
460
+ save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
461
+
462
+ print("\nScript finished.")
463
+
464
+ # --- (Removed the old merging and saving logic as it's now handled by save_checkpoint) ---
r1-a/response_generation/gpt4o_mini.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import uuid
5
+ import time
6
+ import re
7
+ import random
8
+ import concurrent.futures
9
+ from tqdm import tqdm
10
+ import threading
11
+ import traceback # For detailed error logging
12
+
13
+ import requests # Use requests library for HTTP calls
14
+ # Make sure numpy is imported if needed for potential fallback serialization
15
+ import numpy as np
16
+ from datasets import load_from_disk, Dataset
17
+ from dotenv import load_dotenv
18
+
19
+ # --- Configuration ---
20
+ load_dotenv()
21
+
22
+ # 1. API Client Setup
23
+ GPT4O_MODEL_NAME = "freeze_omni" # How it's identified in your dataset's model columns
24
+ API_MODEL_NAME = "gpt-4o-mini-audio-preview" # Actual model name for the API call
25
+ API_ENDPOINT = "https://api2.aigcbest.top/v1/chat/completions"
26
+ try:
27
+ # Assuming a single key for this service based on the original script
28
+ API_TOKEN = "sk-J6Y4OBCEG0D75suEZoj22eFmiwO1DHzLCqvt4bRmyZRTMlTa"
29
+ if not API_TOKEN:
30
+ raise ValueError("AIGCBEST_API_KEY environment variable not set.")
31
+ print("AIGCBEST API Key loaded.")
32
+ except Exception as e:
33
+ print(f"FATAL: Error getting API Key: {e}")
34
+ exit(1)
35
+
36
+ # 2. Dataset Paths
37
+ INPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_sampling_tasks"
38
+ OUTPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_tasks_with_gpt4o_mini"
39
+
40
+ # 3. Output Audio Configuration
41
+ OUTPUT_AUDIO_ROOT_DIR = "/root/autodl-tmp/audio-r1/r1-a/generated_audio/gpt4o_mini"
42
+ OUTPUT_AUDIO_FORMAT = "wav" # API will be requested to return wav
43
+ AVAILABLE_VOICES = ['alloy', 'ash', 'ballad', 'coral', 'echo', 'sage', 'shimmer', 'verse']
44
+
45
+ # 4. API Call Settings
46
+ API_TIMEOUT = 240
47
+ API_RETRY_DELAY = 5
48
+ API_MAX_RETRIES = 3 # Max attempts *for the task*
49
+ MAX_WORKERS = 8 # Adjust based on API rate limits and system resources
50
+
51
+ # 5. Checkpoint Saving Configuration # <-- NEW
52
+ CHECKPOINT_INTERVAL = 500 # Save every 500 completed tasks
53
+
54
+ # --- Helper Functions (encode_audio_base64 and parse_ultra_history remain the same) ---
55
+
56
+ def encode_audio_base64(audio_path):
57
+ if not audio_path or not os.path.exists(audio_path):
58
+ print(f"Warning: Input audio file not found or path is empty: {audio_path}")
59
+ return None
60
+ try:
61
+ with open(audio_path, "rb") as audio_file:
62
+ return base64.b64encode(audio_file.read()).decode("utf-8")
63
+ except Exception as e:
64
+ print(f"Error encoding audio file {audio_path}: {e}")
65
+ return None
66
+
67
+ def parse_ultra_history(history_str):
68
+ messages = []
69
+ pattern = re.compile(r"\[(USER|ASSISTANT)\]\s*([\s\S]*?)(?=\s*\[(?:USER|ASSISTANT)\]|$)")
70
+ matches = pattern.findall(history_str)
71
+ if not matches:
72
+ return []
73
+ for role_tag, content in matches:
74
+ role = role_tag.lower()
75
+ cleaned_content = content.strip()
76
+ if cleaned_content:
77
+ messages.append({"role": role, "content": cleaned_content})
78
+ return messages
79
+
80
+ # --- Modified API Call Worker Function for GPT-4o (Reduced Prints) ---
81
+ def call_gpt4o_api_worker(task_info):
82
+ """
83
+ Worker function to call the custom GPT-4o API for a single task.
84
+ """
85
+ row_idx = task_info["row_idx"]
86
+ slot_idx = task_info["slot_idx"]
87
+ history_messages = task_info["history_messages"]
88
+ prompt_text = task_info["prompt_text"]
89
+ question_text = task_info["question_text"]
90
+ question_audio_path = task_info["question_audio_path"]
91
+ output_audio_filepath = task_info["output_audio_filepath"]
92
+
93
+ retries = 0
94
+ headers = {
95
+ 'Accept': 'application/json',
96
+ 'Authorization': f'Bearer {API_TOKEN}', # Use the single loaded token
97
+ 'Content-Type': 'application/json'
98
+ }
99
+ selected_voice = random.choice(AVAILABLE_VOICES)
100
+ # print(f" [Thread-{threading.get_ident()}] Processing Row {row_idx}, Slot {slot_idx} (GPT4o Voice: {selected_voice})") # Optional log
101
+
102
+ while retries < API_MAX_RETRIES:
103
+ try:
104
+ # 1. Prepare Input Audio
105
+ base64_audio_data = encode_audio_base64(question_audio_path)
106
+ if not base64_audio_data:
107
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Skipping GPT4o API call - missing input audio.")
108
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: Missing input audio]", "saved_audio_path": None}
109
+
110
+ input_audio_format = os.path.splitext(question_audio_path)[1].lstrip('.') or 'wav'
111
+
112
+ # 2. Construct User Message Content
113
+ combined_text = f"{prompt_text}"
114
+ user_content_list = [
115
+ {"type": "text", "text": combined_text},
116
+ {"type": "input_audio", "input_audio": {"data": base64_audio_data, "format": input_audio_format}}
117
+ ]
118
+ messages = history_messages + [{"role": "user", "content": user_content_list}]
119
+
120
+ # 4. Construct Payload
121
+ payload = {
122
+ "model": API_MODEL_NAME,
123
+ "modalities": ["text", "audio"],
124
+ "audio": {"voice": selected_voice, "format": OUTPUT_AUDIO_FORMAT},
125
+ "messages": messages
126
+ }
127
+
128
+ # 5. Make API Call
129
+ response = requests.post(
130
+ API_ENDPOINT,
131
+ headers=headers,
132
+ json=payload,
133
+ timeout=API_TIMEOUT
134
+ )
135
+
136
+ # 6. Process Response
137
+ if response.status_code == 200:
138
+ try:
139
+ response_data = response.json()
140
+ # Make parsing more robust
141
+ choices = response_data.get('choices')
142
+ if not choices or not isinstance(choices, list) or len(choices) == 0:
143
+ raise ValueError("Invalid or empty 'choices' field in response.")
144
+
145
+ message_content = choices[0].get('message', {})
146
+ if not message_content:
147
+ raise ValueError("Missing 'message' field in the first choice.")
148
+
149
+ audio_info = message_content.get('audio', {})
150
+ if not isinstance(audio_info, dict): audio_info = {} # Handle case where audio might be null or not a dict
151
+
152
+ audio_base64_string = audio_info.get('data', '')
153
+ # Try getting text from 'content' if 'transcript' is missing/empty in 'audio'
154
+ collected_text = audio_info.get('transcript', '').strip()
155
+ if not collected_text:
156
+ text_content_list = message_content.get('content', [])
157
+ if isinstance(text_content_list, list):
158
+ for item in text_content_list:
159
+ if isinstance(item, dict) and item.get("type") == "text":
160
+ collected_text = item.get("text", "").strip()
161
+ break # Take the first text part found
162
+ # Still no text? Try the top-level message content directly if it's a string
163
+ elif isinstance(message_content.get('content'), str):
164
+ collected_text = message_content['content'].strip()
165
+
166
+ if not collected_text: print(f"Warning (Row {row_idx}, Slot {slot_idx}): No text content found after checking multiple fields.")
167
+ if not audio_base64_string: print(f"Warning (Row {row_idx}, Slot {slot_idx}): No audio data found.")
168
+
169
+ saved_audio_path = None
170
+ if audio_base64_string:
171
+ try:
172
+ wav_bytes = base64.b64decode(audio_base64_string)
173
+ if len(wav_bytes) == 0:
174
+ print(f"Warning (Row {row_idx}, Slot {slot_idx}): Decoded audio bytes are empty.")
175
+ else:
176
+ os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True)
177
+ with open(output_audio_filepath, "wb") as f:
178
+ f.write(wav_bytes)
179
+ saved_audio_path = output_audio_filepath
180
+ # print(f" Audio saved to: {output_audio_filepath}") # Less verbose log
181
+ except base64.binascii.Error as b64_err:
182
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Decoding base64 audio data failed: {b64_err}")
183
+ except Exception as e:
184
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Saving audio file failed: {e}")
185
+
186
+ # TASK SUCCEEDED (even if audio saving failed, text might be valid)
187
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text, "saved_audio_path": saved_audio_path}
188
+
189
+ except (json.JSONDecodeError, IndexError, KeyError, TypeError, ValueError) as e:
190
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Parsing successful API response failed: {type(e).__name__} - {e}")
191
+ print(f" Response Text (start): {response.text[:500]}...")
192
+ retries += 1
193
+ print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
194
+ time.sleep(API_RETRY_DELAY)
195
+ continue
196
+ except Exception as e: # Catch-all for unexpected errors during processing
197
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Unexpected error processing response: {e}")
198
+ print(traceback.format_exc())
199
+ retries += 1
200
+ print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
201
+ time.sleep(API_RETRY_DELAY)
202
+ continue
203
+
204
+ else: # Handle non-200 status codes
205
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): API returned status {response.status_code}. Response: {response.text[:500]}...")
206
+ retries += 1
207
+ if retries < API_MAX_RETRIES:
208
+ print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
209
+ time.sleep(API_RETRY_DELAY)
210
+ continue # Go to next iteration of while loop
211
+ else:
212
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after status {response.status_code}.")
213
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": f"[API ERROR: Status {response.status_code}]", "saved_audio_path": None}
214
+
215
+
216
+ except requests.exceptions.Timeout:
217
+ retries += 1
218
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): API Call Attempt {retries}/{API_MAX_RETRIES} timed out after {API_TIMEOUT}s.")
219
+ if retries < API_MAX_RETRIES:
220
+ time.sleep(API_RETRY_DELAY)
221
+ continue
222
+ else:
223
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after timeout.")
224
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Timeout]", "saved_audio_path": None}
225
+
226
+ except requests.exceptions.RequestException as e:
227
+ retries += 1
228
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Network/Request Error Attempt {retries}/{API_MAX_RETRIES}: {e}")
229
+ if retries < API_MAX_RETRIES:
230
+ time.sleep(API_RETRY_DELAY)
231
+ continue
232
+ else:
233
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after network error.")
234
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Network Error]", "saved_audio_path": None}
235
+
236
+ except Exception as e:
237
+ retries += 1
238
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Unexpected Error in Worker Loop Attempt {retries}/{API_MAX_RETRIES}: {type(e).__name__} - {e}")
239
+ print(traceback.format_exc())
240
+ if retries < API_MAX_RETRIES:
241
+ time.sleep(API_RETRY_DELAY)
242
+ continue
243
+ else:
244
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after unexpected error.")
245
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Unexpected Worker Error]", "saved_audio_path": None}
246
+
247
+ # If loop finishes without returning, max retries were hit
248
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Task failed after {API_MAX_RETRIES} attempts.")
249
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Max retries reached]", "saved_audio_path": None}
250
+
251
+ # --- Checkpoint Saving Function --- # <-- NEW (Copied from previous response)
252
+ def save_checkpoint(data_to_save, output_dir, dataset_features):
253
+ """Saves the current state of the data to disk."""
254
+ if not data_to_save:
255
+ print("Checkpoint: No data available to save.")
256
+ return
257
+
258
+ # Ensure output directory exists before saving
259
+ os.makedirs(output_dir, exist_ok=True)
260
+
261
+ print(f"\nCheckpoint: Saving {len(data_to_save)} rows to {output_dir}...")
262
+ try:
263
+ # Convert list of dicts back to Dataset object
264
+ checkpoint_dataset = Dataset.from_list(data_to_save, features=dataset_features)
265
+ checkpoint_dataset.save_to_disk(output_dir)
266
+ print(f"Checkpoint: Saved successfully to {output_dir}")
267
+ except Exception as ckpt_save_e:
268
+ print(f"Error saving checkpoint dataset using datasets lib: {ckpt_save_e}")
269
+ # Fallback to JSON Lines (optional, but good practice)
270
+ output_jsonl_path = os.path.join(output_dir, "checkpoint_data.jsonl") # Save inside the dir
271
+ print(f"Attempting to save checkpoint as JSON lines to {output_jsonl_path}...")
272
+ try:
273
+ with open(output_jsonl_path, 'w', encoding='utf-8') as f:
274
+ for item in data_to_save:
275
+ # Basic serialization handling for common types like numpy arrays
276
+ serializable_item = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in item.items()}
277
+ f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n')
278
+ print(f"Checkpoint: Fallback save successful to {output_jsonl_path}")
279
+ except Exception as json_save_e:
280
+ print(f"Error saving checkpoint as JSON lines: {json_save_e}")
281
+
282
+
283
+ # --- Main Processing Logic ---
284
+
285
+ print("Checking for existing checkpoint/output dataset...")
286
+ dataset = None
287
+ original_features = None # Initialize
288
+
289
+ try:
290
+ # 检查输出目录是否存在,并且看起来像一个 Hugging Face datasets 目录
291
+ # (dataset_info.json 或 state.json 是常见的指示文件)
292
+ potential_checkpoint_info = os.path.join(OUTPUT_DATASET_DIR, "dataset_info.json")
293
+ potential_checkpoint_state = os.path.join(OUTPUT_DATASET_DIR, "state.json")
294
+
295
+ if os.path.exists(OUTPUT_DATASET_DIR) and \
296
+ (os.path.exists(potential_checkpoint_info) or os.path.exists(potential_checkpoint_state)):
297
+
298
+ print(f"Attempting to load existing data from output directory: {OUTPUT_DATASET_DIR}")
299
+ try:
300
+ dataset = load_from_disk(OUTPUT_DATASET_DIR)
301
+ original_features = dataset.features # 获取已保存数据集的特征
302
+ print(f"Successfully resumed from {OUTPUT_DATASET_DIR}. Loaded {len(dataset)} rows.")
303
+ except Exception as load_ckpt_e:
304
+ print(f"Warning: Failed to load from {OUTPUT_DATASET_DIR}: {load_ckpt_e}")
305
+ print("Falling back to loading original input dataset.")
306
+ dataset = None # Ensure we proceed to load original if checkpoint load failed
307
+ else:
308
+ print(f"No valid existing data found in {OUTPUT_DATASET_DIR}.")
309
+ # If no checkpoint, ensure dataset is None so original loading happens
310
+
311
+ # 如果 dataset 仍然是 None (因为没有找到 checkpoint 或加载失败)
312
+ if dataset is None:
313
+ print(f"Loading original dataset from {INPUT_DATASET_DIR}...")
314
+ dataset = load_from_disk(INPUT_DATASET_DIR)
315
+ original_features = dataset.features
316
+ print(f"Original dataset loaded successfully with {len(dataset)} rows.")
317
+
318
+ except Exception as initial_load_e:
319
+ print(f"FATAL: Error during initial dataset loading (original or checkpoint): {initial_load_e}")
320
+ print(traceback.format_exc()) # 打印详细错误
321
+ exit(1)
322
+
323
+ os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True)
324
+
325
+ # --- Pre-calculation Step for GPT-4o ---
326
+ print("Pre-calculating GPT-4o tasks...")
327
+ tasks_to_process = []
328
+ # Use a list of dictionaries, which is mutable and easier for direct updates
329
+ updated_data = list(dataset)
330
+
331
+ for idx, row in enumerate(tqdm(updated_data, desc="Scanning dataset for GPT-4o tasks")):
332
+ for i in range(1, 4):
333
+ model_key = f"model_{i}"
334
+ response_text_key = f"response_text_{i}"
335
+ prompt_text_key = f"prompt_text_{i}"
336
+ response_audio_key = f"response_audio_path_{i}" # Key for storing the *new* audio path
337
+
338
+ model_assigned = row.get(model_key)
339
+ response_text_exists = row.get(response_text_key) is not None
340
+
341
+ # Check for the specific model name used in the dataset
342
+ if model_assigned == GPT4O_MODEL_NAME and not response_text_exists:
343
+ question_audio_path = row.get('question_audio')
344
+ if not question_audio_path or not os.path.exists(question_audio_path): # Check path validity here
345
+ print(f"Warning (Row {idx}, Slot {i}): Skipping GPT-4o task - Missing or invalid 'question_audio' path: {question_audio_path}")
346
+ # Pre-fill error? Let's just skip task creation for now.
347
+ # If needed: updated_data[idx][response_text_key] = "[ERROR: Missing input audio]"
348
+ # If needed: updated_data[idx][response_audio_key] = None
349
+ continue # Skip this task
350
+
351
+ metadata_str = row.get('metadata', "{}")
352
+ source_dataset = row.get('source_dataset')
353
+ metadata = {}
354
+ try:
355
+ if metadata_str and isinstance(metadata_str, str): metadata = json.loads(metadata_str)
356
+ elif isinstance(metadata_str, dict): metadata = metadata_str
357
+ except json.JSONDecodeError: pass
358
+
359
+ history_messages = []
360
+ if source_dataset == 'ultra':
361
+ history_str = metadata.get('history', '')
362
+ if history_str: history_messages = parse_ultra_history(history_str)
363
+
364
+ unique_id = str(uuid.uuid4()).replace("-", "")
365
+ output_audio_filename = f"gpt4o_r{idx}_s{i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}"
366
+ output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename)
367
+
368
+ task_info = {
369
+ "row_idx": idx,
370
+ "slot_idx": i,
371
+ # No API key needed here as it's global/single
372
+ "history_messages": history_messages,
373
+ "prompt_text": row.get(prompt_text_key, ""),
374
+ "question_text": row.get('question_text', ""), # Pass question text
375
+ "question_audio_path": question_audio_path,
376
+ "output_audio_filepath": output_audio_filepath,
377
+ }
378
+ tasks_to_process.append(task_info)
379
+ # Decide if you process all slots or just the first unfilled one
380
+ # break # Uncomment this line if you only want the *first* unfilled gpt4o slot per row processed
381
+
382
+ total_tasks = len(tasks_to_process)
383
+ if total_tasks == 0:
384
+ print("No GPT-4o tasks found needing processing.")
385
+ exit(0)
386
+
387
+ print(f"Found {total_tasks} GPT-4o tasks to process.")
388
+
389
+ # --- Threaded Execution with Checkpointing for GPT-4o --- # <-- MODIFIED SECTION
390
+ print(f"Starting GPT-4o processing with up to {MAX_WORKERS} worker threads...")
391
+ start_total_time = time.time()
392
+ # results = {} # No longer needed
393
+ tasks_completed = 0
394
+ tasks_failed = 0
395
+ completed_since_last_save = 0 # <-- Counter for checkpointing
396
+
397
+ # Use context manager for ThreadPoolExecutor
398
+ with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
399
+ future_to_task = {executor.submit(call_gpt4o_api_worker, task): task for task in tasks_to_process}
400
+
401
+ for future in tqdm(concurrent.futures.as_completed(future_to_task), total=total_tasks, desc="Processing GPT-4o tasks"):
402
+ task_info = future_to_task[future] # Get original task info
403
+ row_idx = task_info["row_idx"]
404
+ slot_idx = task_info["slot_idx"]
405
+ result = None # Define result scope
406
+
407
+ try:
408
+ result = future.result()
409
+ # --- Direct Update and Checkpointing Logic ---
410
+ response_text_key = f"response_text_{slot_idx}"
411
+ response_audio_key = f"response_audio_path_{slot_idx}"
412
+
413
+ if 0 <= row_idx < len(updated_data):
414
+ updated_data[row_idx][response_text_key] = result["response_text"]
415
+ updated_data[row_idx][response_audio_key] = result["saved_audio_path"]
416
+ if result["saved_audio_path"] is None or "[ERROR" in result["response_text"]: # Check for error marker
417
+ tasks_failed += 1
418
+ else:
419
+ print(f"Warning: Invalid row index {row_idx} encountered during result merge. Skipping update.")
420
+ tasks_failed += 1 # Count as failed if index is bad
421
+
422
+ tasks_completed += 1
423
+ completed_since_last_save += 1 # Increment checkpoint counter
424
+
425
+ # Check if it's time to save a checkpoint
426
+ if completed_since_last_save >= CHECKPOINT_INTERVAL:
427
+ save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
428
+ completed_since_last_save = 0 # Reset counter
429
+
430
+ except Exception as exc: # Catch exceptions raised *by* the future/worker if not handled inside
431
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): GPT-4o Task generated an unhandled exception: {exc}")
432
+ print(traceback.format_exc())
433
+ # Attempt to record error in the main data structure
434
+ response_text_key = f"response_text_{slot_idx}"
435
+ response_audio_key = f"response_audio_path_{slot_idx}"
436
+ if 0 <= row_idx < len(updated_data):
437
+ updated_data[row_idx][response_text_key] = f"[ERROR: Worker Crash - {exc}]"
438
+ updated_data[row_idx][response_audio_key] = None
439
+ else:
440
+ print(f"Warning: Invalid row index {row_idx} encountered during exception handling merge.")
441
+
442
+ tasks_failed += 1
443
+ tasks_completed += 1 # Count as completed (though failed)
444
+ completed_since_last_save += 1 # Also increment for checkpointing
445
+
446
+ # Check if it's time to save a checkpoint even after an error
447
+ if completed_since_last_save >= CHECKPOINT_INTERVAL:
448
+ save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
449
+ completed_since_last_save = 0 # Reset counter
450
+
451
+ end_total_time = time.time()
452
+ print("\n--- GPT-4o Processing Complete ---")
453
+ print(f"Total GPT-4o tasks processed: {tasks_completed} (Succeeded: {tasks_completed - tasks_failed}, Failed: {tasks_failed})")
454
+ print(f"Total GPT-4o processing time: {(end_total_time - start_total_time)/60:.2f} minutes")
455
+
456
+
457
+ # --- Final Save ---
458
+ # Save one last time to ensure any remaining processed items (< CHECKPOINT_INTERVAL) are saved
459
+ print("\nPerforming final save...")
460
+ save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
461
+
462
+ print("\nScript finished.")
463
+
464
+ # --- (Removed the old merging and saving logic as it's now handled by save_checkpoint) ---
r1-a/response_generation/gpt5o_retry.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import uuid
5
+ import time
6
+ import re
7
+ import random
8
+ import concurrent.futures
9
+ from tqdm import tqdm
10
+ import threading
11
+ import traceback # For detailed error logging
12
+
13
+ import requests # Use requests library for HTTP calls
14
+ import numpy as np # Import numpy for potential fallback serialization
15
+ from datasets import load_from_disk, Dataset
16
+ from dotenv import load_dotenv
17
+
18
+ # --- Configuration ---
19
+ load_dotenv()
20
+
21
+ # --- !!! KEY CONFIGURATION FOR RETRY SCRIPT !!! ---
22
+
23
+ # 1. Identify the model you are retrying
24
+ TARGET_MODEL_NAME = "gpt4o" # Or "qwen_omni" if retrying Qwen
25
+
26
+ # 2. Set the INPUT/OUTPUT dataset directory to the PREVIOUS script's OUTPUT directory
27
+ # This is where the partially processed data (with errors) resides.
28
+ # The script will LOAD from here and SAVE back to here.
29
+ DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_tasks_with_gpt4o" # Adjust if needed
30
+
31
+ # 3. Set the audio output directory (can be the same as before)
32
+ OUTPUT_AUDIO_ROOT_DIR = "/root/autodl-tmp/audio-r1/r1-a/generated_audio/gpt4o_2" # Adjust if needed
33
+
34
+ # 4. API Configuration (Specific to the model being retried)
35
+ API_MODEL_NAME = "gpt-4o-audio-preview" # Actual model name for the API call
36
+ API_ENDPOINT = "https://api2.aigcbest.top/v1/chat/completions"
37
+ try:
38
+ API_TOKEN = "sk-D6jMssP7AZw3ZU6LEZaljdNMO1zif6wzef6XVh4kOgZAhQzI" # Use the correct key
39
+ if not API_TOKEN:
40
+ raise ValueError("API_TOKEN environment variable not set.")
41
+ print(f"{TARGET_MODEL_NAME} API Key loaded.")
42
+ except Exception as e:
43
+ print(f"FATAL: Error getting API Key: {e}")
44
+ exit(1)
45
+
46
+ # 5. Output Audio Configuration (Specific to the model being retried)
47
+ OUTPUT_AUDIO_FORMAT = "wav"
48
+ AVAILABLE_VOICES = ['alloy', 'ash', 'ballad', 'coral', 'echo', 'sage', 'shimmer', 'verse'] # GPT-4o voices
49
+
50
+ # 6. API Call Settings
51
+ API_TIMEOUT = 120
52
+ API_RETRY_DELAY = 5
53
+ API_MAX_RETRIES = 3
54
+ MAX_WORKERS = 8
55
+
56
+ # 7. Checkpoint Saving Configuration
57
+ CHECKPOINT_INTERVAL = 50 # Save every 500 *retried* tasks completed
58
+
59
+ # --- Error Markers to Look For ---
60
+ # These prefixes indicate a failed task that needs retrying
61
+ ERROR_MARKERS = ("[API ERROR", "[ERROR")
62
+
63
+ # --- Helper Functions (encode_audio_base64, parse_ultra_history - unchanged) ---
64
+
65
+ def encode_audio_base64(audio_path):
66
+ if not audio_path or not os.path.exists(audio_path):
67
+ print(f"Warning: Input audio file not found or path is empty: {audio_path}")
68
+ return None
69
+ try:
70
+ with open(audio_path, "rb") as audio_file:
71
+ return base64.b64encode(audio_file.read()).decode("utf-8")
72
+ except Exception as e:
73
+ print(f"Error encoding audio file {audio_path}: {e}")
74
+ return None
75
+
76
+ def parse_ultra_history(history_str):
77
+ messages = []
78
+ pattern = re.compile(r"\[(USER|ASSISTANT)\]\s*([\s\S]*?)(?=\s*\[(?:USER|ASSISTANT)\]|$)")
79
+ matches = pattern.findall(history_str)
80
+ if not matches:
81
+ return []
82
+ for role_tag, content in matches:
83
+ role = role_tag.lower()
84
+ cleaned_content = content.strip()
85
+ if cleaned_content:
86
+ messages.append({"role": role, "content": cleaned_content})
87
+ return messages
88
+
89
+ # --- API Call Worker Function (Use the correct one for the target model - GPT-4o version shown) ---
90
+ # --- (This function call_gpt4o_api_worker is copied directly from the previous script) ---
91
+ def call_gpt4o_api_worker(task_info):
92
+ """
93
+ Worker function to call the custom GPT-4o API for a single task.
94
+ (Identical to the function in the previous script)
95
+ """
96
+ row_idx = task_info["row_idx"]
97
+ slot_idx = task_info["slot_idx"]
98
+ history_messages = task_info["history_messages"]
99
+ prompt_text = task_info["prompt_text"]
100
+ question_text = task_info["question_text"]
101
+ question_audio_path = task_info["question_audio_path"]
102
+ output_audio_filepath = task_info["output_audio_filepath"]
103
+
104
+ retries = 0
105
+ headers = {
106
+ 'Accept': 'application/json',
107
+ 'Authorization': f'Bearer {API_TOKEN}', # Use the single loaded token
108
+ 'Content-Type': 'application/json'
109
+ }
110
+ selected_voice = random.choice(AVAILABLE_VOICES)
111
+
112
+ while retries < API_MAX_RETRIES:
113
+ try:
114
+ # 1. Prepare Input Audio
115
+ base64_audio_data = encode_audio_base64(question_audio_path)
116
+ if not base64_audio_data:
117
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Skipping GPT4o API call - missing input audio.")
118
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: Missing input audio]", "saved_audio_path": None}
119
+
120
+ input_audio_format = os.path.splitext(question_audio_path)[1].lstrip('.') or 'wav'
121
+
122
+ # 2. Construct User Message Content
123
+ combined_text = f"{prompt_text}"
124
+ user_content_list = [
125
+ {"type": "text", "text": combined_text},
126
+ {"type": "input_audio", "input_audio": {"data": base64_audio_data, "format": input_audio_format}}
127
+ ]
128
+ messages = history_messages + [{"role": "user", "content": user_content_list}]
129
+
130
+ # 4. Construct Payload
131
+ payload = {
132
+ "model": API_MODEL_NAME,
133
+ "modalities": ["text", "audio"],
134
+ "audio": {"voice": selected_voice, "format": OUTPUT_AUDIO_FORMAT},
135
+ "messages": messages
136
+ }
137
+
138
+ # 5. Make API Call
139
+ response = requests.post(
140
+ API_ENDPOINT,
141
+ headers=headers,
142
+ json=payload,
143
+ timeout=API_TIMEOUT
144
+ )
145
+
146
+ # 6. Process Response
147
+ if response.status_code == 200:
148
+ try:
149
+ response_data = response.json()
150
+ choices = response_data.get('choices')
151
+ if not choices or not isinstance(choices, list) or len(choices) == 0:
152
+ raise ValueError("Invalid or empty 'choices' field in response.")
153
+ message_content = choices[0].get('message', {})
154
+ if not message_content:
155
+ raise ValueError("Missing 'message' field in the first choice.")
156
+ audio_info = message_content.get('audio', {})
157
+ if not isinstance(audio_info, dict): audio_info = {}
158
+
159
+ audio_base64_string = audio_info.get('data', '')
160
+ collected_text = audio_info.get('transcript', '').strip()
161
+ if not collected_text:
162
+ text_content_list = message_content.get('content', [])
163
+ if isinstance(text_content_list, list):
164
+ for item in text_content_list:
165
+ if isinstance(item, dict) and item.get("type") == "text":
166
+ collected_text = item.get("text", "").strip()
167
+ break
168
+ elif isinstance(message_content.get('content'), str):
169
+ collected_text = message_content['content'].strip()
170
+
171
+ if not collected_text: print(f"Warning (Row {row_idx}, Slot {slot_idx}): No text content found after checking multiple fields.")
172
+ if not audio_base64_string: print(f"Warning (Row {row_idx}, Slot {slot_idx}): No audio data found.")
173
+
174
+ saved_audio_path = None
175
+ if audio_base64_string:
176
+ try:
177
+ wav_bytes = base64.b64decode(audio_base64_string)
178
+ if len(wav_bytes) == 0:
179
+ print(f"Warning (Row {row_idx}, Slot {slot_idx}): Decoded audio bytes are empty.")
180
+ else:
181
+ os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True)
182
+ with open(output_audio_filepath, "wb") as f:
183
+ f.write(wav_bytes)
184
+ saved_audio_path = output_audio_filepath
185
+ except base64.binascii.Error as b64_err:
186
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Decoding base64 audio data failed: {b64_err}")
187
+ except Exception as e:
188
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Saving audio file failed: {e}")
189
+
190
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text, "saved_audio_path": saved_audio_path}
191
+
192
+ except (json.JSONDecodeError, IndexError, KeyError, TypeError, ValueError) as e:
193
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Parsing successful API response failed: {type(e).__name__} - {e}")
194
+ print(f" Response Text (start): {response.text[:500]}...")
195
+ retries += 1
196
+ print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
197
+ time.sleep(API_RETRY_DELAY)
198
+ continue
199
+ except Exception as e:
200
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Unexpected error processing response: {e}")
201
+ print(traceback.format_exc())
202
+ retries += 1
203
+ print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
204
+ time.sleep(API_RETRY_DELAY)
205
+ continue
206
+
207
+ else: # Handle non-200 status codes
208
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): API returned status {response.status_code}. Response: {response.text[:500]}...")
209
+ retries += 1
210
+ if retries < API_MAX_RETRIES:
211
+ print(f" Retrying task ({retries}/{API_MAX_RETRIES})...")
212
+ time.sleep(API_RETRY_DELAY)
213
+ continue
214
+ else:
215
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after status {response.status_code}.")
216
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": f"[API ERROR: Status {response.status_code}]", "saved_audio_path": None}
217
+
218
+ except requests.exceptions.Timeout:
219
+ retries += 1
220
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): API Call Attempt {retries}/{API_MAX_RETRIES} timed out after {API_TIMEOUT}s.")
221
+ if retries < API_MAX_RETRIES:
222
+ time.sleep(API_RETRY_DELAY)
223
+ continue
224
+ else:
225
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after timeout.")
226
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Timeout]", "saved_audio_path": None}
227
+ except requests.exceptions.RequestException as e:
228
+ retries += 1
229
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Network/Request Error Attempt {retries}/{API_MAX_RETRIES}: {e}")
230
+ if retries < API_MAX_RETRIES:
231
+ time.sleep(API_RETRY_DELAY)
232
+ continue
233
+ else:
234
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after network error.")
235
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Network Error]", "saved_audio_path": None}
236
+ except Exception as e:
237
+ retries += 1
238
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Unexpected Error in Worker Loop Attempt {retries}/{API_MAX_RETRIES}: {type(e).__name__} - {e}")
239
+ print(traceback.format_exc())
240
+ if retries < API_MAX_RETRIES:
241
+ time.sleep(API_RETRY_DELAY)
242
+ continue
243
+ else:
244
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Max retries reached after unexpected error.")
245
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Unexpected Worker Error]", "saved_audio_path": None}
246
+
247
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Task failed after {API_MAX_RETRIES} attempts (Worker Loop Exited).")
248
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[API ERROR: Max retries reached]", "saved_audio_path": None}
249
+
250
+
251
+ # --- Checkpoint Saving Function (Unchanged) ---
252
+ def save_checkpoint(data_to_save, output_dir, dataset_features):
253
+ """Saves the current state of the data to disk."""
254
+ if not data_to_save:
255
+ print("Checkpoint: No data available to save.")
256
+ return
257
+ os.makedirs(output_dir, exist_ok=True)
258
+ print(f"\nCheckpoint: Saving {len(data_to_save)} rows to {output_dir}...")
259
+ try:
260
+ checkpoint_dataset = Dataset.from_list(data_to_save, features=dataset_features)
261
+ checkpoint_dataset.save_to_disk(output_dir)
262
+ print(f"Checkpoint: Saved successfully to {output_dir}")
263
+ except Exception as ckpt_save_e:
264
+ print(f"Error saving checkpoint dataset using datasets lib: {ckpt_save_e}")
265
+ output_jsonl_path = os.path.join(output_dir, "checkpoint_data.jsonl")
266
+ print(f"Attempting to save checkpoint as JSON lines to {output_jsonl_path}...")
267
+ try:
268
+ with open(output_jsonl_path, 'w', encoding='utf-8') as f:
269
+ for item in data_to_save:
270
+ serializable_item = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in item.items()}
271
+ f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n')
272
+ print(f"Checkpoint: Fallback save successful to {output_jsonl_path}")
273
+ except Exception as json_save_e:
274
+ print(f"Error saving checkpoint as JSON lines: {json_save_e}")
275
+
276
+ # --- Main Processing Logic (Retry Focus) ---
277
+
278
+ print(f"--- Starting Retry Script for {TARGET_MODEL_NAME} ---")
279
+ print(f"Loading dataset to retry from: {DATASET_DIR}")
280
+
281
+ try:
282
+ # Attempt to load the dataset from the specified directory
283
+ if not os.path.exists(DATASET_DIR) or \
284
+ not (os.path.exists(os.path.join(DATASET_DIR, "dataset_info.json")) or \
285
+ os.path.exists(os.path.join(DATASET_DIR, "state.json"))):
286
+ print(f"FATAL: Dataset directory not found or invalid: {DATASET_DIR}")
287
+ print("Please ensure this path points to the OUTPUT directory of the previous script run.")
288
+ exit(1)
289
+
290
+ dataset = load_from_disk(DATASET_DIR)
291
+ original_features = dataset.features # Store features for saving
292
+ print(f"Dataset loaded successfully with {len(dataset)} rows.")
293
+
294
+ except Exception as e:
295
+ print(f"FATAL: Error loading dataset from {DATASET_DIR}: {e}")
296
+ print(traceback.format_exc())
297
+ exit(1)
298
+
299
+ # Ensure audio output directory exists
300
+ os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True)
301
+
302
+ # --- Pre-calculation Step for Retrying Failed Tasks ---
303
+ print(f"Scanning dataset for failed {TARGET_MODEL_NAME} tasks to retry...")
304
+ tasks_to_process = []
305
+ # Use a list of dictionaries, which is mutable and easier for direct updates
306
+ updated_data = list(dataset) # Load data into memory for modification
307
+
308
+ for idx, row in enumerate(tqdm(updated_data, desc=f"Scanning for failed {TARGET_MODEL_NAME} tasks")):
309
+ for i in range(1, 4): # Check slots 1, 2, 3
310
+ model_key = f"model_{i}"
311
+ response_text_key = f"response_text_{i}"
312
+ prompt_text_key = f"prompt_text_{i}"
313
+ response_audio_key = f"response_audio_path_{i}"
314
+
315
+ model_assigned = row.get(model_key)
316
+ response_text_value = row.get(response_text_key)
317
+
318
+ # --- Core Retry Logic ---
319
+ # Check if the model assigned matches the one we are retrying
320
+ if model_assigned == TARGET_MODEL_NAME:
321
+ # Check if the response text indicates an error
322
+ is_error = False
323
+ if isinstance(response_text_value, str):
324
+ cleaned_text = response_text_value.strip()
325
+ if cleaned_text.startswith(ERROR_MARKERS): # Check if it starts with any error prefix
326
+ is_error = True
327
+ # Optional: You might also want to retry if text is None or empty,
328
+ # but the primary goal is retrying explicit errors.
329
+ # elif response_text_value is None or response_text_value == "":
330
+ # is_error = True # Uncomment if needed
331
+
332
+ if is_error:
333
+ print(f"\nInfo (Row {idx}, Slot {i}): Found failed task to retry. Current text: '{str(response_text_value)[:100]}...'") # Log finding
334
+
335
+ # --- Gather info needed for the task (same as original script) ---
336
+ question_audio_path = row.get('question_audio')
337
+ if not question_audio_path or not os.path.exists(question_audio_path):
338
+ print(f"Warning (Row {idx}, Slot {i}): Skipping retry - Missing or invalid 'question_audio' path: {question_audio_path}")
339
+ # Keep the old error message in updated_data for this case
340
+ continue # Skip this specific task retry
341
+
342
+ metadata_str = row.get('metadata', "{}")
343
+ source_dataset = row.get('source_dataset')
344
+ metadata = {}
345
+ try:
346
+ if metadata_str and isinstance(metadata_str, str): metadata = json.loads(metadata_str)
347
+ elif isinstance(metadata_str, dict): metadata = metadata_str
348
+ except json.JSONDecodeError: pass
349
+
350
+ history_messages = []
351
+ if source_dataset == 'ultra':
352
+ history_str = metadata.get('history', '')
353
+ if history_str: history_messages = parse_ultra_history(history_str)
354
+
355
+ unique_id = str(uuid.uuid4()).replace("-", "")
356
+ # Generate a *new* filename for the potential audio output
357
+ output_audio_filename = f"{TARGET_MODEL_NAME}_retry_r{idx}_s{i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}"
358
+ output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename)
359
+
360
+ task_info = {
361
+ "row_idx": idx,
362
+ "slot_idx": i,
363
+ "history_messages": history_messages,
364
+ "prompt_text": row.get(prompt_text_key, ""),
365
+ "question_text": row.get('question_text', ""),
366
+ "question_audio_path": question_audio_path,
367
+ "output_audio_filepath": output_audio_filepath,
368
+ }
369
+ tasks_to_process.append(task_info)
370
+ # Decide if you want to retry all failed slots in a row or just the first one found
371
+ # break # Uncomment if you only want to retry the FIRST failed slot per row
372
+
373
+ total_tasks = len(tasks_to_process)
374
+ if total_tasks == 0:
375
+ print(f"No failed {TARGET_MODEL_NAME} tasks found needing reprocessing in {DATASET_DIR}.")
376
+ exit(0)
377
+
378
+ print(f"Found {total_tasks} failed {TARGET_MODEL_NAME} tasks to retry.")
379
+
380
+ # --- Threaded Execution with Checkpointing (Identical structure to previous script) ---
381
+ print(f"Starting reprocessing with up to {MAX_WORKERS} worker threads...")
382
+ start_total_time = time.time()
383
+ tasks_completed = 0
384
+ tasks_failed_retries = 0 # Count failures during the *retry* attempt
385
+ completed_since_last_save = 0
386
+
387
+ with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
388
+ # Ensure the correct worker function is called based on TARGET_MODEL_NAME
389
+ api_worker_function = call_gpt4o_api_worker # Default to GPT-4o
390
+ # Add logic here if TARGET_MODEL_NAME could be Qwen
391
+ # if TARGET_MODEL_NAME == "qwen_omni":
392
+ # api_worker_function = call_qwen_omni_api_worker # Assuming you have this function defined/imported
393
+
394
+ future_to_task = {executor.submit(api_worker_function, task): task for task in tasks_to_process}
395
+
396
+ for future in tqdm(concurrent.futures.as_completed(future_to_task), total=total_tasks, desc="Reprocessing tasks"):
397
+ task_info = future_to_task[future]
398
+ row_idx = task_info["row_idx"]
399
+ slot_idx = task_info["slot_idx"]
400
+ result = None
401
+
402
+ try:
403
+ result = future.result()
404
+ # --- Direct Update and Checkpointing Logic ---
405
+ response_text_key = f"response_text_{slot_idx}"
406
+ response_audio_key = f"response_audio_path_{slot_idx}"
407
+
408
+ if 0 <= row_idx < len(updated_data):
409
+ # Update the data in memory
410
+ updated_data[row_idx][response_text_key] = result["response_text"]
411
+ updated_data[row_idx][response_audio_key] = result["saved_audio_path"]
412
+ # Check if the *retry* attempt failed
413
+ if result["saved_audio_path"] is None or str(result["response_text"]).strip().startswith(ERROR_MARKERS):
414
+ tasks_failed_retries += 1
415
+ print(f"Warning (Row {row_idx}, Slot {i}): Retry attempt failed. Result: {str(result['response_text'])[:100]}...")
416
+ else:
417
+ print(f"Warning: Invalid row index {row_idx} encountered during result merge. Skipping update.")
418
+ tasks_failed_retries += 1
419
+
420
+ tasks_completed += 1
421
+ completed_since_last_save += 1
422
+
423
+ # Checkpoint saving
424
+ if completed_since_last_save >= CHECKPOINT_INTERVAL:
425
+ # Save the updated data back to the SAME directory
426
+ save_checkpoint(updated_data, DATASET_DIR, original_features)
427
+ completed_since_last_save = 0
428
+
429
+ except Exception as exc:
430
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Retry Task generated an unhandled exception: {exc}")
431
+ print(traceback.format_exc())
432
+ response_text_key = f"response_text_{slot_idx}"
433
+ response_audio_key = f"response_audio_path_{slot_idx}"
434
+ if 0 <= row_idx < len(updated_data):
435
+ updated_data[row_idx][response_text_key] = f"[ERROR: Retry Worker Crash - {exc}]" # Mark as worker crash during retry
436
+ updated_data[row_idx][response_audio_key] = None
437
+ else:
438
+ print(f"Warning: Invalid row index {row_idx} encountered during exception handling merge.")
439
+
440
+ tasks_failed_retries += 1
441
+ tasks_completed += 1
442
+ completed_since_last_save += 1
443
+
444
+ # Checkpoint saving after error
445
+ if completed_since_last_save >= CHECKPOINT_INTERVAL:
446
+ save_checkpoint(updated_data, DATASET_DIR, original_features)
447
+ completed_since_last_save = 0
448
+
449
+ end_total_time = time.time()
450
+ print("\n--- Reprocessing Complete ---")
451
+ print(f"Total tasks retried: {tasks_completed}")
452
+ print(f" Succeeded on retry: {tasks_completed - tasks_failed_retries}")
453
+ print(f" Failed on retry: {tasks_failed_retries}")
454
+ print(f"Total reprocessing time: {(end_total_time - start_total_time)/60:.2f} minutes")
455
+
456
+ # --- Final Save ---
457
+ # Save the final state of the updated data back to the original location
458
+ print("\nPerforming final save of the reprocessed dataset...")
459
+ save_checkpoint(updated_data, DATASET_DIR, original_features)
460
+
461
+ print(f"\nRetry script finished. Updated dataset saved in: {DATASET_DIR}")
r1-a/response_generation/kimi.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import re # For parsing history
5
+ import uuid # For generating unique filenames
6
+ import torch # Kimi might return tensors
7
+ import soundfile as sf # For saving Kimi audio output
8
+ import sys
9
+ from datasets import load_from_disk, Dataset, Features, Audio, Value
10
+ from dotenv import load_dotenv
11
+ import datetime # For ETA formatting
12
+ from tqdm import tqdm # Import tqdm
13
+ import traceback # For detailed error printing
14
+
15
+ # --- Kimi-Audio Project Path Setup ---
16
+ # <--- *** IMPORTANT: Update this path to the PARENT directory containing the 'kimia_infer' folder *** --->
17
+ kimia_project_parent_dir = "/home/chenyifu/audio-r1/r1-a/response_generation/Kimi-Audio"
18
+
19
+ # Check if the path exists and add it to sys.path
20
+ if os.path.isdir(kimia_project_parent_dir):
21
+ if kimia_project_parent_dir not in sys.path:
22
+ sys.path.insert(0, kimia_project_parent_dir)
23
+ print(f"Added '{kimia_project_parent_dir}' to Python path.")
24
+ # Try importing KimiAudio only after potentially adding the path
25
+ try:
26
+ from kimia_infer.api.kimia import KimiAudio # Kimi model class
27
+ except ImportError as import_err:
28
+ print(f"Error: Could not import KimiAudio from '{kimia_project_parent_dir}'.")
29
+ print(f"ImportError: {import_err}")
30
+ print("Please ensure the 'kimia_infer' directory exists within the specified path and check dependencies.")
31
+ exit(1)
32
+ else:
33
+ print(f"Error: Kimi project parent directory not found: '{kimia_project_parent_dir}'")
34
+ print("Please update the 'kimia_project_parent_dir' variable in the script.")
35
+ exit(1)
36
+
37
+ # --- Configuration ---
38
+ load_dotenv() # Load environment variables if needed (e.g., API keys, though not typical for local Kimi)
39
+
40
+ # 1. Model & Tokenizer Setup (Kimi Specific)
41
+ KIMI_MODEL_NAME = "kimi_audio" # Identifier used in your dataset's model_N columns
42
+ KIMI_MODEL_PATH = "/home/chenyifu/audio-r1/r1-a/response_generation/Kimi-Audio/checkpoint/Kimi-Audio-7B-Instruct" # Path to your Kimi model checkpoint directory
43
+ # KIMI_DEVICE = 'cuda' # KimiAudio class likely handles device selection based on availability. Verify its internal logic if issues arise.
44
+ # KIMI_DTYPE = torch.bfloat16 # KimiAudio likely handles dtype internally.
45
+
46
+ # 2. Dataset Paths
47
+ INPUT_DATASET_DIR = "/home/chenyifu/audio-r1/r1-a/dataset/preference_sampling_tasks" # Original source
48
+ OUTPUT_DATASET_DIR = "/home/chenyifu/audio-r1/r1-a/dataset/preference_tasks_with_kimi" # Where Kimi processed data is saved/resumed from
49
+
50
+ # 3. Output Audio Configuration (Kimi Specific)
51
+ OUTPUT_AUDIO_ROOT_DIR = "/home/chenyifu/audio-r1/r1-a/generated_audio/kimi" # Where Kimi generated audio files are saved
52
+ OUTPUT_AUDIO_FORMAT = "wav"
53
+ OUTPUT_AUDIO_SAMPLERATE = 24000 # Kimi example uses 24kHz output. Confirm this matches your model's expected/native output SR.
54
+
55
+ # 4. Kimi Call Settings (Based on example, adjust as needed)
56
+ KIMI_SAMPLING_PARAMS = {
57
+ "audio_temperature": 0.8,
58
+ "audio_top_k": 10,
59
+ "text_temperature": 0.0, # 0.0 for deterministic text, increase for more variety
60
+ "text_top_k": 5, # Relevant if text_temperature > 0
61
+ "audio_repetition_penalty": 1.0,
62
+ "audio_repetition_window_size": 64,
63
+ "text_repetition_penalty": 1.0,
64
+ "text_repetition_window_size": 16,
65
+ # "max_new_tokens": 128 # Add if needed and supported by KimiAudio.generate
66
+ }
67
+ KIMI_OUTPUT_TYPE = "both" # Generate both audio and text
68
+
69
+ # 5. Periodic Save Settings
70
+ SAVE_EVERY_N_SAMPLES = 50 # Save after processing this many samples
71
+
72
+ # --- Helper Functions ---
73
+
74
+ def format_time(seconds):
75
+ """Formats seconds into a human-readable string H:MM:SS"""
76
+ if seconds < 0:
77
+ return "N/A"
78
+ return str(datetime.timedelta(seconds=int(seconds)))
79
+
80
+ # REMOVED load_audio_minicpm - Kimi takes the path directly
81
+
82
+ def parse_ultra_history(history_str):
83
+ """Parses the specific history string format from ultra metadata for Kimi."""
84
+ messages = []
85
+ # Relaxed pattern to capture content even if tags are slightly off or whitespace varies
86
+ pattern = re.compile(r"\[\s*(USER|ASSISTANT)\s*\]\s*([\s\S]*?)(?=\s*\[\s*(?:USER|ASSISTANT)\s*\]|$)")
87
+ matches = pattern.findall(history_str)
88
+ if not matches and history_str and history_str.strip():
89
+ # Simple fallback if standard pattern fails but there's content
90
+ if history_str.lower().startswith("user:") or history_str.lower().startswith("[user]"):
91
+ role = "user"
92
+ content = re.sub(r"^(user:|\[user\])\s*", "", history_str, flags=re.IGNORECASE).strip()
93
+ if content: messages.append({"role": role, "message_type": "text", "content": content}) # Add Kimi message_type
94
+ elif history_str.lower().startswith("assistant:") or history_str.lower().startswith("[assistant]"):
95
+ role = "assistant"
96
+ content = re.sub(r"^(assistant:|\[assistant\])\s*", "", history_str, flags=re.IGNORECASE).strip()
97
+ if content: messages.append({"role": role, "message_type": "text", "content": content}) # Add Kimi message_type
98
+ else:
99
+ print(f"Warning: Could not parse history string format: {history_str[:100]}...")
100
+ return messages # Return whatever was parsed, even if empty
101
+
102
+ for role_tag, content in matches:
103
+ role = role_tag.strip().lower()
104
+ cleaned_content = content.strip()
105
+ if cleaned_content:
106
+ # IMPORTANT: Add message_type='text' for Kimi history
107
+ messages.append({"role": role, "message_type": "text", "content": cleaned_content})
108
+ return messages
109
+
110
+
111
+ # --- Kimi Model Interaction Function ---
112
+ def call_kimi_model(model, messages_input, sampling_params, output_audio_filepath, output_sample_rate):
113
+ """Calls the Kimi-Audio model, saves audio, returns text and audio path."""
114
+ try:
115
+ # 1. Ensure Output Directory Exists
116
+ os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True)
117
+
118
+ # 2. Call Kimi's Generate Function
119
+ wav_output, text_output = model.generate(
120
+ messages_input,
121
+ **sampling_params,
122
+ output_type=KIMI_OUTPUT_TYPE # Use 'both'
123
+ )
124
+
125
+ # 3. Process and Save Audio Output
126
+ saved_audio_path = None
127
+ if wav_output is not None and isinstance(wav_output, torch.Tensor) and wav_output.numel() > 0: # Check if tensor is not empty
128
+ try:
129
+ # Ensure tensor is on CPU, reshape (if needed, often view(-1)), convert to numpy
130
+ # Check KimiAudio output format - might already be 1D or need specific shape
131
+ audio_data = wav_output.detach().cpu().view(-1).numpy()
132
+
133
+ # Ensure data is float32 or int16 as supported by soundfile/WAV
134
+ if audio_data.dtype != 'float32':
135
+ # Attempt conversion, potentially scale if it's int
136
+ # print(f" Info: Converting Kimi audio output from {audio_data.dtype} to float32 for saving.")
137
+ if np.issubdtype(audio_data.dtype, np.integer):
138
+ # Scale integer types to [-1, 1] float range if necessary
139
+ # Example: if int16 -> audio_data = audio_data.astype(np.float32) / 32768.0
140
+ # Adjust scaling based on the actual integer range if known
141
+ audio_data = audio_data.astype(np.float32) # Simplest conversion, might need scaling
142
+ else:
143
+ audio_data = audio_data.astype(np.float32)
144
+
145
+
146
+ sf.write(output_audio_filepath, audio_data, output_sample_rate)
147
+
148
+ # Check if file was actually created and has size
149
+ if os.path.exists(output_audio_filepath) and os.path.getsize(output_audio_filepath) > 100: # Check for a reasonable size threshold
150
+ saved_audio_path = output_audio_filepath
151
+ else:
152
+ print(f" Error: Kimi generate finished but output audio file seems empty or too small at {output_audio_filepath}")
153
+ if os.path.exists(output_audio_filepath):
154
+ try: os.remove(output_audio_filepath)
155
+ except OSError as rm_err: print(f" Warning: Could not remove empty/small file {output_audio_filepath}: {rm_err}")
156
+ except ImportError:
157
+ print("Error: NumPy library not found. Please install it (`pip install numpy`)")
158
+ return "[ERROR: NumPy Missing]", None # Indicate failure clearly
159
+ except Exception as sf_err:
160
+ print(f" Error saving Kimi audio output to {output_audio_filepath}: {sf_err}")
161
+ traceback.print_exc()
162
+ if os.path.exists(output_audio_filepath):
163
+ try: os.remove(output_audio_filepath)
164
+ except OSError as rm_err: print(f" Warning: Could not remove potentially corrupt file {output_audio_filepath}: {rm_err}")
165
+ elif wav_output is None:
166
+ print(" Warning: Kimi model did not return an audio tensor (wav_output is None).")
167
+ elif isinstance(wav_output, torch.Tensor) and wav_output.numel() == 0:
168
+ print(" Warning: Kimi model returned an empty audio tensor.")
169
+ else:
170
+ print(f" Warning: Kimi model returned unexpected audio output type: {type(wav_output)}. Expected torch.Tensor.")
171
+
172
+
173
+ # 4. Process Text Output
174
+ response_text_cleaned = ""
175
+ if isinstance(text_output, str):
176
+ response_text_cleaned = text_output.strip()
177
+ elif text_output is not None:
178
+ response_text_cleaned = str(text_output).strip() # Convert just in case
179
+ else:
180
+ # If text is None but audio might exist, use a specific marker
181
+ if saved_audio_path:
182
+ response_text_cleaned = "[Audio Generated, No Text Output]"
183
+ else:
184
+ response_text_cleaned = "[ERROR: No Text Output]"
185
+
186
+
187
+ # Return text (even if audio failed) and the path (or None)
188
+ return response_text_cleaned, saved_audio_path
189
+
190
+ except Exception as e:
191
+ print(f"\n --- Error during Kimi model call ---")
192
+ # Avoid printing potentially huge message list directly
193
+ first_message = messages_input[0] if messages_input else "N/A"
194
+ last_message_content = messages_input[-1]['content'] if messages_input else "N/A"
195
+ if isinstance(last_message_content, str) and len(last_message_content) > 100 :
196
+ last_message_preview = last_message_content[:100] + "..."
197
+ else:
198
+ last_message_preview = last_message_content
199
+
200
+ print(f" Input Messages Info: Count={len(messages_input)}, First={first_message}, Last Content Preview='{last_message_preview}'")
201
+ print(f" Exception Type: {type(e).__name__}")
202
+ print(f" Error Details: {e}")
203
+ print(" Traceback:")
204
+ traceback.print_exc()
205
+ print(" --- End Error Details ---")
206
+
207
+ # Attempt cleanup of potentially incomplete output file
208
+ if 'output_audio_filepath' in locals() and os.path.exists(output_audio_filepath):
209
+ try:
210
+ os.remove(output_audio_filepath)
211
+ except OSError as rm_err:
212
+ print(f" Warning: Could not remove file {output_audio_filepath} after error: {rm_err}")
213
+ # Return clear error markers
214
+ return "[ERROR: Kimi Model Call Failed]", None
215
+
216
+ # --- Dataset Saving Function (Modified for Kimi context) ---
217
+ def save_checkpoint(data_list, features, output_dir, fallback_dir=None):
218
+ """Saves the current state of the data list as a Hugging Face Dataset."""
219
+ if not data_list:
220
+ print("\nSkipping checkpoint save: data list is empty.")
221
+ return
222
+
223
+ print(f"\nSaving checkpoint with {len(data_list)} rows to {output_dir}...")
224
+ try:
225
+ # Ensure the list contains dictionaries
226
+ data_to_save = [dict(item) for item in data_list]
227
+
228
+ # --- Feature Check/Adaptation (Optional but recommended) ---
229
+ # Sometimes saving fails if data types changed unexpectedly (e.g., None -> str)
230
+ # It's safer to create the Dataset *without* features first, then cast
231
+ temp_dataset = Dataset.from_list(data_to_save)
232
+ # Now cast to the original features, allowing potential None/type mismatches
233
+ # This might raise warnings but is often more robust than direct from_list with features
234
+ updated_dataset = temp_dataset.cast(features)
235
+ # --- End Feature Check ---
236
+
237
+ # Ensure output directory exists before saving
238
+ os.makedirs(output_dir, exist_ok=True)
239
+ updated_dataset.save_to_disk(output_dir)
240
+ print("Checkpoint saved successfully.")
241
+
242
+ except Exception as e:
243
+ print(f"Error saving checkpoint dataset using save_to_disk to {output_dir}: {e}")
244
+ traceback.print_exc()
245
+ if fallback_dir:
246
+ # Use Kimi-specific name in fallback path
247
+ fallback_path = os.path.join(fallback_dir, f"updated_{KIMI_MODEL_NAME}_data_checkpoint_{int(time.time())}.jsonl")
248
+ print(f"Attempting to save data as JSON Lines fallback to: {fallback_path}")
249
+ try:
250
+ os.makedirs(fallback_dir, exist_ok=True)
251
+ with open(fallback_path, 'w', encoding='utf-8') as f:
252
+ # Reuse data_to_save which is already list of dicts
253
+ for item in data_to_save:
254
+ # Ensure all values are serializable
255
+ serializable_item = {}
256
+ for k, v in item.items():
257
+ if isinstance(v, (datetime.datetime, datetime.date)):
258
+ serializable_item[k] = v.isoformat()
259
+ elif isinstance(v, bytes):
260
+ serializable_item[k] = v.decode('utf-8', errors='ignore')
261
+ elif isinstance(v, torch.Tensor): # Handle potential tensors if not caught earlier
262
+ print(f" Warning: Found unexpected Tensor for key '{k}' in fallback save. Converting to list.")
263
+ serializable_item[k] = v.tolist()
264
+ elif not isinstance(v, (str, int, float, bool, list, dict, type(None))):
265
+ print(f" Warning: Converting non-standard type {type(v)} for key '{k}' to string for JSON fallback.")
266
+ serializable_item[k] = str(v)
267
+ else:
268
+ serializable_item[k] = v
269
+ try:
270
+ f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n')
271
+ except TypeError as json_type_err:
272
+ print(f" Skipping row due to JSON serialization error: {json_type_err} in item part: {k}={v}")
273
+ print("Fallback JSON Lines checkpoint saved successfully.")
274
+ except Exception as json_e:
275
+ print(f"Error saving fallback JSON Lines checkpoint: {json_e}")
276
+
277
+
278
+ # =============================================
279
+ # --- Main Processing Logic ---
280
+ # =============================================
281
+
282
+ # --- STEP 1: Dataset Loading (Modified for Resumption) ---
283
+ print("="*30)
284
+ print("STEP 1: Loading Dataset")
285
+ print("="*30)
286
+ dataset = None
287
+ original_features = None # Initialize
288
+
289
+ # Check if the Kimi-specific output directory exists
290
+ if os.path.exists(OUTPUT_DATASET_DIR):
291
+ print(f"Found existing Kimi processed dataset directory at: {OUTPUT_DATASET_DIR}")
292
+ print("Attempting to load it to resume processing...")
293
+ try:
294
+ dataset = load_from_disk(OUTPUT_DATASET_DIR)
295
+ original_features = dataset.features # Get features from the loaded dataset
296
+ print(f"Resumed Kimi dataset loaded successfully with {len(dataset)} rows.")
297
+ print(f"Features from resumed dataset: {original_features}")
298
+ except Exception as e:
299
+ print(f"Warning: Error loading existing Kimi dataset from {OUTPUT_DATASET_DIR}: {e}")
300
+ traceback.print_exc()
301
+ print("Will attempt to load the original input dataset instead.")
302
+ dataset = None # Reset dataset variable
303
+ else:
304
+ print(f"No existing Kimi processed dataset found at {OUTPUT_DATASET_DIR}.")
305
+ print("Will attempt to load the original input dataset.")
306
+
307
+
308
+ # If dataset is still None, load from the original input directory
309
+ if dataset is None:
310
+ print(f"\nLoading original input dataset from: {INPUT_DATASET_DIR}")
311
+ if not os.path.exists(INPUT_DATASET_DIR):
312
+ print(f"FATAL: Original input dataset directory not found at {INPUT_DATASET_DIR}")
313
+ exit(1)
314
+ try:
315
+ dataset = load_from_disk(INPUT_DATASET_DIR)
316
+ original_features = dataset.features # Get features from the input dataset
317
+ print(f"Original input dataset loaded successfully with {len(dataset)} rows.")
318
+ print(f"Features from input dataset: {original_features}")
319
+ except Exception as e:
320
+ print(f"FATAL: Error loading original input dataset from {INPUT_DATASET_DIR}: {e}")
321
+ traceback.print_exc()
322
+ exit(1)
323
+
324
+ # --- Ensure dataset and features were loaded ---
325
+ if dataset is None or original_features is None:
326
+ print("FATAL: Failed to load any dataset. Exiting.")
327
+ exit(1)
328
+ # --- End Dataset Loading ---
329
+
330
+
331
+ # --- STEP 2: Pre-computation - Identify Kimi Tasks ---
332
+ print("\n" + "="*30)
333
+ print(f"STEP 2: Identifying '{KIMI_MODEL_NAME}' Tasks to Process")
334
+ print("="*30)
335
+ pkusafe_tasks_indices = []
336
+ other_tasks_indices = []
337
+
338
+ # Iterate through the loaded dataset structure
339
+ for idx, row in enumerate(dataset):
340
+ source_dataset = row.get('source_dataset')
341
+ processed_in_row = False # Flag to ensure we only pick one Kimi slot per row
342
+ for i in range(1, 4): # Check slots 1, 2, 3
343
+ model_key = f"model_{i}"
344
+ response_text_key = f"response_text_{i}"
345
+ # Check if the slot is assigned to Kimi and is NOT yet filled (text response missing)
346
+ is_target_model_task = row.get(model_key) == KIMI_MODEL_NAME
347
+ is_unfilled = not row.get(response_text_key) # True if None or empty string
348
+
349
+ if is_target_model_task and is_unfilled and not processed_in_row:
350
+ task_info = (idx, i) # Store tuple of (original_row_index, slot_index)
351
+ if source_dataset == 'pkusafe':
352
+ pkusafe_tasks_indices.append(task_info)
353
+ else:
354
+ other_tasks_indices.append(task_info)
355
+ processed_in_row = True # Mark row as having a task identified
356
+
357
+ # Combine lists, prioritizing pkusafe
358
+ tasks_to_process_indices = pkusafe_tasks_indices + other_tasks_indices
359
+ total_tasks_to_process = len(tasks_to_process_indices)
360
+
361
+ print(f"Found {len(pkusafe_tasks_indices)} 'pkusafe' tasks and {len(other_tasks_indices)} other tasks requiring '{KIMI_MODEL_NAME}' processing in the loaded dataset.")
362
+ print(f"Total tasks remaining to process: {total_tasks_to_process}")
363
+
364
+ if total_tasks_to_process == 0:
365
+ print(f"\nNo remaining tasks to process for {KIMI_MODEL_NAME} based on the loaded dataset.")
366
+ # Optionally, perform a final save for consistency
367
+ # print("Performing a final save to ensure consistency...")
368
+ # final_data_list = [dict(row) for row in dataset]
369
+ # fallback_save_dir_final = os.path.join(os.path.dirname(OUTPUT_DATASET_DIR), f"{KIMI_MODEL_NAME}_checkpoints_fallback")
370
+ # save_checkpoint(final_data_list, original_features, OUTPUT_DATASET_DIR, fallback_save_dir_final)
371
+ print("Exiting.")
372
+ exit(0)
373
+ # --- End Pre-computation Step ---
374
+
375
+
376
+ # --- STEP 3: Load Kimi Model ---
377
+ print("\n" + "="*30)
378
+ print(f"STEP 3: Loading {KIMI_MODEL_NAME} Model")
379
+ print("="*30)
380
+ try:
381
+ # Load Kimi model using the class imported earlier
382
+ model = KimiAudio(model_path=KIMI_MODEL_PATH, load_detokenizer=True) # Assuming detokenizer is needed based on example
383
+ print(f"{KIMI_MODEL_NAME} model loaded successfully from {KIMI_MODEL_PATH}.")
384
+ except NameError:
385
+ print("FATAL: KimiAudio class not defined. Import likely failed earlier.")
386
+ exit(1)
387
+ except Exception as e:
388
+ print(f"Error loading {KIMI_MODEL_NAME} model from {KIMI_MODEL_PATH}: {e}")
389
+ traceback.print_exc()
390
+ exit(1)
391
+
392
+
393
+ # --- STEP 4: Prepare for Processing ---
394
+ print("\n" + "="*30)
395
+ print(f"STEP 4: Preparing for {KIMI_MODEL_NAME} Processing")
396
+ print("="*30)
397
+ # Create output directories if they don't exist
398
+ os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True)
399
+ os.makedirs(OUTPUT_DATASET_DIR, exist_ok=True)
400
+ # Define and create fallback directory for Kimi
401
+ fallback_save_dir = os.path.join(os.path.dirname(OUTPUT_DATASET_DIR), f"{KIMI_MODEL_NAME}_checkpoints_fallback")
402
+ os.makedirs(fallback_save_dir, exist_ok=True)
403
+ print(f"Audio outputs will be saved in: {OUTPUT_AUDIO_ROOT_DIR}")
404
+ print(f"Dataset checkpoints will be saved in: {OUTPUT_DATASET_DIR}")
405
+ print(f"Fallback checkpoints (JSONL) in: {fallback_save_dir}")
406
+
407
+
408
+ # Create a mutable list of dictionaries from the loaded dataset for updates
409
+ updated_data = [dict(row) for row in dataset] # Convert each row to a dictionary
410
+
411
+ tasks_processed_count = 0 # Count successful completions for average time calculation
412
+ start_total_time = time.time()
413
+
414
+
415
+ # --- STEP 5: Start Processing Loop ---
416
+ print("\n" + "="*30)
417
+ print(f"STEP 5: Starting {KIMI_MODEL_NAME} Processing Loop ({total_tasks_to_process} Tasks)")
418
+ print("="*30)
419
+ # Use tqdm for the progress bar, iterating over the identified task indices
420
+ pbar = tqdm(enumerate(tasks_to_process_indices), total=total_tasks_to_process, desc=f"Processing {KIMI_MODEL_NAME} Tasks")
421
+ for loop_idx, (row_idx, slot_i) in pbar:
422
+ # Get the row data *from our mutable list* using the original index
423
+ row = updated_data[row_idx] # This is already a dictionary
424
+
425
+ # Set description in tqdm dynamically
426
+ pbar.set_description(f"Processing Row {row_idx}, Slot {slot_i}")
427
+
428
+ prompt_text_key = f"prompt_text_{slot_i}"
429
+ response_text_key = f"response_text_{slot_i}"
430
+ response_audio_key = f"response_audio_path_{slot_i}"
431
+ model_key = f"model_{slot_i}"
432
+
433
+ # --- Sanity Check ---
434
+ if row.get(model_key) != KIMI_MODEL_NAME:
435
+ tqdm.write(f" Skipping Row {row_idx}, Slot {slot_i}: Model is '{row.get(model_key)}', not '{KIMI_MODEL_NAME}'.")
436
+ continue
437
+ if row.get(response_text_key):
438
+ tqdm.write(f" Skipping Row {row_idx}, Slot {slot_i}: Already has response text '{str(row.get(response_text_key))[:50]}...'.")
439
+ continue
440
+
441
+ # --- Prepare Kimi Model Inputs ---
442
+ prompt_text = row.get(prompt_text_key, "")
443
+ question_audio_path = row.get('question_audio')
444
+ metadata_str = row.get('metadata', "{}")
445
+ source_dataset = row.get('source_dataset')
446
+
447
+ # Check for essential input audio path validity
448
+ if not question_audio_path or not os.path.exists(question_audio_path):
449
+ tqdm.write(f" Error: Input audio path missing or invalid for Row {row_idx}: '{question_audio_path}'. Skipping model call.")
450
+ updated_data[row_idx][response_text_key] = "[ERROR: Missing Input Audio]"
451
+ updated_data[row_idx][response_audio_key] = None
452
+ continue # Move to the next task in the loop
453
+
454
+ # --- Construct Kimi `messages` list ---
455
+ kimi_messages = []
456
+
457
+ # 1. Parse History (if any)
458
+ if source_dataset == 'ultra' and metadata_str:
459
+ try:
460
+ metadata = json.loads(metadata_str)
461
+ history_str = metadata.get('history', '')
462
+ if history_str:
463
+ # Ensure history messages have 'message_type': 'text'
464
+ history_messages_parsed = parse_ultra_history(history_str)
465
+ kimi_messages.extend(history_messages_parsed)
466
+ except json.JSONDecodeError:
467
+ tqdm.write(f" Warning: Could not parse metadata JSON for row {row_idx}")
468
+ except Exception as hist_e:
469
+ tqdm.write(f" Warning: Error processing history for row {row_idx}: {hist_e}")
470
+ # Add elif blocks here for history parsing from other datasets if needed
471
+
472
+ # 2. Add Current User Turn (Text Prompt + Audio Path)
473
+ # Add text prompt first, if it exists and is not empty
474
+ if prompt_text and prompt_text.strip():
475
+ kimi_messages.append({"role": "user", "message_type": "text", "content": prompt_text.strip()})
476
+ # Add the user audio query using its path
477
+ kimi_messages.append({"role": "user", "message_type": "audio", "content": question_audio_path})
478
+
479
+ # Generate unique output audio filename
480
+ unique_id = str(uuid.uuid4())
481
+ output_audio_filename = f"{KIMI_MODEL_NAME}_row{row_idx}_slot{slot_i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}"
482
+ output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename)
483
+
484
+ # --- Call Kimi Model ---
485
+ # tqdm.write(f" Calling {KIMI_MODEL_NAME} for Row {row_idx}, Slot {slot_i}...") # Less verbose log
486
+ call_start_time = time.time()
487
+ response_text, saved_audio_path = call_kimi_model(
488
+ model,
489
+ kimi_messages,
490
+ KIMI_SAMPLING_PARAMS,
491
+ output_audio_filepath,
492
+ OUTPUT_AUDIO_SAMPLERATE
493
+ )
494
+ call_end_time = time.time()
495
+ audio_basename = os.path.basename(str(saved_audio_path)) if saved_audio_path else "None"
496
+ tqdm.write(f" Row {row_idx}, Slot {slot_i}: Finished in {call_end_time - call_start_time:.2f}s. Text: '{str(response_text)[:50]}...', Audio: {audio_basename}")
497
+
498
+ # Store results back into the main data list (updated_data)
499
+ updated_data[row_idx][response_text_key] = response_text # Store text/error marker
500
+ updated_data[row_idx][response_audio_key] = saved_audio_path # Store path or None
501
+
502
+ # Increment success counter based on successful generation (e.g., text isn't an error marker)
503
+ # Consider if audio generation failure should also mark task as failed.
504
+ # Current logic counts success if text seems okay.
505
+ if response_text is not None and not response_text.startswith("[ERROR"):
506
+ tasks_processed_count += 1
507
+
508
+ # --- Periodic Saving ---
509
+ processed_count_in_loop = loop_idx + 1
510
+ if processed_count_in_loop % SAVE_EVERY_N_SAMPLES == 0 or processed_count_in_loop == total_tasks_to_process:
511
+ save_checkpoint(updated_data, original_features, OUTPUT_DATASET_DIR, fallback_save_dir)
512
+
513
+ # --- STEP 6: Final Summary and Save ---
514
+ end_total_time = time.time()
515
+ print("\n" + "="*30)
516
+ print(f"STEP 6: {KIMI_MODEL_NAME} Processing Complete - Summary")
517
+ print("="*30)
518
+ print(f"Total tasks identified for processing in this run: {total_tasks_to_process}")
519
+ print(f"Total tasks successfully processed (generated text): {tasks_processed_count}") # Update definition if needed
520
+ total_duration = end_total_time - start_total_time
521
+ print(f"Total processing time for this run: {format_time(total_duration)}")
522
+ if tasks_processed_count > 0:
523
+ avg_time = total_duration / tasks_processed_count
524
+ print(f"Average time per successfully processed task in this run: {avg_time:.2f} seconds")
525
+ else:
526
+ print("Average time per task: N/A (no tasks successfully processed in this run)")
527
+
528
+ # --- Final Save ---
529
+ print("\nPerforming final save of the dataset...")
530
+ save_checkpoint(updated_data, original_features, OUTPUT_DATASET_DIR, fallback_save_dir)
531
+
532
+ print("\nScript finished.")
r1-a/response_generation/minicpm.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import re # For parsing history
5
+ import uuid # For generating unique filenames
6
+ import random # For random voice selection
7
+ import torch # For MiniCPM-o
8
+ import librosa # For audio loading
9
+ from transformers import AutoModel, AutoTokenizer # For MiniCPM-o
10
+ from datasets import load_from_disk, Dataset, Features, Audio, Value # Import necessary types
11
+ from dotenv import load_dotenv
12
+ import datetime # For ETA formatting
13
+ from tqdm import tqdm # Import tqdm
14
+ import traceback # For detailed error printing
15
+
16
+ # --- Configuration ---
17
+ load_dotenv()
18
+
19
+ # 1. Model & Tokenizer Setup
20
+ MINICPMO_MODEL_NAME = "minicpm" # Name used in the dataset to identify tasks for this model
21
+ MINICPMO_HF_ID = 'openbmb/MiniCPM-o-2_6'
22
+ MINICPMO_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
23
+ MINICPMO_DTYPE = torch.bfloat16
24
+ MINICPMO_ATTN_IMPL = 'sdpa'
25
+
26
+ # 2. Dataset Paths
27
+ INPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_sampling_tasks" # Original source
28
+ OUTPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_tasks_with_minicpmo" # Where processed data is saved/resumed from
29
+
30
+ # 3. Output Audio Configuration
31
+ OUTPUT_AUDIO_ROOT_DIR = "/root/autodl-tmp/audio-r1/r1-a/generated_audio/minicpmo" # Where generated audio files are saved
32
+ OUTPUT_AUDIO_FORMAT = "wav"
33
+ OUTPUT_AUDIO_SAMPLERATE = 16000
34
+
35
+ # --- !!! IMPORTANT: Update these paths to your actual reference voice files !!! ---
36
+ REF_VOICE_PATHS = {
37
+ "female": "/root/autodl-tmp/audio-r1/r1-a/response_generation/minicpm/MiniCPM-o/assets/input_examples/assistant_female_voice.wav",
38
+ "male": "/root/autodl-tmp/audio-r1/r1-a/response_generation/minicpm/MiniCPM-o/assets/input_examples/assistant_male_voice.wav",
39
+ "default_female": "/root/autodl-tmp/audio-r1/r1-a/response_generation/minicpm/MiniCPM-o/assets/input_examples/assistant_default_female_voice.wav"
40
+ }
41
+ # --- End Reference Voice Paths ---
42
+ # Check voice paths exist early
43
+ for key, path in REF_VOICE_PATHS.items():
44
+ if not os.path.exists(path):
45
+ print(f"FATAL ERROR: Reference voice file not found for '{key}': {path}")
46
+ print("Please ensure the reference voice files exist at the specified paths in REF_VOICE_PATHS.")
47
+ exit(1) # Exit early if critical files are missing
48
+
49
+ AVAILABLE_MINICPMO_VOICES = list(REF_VOICE_PATHS.keys())
50
+
51
+ # 4. MiniCPM-o Call Settings
52
+ MODEL_MAX_NEW_TOKENS = 128
53
+ MODEL_TEMPERATURE = 0.3
54
+ MODEL_SAMPLING = True
55
+
56
+ # 5. Periodic Save Settings
57
+ SAVE_EVERY_N_SAMPLES = 50 # Save after processing this many samples
58
+
59
+ # --- Helper Functions ---
60
+
61
+ def format_time(seconds):
62
+ """Formats seconds into a human-readable string H:MM:SS"""
63
+ if seconds < 0:
64
+ return "N/A"
65
+ return str(datetime.timedelta(seconds=int(seconds)))
66
+
67
+ def load_audio_minicpm(audio_path, target_sr=16000):
68
+ """Loads audio using librosa, handling potential errors."""
69
+ if not audio_path or not os.path.exists(audio_path):
70
+ # print(f"Warning: Audio file not found or path is empty: {audio_path}") # Less verbose
71
+ return None
72
+ try:
73
+ audio_array, sr = librosa.load(audio_path, sr=None, mono=True)
74
+ if sr != target_sr:
75
+ # print(f" Resampling audio from {sr} Hz to {target_sr} Hz...") # Less verbose
76
+ audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=target_sr)
77
+ return audio_array
78
+ except Exception as e:
79
+ print(f"\nWarning: Error loading/processing audio file {audio_path}: {e}")
80
+ return None
81
+
82
+ def parse_ultra_history(history_str):
83
+ """Parses the specific history string format from ultra metadata."""
84
+ messages = []
85
+ # Relaxed pattern to capture content even if tags are slightly off or whitespace varies
86
+ pattern = re.compile(r"\[\s*(USER|ASSISTANT)\s*\]\s*([\s\S]*?)(?=\s*\[\s*(?:USER|ASSISTANT)\s*\]|$)")
87
+ matches = pattern.findall(history_str)
88
+ if not matches and history_str and history_str.strip():
89
+ # Simple fallback if standard pattern fails but there's content
90
+ if history_str.lower().startswith("user:") or history_str.lower().startswith("[user]"):
91
+ role = "user"
92
+ content = re.sub(r"^(user:|\[user\])\s*", "", history_str, flags=re.IGNORECASE).strip()
93
+ if content: messages.append({"role": role, "content": content})
94
+ elif history_str.lower().startswith("assistant:") or history_str.lower().startswith("[assistant]"):
95
+ role = "assistant"
96
+ content = re.sub(r"^(assistant:|\[assistant\])\s*", "", history_str, flags=re.IGNORECASE).strip()
97
+ if content: messages.append({"role": role, "content": content})
98
+ else:
99
+ print(f"Warning: Could not parse history string format: {history_str[:100]}...")
100
+ return messages # Return whatever was parsed, even if empty
101
+
102
+ for role_tag, content in matches:
103
+ role = role_tag.strip().lower()
104
+ cleaned_content = content.strip()
105
+ if cleaned_content:
106
+ messages.append({"role": role, "content": cleaned_content})
107
+ # else: # Removed warning for empty content for brevity
108
+ # print(f"Warning: Empty content found for role {role_tag} in history.")
109
+ return messages
110
+
111
+
112
+ # --- MiniCPM-o Model Interaction Function ---
113
+ def call_minicpmo_model(model, tokenizer, history_messages, prompt_text, question_audio_path, output_audio_filepath):
114
+ """Calls the local MiniCPM-o model, saves audio, returns text and audio path."""
115
+ try:
116
+ # 1. Select and Load Random Reference Voice
117
+ selected_voice_key = random.choice(AVAILABLE_MINICPMO_VOICES)
118
+ ref_voice_path = REF_VOICE_PATHS[selected_voice_key]
119
+ ref_audio_array = load_audio_minicpm(ref_voice_path, target_sr=OUTPUT_AUDIO_SAMPLERATE)
120
+ if ref_audio_array is None:
121
+ print(f" Error: Failed to load reference voice: {ref_voice_path}")
122
+ return None, None # Signal failure
123
+
124
+ # 2. Generate System Prompt
125
+ sys_prompt = model.get_sys_prompt(ref_audio=ref_audio_array, mode='audio_assistant', language='en')
126
+
127
+ # 3. Load User Question Audio
128
+ user_audio_array = load_audio_minicpm(question_audio_path, target_sr=OUTPUT_AUDIO_SAMPLERATE)
129
+ if user_audio_array is None:
130
+ print(f" Error: Failed to load user question audio: {question_audio_path}")
131
+ return None, None # Signal failure
132
+
133
+ # 4. Construct User Message
134
+ user_message_content = []
135
+ if prompt_text and prompt_text.strip():
136
+ user_message_content.append(prompt_text.strip())
137
+ # Ensure user_audio_array is added only if loaded successfully
138
+ if user_audio_array is not None:
139
+ user_message_content.append(user_audio_array) # Add audio array
140
+ else:
141
+ print(" Warning: Proceeding without user audio due to loading error.")
142
+ # Optionally decide if you want to proceed without user audio or return error
143
+ # return None, None # If user audio is essential
144
+
145
+ user_message = {'role': 'user', 'content': user_message_content}
146
+
147
+ # 5. Construct Full Message List
148
+ msgs = [sys_prompt] + history_messages + [user_message]
149
+
150
+ # 6. Ensure Output Directory Exists
151
+ os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True)
152
+
153
+ # 7. Call Model's Chat Function
154
+ response_obj = model.chat(
155
+ msgs=msgs,
156
+ tokenizer=tokenizer,
157
+ sampling=MODEL_SAMPLING,
158
+ max_new_tokens=MODEL_MAX_NEW_TOKENS,
159
+ use_tts_template=True,
160
+ generate_audio=True,
161
+ temperature=MODEL_TEMPERATURE,
162
+ output_audio_path=output_audio_filepath # Model saves the audio directly
163
+ )
164
+
165
+ # --- Extract text from the response object ---
166
+ response_text = None
167
+ if hasattr(response_obj, 'text'):
168
+ response_text = response_obj.text
169
+ elif hasattr(response_obj, 'content'):
170
+ response_text = response_obj.content
171
+ elif isinstance(response_obj, str):
172
+ response_text = response_obj
173
+ else:
174
+ print(f" Warning: Could not automatically extract text from model response object of type {type(response_obj)}. Response object dir: {dir(response_obj)}")
175
+ response_text = "[ERROR: Could not extract text]"
176
+
177
+ # Ensure response_text is a string before stripping
178
+ response_text_cleaned = ""
179
+ if isinstance(response_text, str):
180
+ response_text_cleaned = response_text.strip()
181
+ elif response_text is not None:
182
+ response_text_cleaned = str(response_text).strip()
183
+
184
+ # 8. Check if audio file was actually created by the model
185
+ if os.path.exists(output_audio_filepath) and os.path.getsize(output_audio_filepath) > 0: # Check size too
186
+ # Success: Return text and the path where the model saved the audio
187
+ return response_text_cleaned, output_audio_filepath
188
+ else:
189
+ print(f" Error: Model finished but output audio file not found or empty at {output_audio_filepath}")
190
+ # Attempt cleanup if file exists but is empty
191
+ if os.path.exists(output_audio_filepath):
192
+ try:
193
+ os.remove(output_audio_filepath)
194
+ except OSError as rm_err:
195
+ print(f" Warning: Could not remove empty file {output_audio_filepath}: {rm_err}")
196
+ return response_text_cleaned, None # Return text (if any) but signal audio failure
197
+
198
+ except Exception as e:
199
+ print(f"\n --- Error during MiniCPM-o model call for {os.path.basename(question_audio_path)} ---")
200
+ print(f" Exception Type: {type(e).__name__}")
201
+ print(f" Error Details: {e}")
202
+ print(" Traceback:")
203
+ traceback.print_exc()
204
+ print(" --- End Error Details ---")
205
+ # Attempt cleanup of potentially incomplete output file
206
+ if 'output_audio_filepath' in locals() and os.path.exists(output_audio_filepath):
207
+ try:
208
+ os.remove(output_audio_filepath)
209
+ except OSError as rm_err:
210
+ print(f" Warning: Could not remove file {output_audio_filepath} after error: {rm_err}")
211
+ return None, None # Indicate failure
212
+
213
+ # --- Dataset Saving Function ---
214
+ def save_checkpoint(data_list, features, output_dir, fallback_dir=None):
215
+ """Saves the current state of the data list as a Hugging Face Dataset."""
216
+ if not data_list:
217
+ print("\nSkipping checkpoint save: data list is empty.")
218
+ return
219
+
220
+ print(f"\nSaving checkpoint with {len(data_list)} rows to {output_dir}...")
221
+ try:
222
+ # Ensure the list contains dictionaries, not Dataset rows or other objects
223
+ data_to_save = [dict(item) for item in data_list]
224
+ # Create dataset from the current list of dictionaries using original features
225
+ updated_dataset = Dataset.from_list(data_to_save, features=features)
226
+ # Ensure output directory exists before saving
227
+ os.makedirs(output_dir, exist_ok=True)
228
+ updated_dataset.save_to_disk(output_dir)
229
+ print("Checkpoint saved successfully.")
230
+ except Exception as e:
231
+ print(f"Error saving checkpoint dataset using save_to_disk: {e}")
232
+ traceback.print_exc()
233
+ if fallback_dir:
234
+ fallback_path = os.path.join(fallback_dir, f"updated_minicpmo_data_checkpoint_{int(time.time())}.jsonl")
235
+ print(f"Attempting to save data as JSON Lines fallback to: {fallback_path}")
236
+ try:
237
+ os.makedirs(fallback_dir, exist_ok=True)
238
+ with open(fallback_path, 'w', encoding='utf-8') as f:
239
+ for item in data_to_save:
240
+ # Ensure all values are serializable
241
+ serializable_item = {}
242
+ for k, v in item.items():
243
+ if isinstance(v, (datetime.datetime, datetime.date)):
244
+ serializable_item[k] = v.isoformat()
245
+ elif isinstance(v, bytes):
246
+ serializable_item[k] = v.decode('utf-8', errors='ignore')
247
+ # Add handling for specific non-serializable types if they appear
248
+ elif not isinstance(v, (str, int, float, bool, list, dict, type(None))):
249
+ print(f" Warning: Converting non-standard type {type(v)} for key '{k}' to string for JSON fallback.")
250
+ serializable_item[k] = str(v)
251
+ else:
252
+ serializable_item[k] = v
253
+ try:
254
+ f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n')
255
+ except TypeError as json_type_err:
256
+ print(f" Skipping row due to JSON serialization error: {json_type_err} in item part: {k}={v}")
257
+ print("Fallback JSON Lines checkpoint saved successfully.")
258
+ except Exception as json_e:
259
+ print(f"Error saving fallback JSON Lines checkpoint: {json_e}")
260
+
261
+
262
+ # =============================================
263
+ # --- Main Processing Logic ---
264
+ # =============================================
265
+
266
+ # --- Dataset Loading (Modified for Resumption) ---
267
+ print("="*30)
268
+ print("STEP 1: Loading Dataset")
269
+ print("="*30)
270
+ dataset = None
271
+ original_features = None # Initialize
272
+
273
+ if os.path.exists(OUTPUT_DATASET_DIR):
274
+ print(f"Found existing processed dataset directory at: {OUTPUT_DATASET_DIR}")
275
+ print("Attempting to load it to resume processing...")
276
+ try:
277
+ # Need write permissions check sometimes? If saving fails later.
278
+ dataset = load_from_disk(OUTPUT_DATASET_DIR)
279
+ original_features = dataset.features # Get features from the loaded dataset
280
+ print(f"Resumed dataset loaded successfully with {len(dataset)} rows.")
281
+ print(f"Features from resumed dataset: {original_features}")
282
+ except Exception as e:
283
+ print(f"Warning: Error loading existing dataset from {OUTPUT_DATASET_DIR}: {e}")
284
+ traceback.print_exc()
285
+ print("Will attempt to load the original input dataset instead.")
286
+ dataset = None # Reset dataset variable
287
+ else:
288
+ print(f"No existing processed dataset found at {OUTPUT_DATASET_DIR}.")
289
+ print("Will attempt to load the original input dataset.")
290
+
291
+
292
+ # If dataset is still None (either output dir didn't exist or loading it failed), load from input
293
+ if dataset is None:
294
+ print(f"\nLoading original input dataset from: {INPUT_DATASET_DIR}")
295
+ if not os.path.exists(INPUT_DATASET_DIR):
296
+ print(f"FATAL: Original input dataset directory not found at {INPUT_DATASET_DIR}")
297
+ exit(1)
298
+ try:
299
+ dataset = load_from_disk(INPUT_DATASET_DIR)
300
+ original_features = dataset.features # Get features from the input dataset
301
+ print(f"Original input dataset loaded successfully with {len(dataset)} rows.")
302
+ print(f"Features from input dataset: {original_features}")
303
+ except Exception as e:
304
+ print(f"FATAL: Error loading original input dataset from {INPUT_DATASET_DIR}: {e}")
305
+ traceback.print_exc()
306
+ exit(1)
307
+
308
+ # --- Ensure dataset was loaded ---
309
+ if dataset is None or original_features is None:
310
+ print("FATAL: Failed to load any dataset. Exiting.")
311
+ exit(1)
312
+ # --- End Dataset Loading Modification ---
313
+
314
+
315
+ # --- Pre-computation Step: Identify and Prioritize Tasks ---
316
+ print("\n" + "="*30)
317
+ print("STEP 2: Identifying Tasks to Process")
318
+ print("="*30)
319
+ # NO CHANGES NEEDED HERE. This logic will now run on the dataset loaded above
320
+ # (which could be the original input or the partially processed output).
321
+ # It correctly identifies tasks where model is 'minicpm' and response_text is missing.
322
+ pkusafe_tasks_indices = []
323
+ other_tasks_indices = []
324
+
325
+ # Iterate through the loaded dataset structure
326
+ for idx, row in enumerate(dataset):
327
+ source_dataset = row.get('source_dataset')
328
+ processed_in_row = False # Flag to ensure we only pick one slot per row initially
329
+ for i in range(1, 4): # Check slots 1, 2, 3
330
+ model_key = f"model_{i}"
331
+ response_text_key = f"response_text_{i}"
332
+ # Check if the slot is assigned to minicpm and is NOT yet filled
333
+ is_minicpm_task = row.get(model_key) == MINICPMO_MODEL_NAME
334
+ # Crucially, check if the response text field is missing or empty in the loaded data
335
+ is_unfilled = not row.get(response_text_key) # True if None or empty string
336
+
337
+ if is_minicpm_task and is_unfilled and not processed_in_row:
338
+ task_info = (idx, i) # Store tuple of (original_row_index, slot_index)
339
+ if source_dataset == 'pkusafe':
340
+ pkusafe_tasks_indices.append(task_info)
341
+ else:
342
+ other_tasks_indices.append(task_info)
343
+ processed_in_row = True # Mark as processed for this row for task identification
344
+
345
+ # Combine lists, prioritizing pkusafe
346
+ tasks_to_process_indices = pkusafe_tasks_indices + other_tasks_indices
347
+ total_tasks_to_process = len(tasks_to_process_indices)
348
+
349
+ print(f"Found {len(pkusafe_tasks_indices)} 'pkusafe' tasks and {len(other_tasks_indices)} other tasks requiring '{MINICPMO_MODEL_NAME}' processing in the loaded dataset.")
350
+ print(f"Total tasks remaining to process: {total_tasks_to_process}")
351
+
352
+ if total_tasks_to_process == 0:
353
+ print("\nNo remaining tasks to process for MiniCPM-o based on the loaded dataset.")
354
+ # Optionally, perform a final save here if you want ensure the output dir reflects the 'completed' state
355
+ # print("Performing a final save to ensure consistency...")
356
+ # final_data_list = [dict(row) for row in dataset] # Convert dataset rows back to dicts
357
+ # fallback_save_dir_final = os.path.join(os.path.dirname(OUTPUT_DATASET_DIR), "minicpmo_checkpoints_fallback")
358
+ # save_checkpoint(final_data_list, original_features, OUTPUT_DATASET_DIR, fallback_save_dir_final)
359
+ print("Exiting.")
360
+ exit(0)
361
+ # --- End Pre-computation Step ---
362
+
363
+
364
+ # --- Load Model (Only if tasks exist) ---
365
+ print("\n" + "="*30)
366
+ print("STEP 3: Loading Model")
367
+ print("="*30)
368
+ print(f"Loading MiniCPM-o model ({MINICPMO_HF_ID}) and tokenizer...")
369
+ try:
370
+ model = AutoModel.from_pretrained(
371
+ MINICPMO_HF_ID,
372
+ trust_remote_code=True,
373
+ attn_implementation=MINICPMO_ATTN_IMPL,
374
+ torch_dtype=MINICPMO_DTYPE
375
+ )
376
+ model = model.eval().to(MINICPMO_DEVICE)
377
+ tokenizer = AutoTokenizer.from_pretrained(MINICPMO_HF_ID, trust_remote_code=True)
378
+
379
+ print("Initializing TTS...")
380
+ model.init_tts()
381
+ model.tts.float() # Use float32 for TTS stability
382
+ print(f"Model and TTS initialized successfully on {MINICPMO_DEVICE}.")
383
+ except Exception as e:
384
+ print(f"Error loading MiniCPM-o model or tokenizer: {e}")
385
+ traceback.print_exc()
386
+ exit(1)
387
+
388
+
389
+ # --- Prepare for Processing ---
390
+ # Create output directory for MiniCPM-o audio if it doesn't exist
391
+ os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True)
392
+ # Ensure the main output dataset directory exists for saving checkpoints
393
+ os.makedirs(OUTPUT_DATASET_DIR, exist_ok=True)
394
+ # Define and create fallback directory
395
+ fallback_save_dir = os.path.join(os.path.dirname(OUTPUT_DATASET_DIR), "minicpmo_checkpoints_fallback")
396
+ os.makedirs(fallback_save_dir, exist_ok=True)
397
+
398
+
399
+ # Create a mutable list of dictionaries from the loaded dataset for updates
400
+ # This is crucial as Hugging Face datasets are typically immutable
401
+ updated_data = [dict(row) for row in dataset] # Convert each row to a dictionary
402
+
403
+ tasks_processed_count = 0 # Count successful completions for average time calculation
404
+ start_total_time = time.time()
405
+
406
+ print("\n" + "="*30)
407
+ print(f"STEP 4: Starting MiniCPM-o Processing for {total_tasks_to_process} Tasks")
408
+ print("="*30)
409
+ # Use tqdm for the progress bar, iterating over the identified task indices
410
+ pbar = tqdm(enumerate(tasks_to_process_indices), total=total_tasks_to_process, desc="Processing MiniCPM-o Tasks")
411
+ for loop_idx, (row_idx, slot_i) in pbar:
412
+ # Get the row data *from our mutable list* using the original index
413
+ row = updated_data[row_idx] # This is already a dictionary
414
+
415
+ # Set description in tqdm dynamically
416
+ pbar.set_description(f"Processing Row {row_idx}, Slot {slot_i}")
417
+
418
+ prompt_text_key = f"prompt_text_{slot_i}"
419
+ response_text_key = f"response_text_{slot_i}"
420
+ response_audio_key = f"response_audio_path_{slot_i}"
421
+ model_key = f"model_{slot_i}" # Get model key for verification
422
+
423
+ # --- Sanity Check: Ensure this is still a valid MiniCPM-o task ---
424
+ # (This might be redundant if identification was perfect, but good for safety)
425
+ if row.get(model_key) != MINICPMO_MODEL_NAME:
426
+ tqdm.write(f" Skipping Row {row_idx}, Slot {slot_i}: Model is no longer '{MINICPMO_MODEL_NAME}'.")
427
+ continue
428
+ if row.get(response_text_key): # Check again if it got filled somehow concurrently (unlikely here)
429
+ tqdm.write(f" Skipping Row {row_idx}, Slot {slot_i}: Already has response text '{str(row.get(response_text_key))[:50]}...'.")
430
+ continue
431
+
432
+ # --- Prepare Model Inputs ---
433
+ prompt_text = row.get(prompt_text_key, "")
434
+ question_audio_path = row.get('question_audio')
435
+ metadata_str = row.get('metadata', "{}")
436
+ source_dataset = row.get('source_dataset') # Used for history parsing
437
+
438
+ # Basic check for essential input audio
439
+ if not question_audio_path or not os.path.exists(question_audio_path):
440
+ tqdm.write(f" Error: Input audio path missing or invalid for Row {row_idx}: '{question_audio_path}'. Skipping model call.")
441
+ # Update the specific row in the list (mark as failed/skipped)
442
+ updated_data[row_idx][response_text_key] = "[ERROR: Missing Input Audio]"
443
+ updated_data[row_idx][response_audio_key] = None
444
+ continue # Move to the next task
445
+
446
+ # Parse History
447
+ history_messages = []
448
+ if source_dataset == 'ultra' and metadata_str:
449
+ try:
450
+ metadata = json.loads(metadata_str)
451
+ history_str = metadata.get('history', '')
452
+ if history_str:
453
+ history_messages = parse_ultra_history(history_str)
454
+ except json.JSONDecodeError:
455
+ tqdm.write(f" Warning: Could not parse metadata JSON for row {row_idx}")
456
+ except Exception as hist_e:
457
+ tqdm.write(f" Warning: Error processing history for row {row_idx}: {hist_e}")
458
+ # Add elif blocks here if other datasets have different history formats in metadata
459
+
460
+ # Generate unique output audio filename
461
+ unique_id = str(uuid.uuid4())
462
+ output_audio_filename = f"minicpmo_row{row_idx}_slot{slot_i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}"
463
+ output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename)
464
+
465
+ # --- Call Model ---
466
+ # tqdm.write(f" Calling model for Row {row_idx}, Slot {slot_i} (Source: {source_dataset}). Output: {output_audio_filepath}") # More verbose
467
+ call_start_time = time.time()
468
+ response_text, saved_audio_path = call_minicpmo_model(
469
+ model,
470
+ tokenizer,
471
+ history_messages,
472
+ prompt_text,
473
+ question_audio_path,
474
+ output_audio_filepath
475
+ )
476
+ call_end_time = time.time()
477
+ tqdm.write(f" Row {row_idx}, Slot {slot_i}: Finished in {call_end_time - call_start_time:.2f}s. Text: '{str(response_text)[:50]}...', Audio: {os.path.basename(str(saved_audio_path))}")
478
+
479
+
480
+ # Store results directly into the list item (updated_data)
481
+ updated_data[row_idx][response_text_key] = response_text if response_text is not None else "[ERROR: Model Call Failed]"
482
+ updated_data[row_idx][response_audio_key] = saved_audio_path # Will be None if audio saving/generation failed
483
+
484
+ if response_text is not None and saved_audio_path is not None: # Count as successfully processed only if both text and audio are generated
485
+ tasks_processed_count += 1
486
+
487
+ # --- Periodic Saving ---
488
+ # Save after processing N samples (using loop_idx + 1 because index is 0-based)
489
+ # Also save on the very last iteration
490
+ processed_count_in_loop = loop_idx + 1
491
+ if processed_count_in_loop % SAVE_EVERY_N_SAMPLES == 0 or processed_count_in_loop == total_tasks_to_process:
492
+ save_checkpoint(updated_data, original_features, OUTPUT_DATASET_DIR, fallback_save_dir)
493
+
494
+ # Optional small delay if needed for hardware cooling, etc.
495
+ # time.sleep(0.1)
496
+
497
+
498
+ # --- Final Summary ---
499
+ end_total_time = time.time()
500
+ print("\n" + "="*30)
501
+ print("STEP 5: Processing Complete - Summary")
502
+ print("="*30)
503
+ print(f"Total tasks identified for processing in this run: {total_tasks_to_process}")
504
+ print(f"Total tasks successfully processed (generated text & audio) in this run: {tasks_processed_count}")
505
+ total_duration = end_total_time - start_total_time
506
+ print(f"Total processing time for this run: {format_time(total_duration)}")
507
+ if tasks_processed_count > 0:
508
+ avg_time = total_duration / tasks_processed_count
509
+ print(f"Average time per successfully processed task in this run: {avg_time:.2f} seconds")
510
+ else:
511
+ print("Average time per task: N/A (no tasks successfully processed in this run)")
512
+
513
+ # --- Final Save ---
514
+ # This ensures the very last state is saved, even if the last iteration didn't trigger the periodic save exactly.
515
+ # It might be redundant if SAVE_EVERY_N_SAMPLES aligns perfectly, but it's safe to include.
516
+ print("\nPerforming final save of the dataset...")
517
+ save_checkpoint(updated_data, original_features, OUTPUT_DATASET_DIR, fallback_save_dir)
518
+
519
+ print("\nScript finished.")
r1-a/response_generation/minicpm/MiniCPM-o/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.bk
2
+ __pycache__
3
+ .DS_Store
r1-a/response_generation/minicpm/MiniCPM-o/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2024 OpenBMB
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
r1-a/response_generation/minicpm/MiniCPM-o/README.md ADDED
The diff for this file is too large to render. See raw diff
 
r1-a/response_generation/minicpm/MiniCPM-o/README_zh.md ADDED
@@ -0,0 +1,2524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <img src="./assets/MiniCPM-o.png" width="300em" ></img>
4
+
5
+ **端侧可用的 GPT-4o 级视觉、语音、多模态实时流式大模型**
6
+
7
+ <strong>中文 |
8
+ [English](./README.md)</strong>
9
+
10
+
11
+
12
+ <span style="display: inline-flex; align-items: center; margin-right: 2px;">
13
+ <a href="docs/wechat.md" target="_blank"> 微信社区</a> &nbsp;|
14
+ </span>
15
+ <span style="display: inline-flex; align-items: center; margin-left: 2px;">
16
+ MiniCPM-V <a href="docs/best_practice_summary_zh.md" target="_blank">&nbsp; 📖 最佳实践</a>
17
+ </span>
18
+
19
+
20
+ <p align="center">
21
+ MiniCPM-o 2.6 <a href="https://huggingface.co/openbmb/MiniCPM-o-2_6">🤗</a> <a href="https://minicpm-omni-webdemo-us.modelbest.cn/"> 🤖</a> | MiniCPM-V 2.6 <a href="https://huggingface.co/openbmb/MiniCPM-V-2_6">🤗</a> <a href="http://120.92.209.146:8887/">🤖</a> |
22
+ 📄 技术报告 [<a href="https://openbmb.notion.site/MiniCPM-o-2-6-GPT-4o-188ede1b7a558084b3aedd669cb80730">中文</a>/<a href="https://openbmb.notion.site/MiniCPM-o-2-6-A-GPT-4o-Level-MLLM-for-Vision-Speech-and-Multimodal-Live-Streaming-on-Your-Phone-185ede1b7a558042b5d5e45e6b237da9">English</a>]
23
+ </p>
24
+
25
+ </div>
26
+
27
+
28
+ **MiniCPM-o** 是从 MiniCPM-V 升级的最新端侧多模态大模型系列。该系列模型可以以端到端方式,接受图像、视频、文本、音频作为输入,并生成高质量文本和语音输出。自2024年2月以来,我们以实现高性能和高效部署为目标,发布了6个版本的模型。目前系列中最值得关注的模型包括:
29
+
30
+
31
+ - **MiniCPM-o 2.6**: 🔥🔥🔥 MiniCPM-o 系列的最新、性能最佳模型。总参数量 8B,**视觉、语音和多模态流式能力达到了 GPT-4o-202405 级别**,是开源社区中模态支持最丰富、性能最佳的模型之一。在新的语音模式中,MiniCPM-o 2.6 **支持可配置声音的中英双语语音对话,还具备情感/语速/风格控制、端到端声音克隆、角色扮演等进阶能力**。模型也进一步提升了 MiniCPM-V 2.6 的 **OCR、可信行为、多语言支持和视频理解等视觉能力**。基于其领先的视觉 token 密度,MiniCPM-V 2.6 成为了**首个支持在 iPad 等端侧设备上进行多模态实时流式交互**的多模态大模型。
32
+
33
+ - **MiniCPM-V 2.6**: MiniCPM-V 系列中性能最佳的模型。总参数量 8B,单图、多图和视频理解性能**超越了 GPT-4V**。它取得了优于 **GPT-4o mini、Gemini 1.5 Pro 和 Claude 3.5 Sonnet**等的单图理解表现,并成为了首个支持在 iPad 等端侧设备上进行实时视频理解的多模态大模型。
34
+
35
+
36
+ ## 更新日志 <!-- omit in toc -->
37
+
38
+ #### 📌 置顶
39
+
40
+ * [2025.03.01] 🚀🚀🚀 MiniCPM-o 系列的对齐技术 RLAIF-V 被 CVPR 2025 接收了!其[代码](https://github.com/RLHF-V/RLAIF-V)、[数据](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset)、[论文](https://arxiv.org/abs/2405.17220)均已开源。
41
+
42
+ * [2025.01.24] 📢📢📢 MiniCPM-o 2.6 技术报告已发布! 欢迎点击[这里](https://openbmb.notion.site/MiniCPM-o-2-6-A-GPT-4o-Level-MLLM-for-Vision-Speech-and-Multimodal-Live-Streaming-on-Your-Phone-185ede1b7a558042b5d5e45e6b237da9)查看.
43
+
44
+ * [2025.01.23] 💡💡💡 MiniCPM-o 2.6 现在已被北大团队开发的 [Align-Anything](https://github.com/PKU-Alignment/align-anything),一个用于对齐全模态大模型的框架集成,支持 DPO 和 SFT 在视觉和音频模态上的微调。欢迎试用!
45
+
46
+ * [2025.01.19] 📢 **注意!** 我们正在努力将 MiniCPM-o 2.6 的支持合并到 llama.cpp、ollama、vLLM 的官方仓库,但还未完成。请大家暂时先使用我们提供的 fork 来进行部署:[llama.cpp](https://github.com/OpenBMB/llama.cpp/blob/minicpm-omni/examples/llava/README-minicpmo2.6.md)、[ollama](https://github.com/OpenBMB/ollama/blob/minicpm-v2.6/examples/minicpm-v2.6/README.md)、[vllm](https://github.com/OpenBMB/MiniCPM-o?tab=readme-ov-file#efficient-inference-with-llamacpp-ollama-vllm)。 **合并完成前,使用官方仓库可能会导致不可预期的问题**。
47
+
48
+ * [2025.01.19] ⭐️⭐️⭐️ MiniCPM-o 在 GitHub Trending 上登顶, Hugging Face Trending 上也达到了第二!
49
+
50
+ * [2025.01.17] 我们更新了 MiniCPM-o 2.6 int4 量化版本的使用方式,解决了模型初始化的问题,欢迎点击[这里](https://huggingface.co/openbmb/MiniCPM-o-2_6-int4)试用!
51
+
52
+ * [2025.01.13] 🔥🔥🔥 我们开源了 MiniCPM-o 2.6,该模型视觉、语音和多模态流式能力达到了 GPT-4o-202405 级别,进一步优化了 MiniCPM-V 2.6 的众多亮点能力,还支持了很多有趣的新功能。欢迎试用!
53
+
54
+ * [2024.08.17] 🚀🚀🚀 llama.cpp [官方仓库](https://github.com/ggerganov/llama.cpp)正式支持 MiniCPM-V 2.6 啦!点击[这里](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf)查看各种大小的 GGUF 版本。
55
+
56
+ * [2024.08.06] 🔥🔥🔥 我们开源了 MiniCPM-V 2.6,该模型在单图、多图和视频理解方面取得了优于 GPT-4V 的表现。我们还进一步���升了 MiniCPM-Llama3-V 2.5 的多项亮点能力,并首次支持了 iPad 上的实时视频理解。欢迎试用!
57
+
58
+ * [2024.08.03] MiniCPM-Llama3-V 2.5 技术报告已发布!欢迎点击[这里](https://arxiv.org/abs/2408.01800)查看。
59
+
60
+ * [2024.05.23] 🔥🔥🔥 MiniCPM-V 在 GitHub Trending 和 Hugging Face Trending 上登顶!MiniCPM-Llama3-V 2.5 Demo 被 Hugging Face 的 Gradio 官方账户推荐,欢迎点击[这里](https://huggingface.co/spaces/openbmb/MiniCPM-Llama3-V-2_5)体验!
61
+
62
+
63
+ <br>
64
+
65
+ <details>
66
+ <summary>点击查看完整更新日志。</summary>
67
+
68
+ * [2024.08.15] MiniCPM-V 2.6 现在支持多图像 SFT。有关更多详细信息,请参阅[微调文档](https://github.com/OpenBMB/MiniCPM-V/tree/main/finetune)
69
+ * [2024.08.14] MiniCPM-V 2.6 现在可以通过 SWIFT 框架 [微调](https://github.com/modelscope/ms-swift/issues/1613) 了!
70
+ * [2024.08.10] 🚀🚀🚀 llama.cpp [官方仓库](https://github.com/ggerganov/llama.cpp)正式支持 MiniCPM-Llama3-V 2.5 啦!点击[这里](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf/tree/main)查看各种大小的 GGUF 版本。
71
+ * [2024.07.19] MiniCPM-Llama3-V 2.5 现已支持[vLLM](#vllm-部署-) !
72
+ * [2024.06.03] 现在,你可以利用多张低显存显卡(12G/16G)进行GPU串行推理。详情请参见该[文档](https://github.com/OpenBMB/MiniCPM-V/blob/main/docs/inference_on_multiple_gpus.md)配置。
73
+ * [2024.05.28] 💫 我们现在支持 MiniCPM-Llama3-V 2.5 的 LoRA 微调,更多内存使用统计信息可以在[这里](https://github.com/OpenBMB/MiniCPM-V/tree/main/finetune#model-fine-tuning-memory-usage-statistics)找到。
74
+ * [2024.05.28] 💥 MiniCPM-Llama3-V 2.5 现在在 llama.cpp 和 ollama 中完全支持其功能!**请拉取我们最新的 fork 来使用**:[llama.cpp](https://github.com/OpenBMB/llama.cpp/blob/minicpm-v2.5/examples/minicpmv/README.md) & [ollama](https://github.com/OpenBMB/ollama/tree/minicpm-v2.5/examples/minicpm-v2.5)。我们还发布了各种大小的 GGUF 版本,请点击[这里](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf/tree/main)查看。请注意,**目前官方仓库尚未支持 MiniCPM-Llama3-V 2.5**,我们也正积极推进将这些功能合并到 llama.cpp & ollama 官方仓库,敬请关注!
75
+ * [2024.05.25] MiniCPM-Llama3-V 2.5 [支持流式输出和自定义系统提示词](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5#usage)了,欢迎试用!
76
+ * [2024.05.24] 我们开源了 MiniCPM-Llama3-V 2.5 [gguf](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf),支持 [llama.cpp](#llamacpp-部署) 推理!实现端侧 6-8 tokens/s 的流畅解码,欢迎试用!
77
+ * [2024.05.23] 🔍 我们添加了Phi-3-vision-128k-instruct 与 MiniCPM-Llama3-V 2.5的全面对比,包括基准测试评估、多语言能力和推理效率 🌟📊🌍🚀。点击[这里](./docs/compare_with_phi-3_vision.md)查看详细信息。
78
+ * [2024.05.20] 我们开源了 MiniCPM-Llama3-V 2.5,增强了 OCR 能力,支持 30 多种语言,并首次在端侧实现了 GPT-4V 级的多模态能力!我们提供了[高效推理](#手机端部署)和[简易微调](./finetune/readme.md)的支持,欢迎试用!
79
+ * [2024.04.23] 我们增加了MiniCPM-V 2.0对 [vLLM](#vllm-部署-) 的支持,欢迎体验!
80
+ * [2024.04.18] 我们在 HuggingFace Space 新增了 MiniCPM-V 2.0 的 [demo](https://huggingface.co/spaces/openbmb/MiniCPM-V-2),欢迎体验!
81
+ * [2024.04.17] MiniCPM-V 2.0 现在支持用户部署本地 [WebUI Demo](#本地webui-demo部署) 了,欢迎试用!
82
+ * [2024.04.15] MiniCPM-V 2.0 现在可以通过 SWIFT 框架 [微调](https://github.com/modelscope/swift/blob/main/docs/source/Multi-Modal/minicpm-v-2最佳实践.md) 了,支持流式输出!
83
+ * [2024.04.12] 我们开源了 MiniCPM-V 2.0,该模型刷新了 OCRBench 开源模型最佳成绩,在场景文字识别能力上比肩 Gemini Pro,同时还在综合了 11 个主流多模态大模型评测基准的 <a href="https://rank.opencompass.org.cn/leaderboard-multimodal">OpenCompass</a> 榜单上超过了 Qwen-VL-Chat 10B、CogVLM-Chat 17B 和 Yi-VL 34B 等更大参数规模的模型!点击<a href="https://openbmb.vercel.app/minicpm-v-2">这里</a>查看 MiniCPM-V 2.0 技术博客。
84
+ * [2024.03.14] MiniCPM-V 现在支持 SWIFT 框架下的[微调](https://github.com/modelscope/swift/blob/main/docs/source/Multi-Modal/minicpm-v最佳实践.md)了,感谢 [Jintao](https://github.com/Jintao-Huang) 的贡献!
85
+ * [2024.03.01] MiniCPM-V 现在支持在 Mac 电脑上进行部署!
86
+ * [2024.02.01] 我们开源了 MiniCPM-V 和 OmniLMM-12B,分别可以支持高效的端侧部署和同规模领先的多模态能力!
87
+ </details>
88
+
89
+
90
+ ## 目录 <!-- omit in toc -->
91
+
92
+ - [MiniCPM-o 2.6](#minicpm-o-26)
93
+ - [MiniCPM-V 2.6](#minicpm-v-26)
94
+ - [Chat with Our Demo on Gradio 🤗](#chat-with-our-demo-on-gradio-)
95
+ - [推理](#推理)
96
+ - [模型库](#模型库)
97
+ - [多轮对话](#多轮对话)
98
+ - [多图对话](#多图对话)
99
+ - [少样本上下文对话](#少样本上下文对话)
100
+ - [视频��话](#视频对话)
101
+ - [语音对话](#语音对话)
102
+ - [Mimick](#mimick)
103
+ - [可配置声音的语音对话](#可配置声音的语音对话)
104
+ - [更多语音任务](#更多语音任务)
105
+ - [多模态流式交互](#多模态流式交互)
106
+ - [多卡推理](#多卡推理)
107
+ - [Mac 推理](#mac-推理)
108
+ - [基于 llama.cpp、ollama、vLLM 的高效推理](#基于-llamacppollamavllm-的高效推理)
109
+ - [微调](#微调)
110
+ - [FAQs](#faqs)
111
+ - [模型局限性](#模型局限性)
112
+
113
+ ## MiniCPM-o 2.6
114
+
115
+
116
+ MiniCPM-o 2.6 是 MiniCPM-o 系列的最新、性能最佳模型。该模型基于 SigLip-400M、Whisper-medium-300M、ChatTTS-200M 和 Qwen2.5-7B 构建,共 8B 参数,通过端到端方式训练和推理。相比 MiniCPM-V 2.6,该模型在性能上有了显著提升,并支持了实时语音对话和多模态流式交互的新功能。MiniCPM-o 2.6 的主要特性包括:
117
+
118
+
119
+ - 🔥 **领先的视觉能力。**
120
+ MiniCPM-o 2.6 在 OpenCompass 榜单上(综合 8 个主流多模态评测基准)平均得分 70.2,**以 8B 量级的大小在单图理解方面超越了 GPT-4o-202405、Gemini 1.5 Pro 和 Claude 3.5 Sonnet 等主流商用闭源多模态大模型**。此外,它的多图和视频理解表现也**优于 GPT-4V 和 Claude 3.5 Sonnet**,并展现出了优秀的上下文学习能力。
121
+
122
+ - 🎙 **出色的语音能力。**
123
+ MiniCPM-o 2.6 **支持可配置声音的中英双语实时对话**。MiniCPM-o 2.6 在语音理解任务(如 ASR 和 STT 等)**优于 GPT-4o-realtime**,并在语音对话的语义和声学评估中展现了**开源模型中最高的语音生成性能**。它还支持情绪/语速/风格控制、语音克隆、角色扮演等进阶能力。
124
+
125
+ - 🎬 **强大的多模态流式交互能力。**
126
+ 作为一项新功能,MiniCPM-o 2.6 能够**接受连续的视频和音频流,并和用户进行实时语音交互**。在针对实时视频理解、全模态视音频理解、多模态上下文理解的综合评测基准 StreamingBench 中,MiniCPM-o 2.6 取得开源社区最佳水平,并**超过了 GPT-4o-202408 和 Claude 3.5 Sonnet**。
127
+
128
+ - 💪 **强大的 OCR 能力及其他功能。**
129
+ MiniCPM-o 2.6 进一步优化了 MiniCPM-V 2.6 的众多视觉理解能力,其可以处理任意长宽比的图像,像素数可达 180 万(如 1344x1344)。在 OCRBench 上取得**25B 以下最佳水平,超过 GPT-4o-202405 等商用闭源模型**。基于最新的 [RLHF-V](https://rlhf-v.github.io/)、[RLAIF-V](https://github.com/RLHF-V/RLAIF-V/) 和 [VisCPM](https://github.com/OpenBMB/VisCPM) 技术,其具备了**可信的多模态行为**,在 MMHal-Bench 上超过了 GPT-4o 和 Claude 3.5,并支持英语、中文、德语、法语、意大利语、韩语等**30多种语言**。
130
+
131
+ - 🚀 **卓越的效率。**
132
+ 除了对个人用户友好的模型大小,MiniCPM-o 2.6 还表现出**最先进的视觉 token 密度**(即每个视觉 token 编码的像素数量)。它**仅需 640 个 token 即可处理 180 万像素图像,比大多数模型少 75%**。这一特性优化了模型的推理速度、首 token 延迟、内存占用和功耗。因此,MiniCPM-o 2.6 可以支持 iPad 等终端设备上的高效**多模态实时流式交互**。
133
+
134
+
135
+ - 💫 **易于使用。**
136
+ MiniCPM-o 2.6 可以通过多种方式轻松使用:(1) [llama.cpp](https://github.com/OpenBMB/llama.cpp/blob/minicpm-omni/examples/llava/README-minicpmo2.6.md) 支持在本地设备上进行高效的 CPU 推理,(2) [int4](https://huggingface.co/openbmb/MiniCPM-V-2_6-int4) 和 [GGUF](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf) 格式的量化模型,有 16 种尺寸,(3) [vLLM](#基于-llamacppollamavllm-的高效推理) 支持高吞吐量和内存高效的推理,(4) 通过[LLaMA-Factory](./docs/llamafactory_train_and_infer.md)框架针对新领域和任务进行微调,(5) 使用 [Gradio](#本地-webui-demo-) 快速设置本地 WebUI 演示,(6) 部署于服务器的在线 [demo](https://minicpm-omni-webdemo-us.modelbest.cn/)。
137
+
138
+ **模型架构。**
139
+
140
+ - **端到端全模态架构。** 通过**端到端**的方式连接和训练不同模态的编/解码模块以充分利用丰富的多模态知识。模型完全使用 CE 损失端到端训练。
141
+ - **全模态流式机制。** (1) 我们将不同模态的离线编/解码器改造为适用于**流式输入/输出**的在线模块。 (2) 我们针对大语言模型基座设计了**时分复用的全模态流式信息处理机制**,将平行的不同模态的信息流拆分重组为周期性时间片序列。
142
+ - **可配置的声音方案。** 我们设计了新的多模态系统提示,包含传统文本系统提示词,和**用于指定模型声音的语音系统提示词**。模型可在推理时灵活地通过文字或语音样例控制声音风格,并支持端到端声音克隆和音色创建等高级能力。
143
+
144
+ <div align="center">
145
+ <img src="./assets/minicpm-o-26-framework-v2.png" , width=80%>
146
+ </div>
147
+
148
+ <br>
149
+
150
+
151
+
152
+ ### 性能评估 <!-- omit in toc -->
153
+
154
+ <div align="center">
155
+ <img src="./assets/radar.jpg", width=80%>
156
+ </div>
157
+
158
+ <details>
159
+ <summary>点击查看视觉理解能力详细评测结果。</summary>
160
+
161
+ **图像理解能力**
162
+
163
+ <div align="center">
164
+ <table style="margin: 0px auto;">
165
+ <thead>
166
+ <tr>
167
+ <th align="left">Model</th>
168
+ <th>Size</th>
169
+ <th>Token Density<sup>+</sup></th>
170
+ <th>OpenCompass</th>
171
+ <th>OCRBench</th>
172
+ <th>MathVista mini</th>
173
+ <th>ChartQA</th>
174
+ <th>MMVet</th>
175
+ <th>MMStar</th>
176
+ <th>MME</th>
177
+ <th>MMB1.1 test</th>
178
+ <th>AI2D</th>
179
+ <th>MMMU val</th>
180
+ <th>HallusionBench</th>
181
+ <th>TextVQA val</th>
182
+ <th>DocVQA test</th>
183
+ <th>MathVerse mini</th>
184
+ <th>MathVision</th>
185
+ <th>MMHal Score</th>
186
+ </tr>
187
+ </thead>
188
+ <tbody align="center">
189
+ <tr>
190
+ <td colspan="19" align="left"><strong>Proprietary</strong></td>
191
+ </tr>
192
+ <tr>
193
+ <td nowrap="nowrap" align="left">GPT-4o-20240513</td>
194
+ <td>-</td>
195
+ <td>1088</td>
196
+ <td><u>69.9</u></td>
197
+ <td>736</td>
198
+ <td>61.3</td>
199
+ <td>85.7</td>
200
+ <td><strong>69.1</strong></td>
201
+ <td>63.9</td>
202
+ <td>2328.7</td>
203
+ <td>82.2</td>
204
+ <td>84.6</td>
205
+ <td><strong>69.2</strong></td>
206
+ <td><strong>55.0</strong></td>
207
+ <td>-</td>
208
+ <td>92.8</td>
209
+ <td><strong>50.2</strong></td>
210
+ <td><strong>30.4</strong></td>
211
+ <td><u>3.6</u></td>
212
+ </tr>
213
+ <tr>
214
+ <td nowrap="nowrap" align="left">Claude3.5-Sonnet</td>
215
+ <td>-</td>
216
+ <td>750</td>
217
+ <td>67.9</td>
218
+ <td>788</td>
219
+ <td>61.6</td>
220
+ <td><strong>90.8</strong></td>
221
+ <td>66.0</td>
222
+ <td>62.2</td>
223
+ <td>1920.0</td>
224
+ <td>78.5</td>
225
+ <td>80.2</td>
226
+ <td><u>65.9</u></td>
227
+ <td>49.9</td>
228
+ <td>-</td>
229
+ <td><strong>95.2</strong></td>
230
+ <td>-</td>
231
+ <td>-</td>
232
+ <td>3.4</td>
233
+ </tr>
234
+ <tr>
235
+ <td nowrap="nowrap" align="left">Gemini 1.5 Pro</td>
236
+ <td>-</td>
237
+ <td>-</td>
238
+ <td>64.4</td>
239
+ <td>754</td>
240
+ <td>57.7</td>
241
+ <td>81.3</td>
242
+ <td>64.0</td>
243
+ <td>59.1</td>
244
+ <td>2110.6</td>
245
+ <td>73.9</td>
246
+ <td>79.1</td>
247
+ <td>60.6</td>
248
+ <td>45.6</td>
249
+ <td>73.5</td>
250
+ <td>86.5</td>
251
+ <td>-</td>
252
+ <td>19.2</td>
253
+ <td>-</td>
254
+ </tr>
255
+ <tr>
256
+ <td nowrap="nowrap" align="left">GPT-4o-mini-20240718</td>
257
+ <td>-</td>
258
+ <td>1088</td>
259
+ <td>64.1</td>
260
+ <td>785</td>
261
+ <td>52.4</td>
262
+ <td>-</td>
263
+ <td>66.9</td>
264
+ <td>54.8</td>
265
+ <td>2003.4</td>
266
+ <td>76.0</td>
267
+ <td>77.8</td>
268
+ <td>60.0</td>
269
+ <td>46.1</td>
270
+ <td>-</td>
271
+ <td>-</td>
272
+ <td>-</td>
273
+ <td>-</td>
274
+ <td>3.3</td>
275
+ </tr>
276
+ <tr>
277
+ <td colspan="19" align="left"><strong>Open Source</strong></td>
278
+ </tr>
279
+ <tr>
280
+ <td nowrap="nowrap" align="left">Cambrian-34B</td>
281
+ <td>34B</td>
282
+ <td><u>1820</u></td>
283
+ <td>58.3</td>
284
+ <td>591</td>
285
+ <td>50.3</td>
286
+ <td>75.6</td>
287
+ <td>53.2</td>
288
+ <td>54.2</td>
289
+ <td>2049.9</td>
290
+ <td>77.8</td>
291
+ <td>79.5</td>
292
+ <td>50.4</td>
293
+ <td>41.6</td>
294
+ <td>76.7</td>
295
+ <td>75.5</td>
296
+ <td>-</td>
297
+ <td>-</td>
298
+ <td>-</td>
299
+ </tr>
300
+ <tr>
301
+ <td nowrap="nowrap" align="left">GLM-4V-9B</td>
302
+ <td>13B</td>
303
+ <td>784</td>
304
+ <td>59.1</td>
305
+ <td>776</td>
306
+ <td>51.1</td>
307
+ <td>-</td>
308
+ <td>58.0</td>
309
+ <td>54.8</td>
310
+ <td>2018.8</td>
311
+ <td>67.9</td>
312
+ <td>71.2</td>
313
+ <td>46.9</td>
314
+ <td>45.0</td>
315
+ <td>-</td>
316
+ <td>-</td>
317
+ <td>-</td>
318
+ <td>-</td>
319
+ <td>-</td>
320
+ </tr>
321
+ <tr>
322
+ <td nowrap="nowrap" align="left">Pixtral-12B</td>
323
+ <td>12B</td>
324
+ <td>256</td>
325
+ <td>61.0</td>
326
+ <td>685</td>
327
+ <td>56.9</td>
328
+ <td>81.8</td>
329
+ <td>58.5</td>
330
+ <td>54.5</td>
331
+ <td>-</td>
332
+ <td>72.7</td>
333
+ <td>79.0</td>
334
+ <td>51.1</td>
335
+ <td>47.0</td>
336
+ <td>75.7</td>
337
+ <td>90.7</td>
338
+ <td>-</td>
339
+ <td>-</td>
340
+ <td>-</td>
341
+ </tr>
342
+ <tr>
343
+ <td nowrap="nowrap" align="left">DeepSeek-VL2-27B (4B)</td>
344
+ <td>27B</td>
345
+ <td>672</td>
346
+ <td>66.4</td>
347
+ <td>809</td>
348
+ <td>63.9</td>
349
+ <td>86.0</td>
350
+ <td>60.0</td>
351
+ <td>61.9</td>
352
+ <td>2253.0</td>
353
+ <td>81.2</td>
354
+ <td>83.8</td>
355
+ <td>54.0</td>
356
+ <td>45.3</td>
357
+ <td><u>84.2</u></td>
358
+ <td>93.3</td>
359
+ <td>-</td>
360
+ <td>-</td>
361
+ <td>3.0</td>
362
+ </tr>
363
+ <tr>
364
+ <td nowrap="nowrap" align="left">Qwen2-VL-7B</td>
365
+ <td>8B</td>
366
+ <td>784</td>
367
+ <td>67.1</td>
368
+ <td><u>866</u></td>
369
+ <td>58.2</td>
370
+ <td>83.0</td>
371
+ <td>62.0</td>
372
+ <td>60.7</td>
373
+ <td>2326.0</td>
374
+ <td>81.8</td>
375
+ <td>83.0</td>
376
+ <td>54.1</td>
377
+ <td>50.6</td>
378
+ <td><strong>84.3</strong></td>
379
+ <td><u>94.5</u></td>
380
+ <td>31.9</td>
381
+ <td>16.3</td>
382
+ <td>3.2</td>
383
+ </tr>
384
+ <tr>
385
+ <td nowrap="nowrap" align="left">LLaVA-OneVision-72B</td>
386
+ <td>72B</td>
387
+ <td>182</td>
388
+ <td>68.1</td>
389
+ <td>741</td>
390
+ <td>67.5</td>
391
+ <td>83.7</td>
392
+ <td>60.6</td>
393
+ <td><strong>65.8</strong></td>
394
+ <td>2261.0</td>
395
+ <td><strong>85.0</strong></td>
396
+ <td><u>85.6</u></td>
397
+ <td>56.8</td>
398
+ <td>49.0</td>
399
+ <td>80.5</td>
400
+ <td>91.3</td>
401
+ <td>39.1</td>
402
+ <td>-</td>
403
+ <td>3.5</td>
404
+ </tr>
405
+ <tr>
406
+ <td nowrap="nowrap" align="left">InternVL2.5-8B</td>
407
+ <td>8B</td>
408
+ <td>706</td>
409
+ <td>68.3</td>
410
+ <td>822</td>
411
+ <td><u>64.4</u></td>
412
+ <td>84.8</td>
413
+ <td>62.8</td>
414
+ <td>62.8</td>
415
+ <td>2344.0</td>
416
+ <td><u>83.6</u></td>
417
+ <td>84.5</td>
418
+ <td>56.0</td>
419
+ <td>50.1</td>
420
+ <td>79.1</td>
421
+ <td>93.0</td>
422
+ <td>39.5</td>
423
+ <td>19.7</td>
424
+ <td>3.4</td>
425
+ </tr>
426
+ <tr>
427
+ <td nowrap="nowrap" align="left">MiniCPM-V 2.6</td>
428
+ <td>8B</td>
429
+ <td><strong>2822</strong></td>
430
+ <td>65.2</td>
431
+ <td>852*</td>
432
+ <td>60.6</td>
433
+ <td>79.4</td>
434
+ <td>60.0</td>
435
+ <td>57.5</td>
436
+ <td><u>2348.4*</u></td>
437
+ <td>78.0</td>
438
+ <td>82.1</td>
439
+ <td>49.8*</td>
440
+ <td>48.1*</td>
441
+ <td>80.1</td>
442
+ <td>90.8</td>
443
+ <td>25.7</td>
444
+ <td>18.3</td>
445
+ <td>3.6</td>
446
+ </tr>
447
+ <tr>
448
+ <td nowrap="nowrap" align="left">MiniCPM-o 2.6</td>
449
+ <td>8B</td>
450
+ <td><strong>2822</strong></td>
451
+ <td><strong>70.2</strong></td>
452
+ <td><strong>897*</strong></td>
453
+ <td><strong>71.9*</strong></td>
454
+ <td><u>86.9*</u></td>
455
+ <td><u>67.5</u></td>
456
+ <td><u>64.0</u></td>
457
+ <td><strong>2372.0*</strong></td>
458
+ <td>80.5</td>
459
+ <td><strong>85.8</strong></td>
460
+ <td>50.4*</td>
461
+ <td><u>51.9</u></td>
462
+ <td>82.0</td>
463
+ <td>93.5</td>
464
+ <td><u>41.4*</u></td>
465
+ <td><u>23.1*</u></td>
466
+ <td><strong>3.8</strong></td>
467
+ </tr>
468
+ </tbody>
469
+ </table>
470
+ </div>
471
+ * 我们使用思维链提示词来评估这些基准,对于 MME 我们只在 Cognition 任务上使用了思维链。
472
+ + Token Density:每个视觉 token 在最大分辨率下编码的像素数,即最大分辨率下的像素数 / 视觉 token 数。
473
+
474
+ 注意:闭源模型的 Token Density 由 API 收费方式估算得到。
475
+
476
+ **多图和视频理解能力**
477
+
478
+ <div align="center">
479
+
480
+ <table style="margin: 0px auto;">
481
+ <thead>
482
+ <tr>
483
+ <th align="left">Model</th>
484
+ <th>Size</th>
485
+ <th>BLINK val</th>
486
+ <th>Mantis Eval</th>
487
+ <th>MIRB</th>
488
+ <th>Video-MME (wo / w subs)</th>
489
+ </tr>
490
+ </thead>
491
+ <tbody align="center">
492
+ <tr>
493
+ <td colspan="6" align="left"><strong>Proprietary</strong></td>
494
+ </tr>
495
+ <tr>
496
+ <td nowrap="nowrap" align="left">GPT-4o-20240513</td>
497
+ <td>-</td>
498
+ <td><strong>68</strong></td>
499
+ <td>-</td>
500
+ <td>-</td>
501
+ <td><strong>71.9/77.2<strong></td>
502
+ </tr>
503
+ <tr>
504
+ <td nowrap="nowrap" align="left">GPT4V</td>
505
+ <td>-</td>
506
+ <td>54.6</td>
507
+ <td>62.7</td>
508
+ <td>53.1</td>
509
+ <td>59.9/63.3</td>
510
+ </tr>
511
+ <tr>
512
+ <td colspan="6" align="left"><strong>Open-source</strong></td>
513
+ </tr>
514
+ <tr>
515
+ <td nowrap="nowrap" align="left">LLaVA-NeXT-Interleave 14B</td>
516
+ <td>14B</td>
517
+ <td>52.6</td>
518
+ <td>66.4</td>
519
+ <td>30.2</td>
520
+ <td>-</td>
521
+ </tr>
522
+ <tr>
523
+ <td nowrap="nowrap" align="left">LLaVA-OneVision-72B</td>
524
+ <td>72B</td>
525
+ <td>55.4</td>
526
+ <td><strong>77.6</strong></td>
527
+ <td>-</td>
528
+ <td><u>66.2/69.5</u></td>
529
+ </tr>
530
+ <tr>
531
+ <td nowrap="nowrap" align="left">MANTIS 8B</td>
532
+ <td>8B</td>
533
+ <td>49.1</td>
534
+ <td>59.5</td>
535
+ <td>34.8</td>
536
+ <td>-</td>
537
+ </tr>
538
+ <tr>
539
+ <td nowrap="nowrap" align="left">Qwen2-VL-7B</td>
540
+ <td>8B</td>
541
+ <td>53.2</td>
542
+ <td>69.6*</td>
543
+ <td><strong>67.6*</strong></td>
544
+ <td>63.3/69.0</td>
545
+ </tr>
546
+ <tr>
547
+ <td nowrap="nowrap" align="left">InternVL2.5-8B</td>
548
+ <td>8B</td>
549
+ <td>54.8</td>
550
+ <td>67.7</td>
551
+ <td>52.5</td>
552
+ <td>64.2/66.9</td>
553
+ </tr>
554
+ <tr>
555
+ <td nowrap="nowrap" align="left">MiniCPM-V 2.6</td>
556
+ <td>8B</td>
557
+ <td>53</td>
558
+ <td>69.1</td>
559
+ <td>53.8</td>
560
+ <td>60.9/63.6</td>
561
+ </tr>
562
+ <tr>
563
+ <td nowrap="nowrap" align="left">MiniCPM-o 2.6</td>
564
+ <td>8B</td>
565
+ <td><u>56.7</u></td>
566
+ <td><u>71.9</u></td>
567
+ <td><u>58.6</u></td>
568
+ <td>63.9/67.9</td>
569
+ </tr>
570
+ </tbody>
571
+ </table>
572
+
573
+ </div>
574
+ * 正式开源模型权重的评测结果。
575
+
576
+ </details>
577
+
578
+
579
+ <details>
580
+ <summary>点击查看语音理解和生成能力的详细评测结果。</summary>
581
+
582
+ **语音理解能力**
583
+
584
+ <div align="center">
585
+ <table style="margin: 0px auto;">
586
+ <thead>
587
+ <tr>
588
+ <th align="left">Task</th>
589
+ <th>Size</th>
590
+ <th colspan="3">ASR (zh)</th>
591
+ <th colspan="3">ASR (en)</th>
592
+ <th colspan="2">AST</th>
593
+ <th>Emotion</th>
594
+ </tr>
595
+ <tr>
596
+ <th align="left">Metric</th>
597
+ <td></td>
598
+ <th colspan="3">CER↓</th>
599
+ <th colspan="3">WER↓</th>
600
+ <th colspan="2">BLEU↑</th>
601
+ <th>ACC↑</th>
602
+ </tr>
603
+ <tr>
604
+ <th align="left">Dataset</th>
605
+ <td></td>
606
+ <th>AISHELL-1</th>
607
+ <th>Fleurs zh</th>
608
+ <th>WenetSpeech test-net</th>
609
+ <th>LibriSpeech test-clean</th>
610
+ <th>GigaSpeech</th>
611
+ <th>TED-LIUM</th>
612
+ <th>CoVoST en2zh</th>
613
+ <th>CoVoST zh2en</th>
614
+ <th>MELD emotion</th>
615
+ </tr>
616
+ </thead>
617
+ <tbody align="center">
618
+ <tr>
619
+ <td colspan="11" align="left"><strong>Proprietary</strong></td>
620
+ </tr>
621
+ <tr>
622
+ <td nowrap="nowrap" align="left">GPT-4o-Realtime</td>
623
+ <td>-</td>
624
+ <td>7.3*</td>
625
+ <td><u>5.4*</u></td>
626
+ <td>28.9*</td>
627
+ <td>2.6*</td>
628
+ <td>12.9*</td>
629
+ <td>4.8*</td>
630
+ <td>37.1*</td>
631
+ <td>15.7*</td>
632
+ <td>33.2*</td>
633
+ </tr>
634
+ <tr>
635
+ <td nowrap="nowrap" align="left">Gemini 1.5 Pro</td>
636
+ <td>-</td>
637
+ <td>4.5*</td>
638
+ <td>5.9*</td>
639
+ <td>14.3*</td>
640
+ <td>2.9*</td>
641
+ <td>10.6*</td>
642
+ <td><strong>3.0*</strong></td>
643
+ <td><u>47.3*</u></td>
644
+ <td>22.6*</td>
645
+ <td>48.4*</td>
646
+ </tr>
647
+ <tr>
648
+ <td colspan="11" align="left"><strong>Open-Source</strong></td>
649
+ </tr>
650
+ <tr>
651
+ <td nowrap="nowrap" align="left">Qwen2-Audio-7B</td>
652
+ <td>8B</td>
653
+ <td>-</td>
654
+ <td>7.5</td>
655
+ <td>-</td>
656
+ <td><strong>1.6</strong></td>
657
+ <td>-</td>
658
+ <td>-</td>
659
+ <td>45.2</td>
660
+ <td><u>24.4</u></td>
661
+ <td><strong>55.3</strong></td>
662
+ </tr>
663
+ <tr>
664
+ <td nowrap="nowrap" align="left">Qwen2-Audio-7B-Instruct</td>
665
+ <td>8B</td>
666
+ <td>2.6*</td>
667
+ <td>6.9*</td>
668
+ <td><u>10.3*</u></td>
669
+ <td>3.1*</td>
670
+ <td><u>9.7</u>*</td>
671
+ <td>5.9*</td>
672
+ <td>39.5*</td>
673
+ <td>22.9*</td>
674
+ <td>17.4*</td>
675
+ </tr>
676
+ <tr>
677
+ <td nowrap="nowrap" align="left">GLM-4-Voice-Base</td>
678
+ <td>9B</td>
679
+ <td><u>2.5</u></td>
680
+ <td>-</td>
681
+ <td>-</td>
682
+ <td>2.8</td>
683
+ <td>-</td>
684
+ <td>-</td>
685
+ <td>-</td>
686
+ <td>-</td>
687
+ </tr>
688
+ <tr>
689
+ <td nowrap="nowrap" align="left">MiniCPM-o 2.6</td>
690
+ <td>8B</td>
691
+ <td><strong>1.6</strong></td>
692
+ <td><strong>4.4</strong></td>
693
+ <td><strong>6.9</strong></td>
694
+ <td><u>1.7</u></td>
695
+ <td><strong>8.7</strong></td>
696
+ <td><strong>3.0</strong></td>
697
+ <td><strong>48.2</strong></td>
698
+ <td><strong>27.2</strong></td>
699
+ <td><u>52.4</u></td>
700
+ </tr>
701
+ </tbody>
702
+ </table>
703
+ </div>
704
+ * 正式开源模型权重的评测结果。<br><br>
705
+
706
+ **语音生成能力。**
707
+
708
+ <div align="center">
709
+ <table style="margin: 0px auto;">
710
+ <thead>
711
+ <tr>
712
+ <th align="left">Task</th>
713
+ <th>Size</th>
714
+ <th colspan="9">SpeechQA</th>
715
+ </tr>
716
+ <tr>
717
+ <th align="left">Metric</th>
718
+ <th></th>
719
+ <th colspan="3">ACC↑</th>
720
+ <th>G-Eval (10 point)↑</th>
721
+ <th>Semantic ELO score↑</th>
722
+ <th>Acoustic ELO score↑</th>
723
+ <th>Overall ELO score↑</th>
724
+ <th>UTMOS↑</th>
725
+ <th>ASR-WER↓</th>
726
+ </tr>
727
+ <tr>
728
+ <th align="left">Dataset</th>
729
+ <th></th>
730
+ <th>Speech Llama Q.</th>
731
+ <th>Speech Web Q.</th>
732
+ <th>Speech Trivia QA</th>
733
+ <th>Speech AlpacaEval</th>
734
+ <th colspan="5">AudioArena</th>
735
+ </tr>
736
+ </thead>
737
+ <tbody align="center">
738
+ <tr>
739
+ <td colspan="11" align="left"><strong>Proprietary</strong></td>
740
+ </tr>
741
+ <tr>
742
+ <td nowrap="nowrap" align="left">GPT-4o-Realtime</td>
743
+ <td></td>
744
+ <td><strong>71.7</strong></td>
745
+ <td><strong>51.6</strong></td>
746
+ <td><strong>69.7</strong></td>
747
+ <td><strong>7.4</strong></td>
748
+ <td><strong>1157</strong></td>
749
+ <td><strong>1203</strong></td>
750
+ <td><strong>1200</strong></td>
751
+ <td><strong>4.2</strong></td>
752
+ <td><strong>2.3</strong></td>
753
+ </tr>
754
+ <tr>
755
+ <td colspan="11" align="left"><strong>Open-Source</strong></td>
756
+ </tr>
757
+ <tr>
758
+ <td nowrap="nowrap" align="left">GLM-4-Voice</td>
759
+ <td>9B</td>
760
+ <td>50.0</td>
761
+ <td>32.0</td>
762
+ <td>36.4</td>
763
+ <td><u>5.1</u></td>
764
+ <td>999</td>
765
+ <td>1147</td>
766
+ <td>1035</td>
767
+ <td><u>4.1</u></td>
768
+ <td><u>11.7</u></td>
769
+ </tr>
770
+ <tr>
771
+ <td nowrap="nowrap" align="left">Llama-Omni</td>
772
+ <td>8B</td>
773
+ <td>45.3</td>
774
+ <td>22.9</td>
775
+ <td>10.7</td>
776
+ <td>3.9</td>
777
+ <td>960</td>
778
+ <td>878</td>
779
+ <td>897</td>
780
+ <td>3.2</td>
781
+ <td>24.3</td>
782
+ </tr>
783
+ <tr>
784
+ <td nowrap="nowrap" align="left">VITA-1.5</td>
785
+ <td>8B</td>
786
+ <td>46.7</td>
787
+ <td>28.1</td>
788
+ <td>23.3</td>
789
+ <td>2.0</td>
790
+ <td>-</td>
791
+ <td>-</td>
792
+ <td>-</td>
793
+ <td>-</td>
794
+ <td>-</td>
795
+ </tr>
796
+ <tr>
797
+ <td nowrap="nowrap" align="left">Moshi</td>
798
+ <td>7B</td>
799
+ <td>43.7</td>
800
+ <td>23.8</td>
801
+ <td>16.7</td>
802
+ <td>2.4</td>
803
+ <td>871</td>
804
+ <td>808</td>
805
+ <td>875</td>
806
+ <td>2.8</td>
807
+ <td>8.2</td>
808
+ </tr>
809
+ <tr>
810
+ <td nowrap="nowrap" align="left">Mini-Omni</td>
811
+ <td>1B</td>
812
+ <td>22.0</td>
813
+ <td>12.8</td>
814
+ <td>6.9</td>
815
+ <td>2.5</td>
816
+ <td>926</td>
817
+ <td>803</td>
818
+ <td>865</td>
819
+ <td>3.4</td>
820
+ <td>10.0</td>
821
+ </tr>
822
+ <tr>
823
+ <td nowrap="nowrap" align="left">MiniCPM-o 2.6</td>
824
+ <td>8B</td>
825
+ <td><u>61.0</u></td>
826
+ <td><u>40.0</u></td>
827
+ <td><u>40.2</u></td>
828
+ <td><u>5.1</u></td>
829
+ <td><u>1088</u></td>
830
+ <td><u>1163</u></td>
831
+ <td><u>1131</u></td>
832
+ <td><strong>4.2</strong></td>
833
+ <td>9.8</td>
834
+ </tr>
835
+ </tbody>
836
+ </table>
837
+ </div>
838
+ 所有的结果都基于 <a href="https://github.com/OpenBMB/UltraEval-Audio" target="_blank">AudioEvals</a>。<br><br>
839
+
840
+ **端到端声音克隆能力。**
841
+
842
+ <div align="center">
843
+ <table style="margin: 0px auto;">
844
+ <thead>
845
+ <tr>
846
+ <th align="left">Task</th>
847
+ <th colspan="2">TTS</th>
848
+ </tr>
849
+ <tr>
850
+ <th align="left">Metric</th>
851
+ <th>SIMO↑</th>
852
+ <th>SIMO↑</th>
853
+ </tr>
854
+ <tr>
855
+ <th align="left">Dataset</th>
856
+ <th>Seed-TTS test-zh</th>
857
+ <th>Seed-TTS test-en</th>
858
+ </tr>
859
+ </thead>
860
+ <tbody align="center">
861
+ <tr>
862
+ <td nowrap="nowrap" align="left">F5-TTS</td>
863
+ <td><strong>76</strong></td>
864
+ <td><strong>67</strong></td>
865
+ </tr>
866
+ <tr>
867
+ <td nowrap="nowrap" align="left">CosyVoice</td>
868
+ <td><u>75</u></td>
869
+ <td><u>64</u></td>
870
+ </tr>
871
+ <tr>
872
+ <td nowrap="nowrap" align="left">FireRedTTS</td>
873
+ <td>63</td>
874
+ <td>46</td>
875
+ </tr>
876
+ <tr>
877
+ <td nowrap="nowrap" align="left">MiniCPM-o 2.6</td>
878
+ <td>57</td>
879
+ <td>47</td>
880
+ </tr>
881
+ </tbody>
882
+ </table>
883
+ </div>
884
+
885
+ </details>
886
+
887
+ <details>
888
+ <summary>点击查看多模态流式交互能力评测详细结果。</summary>
889
+
890
+ **多模态流式交互能力**: StreamingBench 分数
891
+
892
+ <table style="margin: 0px auto;">
893
+ <thead>
894
+ <tr>
895
+ <th align="left">Model</th>
896
+ <th>Size</th>
897
+ <th>Real-Time Video Understanding</th>
898
+ <th>Omni-Source Understanding</th>
899
+ <th>Contextual Understanding</th>
900
+ <th>Overall</th>
901
+ </tr>
902
+ </thead>
903
+ <tbody align="center">
904
+ <tr>
905
+ <td colspan="7" align="left"><strong>Proprietary</strong></td>
906
+ </tr>
907
+ <tr>
908
+ <td nowrap="nowrap" align="left">Gemini 1.5 Pro</td>
909
+ <td>-</td>
910
+ <td><u>77.4</u></td>
911
+ <td><strong>67.8</strong></td>
912
+ <td><strong>51.1</strong></td>
913
+ <td><strong>70.3</strong></td>
914
+ </tr>
915
+ <tr>
916
+ <td nowrap="nowrap" align="left">GPT-4o-202408</td>
917
+ <td>-</td>
918
+ <td>74.5</td>
919
+ <td>51.0</td>
920
+ <td><u>48.0</u></td>
921
+ <td>64.1</td>
922
+ </tr>
923
+ <tr>
924
+ <td nowrap="nowrap" align="left">Claude-3.5-Sonnet</td>
925
+ <td>-</td>
926
+ <td>74.0</td>
927
+ <td>41.4</td>
928
+ <td>37.8</td>
929
+ <td>59.7</td>
930
+ </tr>
931
+ <tr>
932
+ <td colspan="9" align="left"><strong>Open-source</strong></td>
933
+ </tr>
934
+ <tr>
935
+ <td nowrap="nowrap" align="left">VILA-1.5</td>
936
+ <td>8B</td>
937
+ <td>61.5</td>
938
+ <td>37.5</td>
939
+ <td>26.7</td>
940
+ <td>49.5</td>
941
+ </tr>
942
+ <tr>
943
+ <td nowrap="nowrap" align="left">LongVA</td>
944
+ <td>7B</td>
945
+ <td>63.1</td>
946
+ <td>35.9</td>
947
+ <td>30.2</td>
948
+ <td>50.7</td>
949
+ </tr>
950
+ <tr>
951
+ <td nowrap="nowrap" align="left">LLaVA-Next-Video-34B</td>
952
+ <td>34B</td>
953
+ <td>69.8</td>
954
+ <td>41.7</td>
955
+ <td>34.3</td>
956
+ <td>56.7</td>
957
+ </tr>
958
+ <tr>
959
+ <td nowrap="nowrap" align="left">Qwen2-VL-7B</td>
960
+ <td>8B</td>
961
+ <td>71.2</td>
962
+ <td>40.7</td>
963
+ <td>33.1</td>
964
+ <td>57.0</td>
965
+ </tr>
966
+ <tr>
967
+ <td nowrap="nowrap" align="left">InternVL2-8B</td>
968
+ <td>8B</td>
969
+ <td>70.1</td>
970
+ <td>42.7</td>
971
+ <td>34.1</td>
972
+ <td>57.0</td>
973
+ </tr>
974
+ <tr>
975
+ <td nowrap="nowrap" align="left">VITA-1.5</td>
976
+ <td>8B</td>
977
+ <td>70.9</td>
978
+ <td>40.8</td>
979
+ <td>35.8</td>
980
+ <td>57.4</td>
981
+ </tr>
982
+ <tr>
983
+ <td nowrap="nowrap" align="left">LLaVA-OneVision-7B</td>
984
+ <td>8B</td>
985
+ <td>74.3</td>
986
+ <td>40.8</td>
987
+ <td>31.0</td>
988
+ <td>58.4</td>
989
+ </tr>
990
+ <tr>
991
+ <td nowrap="nowrap" align="left">InternLM-XC2.5-OL-7B</td>
992
+ <td>8B</td>
993
+ <td>75.4</td>
994
+ <td>46.2</td>
995
+ <td>33.6</td>
996
+ <td>60.8</td>
997
+ </tr>
998
+ <tr>
999
+ <td nowrap="nowrap" align="left">MiniCPM-V 2.6</td>
1000
+ <td>8B</td>
1001
+ <td>72.4</td>
1002
+ <td>40.2</td>
1003
+ <td>33.4</td>
1004
+ <td>57.7</td>
1005
+ </tr>
1006
+ <tr>
1007
+ <td nowrap="nowrap" align="left">MiniCPM-o 2.6</td>
1008
+ <td>8B</td>
1009
+ <td><strong>79.9</strong></td>
1010
+ <td><u>53.4</u></td>
1011
+ <td>38.5</td>
1012
+ <td><u>66.0</u></td>
1013
+ </tr>
1014
+ </tbody>
1015
+ </table>
1016
+
1017
+ </details>
1018
+
1019
+
1020
+ ### 典型示例 <!-- omit in toc -->
1021
+
1022
+ 以下为 MiniCPM-o 2.6 的 iPad Pro 实机演示和 web demo 演示样例:
1023
+
1024
+
1025
+ <div align="center">
1026
+ <a href="https://www.youtube.com/watch?v=vRIMbxJzStY&t=2s"><img src="./assets/minicpmo2_6/2dot6_o_demo_video_img.png", width=70%></a>
1027
+ </div>
1028
+ <br>
1029
+
1030
+
1031
+
1032
+ <div style="display: flex; flex-direction: column; align-items: center;">
1033
+ <img src="assets/minicpmo2_6/minicpmo2_6_math_intersect.png" alt="math" style="margin-bottom: 5px;">
1034
+ <img src="assets/minicpmo2_6/minicpmo2_6_diagram_train_NN.png" alt="diagram" style="margin-bottom: 5px;">
1035
+ <img src="assets/minicpmo2_6/minicpmo2_6_multi-image_bike.png" alt="bike" style="margin-bottom: 5px;">
1036
+ </div>
1037
+
1038
+
1039
+ <details>
1040
+ <summary>Click to view more details of MiniCPM-V 2.6</summary>
1041
+
1042
+
1043
+ ## MiniCPM-V 2.6
1044
+
1045
+ **MiniCPM-V 2.6** 是 MiniCPM-V 系列中最新、性能最佳的模型。该模型基于 SigLip-400M 和 Qwen2-7B 构建,共 8B 参数。与 MiniCPM-Llama3-V 2.5 相比,MiniCPM-V 2.6 性能提升显著,并引入了多图和视频理解的新功能。MiniCPM-V 2.6 的主要特点包括:
1046
+
1047
+
1048
+ - 🔥 **领先的性能。**
1049
+ MiniCPM-V 2.6 在最新版本 OpenCompass 榜单上(综合 8 个主流多模态评测基准)平均得分 65.2,**以8B量级的大小在单图理解方面超越了 GPT-4o mini、GPT-4V、Gemini 1.5 Pro 和 Claude 3.5 Sonnet 等主流商用闭源多模态大模型**。
1050
+
1051
+ - 🖼️ **多图理解和上下文学习。**
1052
+ MiniCPM-V 2.6 还支持**多图对话和推理**。它在 Mantis-Eval、BLINK、Mathverse mv 和 Sciverse mv 等主流多图评测基准中取得了**最佳水平**,并展现出了优秀的上下文学习能力。
1053
+
1054
+ - 🎬 **视频理解。**
1055
+ MiniCPM-V 2.6 还可以**接受视频输入**,进行对话和提供涵盖时序和空间信息的详细视频描述。模型在 有/无字幕 评测场景下的 Video-MME 表现均超过了 **GPT-4V、Claude 3.5 Sonnet 和 LLaVA-NeXT-Video-34B**等商用闭源模型。
1056
+
1057
+ - 💪 **强大的 OCR 能力及其他功能。**
1058
+ MiniCPM-V 2.6 可以处理任意长宽比的图像,像素数可达 180 万(如 1344x1344)。在 OCRBench 上取得**最佳水平,超过 GPT-4o、GPT-4V 和 Gemini 1.5 Pro 等商用闭源模型**。基于最新的 [RLAIF-V](https://github.com/RLHF-V/RLAIF-V/) 和 [VisCPM](https://github.com/OpenBMB/VisCPM) 技术,其具备了**可信的多模态行为**,在 Object HalBench 上的幻觉率显著低于 GPT-4o 和 GPT-4V,并支持英语、中文、德语、法语、意大利语、韩语等**多种语言**。
1059
+
1060
+ - 🚀 **卓越的效率。**
1061
+ 除了对个人用户友好的模型大小,MiniCPM-V 2.6 还表现出**最先进的视觉 token 密度**(即每个视觉 token 编码的像素数量)。它**仅需 640 个 token 即可处理 180 万像素图像,比大多数模型少 75%**。这一特性优化了模型的推理速度、首 token 延迟、内存占用和功耗。因此,MiniCPM-V 2.6 可以支持 iPad 等终端设备上的高效**实时视频理解**。
1062
+
1063
+ - 💫 **易于使用。**
1064
+ MiniCPM-V 2.6 可以通过多种方式轻松使用:(1) [llama.cpp](https://github.com/OpenBMB/llama.cpp/blob/minicpmv-main/examples/llava/README-minicpmv2.6.md) 和 [ollama](https://github.com/OpenBMB/ollama/blob/minicpm-v2.6/examples/minicpm-v2.6/README.md) 支持在本地设备上进行高效的 CPU 推理,(2) [int4](https://huggingface.co/openbmb/MiniCPM-V-2_6-int4) 和 [GGUF](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf) 格式的量化模型,有 16 种尺寸,(3) [vLLM](#vllm-部署-) 支持高吞吐量和内存高效的推理,(4) 针对新领域和任务进行微调,(5) 使用 [Gradio](#本地-webui-demo-) 快速设置本地 WebUI 演示,(6) 在线[demo](http://120.92.209.146:8887/)即可体验。
1065
+
1066
+ ### 性能评估 <!-- omit in toc -->
1067
+ <div align="center">
1068
+ <img src=assets/radar_final.png width=90% />
1069
+ </div>
1070
+
1071
+ <details>
1072
+ <summary>点击查看 OpenCompass, MME, MMVet, OCRBench, MMMU, MathVista, MMB, AI2D, TextVQA, DocVQA, HallusionBench, Object HalBench 上的单图评测结果详情。 </summary>
1073
+ <div align="center">
1074
+
1075
+ <table style="margin: 0px auto;">
1076
+ <thead>
1077
+ <tr>
1078
+ <th align="left">Model</th>
1079
+ <th>Size</th>
1080
+ <th>Token Density<sup>+</sup></th>
1081
+ <th>OpenCompass</th>
1082
+ <th>MME</th>
1083
+ <th>MMVet</th>
1084
+ <th>OCRBench</th>
1085
+ <th>MMMU val</th>
1086
+ <th>MathVista mini</th>
1087
+ <th>MMB1.1 test</th>
1088
+ <th>AI2D</th>
1089
+ <th>TextVQA val</th>
1090
+ <th>DocVQA test</th>
1091
+ <th>HallusionBench</th>
1092
+ <th>Object HalBench</th>
1093
+ </tr>
1094
+ </thead>
1095
+ <tbody align="center">
1096
+ <tr>
1097
+ <td colspan="15" align="left"><strong>Proprietary</strong></td>
1098
+ </tr>
1099
+ <tr>
1100
+ <td nowrap="nowrap" align="left">GPT-4o</td>
1101
+ <td>-</td>
1102
+ <td>1088</td>
1103
+ <td>69.9</td>
1104
+ <td>2328.7</td>
1105
+ <td>69.1</td>
1106
+ <td>736</td>
1107
+ <td>69.2</td>
1108
+ <td>61.3</td>
1109
+ <td>82.2</td>
1110
+ <td>84.6</td>
1111
+ <td>-</td>
1112
+ <td>92.8</td>
1113
+ <td>55.0</td>
1114
+ <td>17.6</td>
1115
+ </tr>
1116
+ <tr>
1117
+ <td nowrap="nowrap" align="left">Claude 3.5 Sonnet</td>
1118
+ <td>-</td>
1119
+ <td>750</td>
1120
+ <td>67.9</td>
1121
+ <td>1920.0</td>
1122
+ <td>66.0</td>
1123
+ <td>788</td>
1124
+ <td>65.9</td>
1125
+ <td>61.6</td>
1126
+ <td>78.5</td>
1127
+ <td>80.2</td>
1128
+ <td>-</td>
1129
+ <td>95.2</td>
1130
+ <td>49.9</td>
1131
+ <td>13.8</td>
1132
+ </tr>
1133
+ <tr>
1134
+ <td nowrap="nowrap" align="left">Gemini 1.5 Pro</td>
1135
+ <td>-</td>
1136
+ <td>-</td>
1137
+ <td>64.4</td>
1138
+ <td>2110.6</td>
1139
+ <td>64.0</td>
1140
+ <td>754</td>
1141
+ <td>60.6</td>
1142
+ <td>57.7</td>
1143
+ <td>73.9</td>
1144
+ <td>79.1</td>
1145
+ <td>73.5</td>
1146
+ <td>86.5</td>
1147
+ <td>45.6</td>
1148
+ <td>-</td>
1149
+ </tr>
1150
+ <tr>
1151
+ <td nowrap="nowrap" align="left">GPT-4o mini</td>
1152
+ <td>-</td>
1153
+ <td>1088</td>
1154
+ <td>64.1</td>
1155
+ <td>2003.4</td>
1156
+ <td>66.9</td>
1157
+ <td>785</td>
1158
+ <td>60.0</td>
1159
+ <td>52.4</td>
1160
+ <td>76.0</td>
1161
+ <td>77.8</td>
1162
+ <td>-</td>
1163
+ <td>-</td>
1164
+ <td>46.1</td>
1165
+ <td>12.4</td>
1166
+ </tr>
1167
+ <tr>
1168
+ <td nowrap="nowrap" align="left">GPT-4V</td>
1169
+ <td>-</td>
1170
+ <td>1088</td>
1171
+ <td>63.5</td>
1172
+ <td>2070.2</td>
1173
+ <td>67.5</td>
1174
+ <td>656</td>
1175
+ <td>61.7</td>
1176
+ <td>54.7</td>
1177
+ <td>79.8</td>
1178
+ <td>78.6</td>
1179
+ <td>78.0</td>
1180
+ <td>87.2</td>
1181
+ <td>43.9</td>
1182
+ <td>14.2</td>
1183
+ </tr>
1184
+ <tr>
1185
+ <td nowrap="nowrap" align="left">Step-1V</td>
1186
+ <td>-</td>
1187
+ <td>-</td>
1188
+ <td>59.5</td>
1189
+ <td>2206.4</td>
1190
+ <td>63.3</td>
1191
+ <td>625</td>
1192
+ <td>49.9</td>
1193
+ <td>44.8</td>
1194
+ <td>78.0</td>
1195
+ <td>79.2</td>
1196
+ <td>71.6</td>
1197
+ <td>-</td>
1198
+ <td>48.4</td>
1199
+ <td>-</td>
1200
+ </tr>
1201
+ <tr>
1202
+ <td nowrap="nowrap" align="left">Qwen-VL-Max</td>
1203
+ <td>-</td>
1204
+ <td>784</td>
1205
+ <td>58.3</td>
1206
+ <td>2281.7</td>
1207
+ <td>61.8</td>
1208
+ <td>684</td>
1209
+ <td>52.0</td>
1210
+ <td>43.4</td>
1211
+ <td>74.6</td>
1212
+ <td>75.7</td>
1213
+ <td>79.5</td>
1214
+ <td>93.1</td>
1215
+ <td>41.2</td>
1216
+ <td>13.4</td>
1217
+ </tr>
1218
+ <tr>
1219
+ <td colspan="15" align="left"><strong>Open-source</strong></td>
1220
+ </tr>
1221
+ <tr>
1222
+ <td nowrap="nowrap" align="left">LLaVA-NeXT-Yi-34B</td>
1223
+ <td>34B</td>
1224
+ <td>157</td>
1225
+ <td>55.0</td>
1226
+ <td>2006.5</td>
1227
+ <td>50.7</td>
1228
+ <td>574</td>
1229
+ <td>48.8</td>
1230
+ <td>40.4</td>
1231
+ <td>77.8</td>
1232
+ <td>78.9</td>
1233
+ <td>69.3</td>
1234
+ <td>-</td>
1235
+ <td>34.8</td>
1236
+ <td>12.6</td>
1237
+ </tr>
1238
+ <tr>
1239
+ <td nowrap="nowrap" align="left">Mini-Gemini-HD-34B</td>
1240
+ <td>34B</td>
1241
+ <td>157</td>
1242
+ <td>-</td>
1243
+ <td>2141</td>
1244
+ <td>59.3</td>
1245
+ <td>518</td>
1246
+ <td>48.0</td>
1247
+ <td>43.3</td>
1248
+ <td>-</td>
1249
+ <td>80.5</td>
1250
+ <td>74.1</td>
1251
+ <td>78.9</td>
1252
+ <td>-</td>
1253
+ <td>-</td>
1254
+ </tr>
1255
+ <tr>
1256
+ <td nowrap="nowrap" align="left">Cambrian-34B</td>
1257
+ <td>34B</td>
1258
+ <td>1820</td>
1259
+ <td>58.3</td>
1260
+ <td>2049.9</td>
1261
+ <td>53.2</td>
1262
+ <td>591</td>
1263
+ <td>50.4</td>
1264
+ <td>50.3</td>
1265
+ <td>77.8</td>
1266
+ <td>79.5</td>
1267
+ <td>76.7</td>
1268
+ <td>75.5</td>
1269
+ <td>41.6</td>
1270
+ <td>14.7</td>
1271
+ </tr>
1272
+ <tr>
1273
+ <td nowrap="nowrap" align="left">GLM-4V-9B</td>
1274
+ <td>13B</td>
1275
+ <td>784</td>
1276
+ <td>59.1</td>
1277
+ <td>2018.8</td>
1278
+ <td>58.0</td>
1279
+ <td>776</td>
1280
+ <td>46.9</td>
1281
+ <td>51.1</td>
1282
+ <td>67.9</td>
1283
+ <td>71.2</td>
1284
+ <td>-</td>
1285
+ <td>-</td>
1286
+ <td>45.0</td>
1287
+ <td>-</td>
1288
+ </tr>
1289
+ <tr>
1290
+ <td nowrap="nowrap" align="left">InternVL2-8B</td>
1291
+ <td>8B</td>
1292
+ <td>706</td>
1293
+ <td>64.1</td>
1294
+ <td>2215.1</td>
1295
+ <td>54.3</td>
1296
+ <td>794</td>
1297
+ <td><strong>51.2</strong></td>
1298
+ <td>58.3</td>
1299
+ <td><strong>79.4</strong></td>
1300
+ <td><strong>83.6</strong></td>
1301
+ <td>77.4</td>
1302
+ <td><strong>91.6</strong></td>
1303
+ <td>45.0</td>
1304
+ <td>21.3</td>
1305
+ </tr>
1306
+ <tr>
1307
+ <td nowrap="nowrap" align="left">MiniCPM-Llama-V 2.5</td>
1308
+ <td>8B</td>
1309
+ <td>1882</td>
1310
+ <td>58.8</td>
1311
+ <td>2024.6</td>
1312
+ <td>52.8</td>
1313
+ <td>725</td>
1314
+ <td>45.8</td>
1315
+ <td>54.3</td>
1316
+ <td>72.0</td>
1317
+ <td>78.4</td>
1318
+ <td>76.6</td>
1319
+ <td>84.8</td>
1320
+ <td>42.4</td>
1321
+ <td>10.3</td>
1322
+ </tr>
1323
+ <tr>
1324
+ <td nowrap="nowrap" align="left">MiniCPM-V 2.6</td>
1325
+ <td>8B</td>
1326
+ <td><strong>2822</strong></td>
1327
+ <td><strong>65.2</strong></td>
1328
+ <td><strong>2348.4</strong>*</td>
1329
+ <td><strong>60.0</strong></td>
1330
+ <td><strong>852</strong>*</td>
1331
+ <td>49.8*</td>
1332
+ <td><strong>60.6</strong></td>
1333
+ <td>78.0</td>
1334
+ <td>82.1</td>
1335
+ <td><strong>80.1<strong></td>
1336
+ <td>90.8</td>
1337
+ <td><strong>48.1</strong>*</td>
1338
+ <td><strong>8.2</strong></td>
1339
+ </tr>
1340
+ </tbody>
1341
+ </table>
1342
+
1343
+ </div>
1344
+ * 我们使用思维链提示词来评估这些基准。
1345
+
1346
+ <sup>+</sup> Token Density:每个视觉 token 在最大分辨率下编码的像素数,即最大分辨率下的像素数 / 视觉 token 数。
1347
+
1348
+ 注意:闭源模型的 Token Density 由 API 收费方式估算得到。
1349
+ </details>
1350
+
1351
+
1352
+ <details>
1353
+ <summary>点击查看 Mantis Eval, BLINK, Mathverse mv, Sciverse mv, MIRB 上的多图评测结果详情。</summary>
1354
+ <div align="center">
1355
+
1356
+ <table style="margin: 0px auto;">
1357
+ <thead>
1358
+ <tr>
1359
+ <th align="left">Model</th>
1360
+ <th>Size</th>
1361
+ <th>Mantis Eval</th>
1362
+ <th>BLINK val</th>
1363
+ <th>Mathverse mv</th>
1364
+ <th>Sciverse mv</th>
1365
+ <th>MIRB</th>
1366
+ </tr>
1367
+ </thead>
1368
+ <tbody align="center">
1369
+ <tr>
1370
+ <td colspan="7" align="left"><strong>Proprietary</strong></td>
1371
+ </tr>
1372
+ <tr>
1373
+ <td nowrap="nowrap" align="left">GPT-4V</td>
1374
+ <td>-</td>
1375
+ <td>62.7</td>
1376
+ <td>54.6</td>
1377
+ <td>60.3</td>
1378
+ <td>66.9</td>
1379
+ <td>53.1</td>
1380
+ </tr>
1381
+ <tr>
1382
+ <td nowrap="nowrap" align="left">LLaVA-NeXT-Interleave-14B</td>
1383
+ <td>14B</td>
1384
+ <td>66.4</td>
1385
+ <td>52.6</td>
1386
+ <td>32.7</td>
1387
+ <td>30.2</td>
1388
+ <td>-</td>
1389
+ </tr>
1390
+ <tr>
1391
+ <td colspan="7" align="left"><strong>Open-source</strong></td>
1392
+ </tr>
1393
+ <tr>
1394
+ <td nowrap="nowrap" align="left">Emu2-Chat</td>
1395
+ <td>37B</td>
1396
+ <td>37.8</td>
1397
+ <td>36.2</td>
1398
+ <td>-</td>
1399
+ <td>27.2</td>
1400
+ <td>-</td>
1401
+ </tr>
1402
+ <tr>
1403
+ <td nowrap="nowrap" align="left">CogVLM</td>
1404
+ <td>17B</td>
1405
+ <td>45.2</td>
1406
+ <td>41.1</td>
1407
+ <td>-</td>
1408
+ <td>-</td>
1409
+ <td>-</td>
1410
+ </tr>
1411
+ <tr>
1412
+ <td nowrap="nowrap" align="left">VPG-C</td>
1413
+ <td>7B</td>
1414
+ <td>52.4</td>
1415
+ <td>43.1</td>
1416
+ <td>24.3</td>
1417
+ <td>23.1</td>
1418
+ <td>-</td>
1419
+ </tr>
1420
+ <tr>
1421
+ <td nowrap="nowrap" align="left">VILA 8B</td>
1422
+ <td>8B</td>
1423
+ <td>51.2</td>
1424
+ <td>39.3</td>
1425
+ <td>-</td>
1426
+ <td>36.5</td>
1427
+ <td>-</td>
1428
+ </tr>
1429
+ <tr>
1430
+ <td nowrap="nowrap" align="left">InternLM-XComposer-2.5</td>
1431
+ <td>8B</td>
1432
+ <td>53.1*</td>
1433
+ <td>48.9</td>
1434
+ <td>32.1*</td>
1435
+ <td>-</td>
1436
+ <td>42.5</td>
1437
+ </tr>
1438
+ <tr>
1439
+ <td nowrap="nowrap" align="left">InternVL2-8B</td>
1440
+ <td>8B</td>
1441
+ <td>59.0*</td>
1442
+ <td>50.9</td>
1443
+ <td>30.5*</td>
1444
+ <td>34.4*</td>
1445
+ <td><strong>56.9*</strong></td>
1446
+ </tr>
1447
+ <tr>
1448
+ <td nowrap="nowrap" align="left">MiniCPM-V 2.6</td>
1449
+ <td>8B</td>
1450
+ <td><strong>69.1</strong></td>
1451
+ <td><strong>53.0</strong></td>
1452
+ <td><strong>84.9</strong></td>
1453
+ <td><strong>74.9</strong></td>
1454
+ <td>53.8</td>
1455
+ </tr>
1456
+ </tbody>
1457
+ </table>
1458
+
1459
+
1460
+ </div>
1461
+ * 正式开源模型权重的评测结果。
1462
+ </details>
1463
+
1464
+ <details>
1465
+ <summary>点击查看 Video-MME 和 Video-ChatGPT 上的视频评测结果详情。</summary>
1466
+ <div align="center">
1467
+
1468
+ <table style="margin: 0px auto;">
1469
+ <thead>
1470
+ <tr>
1471
+ <th align="left">Model</th>
1472
+ <th>Size</th>
1473
+ <th colspan="2">Video-MME</th>
1474
+ <th colspan="5">Video-ChatGPT</th>
1475
+ </tr>
1476
+ <tr>
1477
+ <th align="left"></th>
1478
+ <th></th>
1479
+ <th>w/o subs</th>
1480
+ <th>w subs</th>
1481
+ <th>Correctness</th>
1482
+ <th>Detail</th>
1483
+ <th>Context</th>
1484
+ <th>Temporal</th>
1485
+ <th>Consistency</th>
1486
+ </tr>
1487
+ </thead>
1488
+ <tbody align="center">
1489
+ <tr>
1490
+ <td colspan="9" align="left"><strong>Proprietary</strong></td>
1491
+ </tr>
1492
+ <tr>
1493
+ <td nowrap="nowrap" align="left">Claude 3.5 Sonnet</td>
1494
+ <td>-</td>
1495
+ <td>60.0</td>
1496
+ <td>62.9</td>
1497
+ <td>-</td>
1498
+ <td>-</td>
1499
+ <td>-</td>
1500
+ <td>-</td>
1501
+ <td>-</td>
1502
+ </tr>
1503
+ <tr>
1504
+ <td nowrap="nowrap" align="left">GPT-4V</td>
1505
+ <td>-</td>
1506
+ <td>59.9</td>
1507
+ <td>63.3</td>
1508
+ <td>-</td>
1509
+ <td>-</td>
1510
+ <td>-</td>
1511
+ <td>-</td>
1512
+ <td>-</td>
1513
+ </tr>
1514
+ <tr>
1515
+ <td colspan="9" align="left"><strong>Open-source</strong></td>
1516
+ </tr>
1517
+ <tr>
1518
+ <td nowrap="nowrap" align="left">LLaVA-NeXT-7B</td>
1519
+ <td>7B</td>
1520
+ <td>-</td>
1521
+ <td>-</td>
1522
+ <td>3.39</td>
1523
+ <td>3.29</td>
1524
+ <td>3.92</td>
1525
+ <td>2.60</td>
1526
+ <td>3.12</td>
1527
+ </tr>
1528
+ <tr>
1529
+ <td nowrap="nowrap" align="left">LLaVA-NeXT-34B</td>
1530
+ <td>34B</td>
1531
+ <td>-</td>
1532
+ <td>-</td>
1533
+ <td>3.29</td>
1534
+ <td>3.23</td>
1535
+ <td>3.83</td>
1536
+ <td>2.51</td>
1537
+ <td>3.47</td>
1538
+ </tr>
1539
+ <tr>
1540
+ <td nowrap="nowrap" align="left">CogVLM2-Video</td>
1541
+ <td>12B</td>
1542
+ <td>-</td>
1543
+ <td>-</td>
1544
+ <td>3.49</td>
1545
+ <td><strong>3.46</strong></td>
1546
+ <td>3.23</td>
1547
+ <td><strong>2.98</strong></td>
1548
+ <td><strong>3.64</strong></td>
1549
+ </tr>
1550
+ <tr>
1551
+ <td nowrap="nowrap" align="left">LongVA</td>
1552
+ <td>7B</td>
1553
+ <td>52.4</td>
1554
+ <td>54.3</td>
1555
+ <td>3.05</td>
1556
+ <td>3.09</td>
1557
+ <td>3.77</td>
1558
+ <td>2.44</td>
1559
+ <td><strong>3.64</strong></td>
1560
+ </tr>
1561
+ <tr>
1562
+ <td nowrap="nowrap" align="left">InternVL2-8B</td>
1563
+ <td>8B</td>
1564
+ <td>54.0</td>
1565
+ <td>56.9</td>
1566
+ <td>-</td>
1567
+ <td>-</td>
1568
+ <td>-</td>
1569
+ <td>-</td>
1570
+ <td>-</td>
1571
+ </tr>
1572
+ <tr>
1573
+ <td nowrap="nowrap" align="left">InternLM-XComposer-2.5</td>
1574
+ <td>8B</td>
1575
+ <td>55.8</td>
1576
+ <td>-</td>
1577
+ <td>-</td>
1578
+ <td>-</td>
1579
+ <td>-</td>
1580
+ <td>-</td>
1581
+ <td>-</td>
1582
+ </tr>
1583
+ <tr>
1584
+ <td nowrap="nowrap" align="left">LLaVA-NeXT-Video</td>
1585
+ <td>32B</td>
1586
+ <td>60.2</td>
1587
+ <td>63.0</td>
1588
+ <td>3.48</td>
1589
+ <td>3.37</td>
1590
+ <td><strong>3.95</strong></td>
1591
+ <td>2.64</td>
1592
+ <td>3.28</td>
1593
+ </tr>
1594
+ <tr>
1595
+ <td nowrap="nowrap" align="left">MiniCPM-V 2.6</td>
1596
+ <td>8B</td>
1597
+ <td><strong>60.9</strong></td>
1598
+ <td><strong>63.6</strong></td>
1599
+ <td><strong>3.59</strong></td>
1600
+ <td>3.28</td>
1601
+ <td>3.93</td>
1602
+ <td>2.73</td>
1603
+ <td>3.62</td>
1604
+ </tr>
1605
+ </tbody>
1606
+ </table>
1607
+ </div>
1608
+ </details>
1609
+
1610
+
1611
+ <details>
1612
+ <summary>点击查看 TextVQA, VizWiz, VQAv2, OK-VQA上的少样本评测结果详情。</summary>
1613
+ <div align="center">
1614
+
1615
+ <table style="margin: 0px auto;">
1616
+ <thead>
1617
+ <tr>
1618
+ <th align="left">Model</th>
1619
+ <th>Size</th>
1620
+ <th>Shot</th>
1621
+ <th>TextVQA val</th>
1622
+ <th>VizWiz test-dev</th>
1623
+ <th>VQAv2 test-dev</th>
1624
+ <th>OK-VQA val</th>
1625
+ </tr>
1626
+ </thead>
1627
+ <tbody align="center">
1628
+ <tr>
1629
+ <td align="left" nowrap="nowrap" rowspan="3">Flamingo</td>
1630
+ <td rowspan="3">80B</td>
1631
+ <td>0*</td>
1632
+ <td>35.0</td>
1633
+ <td>31.6</td>
1634
+ <td>56.3</td>
1635
+ <td>40.6</td>
1636
+ </tr>
1637
+ <tr>
1638
+ <td>4</td>
1639
+ <td>36.5</td>
1640
+ <td>39.6</td>
1641
+ <td>63.1</td>
1642
+ <td><strong>57.4</strong></td>
1643
+ </tr>
1644
+ <tr>
1645
+ <td>8</td>
1646
+ <td>37.3</td>
1647
+ <td>44.8</td>
1648
+ <td>65.6</td>
1649
+ <td>57.5</td>
1650
+ </tr>
1651
+ <tr>
1652
+ <td align="left" nowrap="nowrap" rowspan="3">IDEFICS</td>
1653
+ <td rowspan="3">80B</td>
1654
+ <td>0*</td>
1655
+ <td>30.9</td>
1656
+ <td>36.0</td>
1657
+ <td>60.0</td>
1658
+ <td>45.2</td>
1659
+ </tr>
1660
+ <tr>
1661
+ <td>4</td>
1662
+ <td>34.3</td>
1663
+ <td>40.4</td>
1664
+ <td>63.6</td>
1665
+ <td>52.4</td>
1666
+ </tr>
1667
+ <tr>
1668
+ <td>8</td>
1669
+ <td>35.7</td>
1670
+ <td>46.1</td>
1671
+ <td>64.8</td>
1672
+ <td>55.1</td>
1673
+ </tr>
1674
+ <tr>
1675
+ <td align="left" nowrap="nowrap" rowspan="3">OmniCorpus</td>
1676
+ <td rowspan="3">7B</td>
1677
+ <td>0*</td>
1678
+ <td>43.0</td>
1679
+ <td>49.8</td>
1680
+ <td>63.2</td>
1681
+ <td>45.5</td>
1682
+ </tr>
1683
+ <tr>
1684
+ <td>4</td>
1685
+ <td>45.4</td>
1686
+ <td>51.3</td>
1687
+ <td>64.5</td>
1688
+ <td>46.5</td>
1689
+ </tr>
1690
+ <tr>
1691
+ <td>8</td>
1692
+ <td>45.6</td>
1693
+ <td>52.2</td>
1694
+ <td>64.7</td>
1695
+ <td>46.6</td>
1696
+ </tr>
1697
+ <tr>
1698
+ <td align="left" nowrap="nowrap" rowspan="3">Emu2</td>
1699
+ <td rowspan="3">37B</td>
1700
+ <td>0</td>
1701
+ <td>26.4</td>
1702
+ <td>40.4</td>
1703
+ <td>33.5</td>
1704
+ <td>26.7</td>
1705
+ </tr>
1706
+ <tr>
1707
+ <td>4</td>
1708
+ <td>48.2</td>
1709
+ <td>54.6</td>
1710
+ <td>67.0</td>
1711
+ <td>53.2</td>
1712
+ </tr>
1713
+ <tr>
1714
+ <td>8</td>
1715
+ <td>49.3</td>
1716
+ <td>54.7</td>
1717
+ <td>67.8</td>
1718
+ <td>54.1</td>
1719
+ </tr>
1720
+ <tr>
1721
+ <td align="left" nowrap="nowrap" rowspan="2">MM1</td>
1722
+ <td rowspan="2">30B</td>
1723
+ <td>0</td>
1724
+ <td>26.2</td>
1725
+ <td>40.4</td>
1726
+ <td>48.9</td>
1727
+ <td>26.7</td>
1728
+ </tr>
1729
+ <tr>
1730
+ <td>8</td>
1731
+ <td>49.3</td>
1732
+ <td>54.7</td>
1733
+ <td><strong>70.9</strong></td>
1734
+ <td>54.1</td>
1735
+ </tr>
1736
+ <tr>
1737
+ <td align="left" nowrap="nowrap" rowspan="3">MiniCPM-V 2.6<sup>+</sup></td>
1738
+ <td rowspan="3">8B</td>
1739
+ <td>0</td>
1740
+ <td>43.9</td>
1741
+ <td>33.8</td>
1742
+ <td>45.4</td>
1743
+ <td>23.9</td>
1744
+ </tr>
1745
+ <tr>
1746
+ <td>4</td>
1747
+ <td>63.6</td>
1748
+ <td>60.5</td>
1749
+ <td>65.5</td>
1750
+ <td>50.1</td>
1751
+ </tr>
1752
+ <tr>
1753
+ <td>8</td>
1754
+ <td><strong>64.6</strong></td>
1755
+ <td><strong>63.4</strong></td>
1756
+ <td>68.2</td>
1757
+ <td>51.4</td>
1758
+ </tr>
1759
+ </tbody>
1760
+ </table>
1761
+
1762
+
1763
+ </div>
1764
+ * 使用 Flamingo 方式 zero image shot 和 two additional text shots 评估零样本性能。
1765
+
1766
+ <sup>+</sup> 我们在没有进行监督微调 (SFT) 的情况下评估预训练的模型权重 (ckpt)。
1767
+ </details>
1768
+
1769
+ ### 典型示例 <!-- omit in toc -->
1770
+
1771
+ <div style="display: flex; flex-direction: column; align-items: center;">
1772
+ <img src="assets/minicpmv2_6/multi_img-bike.png" alt="Bike" style="margin-bottom: 5px;">
1773
+ <img src="assets/minicpmv2_6/multi_img-menu.png" alt="Menu" style="margin-bottom: 5px;">
1774
+ <img src="assets/minicpmv2_6/multi_img-code.png" alt="Code" style="margin-bottom: 5px;">
1775
+ <img src="assets/minicpmv2_6/ICL-Mem.png" alt="Mem" style="margin-bottom: 5px;">
1776
+ <img src="assets/minicpmv2_6/multiling-medal.png" alt="medal" style="margin-bottom: 10px;">
1777
+ </div>
1778
+ <details>
1779
+ <summary>点击查看更多示例。</summary>
1780
+ <div style="display: flex; flex-direction: column; align-items: center;">
1781
+ <img src="assets/minicpmv2_6/ICL-elec.png" alt="elec" style="margin-bottom: 5px;">
1782
+ <img src="assets/minicpmv2_6/multiling-olympic.png" alt="Menu" style="margin-bottom: 10px;">
1783
+ </div>
1784
+ </details>
1785
+
1786
+ 我们将 MiniCPM-V 2.6 部署在iPad Pro上,并录制了以下演示视频。
1787
+
1788
+ <table align="center">
1789
+ <p align="center">
1790
+ <img src="assets/gif_cases/ai.gif" width=32%/>
1791
+ &nbsp;&nbsp;&nbsp;&nbsp;
1792
+ <img src="assets/gif_cases/beer.gif" width=32%/>
1793
+ </p>
1794
+ </table>
1795
+
1796
+ <table align="center">
1797
+ <p align="center">
1798
+ <video src="https://github.com/user-attachments/assets/21f4b818-ede1-4822-920e-91281725c830" width="360" /> </video>
1799
+ <!-- <video src="https://github.com/user-attachments/assets/c835f757-206b-4d9c-8e36-70d67b453628" width="360" /> </video> -->
1800
+ </p>
1801
+ </table>
1802
+
1803
+ </details>
1804
+
1805
+ ## 历史版本模型 <!-- omit in toc -->
1806
+
1807
+
1808
+ | 模型 | 介绍信息和使用教程 |
1809
+ |:----------------------|:-------------------:|
1810
+ | MiniCPM-Llama3-V 2.5 | [文档](./docs/minicpm_llama3_v2dot5.md) |
1811
+ | MiniCPM-V 2.0 | [文档](./docs/minicpm_v2.md) |
1812
+ | MiniCPM-V 1.0 | [文档](./docs/minicpm_v1.md) |
1813
+ | OmniLMM-12B | [文档](./omnilmm.md) |
1814
+
1815
+
1816
+ ## Chat with Our Demo on Gradio 🤗
1817
+
1818
+ 我们提供由 Hugging Face Gradio <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a> 支持的在线和本地 Demo。Gradio 是目前最流行的模型部署框架,支持流式输出、进度条、process bars 和其他常用功能。
1819
+
1820
+ ### Online Demo <!-- omit in toc -->
1821
+
1822
+ 欢迎试用 Online Demo: [MiniCPM-V 2.6](http://120.92.209.146:8887/) | [MiniCPM-Llama3-V 2.5](https://huggingface.co/spaces/openbmb/MiniCPM-Llama3-V-2_5) | [MiniCPM-V 2.0](https://huggingface.co/spaces/openbmb/MiniCPM-V-2) 。
1823
+
1824
+ ### 本地 WebUI Demo <!-- omit in toc -->
1825
+
1826
+ 您可以使用以下命令轻松构建自己的本地 WebUI Demo。更详细的部署教程请参考[文档](https://modelbest.feishu.cn/wiki/RnjjwnUT7idMSdklQcacd2ktnyN)。
1827
+
1828
+ **实时流式视频/语音通话demo:**
1829
+ 1. 启动model server:
1830
+ ```shell
1831
+ pip install -r requirements_o2.6.txt
1832
+
1833
+ python web_demos/minicpm-o_2.6/model_server.py
1834
+ ```
1835
+ 请确保 `transformers==4.44.2`,其他版本目前可能会有兼容性问题,我们正在解决。
1836
+ 如果你使用的低版本的 Pytorch,你可能会遇到这个错误`"weight_norm_fwd_first_dim_kernel" not implemented for 'BFloat16'`, 请在模型初始化的时候添加 `self.minicpmo_model.tts.float()`
1837
+
1838
+ 2. 启动web server:
1839
+ ```shell
1840
+ # Make sure Node and PNPM is installed.
1841
+ sudo apt-get update
1842
+ sudo apt-get install nodejs npm
1843
+ npm install -g pnpm
1844
+
1845
+
1846
+ cd web_demos/minicpm-o_2.6/web_server
1847
+ # 为https创建自签名证书, 要申请浏览器摄像头和麦克风权限须启动https.
1848
+ bash ./make_ssl_cert.sh # output key.pem and cert.pem
1849
+
1850
+ pnpm install # install requirements
1851
+ pnpm run dev # start server
1852
+ ```
1853
+ 浏览器打开`https://localhost:8088/`,开始体验实时流式视频/语音通话.
1854
+
1855
+ **Chatbot图文对话demo:**
1856
+ ```shell
1857
+ pip install -r requirements_o2.6.txt
1858
+
1859
+ python web_demos/minicpm-o_2.6/chatbot_web_demo_o2.6.py
1860
+ ```
1861
+ 浏览器打开`http://localhost:8000/`,开始体验图文对话Chatbot.
1862
+
1863
+
1864
+ ## 推理
1865
+
1866
+ ### 模型库
1867
+
1868
+ | 模型 | 设备 | 资源 | &emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; 简介 | 下载链接 |
1869
+ |:--------------|:-:|:----------:|:-------------------|:---------------:|
1870
+ | MiniCPM-o 2.6| GPU | 18 GB | 最新版本,提供端侧 GPT-4o 级的视觉、语音、多模态流式交互能力。 | [🤗](https://huggingface.co/openbmb/MiniCPM-o-2_6) &nbsp;&nbsp; [<img src="./assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-o-2_6) |
1871
+ | MiniCPM-o 2.6 gguf | CPU | 8 GB | gguf 版本,更低的内存占用和更高的推理效率。 | [🤗](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) &nbsp;&nbsp; [<img src="./assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-o-2_6-gguf) |
1872
+ | MiniCPM-o 2.6 int4 | GPU | 9 GB | int4量化版,更低显存占用。 | [🤗](https://huggingface.co/openbmb/MiniCPM-o-2_6-int4) &nbsp;&nbsp; [<img src="./assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-o-2_6-int4) |
1873
+ | MiniCPM-V 2.6| GPU | 17 GB | 提供出色的端侧单图、多图、视频理解能力。 | [🤗](https://huggingface.co/openbmb/MiniCPM-V-2_6) &nbsp;&nbsp; [<img src="./assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2_6) |
1874
+ | MiniCPM-V 2.6 gguf | CPU | 6 GB | gguf 版本,更低的内存占用和更高的推理效率。 | [🤗](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf) &nbsp;&nbsp; [<img src="./assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2_6-gguf) |
1875
+ | MiniCPM-V 2.6 int4 | GPU | 7 GB | int4量化版,更低显存占用。 | [🤗](https://huggingface.co/openbmb/MiniCPM-V-2_6-int4) &nbsp;&nbsp; [<img src="./assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2_6-int4) |
1876
+
1877
+ 更多[历史版本模型](#legacy-models)
1878
+
1879
+
1880
+ ### 多轮对话
1881
+ 请确保 `transformers==4.44.2`,其他版本目前可能会有兼容性问题
1882
+
1883
+ ```shell
1884
+ pip install -r requirements_o2.6.txt
1885
+ ```
1886
+
1887
+ <div align="center">
1888
+ <img src="assets/minicpmo2_6/show_demo.jpg" width="500px">
1889
+ </div>
1890
+
1891
+
1892
+ ```python
1893
+ import torch
1894
+ from PIL import Image
1895
+ from transformers import AutoModel, AutoTokenizer
1896
+
1897
+ torch.manual_seed(100)
1898
+
1899
+ model = AutoModel.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True,
1900
+ attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
1901
+ model = model.eval().cuda()
1902
+ tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True)
1903
+
1904
+ image = Image.open('./assets/minicpmo2_6/show_demo.jpg').convert('RGB')
1905
+
1906
+ # First round chat
1907
+ question = "What is the landform in the picture?"
1908
+ msgs = [{'role': 'user', 'content': [image, question]}]
1909
+
1910
+ answer = model.chat(
1911
+ msgs=msgs,
1912
+ tokenizer=tokenizer
1913
+ )
1914
+ print(answer)
1915
+
1916
+ # Second round chat, pass history context of multi-turn conversation
1917
+ msgs.append({"role": "assistant", "content": [answer]})
1918
+ msgs.append({"role": "user", "content": ["What should I pay attention to when traveling here?"]})
1919
+
1920
+ answer = model.chat(
1921
+ msgs=msgs,
1922
+ tokenizer=tokenizer
1923
+ )
1924
+ print(answer)
1925
+ ```
1926
+
1927
+ 你可以得到如下推理结果:
1928
+
1929
+ ```
1930
+ "The landform in the picture is a mountain range. The mountains appear to be karst formations, characterized by their steep, rugged peaks and smooth, rounded shapes. These types of mountains are often found in regions with limestone bedrock and are shaped by processes such as erosion and weathering. The reflection of the mountains in the water adds to the scenic beauty of the landscape."
1931
+
1932
+ "When traveling to this scenic location, it's important to pay attention to the weather conditions, as the area appears to be prone to fog and mist, especially during sunrise or sunset. Additionally, ensure you have proper footwear for navigating the potentially slippery terrain around the water. Lastly, respect the natural environment by not disturbing the local flora and fauna."
1933
+ ```
1934
+
1935
+ #### 多图对话
1936
+ <details>
1937
+ <summary> 点击查看 MiniCPM-o 2.6 多图输入的 Python 代码。 </summary>
1938
+
1939
+ ```python
1940
+ import torch
1941
+ from PIL import Image
1942
+ from transformers import AutoModel, AutoTokenizer
1943
+
1944
+ model = AutoModel.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True,
1945
+ attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
1946
+ model = model.eval().cuda()
1947
+ tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True)
1948
+
1949
+ image1 = Image.open('image1.jpg').convert('RGB')
1950
+ image2 = Image.open('image2.jpg').convert('RGB')
1951
+ question = 'Compare image 1 and image 2, tell me about the differences between image 1 and image 2.'
1952
+
1953
+ msgs = [{'role': 'user', 'content': [image1, image2, question]}]
1954
+
1955
+ answer = model.chat(
1956
+ msgs=msgs,
1957
+ tokenizer=tokenizer
1958
+ )
1959
+ print(answer)
1960
+ ```
1961
+ </details>
1962
+
1963
+ #### 少样本上下文对话
1964
+ <details>
1965
+ <summary> 点击查看 MiniCPM-o 2.6 少样本上下文对话的 Python 代码。 </summary>
1966
+
1967
+ ```python
1968
+ import torch
1969
+ from PIL import Image
1970
+ from transformers import AutoModel, AutoTokenizer
1971
+
1972
+ model = AutoModel.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True,
1973
+ attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
1974
+ model = model.eval().cuda()
1975
+ tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True)
1976
+
1977
+ question = "production date"
1978
+ image1 = Image.open('example1.jpg').convert('RGB')
1979
+ answer1 = "2023.08.04"
1980
+ image2 = Image.open('example2.jpg').convert('RGB')
1981
+ answer2 = "2007.04.24"
1982
+ image_test = Image.open('test.jpg').convert('RGB')
1983
+
1984
+ msgs = [
1985
+ {'role': 'user', 'content': [image1, question]}, {'role': 'assistant', 'content': [answer1]},
1986
+ {'role': 'user', 'content': [image2, question]}, {'role': 'assistant', 'content': [answer2]},
1987
+ {'role': 'user', 'content': [image_test, question]}
1988
+ ]
1989
+
1990
+ answer = model.chat(
1991
+ msgs=msgs,
1992
+ tokenizer=tokenizer
1993
+ )
1994
+ print(answer)
1995
+ ```
1996
+ </details>
1997
+
1998
+ #### 视频对话
1999
+ <details>
2000
+ <summary> 点击查看 MiniCPM-o 2.6 视频输入的 Python 代码。 </summary>
2001
+
2002
+ ```python
2003
+ import torch
2004
+ from PIL import Image
2005
+ from transformers import AutoModel, AutoTokenizer
2006
+ from decord import VideoReader, cpu # pip install decord
2007
+
2008
+ model = AutoModel.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True,
2009
+ attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
2010
+ model = model.eval().cuda()
2011
+ tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True)
2012
+
2013
+ MAX_NUM_FRAMES=64 # if cuda OOM set a smaller number
2014
+
2015
+ def encode_video(video_path):
2016
+ def uniform_sample(l, n):
2017
+ gap = len(l) / n
2018
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
2019
+ return [l[i] for i in idxs]
2020
+
2021
+ vr = VideoReader(video_path, ctx=cpu(0))
2022
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
2023
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
2024
+ if len(frame_idx) > MAX_NUM_FRAMES:
2025
+ frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
2026
+ frames = vr.get_batch(frame_idx).asnumpy()
2027
+ frames = [Image.fromarray(v.astype('uint8')) for v in frames]
2028
+ print('num frames:', len(frames))
2029
+ return frames
2030
+
2031
+ video_path="video_test.mp4"
2032
+ frames = encode_video(video_path)
2033
+ question = "Describe the video"
2034
+ msgs = [
2035
+ {'role': 'user', 'content': frames + [question]},
2036
+ ]
2037
+
2038
+ # Set decode params for video
2039
+ params = {}
2040
+ params["use_image_id"] = False
2041
+ params["max_slice_nums"] = 2 # use 1 if cuda OOM and video resolution > 448*448
2042
+
2043
+ answer = model.chat(
2044
+ msgs=msgs,
2045
+ tokenizer=tokenizer,
2046
+ **params
2047
+ )
2048
+ print(answer)
2049
+ ```
2050
+ </details>
2051
+
2052
+
2053
+ #### 语音对话
2054
+ <details> <summary> 初始化模型 </summary>
2055
+
2056
+ ```python
2057
+ import torch
2058
+ import librosa
2059
+ from transformers import AutoModel, AutoTokenizer
2060
+
2061
+ model = AutoModel.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True,
2062
+ attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
2063
+ model = model.eval().cuda()
2064
+ tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True)
2065
+
2066
+ model.init_tts()
2067
+ model.tts.float()
2068
+ ```
2069
+
2070
+ </details>
2071
+
2072
+ ##### Mimick
2073
+
2074
+ <details> <summary> 点击查看 MiniCPM-o 2.6 端到端语音理解生成的 Python 代码。 </summary>
2075
+
2076
+ - `Mimick` 任务反映了模型的端到端语音建模能力。模型接受音频输入,输出语音识别(ASR)转录结果,并随后以高相似度重建原始音频。重建的音频相似度和原始音频越高,表明模型有越高的���音端到端建模基础能力。
2077
+ ```python
2078
+ mimick_prompt = "Please repeat each user's speech, including voice style and speech content."
2079
+ audio_input, _ = librosa.load('xxx.wav', sr=16000, mono=True)
2080
+ msgs = [{'role': 'user', 'content': [mimick_prompt,audio_input]}]
2081
+ res = model.chat(
2082
+ msgs=msgs,
2083
+ tokenizer=tokenizer,
2084
+ sampling=True,
2085
+ max_new_tokens=128,
2086
+ use_tts_template=True,
2087
+ temperature=0.3,
2088
+ generate_audio=True,
2089
+ output_audio_path='output.wav', # save the tts result to output_audio_path
2090
+ )
2091
+ ```
2092
+
2093
+ </details>
2094
+
2095
+ ##### 可配置声音的语音对话
2096
+ <details> <summary> 点击查看个性化配置 MiniCPM-o 2.6 对话声音的 Python 代码。</summary>
2097
+
2098
+ ```python
2099
+ ref_audio, _ = librosa.load('./assets/voice_01.wav', sr=16000, mono=True) # load the reference audio
2100
+
2101
+ # Audio RolePlay: # With this mode, model will role-play the character based on the audio prompt.
2102
+ sys_prompt = model.get_sys_prompt(ref_audio=ref_audio, mode='audio_roleplay', language='en')
2103
+ user_question = {'role': 'user', 'content': [librosa.load('xxx.wav', sr=16000, mono=True)[0]]}
2104
+
2105
+ # Audio Assistant: # With this mode, model will speak with the voice in ref_audio as a AI assistant.
2106
+ # sys_prompt = model.get_sys_prompt(ref_audio=ref_audio, mode='audio_assistant', language='en')
2107
+ # user_question = {'role': 'user', 'content': [librosa.load('xxx.wav', sr=16000, mono=True)[0]]} # Try to ask something!
2108
+ ```
2109
+ ```python
2110
+ msgs = [sys_prompt, user_question]
2111
+ res = model.chat(
2112
+ msgs=msgs,
2113
+ tokenizer=tokenizer,
2114
+ sampling=True,
2115
+ max_new_tokens=128,
2116
+ use_tts_template=True,
2117
+ generate_audio=True,
2118
+ temperature=0.3,
2119
+ output_audio_path='result.wav',
2120
+ )
2121
+
2122
+ # round two
2123
+ history = msgs.append({'role': 'assistant', 'content': res})
2124
+ user_question = {'role': 'user', 'content': [librosa.load('xxx.wav', sr=16000, mono=True)[0]]}
2125
+ msgs = history.append(user_question)
2126
+ res = model.chat(
2127
+ msgs=msgs,
2128
+ tokenizer=tokenizer,
2129
+ sampling=True,
2130
+ max_new_tokens=128,
2131
+ use_tts_template=True,
2132
+ generate_audio=True,
2133
+ temperature=0.3,
2134
+ output_audio_path='result_round_2.wav',
2135
+ )
2136
+ print(res)
2137
+ ```
2138
+
2139
+ </details>
2140
+
2141
+ ##### 更多语音任务
2142
+ <details>
2143
+ <summary> 点击查看 MiniCPM-o 2.6 完成更多语音任务的 Python 代码。 </summary>
2144
+
2145
+ ```python
2146
+ '''
2147
+ Audio Understanding Task Prompt:
2148
+ Speech:
2149
+ ASR with ZH(same as AST en2zh): 请仔细听这段音频片段,并将其内容逐字记录。
2150
+ ASR with EN(same as AST zh2en): Please listen to the audio snippet carefully and transcribe the content.
2151
+ Speaker Analysis: Based on the speaker's content, speculate on their gender, condition, age range, and health status.
2152
+ General Audio:
2153
+ Audio Caption: Summarize the main content of the audio.
2154
+ Sound Scene Tagging: Utilize one keyword to convey the audio's content or the associated scene.
2155
+ '''
2156
+ task_prompt = "\n"
2157
+ audio_input, _ = librosa.load('xxx.wav', sr=16000, mono=True)
2158
+
2159
+ msgs = [{'role': 'user', 'content': [task_prompt,audio_input]}]
2160
+
2161
+ res = model.chat(
2162
+ msgs=msgs,
2163
+ tokenizer=tokenizer,
2164
+ sampling=True,
2165
+ max_new_tokens=128,
2166
+ use_tts_template=True,
2167
+ generate_audio=True,
2168
+ temperature=0.3,
2169
+ output_audio_path='result.wav',
2170
+ )
2171
+ print(res)
2172
+ ```
2173
+ ```python
2174
+ '''
2175
+ Speech Generation Task Prompt:
2176
+ Human Instruction-to-Speech: see https://voxinstruct.github.io/VoxInstruct/
2177
+ Example:
2178
+ # 在新闻中,一个年轻男性兴致勃勃地说:“祝福亲爱的祖国母亲美丽富强!”他用低音调和低音量,慢慢地说出了这句话。
2179
+ # Delighting in a surprised tone, an adult male with low pitch and low volume comments:"One even gave my little dog a biscuit" This dialogue takes place at a leisurely pace, delivering a sense of excitement and surprise in the context.
2180
+
2181
+ Voice Cloning or Voice Creation: With this mode, model will act like a TTS model.
2182
+ '''
2183
+ # Human Instruction-to-Speech:
2184
+ task_prompt = '' #Try to make some Human Instruction-to-Speech prompt
2185
+ msgs = [{'role': 'user', 'content': [task_prompt]}] # you can try to use the same audio question
2186
+
2187
+ # Voice Cloning mode: With this mode, model will act like a TTS model.
2188
+ # sys_prompt = model.get_sys_prompt(ref_audio=ref_audio, mode='voice_cloning', language='en')
2189
+ # text_prompt = f"Please read the text below."
2190
+ # user_question = {'role': 'user', 'content': [text_prompt, "content that you want to read"]} # using same voice in sys_prompt to read the text. (Voice Cloning)
2191
+ # user_question = {'role': 'user', 'content': [text_prompt, librosa.load('xxx.wav', sr=16000, mono=True)[0]]} # using same voice in sys_prompt to read 'xxx.wav'. (Voice Creation)
2192
+
2193
+ msgs = [sys_prompt, user_question]
2194
+ res = model.chat(
2195
+ msgs=msgs,
2196
+ tokenizer=tokenizer,
2197
+ sampling=True,
2198
+ max_new_tokens=128,
2199
+ use_tts_template=True,
2200
+ generate_audio=True,
2201
+ temperature=0.3,
2202
+ output_audio_path='result.wav',
2203
+ )
2204
+
2205
+
2206
+ ```
2207
+
2208
+ </details>
2209
+
2210
+ #### 多模态流式交互
2211
+ <details>
2212
+ <summary> 点击查看 MiniCPM-o 2.6 多模态流式交互的 Python 代码。 </summary>
2213
+
2214
+ ```python
2215
+ import math
2216
+ import numpy as np
2217
+ from PIL import Image
2218
+ from moviepy.editor import VideoFileClip
2219
+ import tempfile
2220
+ import librosa
2221
+ import soundfile as sf
2222
+ import torch
2223
+ from transformers import AutoModel, AutoTokenizer
2224
+
2225
+ def get_video_chunk_content(video_path, flatten=True):
2226
+ video = VideoFileClip(video_path)
2227
+ print('video_duration:', video.duration)
2228
+
2229
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_audio_file:
2230
+ temp_audio_file_path = temp_audio_file.name
2231
+ video.audio.write_audiofile(temp_audio_file_path, codec="pcm_s16le", fps=16000)
2232
+ audio_np, sr = librosa.load(temp_audio_file_path, sr=16000, mono=True)
2233
+ num_units = math.ceil(video.duration)
2234
+
2235
+ # 1 frame + 1s audio chunk
2236
+ contents= []
2237
+ for i in range(num_units):
2238
+ frame = video.get_frame(i+1)
2239
+ image = Image.fromarray((frame).astype(np.uint8))
2240
+ audio = audio_np[sr*i:sr*(i+1)]
2241
+ if flatten:
2242
+ contents.extend(["<unit>", image, audio])
2243
+ else:
2244
+ contents.append(["<unit>", image, audio])
2245
+
2246
+ return contents
2247
+
2248
+
2249
+ model = AutoModel.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True,
2250
+ attn_implementation='sdpa', torch_dtype=torch.bfloat16)
2251
+ model = model.eval().cuda()
2252
+ tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True)
2253
+
2254
+ model.init_tts()
2255
+
2256
+ # If you are using an older version of PyTorch, you might encounter this issue "weight_norm_fwd_first_dim_kernel" not implemented for 'BFloat16', Please convert the TTS to float32 type.
2257
+ # model.tts.float()
2258
+
2259
+ # https://huggingface.co/openbmb/MiniCPM-o-2_6/blob/main/assets/Skiing.mp4
2260
+ video_path="assets/Skiing.mp4"
2261
+ sys_msg = model.get_sys_prompt(mode='omni', language='en')
2262
+ # if use voice clone prompt, please set ref_audio
2263
+ # ref_audio_path = '/path/to/ref_audio'
2264
+ # ref_audio, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
2265
+ # sys_msg = model.get_sys_prompt(ref_audio=ref_audio, mode='omni', language='en')
2266
+
2267
+ contents = get_video_chunk_content(video_path)
2268
+ msg = {"role":"user", "content": contents}
2269
+ msgs = [sys_msg, msg]
2270
+
2271
+ # please set generate_audio=True and output_audio_path to save the tts result
2272
+ generate_audio = True
2273
+ output_audio_path = 'output.wav'
2274
+
2275
+ res = model.chat(
2276
+ msgs=msgs,
2277
+ tokenizer=tokenizer,
2278
+ sampling=True,
2279
+ temperature=0.5,
2280
+ max_new_tokens=4096,
2281
+ omni_input=True, # please set omni_input=True when omni inference
2282
+ use_tts_template=True,
2283
+ generate_audio=generate_audio,
2284
+ output_audio_path=output_audio_path,
2285
+ max_slice_nums=1,
2286
+ use_image_id=False,
2287
+ return_dict=True
2288
+ )
2289
+ print(res)
2290
+ ```
2291
+ </details>
2292
+
2293
+ <details>
2294
+ <summary> 点击查看多模态流式推理设置。 </summary>
2295
+
2296
+ 注意:流式推理存在轻微的性能下降,因为音频编码并非全局的。
2297
+ ```python
2298
+ # a new conversation need reset session first, it will reset the kv-cache
2299
+ model.reset_session()
2300
+
2301
+ contents = get_video_chunk_content(video_path, flatten=False)
2302
+ session_id = '123'
2303
+ generate_audio = True
2304
+
2305
+ # 1. prefill system prompt
2306
+ res = model.streaming_prefill(
2307
+ session_id=session_id,
2308
+ msgs=[sys_msg],
2309
+ tokenizer=tokenizer
2310
+ )
2311
+
2312
+ # 2. prefill video/audio chunks
2313
+ for content in contents:
2314
+ msgs = [{"role":"user", "content": content}]
2315
+ res = model.streaming_prefill(
2316
+ session_id=session_id,
2317
+ msgs=msgs,
2318
+ tokenizer=tokenizer
2319
+ )
2320
+
2321
+ # 3. generate
2322
+ res = model.streaming_generate(
2323
+ session_id=session_id,
2324
+ tokenizer=tokenizer,
2325
+ temperature=0.5,
2326
+ generate_audio=generate_audio
2327
+ )
2328
+
2329
+ audios = []
2330
+ text = ""
2331
+
2332
+ if generate_audio:
2333
+ for r in res:
2334
+ audio_wav = r.audio_wav
2335
+ sampling_rate = r.sampling_rate
2336
+ txt = r.text
2337
+
2338
+ audios.append(audio_wav)
2339
+ text += txt
2340
+
2341
+ res = np.concatenate(audios)
2342
+ sf.write("output.wav", res, samplerate=sampling_rate)
2343
+ print("text:", text)
2344
+ print("audio saved to output.wav")
2345
+ else:
2346
+ for r in res:
2347
+ text += r['text']
2348
+ print("text:", text)
2349
+ ```
2350
+
2351
+ </details>
2352
+
2353
+
2354
+ ### 多卡推理
2355
+ 您可以通过将模型的层分布在多个低显存显卡(12 GB 或 16 GB)上,运行 MiniCPM-Llama3-V 2.5。请查看该[教程](https://github.com/OpenBMB/MiniCPM-V/blob/main/docs/inference_on_multiple_gpus.md),详细了解如何使用多张低显存显卡载入模型并进行推理。
2356
+
2357
+
2358
+ ### Mac 推理
2359
+ <details>
2360
+ <summary>点击查看 MiniCPM-Llama3-V 2.5 / MiniCPM-V 2.0 基于Mac MPS运行 (Apple silicon 或 AMD GPUs)的示例。 </summary>
2361
+
2362
+ ```python
2363
+ # test.py Need more than 16GB memory to run.
2364
+ import torch
2365
+ from PIL import Image
2366
+ from transformers import AutoModel, AutoTokenizer
2367
+
2368
+ model = AutoModel.from_pretrained('openbmb/MiniCPM-Llama3-V-2_5', trust_remote_code=True, low_cpu_mem_usage=True)
2369
+ model = model.to(device='mps')
2370
+
2371
+ tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-Llama3-V-2_5', trust_remote_code=True)
2372
+ model.eval()
2373
+
2374
+ image = Image.open('./assets/hk_OCR.jpg').convert('RGB')
2375
+ question = 'Where is this photo taken?'
2376
+ msgs = [{'role': 'user', 'content': question}]
2377
+
2378
+ answer, context, _ = model.chat(
2379
+ image=image,
2380
+ msgs=msgs,
2381
+ context=None,
2382
+ tokenizer=tokenizer,
2383
+ sampling=True
2384
+ )
2385
+ print(answer)
2386
+ ```
2387
+ 运行:
2388
+ ```shell
2389
+ PYTORCH_ENABLE_MPS_FALLBACK=1 python test.py
2390
+ ```
2391
+ </details>
2392
+
2393
+
2394
+ ### 基于 llama.cpp、ollama、vLLM 的高效推理
2395
+
2396
+ llama.cpp 用法请参考[我们的fork llama.cpp](https://github.com/OpenBMB/llama.cpp/tree/minicpmv-main/examples/llava/README-minicpmv2.6.md), 在iPad上可以支持 16~18 token/s 的流畅推理(测试环境:iPad Pro + M4)。
2397
+
2398
+ ollama 用法请参考[我们的fork ollama](https://github.com/OpenBMB/ollama/blob/minicpm-v2.6/examples/minicpm-v2.6/README.md), 在iPad上可以支持 16~18 token/s 的流畅推理(测试环境:iPad Pro + M4)。
2399
+
2400
+ <details>
2401
+ <summary>点击查看, vLLM 现已官方支持MiniCPM-o 2.6、MiniCPM-V 2.6、MiniCPM-Llama3-V 2.5 和 MiniCPM-V 2.0。 </summary>
2402
+ 1. 安装 vLLM(>=0.7.1):
2403
+
2404
+ ```shell
2405
+ pip install vllm
2406
+ ```
2407
+
2408
+ 2. 运行示例代码:(注意:如果使用本地路径的模型,请确保模型代码已更新到Hugging Face上的最新版)
2409
+
2410
+ * [图文示例](https://docs.vllm.ai/en/latest/getting_started/examples/vision_language.html)
2411
+ * [音频示例](https://docs.vllm.ai/en/latest/getting_started/examples/audio_language.html)
2412
+
2413
+ </details>
2414
+
2415
+
2416
+ ## 微调
2417
+
2418
+ ### 简易微调 <!-- omit in toc -->
2419
+
2420
+ 我们支持使用 Huggingface Transformers 库简易地微调 MiniCPM-o 2.6、MiniCPM-V 2.6、MiniCPM-Llama3-V 2.5 和 MiniCPM-V 2.0 模型。
2421
+
2422
+ [参考文档](./finetune/readme.md)
2423
+
2424
+
2425
+ ### 使用 Align-Anything <!-- omit in toc -->
2426
+
2427
+ 我们支持使用北大团队开发的 [Align-Anything](https://github.com/PKU-Alignment/align-anything) 框架微调 MiniCPM-o 系列模型,同时支持 DPO 和 SFT 在视觉和音频模态上的微调。Align-Anything 是一个用于对齐全模态大模型的高度可扩展框架,开源了[数据集、模型和评测](https://huggingface.co/datasets/PKU-Alignment/align-anything)。它支持了 30+ 开源基准,40+ 模型,以及包含SFT、SimPO、RLHF在内的多种算法,并提供了 30+ 直接可运行的脚本,适合初学者快速上手。
2428
+
2429
+ 最佳实践: [MiniCPM-o 2.6](https://github.com/PKU-Alignment/align-anything/tree/main/scripts).
2430
+
2431
+
2432
+ ### 使用 LLaMA-Factory <!-- omit in toc -->
2433
+
2434
+ 我们支持使用 LLaMA-Factory 微调 MiniCPM-o 2.6 和 MiniCPM-V 2.6。LLaMA-Factory 提供了一种灵活定制 200 多个大型语言模型(LLM)微调(Lora/Full/Qlora)解决方案,无需编写代码,通过内置的 Web 用户界面 LLaMABoard 即可实现训练/推理/评估。它支持多种训练方法,如 sft/ppo/dpo/kto,并且还支持如 Galore/BAdam/LLaMA-Pro/Pissa/LongLoRA 等高级算法。
2435
+
2436
+ 最佳实践: [MiniCPM-o 2.6 | MiniCPM-V 2.6](./docs/llamafactory_train_and_infer.md).
2437
+
2438
+
2439
+ ### 使用 SWIFT 框架 <!-- omit in toc -->
2440
+
2441
+ 我们支持使用 SWIFT 框架微调 MiniCPM-V 系列模型。SWIFT 支持近 200 种大语言模型和多模态大模型的训练、推理、评测和部署。支持 PEFT 提供的轻量训练方案和完整的 Adapters 库支持的最新训练技术如 NEFTune、LoRA+、LLaMA-PRO 等。
2442
+
2443
+ 参考文档:[MiniCPM-V 1.0](https://github.com/modelscope/swift/blob/main/docs/source/Multi-Modal/minicpm-v最佳实践.md),[MiniCPM-V 2.0](https://github.com/modelscope/swift/blob/main/docs/source/Multi-Modal/minicpm-v-2最佳实践.md) [MiniCPM-V 2.6](https://github.com/modelscope/ms-swift/issues/1613).
2444
+
2445
+ ## FAQs
2446
+ 点击查看 [FAQs](./docs/faqs.md)
2447
+
2448
+
2449
+ ## 模型局限性
2450
+
2451
+ 我们实验发现 MiniCPM-o 2.6 存在一些显著的局限性,需要进一步研究和改进:
2452
+ - **不稳定的语音输出。** 语音生成可能会受到背景噪音和无意义声音的影响,表现不稳定。
2453
+ - **重复响应。** 当遇到连续相似的用户请求时,模型往往会重复相同的回答。
2454
+ - **Web Demo 延迟较高。** 用户在使用远程服务器上部署的 web demo 时可能会产生较高延迟。我们推荐用户在本地部署来获得更低延迟的体验。
2455
+
2456
+
2457
+ ## 模型协议 <!-- omit in toc -->
2458
+
2459
+ * 本仓库中代码依照 [Apache-2.0](https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE) 协议开源
2460
+ * MiniCPM-o/V 模型权重的使用则需要遵循 [“MiniCPM模型商用许可协议.md”](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%E6%A8%A1%E5%9E%8B%E5%95%86%E7%94%A8%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.md)。
2461
+ * MiniCPM 模型权重对学术研究完全开放,在填写[“问卷”](https://modelbest.feishu.cn/share/base/form/shrcnpV5ZT9EJ6xYjh3Kx0J6v8g)进行登记后亦允许免费商业使用。
2462
+
2463
+ ## 声明 <!-- omit in toc -->
2464
+
2465
+ 作为多模态大模型,MiniCPM-o/V 系列模型(包括 OmniLMM)通过学习大量的多模态数据来生成内容,但它无法理解、表达个人观点或价值判断,它所输出的任何内容都不代表模型开发者的观点和立场。
2466
+
2467
+ 因此用户在使用本项目的系列模型生成的内容时,应自行负责对其进行评估和验证。如果由于使用本项目的系列开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。
2468
+
2469
+
2470
+ ## 机构 <!-- omit in toc -->
2471
+
2472
+ 本项目由以下机构共同开发:
2473
+
2474
+ - <img src="assets/thunlp.png" width="28px"> [清华大学自然语言处理实验室](https://nlp.csai.tsinghua.edu.cn/)
2475
+ - <img src="assets/modelbest.png" width="28px"> [面壁智能](https://modelbest.cn/)
2476
+
2477
+ ## 🌟 Star History <!-- omit in toc -->
2478
+
2479
+
2480
+ <!-- <table align="center">
2481
+ <p align="center">
2482
+ <img src="assets/star_history.svg"/>
2483
+ </p>
2484
+ </table> -->
2485
+
2486
+ <picture>
2487
+ <source
2488
+ media="(prefers-color-scheme: dark)"
2489
+ srcset="
2490
+ https://api.star-history.com/svg?repos=OpenBMB/MiniCPM-o&type=Date&theme=dark
2491
+ "
2492
+ />
2493
+ <source
2494
+ media="(prefers-color-scheme: light)"
2495
+ srcset="
2496
+ https://api.star-history.com/svg?repos=OpenBMB/MiniCPM-o&type=Date
2497
+ "
2498
+ />
2499
+ <img
2500
+ alt="Star History Chart"
2501
+ src="https://api.star-history.com/svg?repos=OpenBMB/MiniCPM-o&type=Date"
2502
+ />
2503
+ </picture>
2504
+
2505
+ ## 支持技术和其他多模态项目 <!-- omit in toc -->
2506
+
2507
+ 👏 欢迎了解 MiniCPM-o/V 背后的支持技术和更多我们的多模态项目!
2508
+
2509
+ [VisCPM](https://github.com/OpenBMB/VisCPM/tree/main) | [RLHF-V](https://github.com/RLHF-V/RLHF-V) | [LLaVA-UHD](https://github.com/thunlp/LLaVA-UHD) | [RLAIF-V](https://github.com/RLHF-V/RLAIF-V)
2510
+
2511
+
2512
+
2513
+ ## 引用 <!-- omit in toc -->
2514
+
2515
+ 如果您觉得我们模型/代码/论文有帮助,请给我们 ⭐ 和 引用 📝,感谢!
2516
+
2517
+ ```bib
2518
+ @article{yao2024minicpm,
2519
+ title={MiniCPM-V: A GPT-4V Level MLLM on Your Phone},
2520
+ author={Yao, Yuan and Yu, Tianyu and Zhang, Ao and Wang, Chongyi and Cui, Junbo and Zhu, Hongji and Cai, Tianchi and Li, Haoyu and Zhao, Weilin and He, Zhihui and others},
2521
+ journal={arXiv preprint arXiv:2408.01800},
2522
+ year={2024}
2523
+ }
2524
+ ```
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/dataset/cgbench.py ADDED
@@ -0,0 +1,1760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+ from ..smp import *
3
+ from .video_base import VideoBaseDataset
4
+ from .utils import build_judge, DEBUG_MESSAGE
5
+ from .utils.cgbench import *
6
+ from ..utils import track_progress_rich
7
+
8
+
9
+ class CGBench_MCQ_Grounding_Mini(VideoBaseDataset):
10
+
11
+ dataset = "CG-Bench_MCQ_Grounding_Mini"
12
+
13
+ TYPE = "Video-MCQ-Grounding"
14
+
15
+ MD5 = "54ed3e90a51a6fb375c92b319a715f72"
16
+
17
+ SYS = {
18
+ "long_acc": (
19
+ "You will be provided with sampled frames from a video, along with a "
20
+ "multiple-choice question that includes a question and several answer options.\n"
21
+ "Your task is to analyze the provided frames, infer the most plausible "
22
+ "answer based on the visual information.\n"
23
+ "If the video does not provide enough information, infer the answer based "
24
+ "on the options available and still provide a result. "
25
+ "Therefore, In all cases, an answer must be given.\n"
26
+ "Only output the answer in the following format:\n\n"
27
+ '```json\n{"result": "option"}\n```\n\n'
28
+ 'The "option" is the uppercase letter corresponding to your answer.\n\n'
29
+ ),
30
+ "clue_acc": (
31
+ "You will be provided with sampled frames from a video, along with a "
32
+ "multiple-choice question that includes a question and several answer options.\n"
33
+ "Your task is to analyze the provided frames, infer the most plausible "
34
+ "answer based on the visual information.\n"
35
+ "If the video does not provide enough information, infer the answer based "
36
+ "on the options available and still provide a result. "
37
+ "Therefore, In all cases, an answer must be given.\n"
38
+ "Only output the answer in the following format:\n\n"
39
+ '```json\n{"result": "option"}\n```\n\n'
40
+ "The 'option' is the uppercase letter corresponding to your answer.\n\n"
41
+ ),
42
+ "miou": (
43
+ "You will be provided with uniformly sampled frames from a video and their "
44
+ "timestamps, along with a multiple-choice question that includes a question "
45
+ "and several answer options.\n"
46
+ "Your task is to determine in which intervals the 'clue intervals' exist "
47
+ "that contain visual information needed to answer the question.\n"
48
+ "Only output the answer in the following format:\n\n"
49
+ '```json\n{"result": [[start1, end1], [start2, end2], ...]}\n```\n\n'
50
+ "In this output format, each 'start' and 'end' represents the beginning and "
51
+ "end of an interval in seconds where relevant clues can be found.\n"
52
+ "You must provide at least one interval and at most five intervals. "
53
+ "Intervals exceeding five will NOT be considered valid.\n"
54
+ ),
55
+ "miou_wo_frame_time": (
56
+ "You will be provided with uniformly sampled frames from a video, along "
57
+ "with a multiple-choice question that includes a question and several "
58
+ "answer options.\n"
59
+ "Your task is to determine in which intervals the 'clue intervals' exist "
60
+ "that contain visual information needed to answer the question.\n"
61
+ "Only output the answer in the following format:\n\n"
62
+ '```json\n{"result": [[start1, end1], [start2, end2], ...]}\n```\n\n'
63
+ 'In this output format, each "start" and "end" represents the start and '
64
+ "end of the video where the relevant clue can be found in the form of a "
65
+ "floating point number between 0 and 1, where 0 represents the start time "
66
+ "of the video and 1 represents the end time of the video.\n"
67
+ "You must provide at least one interval and at most five intervals. "
68
+ "Intervals exceeding five will NOT be considered valid.\n"
69
+ ),
70
+ }
71
+
72
+ def __init__(
73
+ self,
74
+ dataset="CG-Bench_MCQ_Grounding_Mini",
75
+ use_subtitle=False,
76
+ use_subtitle_time=False,
77
+ use_frame_time=False,
78
+ nframe=0,
79
+ fps=-1,
80
+ ):
81
+ super().__init__(dataset=dataset, nframe=nframe, fps=fps)
82
+ self.use_subtitle = use_subtitle
83
+ self.use_subtitle_time = use_subtitle_time
84
+ self.use_frame_time = use_frame_time
85
+ self.dataset_name = dataset
86
+ lmu_root = LMUDataRoot()
87
+ self.clue_frame_root = osp.join(lmu_root, "clue_images", dataset)
88
+
89
+ @classmethod
90
+ def supported_datasets(cls):
91
+ return ["CG-Bench_MCQ_Grounding_Mini"]
92
+
93
+ def clue_frame_paths(self, qid, num_frames=8):
94
+ frame_root = osp.join(self.clue_frame_root, qid)
95
+ os.makedirs(frame_root, exist_ok=True)
96
+ return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
97
+
98
+ def clue_frame_paths_fps(self, qid, num_frames=8, fps=-1):
99
+ frame_root = osp.join(self.clue_frame_root, qid)
100
+ os.makedirs(frame_root, exist_ok=True)
101
+ return [osp.join(frame_root, self.frame_tmpl_fps.format(i, num_frames, fps)) for i in range(1, num_frames + 1)]
102
+
103
+ def get_subtitles(self, subtitle_path, frame_indices=None, fps=None, sub_time=False):
104
+
105
+ subtitles = []
106
+
107
+ srt_path = osp.join(self.data_root, subtitle_path)
108
+ assert osp.exists(srt_path)
109
+ import pysubs2
110
+
111
+ subs = pysubs2.load(srt_path, encoding="utf-8")
112
+ if not frame_indices:
113
+ for sub in subs:
114
+ sub_text = sub.text.replace("\\N", " ")
115
+ if sub_time:
116
+ start_time = milliseconds_to_seconds(sub.start)
117
+ end_time = milliseconds_to_seconds(sub.end)
118
+ sub_text = f"[{start_time}, {end_time}] {sub_text}"
119
+ if sub_text.strip() and sub_text not in subtitles:
120
+ subtitles.append(sub_text)
121
+ else:
122
+ for selected_frame_id in frame_indices:
123
+ cur_time = pysubs2.make_time(fps=fps, frames=selected_frame_id)
124
+ for sub in subs:
125
+ if sub.start < cur_time and sub.end > cur_time:
126
+ sub_text = sub.text.replace("\\N", " ")
127
+ if sub_time:
128
+ start_time = milliseconds_to_seconds(sub.start)
129
+ end_time = milliseconds_to_seconds(sub.end)
130
+ sub_text = f"[{start_time}, {end_time}] {sub_text}"
131
+ if sub_text.strip() and sub_text not in subtitles:
132
+ subtitles.append(sub_text)
133
+
134
+ if subtitles:
135
+ subtitles_str = '\n'.join(subtitles)
136
+ return f"The subtitles of the video are as follows:\n\n{subtitles_str}\n\n"
137
+ else:
138
+ return ""
139
+
140
+ def prepare_dataset(self, dataset_name="CG-Bench_MCQ_Grounding_Mini", repo_id="CG-Bench/CG-Bench"):
141
+
142
+ def check_integrity(pth):
143
+ data_file = osp.join(pth, f"{dataset_name}.tsv")
144
+
145
+ if not os.path.exists(data_file):
146
+ return False
147
+
148
+ if md5(data_file) != self.MD5:
149
+ return False
150
+ data = load(data_file)
151
+ for video_pth in data["video"]:
152
+ if not osp.exists(osp.join(pth, video_pth)):
153
+ return False
154
+
155
+ return True
156
+
157
+ cache_path = get_cache_path(repo_id)
158
+
159
+ if cache_path is not None and check_integrity(cache_path):
160
+ dataset_path = cache_path
161
+ else:
162
+
163
+ def generate_tsv(pth):
164
+
165
+ tsv_file = osp.join(pth, f"{dataset_name}.tsv")
166
+
167
+ task_modes = ["long_acc", "clue_acc", "miou"]
168
+ all_data = []
169
+ for task_mode in task_modes:
170
+ with open(osp.join(pth, "cgbench_mini.json"), "r") as f:
171
+ data_file = pd.DataFrame(json.load(f))
172
+
173
+ data_file = data_file.assign(index=range(len(data_file)))
174
+ data_file["video"] = data_file["video_uid"].apply(lambda x: f"cg_videos_720p/{x}.mp4")
175
+ data_file["subtitle_path"] = data_file["video_uid"].apply(
176
+ lambda x: (
177
+ f"cg_subtitles/{x}.srt"
178
+ if osp.exists(osp.join(dataset_path, f"cg_subtitles/{x}.srt"))
179
+ else ""
180
+ )
181
+ )
182
+
183
+ data_file["clue_video_path"] = ""
184
+
185
+ if task_mode in ["clue_acc"]:
186
+ data_file["clue_video_path"] = data_file["clue_video_path"] = data_file.apply(
187
+ lambda row: f"cg_clue_videos/{row['qid']}.mp4", axis=1
188
+ )
189
+
190
+ data_file["task_mode"] = task_mode
191
+
192
+ if task_mode in ["clue_acc", "long_acc"]:
193
+ data_file["answer"] = data_file["right_answer"]
194
+
195
+ if task_mode == "miou":
196
+ data_file["answer"] = data_file["clue_intervals"]
197
+
198
+ if task_mode in ["long_acc", "miou"]:
199
+ data_file["clue_intervals"] = ""
200
+
201
+ data_file = data_file[
202
+ [
203
+ "index",
204
+ "video_uid",
205
+ "video",
206
+ "duration",
207
+ "domain",
208
+ "choices",
209
+ "sub_category",
210
+ "subtitle_path",
211
+ "question",
212
+ "answer",
213
+ "task_mode",
214
+ "clue_intervals",
215
+ "qid",
216
+ "clue_video_path",
217
+ ]
218
+ ]
219
+
220
+ all_data.append(data_file)
221
+
222
+ final_data = pd.concat(all_data, ignore_index=True)
223
+ final_data["index"] = range(len(final_data))
224
+ final_data.to_csv(tsv_file, sep="\t", index=False)
225
+
226
+ if modelscope_flag_set():
227
+ from modelscope import dataset_snapshot_download
228
+
229
+ dataset_path = dataset_snapshot_download(dataset_id=repo_id)
230
+ else:
231
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
232
+
233
+ unzip_hf_zip(dataset_path)
234
+ generate_tsv(dataset_path)
235
+
236
+ tsv_file = osp.join(dataset_path, f"{dataset_name}.tsv")
237
+
238
+ return dict(data_file=tsv_file, root=dataset_path)
239
+
240
+ def build_prompt(self, line, video_llm):
241
+
242
+ if isinstance(line, int):
243
+ assert line < len(self)
244
+ line = self.data.iloc[line]
245
+
246
+ task_mode = line["task_mode"]
247
+
248
+ message = []
249
+
250
+ origin_use_subtitle_time = self.use_subtitle_time
251
+
252
+ try:
253
+ if task_mode in ["long_acc", "clue_acc"]:
254
+ system_prompt = self.SYS[task_mode]
255
+ elif task_mode == "miou":
256
+ if self.use_frame_time and not video_llm:
257
+ system_prompt = self.SYS[task_mode]
258
+ else:
259
+ system_prompt = self.SYS["miou_wo_frame_time"]
260
+ if self.use_subtitle_time is True:
261
+ self.use_subtitle_time = False
262
+
263
+ user_prompt = ""
264
+
265
+ if task_mode in ["long_acc", "miou"]:
266
+ video_path = line["video"]
267
+
268
+ if video_llm:
269
+ message.append(dict(type="video", value=osp.join(self.data_root, video_path)))
270
+
271
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
272
+ if self.nframe:
273
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
274
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
275
+ )
276
+ user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
277
+ fps=vid_fps, sub_time=self.use_subtitle_time)
278
+ else:
279
+ user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
280
+ else:
281
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
282
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
283
+ )
284
+ message.extend(dict(type="image", value=im) for im in image_paths)
285
+
286
+ if self.use_frame_time:
287
+ user_prompt += get_timestampes(frame_indices, vid_fps)
288
+
289
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
290
+ user_prompt += self.get_subtitles(
291
+ line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
292
+ sub_time=self.use_subtitle_time
293
+ )
294
+
295
+ elif task_mode == "clue_acc":
296
+ clue_video_path = line["clue_video_path"]
297
+ video_path = line["video"]
298
+
299
+ if video_llm:
300
+ message.append(dict(type="video", value=osp.join(self.data_root, clue_video_path)))
301
+ print(message)
302
+
303
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
304
+ if self.nframe:
305
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
306
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
307
+ )
308
+ user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
309
+ fps=vid_fps, sub_time=self.use_subtitle_time)
310
+ else:
311
+ user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
312
+ else:
313
+ if self.nframe > 32:
314
+ self.nframe = 32
315
+ print("The maximum number of frames is 32 when evaluating clue-based mcq in CG-Bench !")
316
+
317
+ clue_intervals = eval(line["clue_intervals"])
318
+
319
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
320
+ video_path, uid=line["qid"], clue_intervals=clue_intervals, num_frames=self.nframe, fps=self.fps
321
+ )
322
+
323
+ message.extend(dict(type="image", value=im) for im in image_paths)
324
+
325
+ if self.use_frame_time:
326
+ user_prompt += get_timestampes(frame_indices, vid_fps)
327
+
328
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
329
+ user_prompt += self.get_subtitles(
330
+ line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
331
+ sub_time=self.use_subtitle_time
332
+ )
333
+
334
+ question = line["question"]
335
+ user_prompt += f"Question: {question}\n\n"
336
+
337
+ choices = eval(line["choices"])
338
+ labels = [chr(ord("A") + i) for i in range(len(choices))]
339
+ user_prompt += "\n".join([f"{label}:{value}" for label, value in zip(labels, choices)]) + "\n\n"
340
+
341
+ message.append(dict(type="text", value=system_prompt + user_prompt))
342
+
343
+ return message
344
+
345
+ finally:
346
+ # Ensure that `use_subtitle_time` is always restored to its original value
347
+ self.use_subtitle_time = origin_use_subtitle_time
348
+
349
+ def save_video_frames(self, video, uid, clue_intervals=None, num_frames=8, fps=-1):
350
+
351
+ if type(uid) is not str:
352
+ uid = str(uid)
353
+
354
+ vid_path = osp.join(self.data_root, video)
355
+ vid = decord.VideoReader(vid_path)
356
+ vid_fps = vid.get_avg_fps()
357
+ n_frames = len(vid)
358
+
359
+ if clue_intervals is not None:
360
+ merged_intervals = merge_intervals(clue_intervals)
361
+
362
+ if num_frames > 0 and fps < 0:
363
+ indices = sample_frames_clue_average(merged_intervals, num_frames, vid_fps)
364
+ frame_paths = self.clue_frame_paths(uid, len(indices))
365
+
366
+ elif fps > 0:
367
+ frame_indices = []
368
+ for start, end in merged_intervals:
369
+ start_frame = int(start * vid_fps)
370
+ end_frame = int(end * vid_fps)
371
+ step = vid_fps / fps
372
+ interval_indices = [
373
+ int(start_frame + i * step) for i in range(int((end_frame - start_frame) / step))
374
+ ]
375
+ frame_indices.extend(interval_indices)
376
+
377
+ if len(frame_indices) < 32:
378
+ indices = sample_frames_clue_average(merged_intervals, 32, vid_fps)
379
+ else:
380
+ indices = frame_indices
381
+ frame_paths = self.clue_frame_paths_fps(uid, len(indices), fps)
382
+
383
+ else:
384
+ if num_frames > 0 and fps < 0:
385
+ step_size = len(vid) / (num_frames + 1)
386
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
387
+
388
+ frame_paths = self.frame_paths(uid)
389
+ elif fps > 0:
390
+ total_duration = n_frames / vid_fps
391
+ required_frames = int(total_duration * fps)
392
+ step_size = vid_fps / fps
393
+ indices = [int(i * step_size) for i in range(required_frames)]
394
+ frame_paths = self.frame_paths_fps(uid, len(indices))
395
+
396
+ # Save and validate frames
397
+ valid_paths = []
398
+ valid_indices = []
399
+
400
+ if not np.all([osp.exists(p) for p in frame_paths]):
401
+ images = [vid[i].asnumpy() for i in indices]
402
+ for i, (img_array, path) in enumerate(zip(images, frame_paths)):
403
+ if osp.exists(path):
404
+ try:
405
+ with Image.open(path) as img:
406
+ img.verify()
407
+ valid_paths.append(path)
408
+ valid_indices.append(indices[i])
409
+ except Exception:
410
+ continue
411
+ else:
412
+ try:
413
+ img = Image.fromarray(img_array)
414
+ img.save(path)
415
+ img.verify()
416
+ valid_paths.append(path)
417
+ valid_indices.append(indices[i])
418
+ except Exception:
419
+ continue
420
+ else:
421
+ for i, path in enumerate(frame_paths):
422
+ try:
423
+ with Image.open(path) as img:
424
+ img.verify()
425
+ valid_paths.append(path)
426
+ valid_indices.append(indices[i])
427
+ except Exception:
428
+ continue
429
+
430
+ return valid_paths, valid_indices, vid_fps
431
+
432
+ def evaluate(self, eval_file, **judge_kwargs):
433
+
434
+ assert eval_file.endswith(".xlsx"), "data file should be an xlsx file"
435
+
436
+ tgt_file = eval_file.replace(".xlsx", "_rating.json")
437
+ score_file = eval_file.replace(".xlsx", "_score.xlsx")
438
+
439
+ data = load(eval_file)
440
+
441
+ data_un = data[~pd.isna(data["prediction"])]
442
+ data_pred_na = data[pd.isna(data["prediction"])]
443
+
444
+ data_pred_na["score"] = -1
445
+
446
+ data_un["score"] = data_un.apply(
447
+ lambda row: post_process(
448
+ response=row["prediction"],
449
+ right_answer=row["answer"],
450
+ task_mode=row["task_mode"],
451
+ duration=row["duration"],
452
+ ),
453
+ axis=1,
454
+ )
455
+
456
+ data = pd.concat([data_pred_na, data_un])
457
+
458
+ rejected_count = (data["score"] == -1).sum()
459
+
460
+ print(
461
+ f"Among {len(data)} questions, "
462
+ f"failed to obtain prediction for {len(data_pred_na)} questions, "
463
+ f"failed to obtain the score for {rejected_count - len(data_pred_na)} questions. "
464
+ f"Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating."
465
+ )
466
+
467
+ dump(data, score_file)
468
+
469
+ rating = get_dimention_rating_mcq_grouding(score_file)
470
+
471
+ dump(rating, tgt_file)
472
+
473
+ return rating
474
+
475
+
476
+ # 评估时,step_2 评估时,给出 [prompt] + image_paths 就行
477
+ class CGBench_OpenEnded_Mini(VideoBaseDataset):
478
+
479
+ TYPE = "Video-OpenEnded"
480
+
481
+ dataset = "CG-Bench_OpenEnded_Mini"
482
+
483
+ MD5 = "9175791b11afdfa305fdb3e525b7a4ee"
484
+
485
+ SYS = (
486
+ "You will be provided with sampled frames from a video, along with a "
487
+ "question.\n"
488
+ "Your task is to analyze the provided frames and infer the most plausible "
489
+ "answer based on the visual information.\n"
490
+ "If the visual information is ambiguous or insufficient, use the available "
491
+ "context to reason your answer.\n"
492
+ "Only output the answer in the following format:\n\n"
493
+ '```json\n{"result": "answer"}\n```\n\n'
494
+ 'The "answer" can be a word, phrase, or sentence that directly responds to '
495
+ "the question.\n\n"
496
+ )
497
+
498
+ def __init__(
499
+ self,
500
+ dataset="CG-Bench_OpenEnded_Mini",
501
+ use_subtitle=False,
502
+ use_subtitle_time=False,
503
+ use_frame_time=False,
504
+ nframe=0,
505
+ fps=-1,
506
+ ):
507
+ super().__init__(dataset=dataset, nframe=nframe, fps=fps)
508
+ self.use_subtitle = use_subtitle
509
+ self.use_subtitle_time = use_subtitle_time
510
+ self.use_frame_time = use_frame_time
511
+ self.dataset_name = dataset
512
+ lmu_root = LMUDataRoot()
513
+ self.clue_frame_root = osp.join(lmu_root, "clue_images", dataset)
514
+
515
+ @classmethod
516
+ def supported_datasets(cls):
517
+ return ["CG-Bench_OpenEnded_Mini"]
518
+
519
+ def get_subtitles(self, subtitle_path, frame_indices=None, fps=None, sub_time=False):
520
+
521
+ subtitles = []
522
+
523
+ srt_path = osp.join(self.data_root, subtitle_path)
524
+ assert osp.exists(srt_path)
525
+ import pysubs2
526
+
527
+ subs = pysubs2.load(srt_path, encoding="utf-8")
528
+ if not frame_indices:
529
+ for sub in subs:
530
+ sub_text = sub.text.replace("\\N", " ")
531
+ if sub_time:
532
+ start_time = milliseconds_to_seconds(sub.start)
533
+ end_time = milliseconds_to_seconds(sub.end)
534
+ sub_text = f"[{start_time}, {end_time}] {sub_text}"
535
+ if sub_text.strip() and sub_text not in subtitles:
536
+ subtitles.append(sub_text)
537
+ else:
538
+ for selected_frame_id in frame_indices:
539
+ cur_time = pysubs2.make_time(fps=fps, frames=selected_frame_id)
540
+ for sub in subs:
541
+ if sub.start < cur_time and sub.end > cur_time:
542
+ sub_text = sub.text.replace("\\N", " ")
543
+ if sub_time:
544
+ start_time = milliseconds_to_seconds(sub.start)
545
+ end_time = milliseconds_to_seconds(sub.end)
546
+ sub_text = f"[{start_time}, {end_time}] {sub_text}"
547
+ if sub_text.strip() and sub_text not in subtitles:
548
+ subtitles.append(sub_text)
549
+
550
+ if subtitles:
551
+ subtitles_str = '\n'.join(subtitles)
552
+ return f"The subtitles of the video are as follows:\n\n{subtitles_str}\n\n"
553
+ else:
554
+ return ""
555
+
556
+ def prepare_dataset(self, dataset_name="CG-Bench_OpenEnded_Mini", repo_id="CG-Bench/CG-Bench"):
557
+
558
+ def check_integrity(pth):
559
+ data_file = osp.join(pth, f"{dataset_name}.tsv")
560
+
561
+ if not os.path.exists(data_file):
562
+ return False
563
+
564
+ if md5(data_file) != self.MD5:
565
+ return False
566
+ data = load(data_file)
567
+ for video_pth in data["video"]:
568
+ if not osp.exists(osp.join(pth, video_pth)):
569
+ return False
570
+
571
+ return True
572
+
573
+ cache_path = get_cache_path(repo_id)
574
+
575
+ if cache_path is not None and check_integrity(cache_path):
576
+ dataset_path = cache_path
577
+ else:
578
+
579
+ def generate_tsv(pth):
580
+
581
+ tsv_file = osp.join(pth, f"{dataset_name}.tsv")
582
+
583
+ with open(osp.join(pth, "cgbench_mini.json"), "r") as f:
584
+ data_file = pd.DataFrame(json.load(f))
585
+
586
+ data_file = data_file.assign(index=range(len(data_file)))
587
+ data_file["video"] = data_file["video_uid"].apply(lambda x: f"cg_videos_720p/{x}.mp4")
588
+ data_file["subtitle_path"] = data_file["video_uid"].apply(
589
+ lambda x: f"cg_subtitles/{x}.srt" if osp.exists(osp.join(pth, f"cg_subtitles/{x}.srt")) else ""
590
+ )
591
+
592
+ data_file = data_file[
593
+ [
594
+ "index",
595
+ "video_uid",
596
+ "video",
597
+ "duration",
598
+ "domain",
599
+ "sub_category",
600
+ "subtitle_path",
601
+ "question",
602
+ "answer",
603
+ "clue_intervals",
604
+ "qid",
605
+ ]
606
+ ]
607
+
608
+ data_file.to_csv(tsv_file, sep="\t", index=False)
609
+
610
+ if modelscope_flag_set():
611
+ from modelscope import dataset_snapshot_download
612
+
613
+ dataset_path = dataset_snapshot_download(dataset_id=repo_id)
614
+ else:
615
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
616
+
617
+ unzip_hf_zip(dataset_path)
618
+ generate_tsv(dataset_path)
619
+
620
+ tsv_file = osp.join(dataset_path, f"{dataset_name}.tsv")
621
+
622
+ return dict(data_file=tsv_file, root=dataset_path)
623
+
624
+ def build_prompt(self, line, video_llm):
625
+
626
+ if isinstance(line, int):
627
+ assert line < len(self)
628
+ line = self.data.iloc[line]
629
+
630
+ message = []
631
+
632
+ sys_prompt = self.SYS
633
+
634
+ user_prompt = ""
635
+
636
+ video_path = line["video"]
637
+
638
+ if video_llm:
639
+ message.append(dict(type="video", value=osp.join(self.data_root, video_path)))
640
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
641
+ if self.nframe:
642
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
643
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
644
+ )
645
+ user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
646
+ fps=vid_fps, sub_time=self.use_subtitle_time)
647
+ else:
648
+ user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
649
+ else:
650
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
651
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
652
+ )
653
+ message.extend(dict(type="image", value=im) for im in image_paths)
654
+
655
+ if self.use_frame_time:
656
+ user_prompt += get_timestampes(frame_indices, vid_fps)
657
+
658
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
659
+ user_prompt += self.get_subtitles(
660
+ line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
661
+ sub_time=self.use_subtitle_time
662
+ )
663
+
664
+ question = line["question"]
665
+ user_prompt += f"Question: {question}\n\n"
666
+
667
+ message.append(dict(type="text", value=sys_prompt + user_prompt))
668
+
669
+ return message
670
+
671
+ def clue_frame_paths(self, qid, num_frames=8):
672
+ frame_root = osp.join(self.clue_frame_root, qid)
673
+ os.makedirs(frame_root, exist_ok=True)
674
+ return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
675
+
676
+ def save_video_frames(self, video, uid, clue_intervals=None, num_frames=8, fps=-1):
677
+
678
+ if type(uid) is not str:
679
+ uid = str(uid)
680
+
681
+ vid_path = osp.join(self.data_root, video)
682
+ vid = decord.VideoReader(vid_path)
683
+ vid_fps = vid.get_avg_fps()
684
+ n_frames = len(vid)
685
+
686
+ if clue_intervals is not None:
687
+ merged_intervals = merge_intervals(clue_intervals)
688
+
689
+ if num_frames > 0 and fps < 0:
690
+ indices = sample_frames_clue_average(merged_intervals, num_frames, vid_fps)
691
+ frame_paths = self.clue_frame_paths(uid, len(indices))
692
+
693
+ elif fps > 0:
694
+ frame_indices = []
695
+ for start, end in merged_intervals:
696
+ start_frame = int(start * vid_fps)
697
+ end_frame = int(end * vid_fps)
698
+ step = vid_fps / fps
699
+ interval_indices = [
700
+ int(start_frame + i * step) for i in range(int((end_frame - start_frame) / step))
701
+ ]
702
+ frame_indices.extend(interval_indices)
703
+
704
+ if len(frame_indices) < 32:
705
+ indices = sample_frames_clue_average(merged_intervals, 32, vid_fps)
706
+ else:
707
+ indices = frame_indices
708
+ frame_paths = self.clue_frame_paths_fps(uid, len(indices), fps)
709
+
710
+ else:
711
+ if num_frames > 0 and fps < 0:
712
+ step_size = len(vid) / (num_frames + 1)
713
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
714
+ frame_paths = self.frame_paths(uid)
715
+ elif fps > 0:
716
+ total_duration = n_frames / vid_fps
717
+ required_frames = int(total_duration * fps)
718
+ step_size = vid_fps / fps
719
+ indices = [int(i * step_size) for i in range(required_frames)]
720
+ frame_paths = self.frame_paths_fps(uid, len(indices))
721
+
722
+ valid_paths = []
723
+ valid_indices = []
724
+
725
+ if not np.all([osp.exists(p) for p in frame_paths]):
726
+ images = [vid[i].asnumpy() for i in indices]
727
+ for i, (img_array, path) in enumerate(zip(images, frame_paths)):
728
+ if osp.exists(path):
729
+ try:
730
+ with Image.open(path) as img:
731
+ img.verify()
732
+ valid_paths.append(path)
733
+ valid_indices.append(indices[i])
734
+ except Exception:
735
+ continue
736
+ else:
737
+ try:
738
+ img = Image.fromarray(img_array)
739
+ img.save(path)
740
+ img.verify()
741
+ valid_paths.append(path)
742
+ valid_indices.append(indices[i])
743
+ except Exception:
744
+ continue
745
+ else:
746
+ for i, path in enumerate(frame_paths):
747
+ try:
748
+ with Image.open(path) as img:
749
+ img.verify()
750
+ valid_paths.append(path)
751
+ valid_indices.append(indices[i])
752
+ except Exception:
753
+ continue
754
+
755
+ return valid_paths, valid_indices, vid_fps
756
+
757
+ def evaluate(self, eval_file, **judge_kwargs):
758
+
759
+ from .utils.cgbench import get_dimention_rating_open_ended, post_process_open
760
+
761
+ assert eval_file.endswith(".xlsx"), "data file should be an xlsx file"
762
+
763
+ tgt_file = eval_file.replace(".xlsx", "_rating.json")
764
+ score_file = eval_file.replace(".xlsx", "_score.xlsx")
765
+ step_1_tmp_file = eval_file.replace(".xlsx", "_step_1.pkl")
766
+ step_2_tmp_file = eval_file.replace(".xlsx", "_step_2.pkl")
767
+
768
+ data = load(eval_file)
769
+
770
+ data_pred_no_na = data[~pd.isna(data["prediction"])]
771
+ data_pred_na = data[pd.isna(data["prediction"])]
772
+
773
+ data_pred_na["model_result"] = -1
774
+ data_pred_na["step_1_result"] = -1
775
+ data_pred_na["step_2_result"] = -1
776
+ data_pred_na["score"] = -1
777
+
778
+ data_pred_no_na["model_result"] = data_pred_no_na.apply(
779
+ lambda row: post_process_open(
780
+ response=row["prediction"],
781
+ ),
782
+ axis=1,
783
+ )
784
+
785
+ data_no_model_result = data_pred_no_na[data_pred_no_na["model_result"] == -1]
786
+ data_step_1 = data_pred_no_na[data_pred_no_na["model_result"] != -1]
787
+
788
+ if judge_kwargs.get("model", None) != "gpt-4o-0806":
789
+ judge_kwargs["model"] = "gpt-4o-0806"
790
+ print("The judge model in cg-bench is gpt-4o-0806!")
791
+
792
+ model_step_1 = build_judge(system_prompt=sys_prompt_open_eval_step_1, **judge_kwargs)
793
+ nproc = judge_kwargs.pop("nproc", 32)
794
+
795
+ lines_step_1 = data_step_1.to_dict("records")
796
+ tups_step_1 = [(model_step_1, line) for line in lines_step_1]
797
+
798
+ keys_step_1 = {line["qid"] for line in lines_step_1}
799
+
800
+ ans = {}
801
+ if osp.exists(step_1_tmp_file):
802
+ ans = load(step_1_tmp_file)
803
+ tups_step_1 = [x for x, i in zip(tups_step_1, keys_step_1) if i not in ans]
804
+ keys_step_1 = [i for i in keys_step_1 if i not in ans]
805
+
806
+ _ = track_progress_rich(
807
+ eval_open_first,
808
+ tups_step_1,
809
+ nproc=nproc,
810
+ keys=keys_step_1,
811
+ save=step_1_tmp_file,
812
+ )
813
+
814
+ step_1_results = load(step_1_tmp_file)
815
+ data_step_1 = save_step_1_steps(data_step_1, step_1_results) # -1, 0, 1, 2
816
+
817
+ data_no_step_1_results = data_step_1[data_step_1["step_1_result"] == -1]
818
+ data_step_1_over = data_step_1[data_step_1["step_1_result"].isin([0, 1])]
819
+ data_step_2 = data_step_1[data_step_1["step_1_result"] == 2]
820
+
821
+ print(judge_kwargs)
822
+
823
+ model_step_2 = build_judge(system_prompt=sys_prompt_open_eval_step_2, **judge_kwargs)
824
+
825
+ lines_step_2 = data_step_2.to_dict("records")
826
+
827
+ tups_step_2 = []
828
+
829
+ for line in tqdm(lines_step_2):
830
+ clue_intervals = eval(line["clue_intervals"])
831
+ lmu_root = LMUDataRoot()
832
+ clue_frame_root = osp.join(lmu_root, "clue_images", self.dataset)
833
+ data_root = self.data_root
834
+ frame_paths, _, _ = save_clue_video_frames(
835
+ data_root,
836
+ clue_frame_root,
837
+ video=line["video"],
838
+ uid=line["qid"],
839
+ clue_intervals=clue_intervals,
840
+ num_frames=32,
841
+ )
842
+ tups_step_2.append((model_step_2, line, frame_paths))
843
+
844
+ keys_step_2 = {line["qid"] for line in lines_step_2}
845
+
846
+ ans = {}
847
+ if osp.exists(step_2_tmp_file):
848
+ ans = load(step_2_tmp_file)
849
+ tups_step_2 = [x for x, i in zip(tups_step_2, keys_step_2) if i not in ans]
850
+ keys_step_2 = [i for i in keys_step_2 if i not in ans]
851
+
852
+ _ = track_progress_rich(
853
+ eval_open_second,
854
+ tups_step_2,
855
+ nproc=nproc,
856
+ keys=keys_step_2,
857
+ save=step_2_tmp_file,
858
+ )
859
+
860
+ step_2_results = load(step_2_tmp_file)
861
+ data_step_2 = save_step_2_steps(data_step_2, step_2_results)
862
+
863
+ data_no_step_2_results = data_step_2[data_step_2["score"] == -1]
864
+ data_step_2_over = data_step_2[data_step_2["score"].isin([0, 1])]
865
+
866
+ data = pd.concat(
867
+ [
868
+ data_pred_na,
869
+ data_no_model_result,
870
+ data_no_step_1_results,
871
+ data_step_1_over,
872
+ data_no_step_2_results,
873
+ data_step_2_over,
874
+ ]
875
+ )
876
+
877
+ dump(data, score_file)
878
+
879
+ rating = get_dimention_rating_open_ended(score_file)
880
+
881
+ dump(rating, tgt_file)
882
+
883
+ return rating
884
+
885
+
886
+ class CGBench_MCQ_Grounding(VideoBaseDataset):
887
+
888
+ TYPE = "Video-MCQ-Grounding"
889
+
890
+ MD5 = "eaead3d978a689269fefce4ae29c86df"
891
+
892
+ SYS = {
893
+ "long_acc": (
894
+ "You will be provided with sampled frames from a video, along with a "
895
+ "multiple-choice question that includes a question and several answer options.\n"
896
+ "Your task is to analyze the provided frames, infer the most plausible "
897
+ "answer based on the visual information.\n"
898
+ "If the video does not provide enough information, infer the answer based "
899
+ "on the options available and still provide a result. "
900
+ "Therefore, In all cases, an answer must be given.\n"
901
+ "Only output the answer in the following format:\n\n"
902
+ '```json\n{"result": "option"}\n```\n\n'
903
+ 'The "option" is the uppercase letter corresponding to your answer.\n\n'
904
+ ),
905
+ "clue_acc": (
906
+ "You will be provided with sampled frames from a video, along with a "
907
+ "multiple-choice question that includes a question and several answer options.\n"
908
+ "Your task is to analyze the provided frames, infer the most plausible "
909
+ "answer based on the visual information.\n"
910
+ "If the video does not provide enough information, infer the answer based "
911
+ "on the options available and still provide a result. "
912
+ "Therefore, In all cases, an answer must be given.\n"
913
+ "Only output the answer in the following format:\n\n"
914
+ '```json\n{"result": "option"}\n```\n\n'
915
+ "The 'option' is the uppercase letter corresponding to your answer.\n\n"
916
+ ),
917
+ "miou": (
918
+ "You will be provided with uniformly sampled frames from a video and their "
919
+ "timestamps, along with a multiple-choice question that includes a question "
920
+ "and several answer options.\n"
921
+ "Your task is to determine in which intervals the 'clue intervals' exist "
922
+ "that contain visual information needed to answer the question.\n"
923
+ "Only output the answer in the following format:\n\n"
924
+ '```json\n{"result": [[start1, end1], [start2, end2], ...]}\n```\n\n'
925
+ "In this output format, each 'start' and 'end' represents the beginning and "
926
+ "end of an interval in seconds where relevant clues can be found.\n"
927
+ "You must provide at least one interval and at most five intervals. "
928
+ "Intervals exceeding five will NOT be considered valid.\n"
929
+ ),
930
+ "miou_wo_frame_time": (
931
+ "You will be provided with uniformly sampled frames from a video, along "
932
+ "with a multiple-choice question that includes a question and several "
933
+ "answer options.\n"
934
+ "Your task is to determine in which intervals the 'clue intervals' exist "
935
+ "that contain visual information needed to answer the question.\n"
936
+ "Only output the answer in the following format:\n\n"
937
+ '```json\n{"result": [[start1, end1], [start2, end2], ...]}\n```\n\n'
938
+ 'In this output format, each "start" and "end" represents the start and '
939
+ "end of the video where the relevant clue can be found in the form of a "
940
+ "floating point number between 0 and 1, where 0 represents the start time "
941
+ "of the video and 1 represents the end time of the video.\n"
942
+ "You must provide at least one interval and at most five intervals. "
943
+ "Intervals exceeding five will NOT be considered valid.\n"
944
+ ),
945
+ }
946
+
947
+ def __init__(
948
+ self,
949
+ dataset="CG-Bench_MCQ_Grounding",
950
+ use_subtitle=False,
951
+ use_subtitle_time=False,
952
+ use_frame_time=False,
953
+ nframe=0,
954
+ fps=-1,
955
+ ):
956
+ super().__init__(dataset=dataset, nframe=nframe, fps=fps)
957
+ self.use_subtitle = use_subtitle
958
+ self.use_subtitle_time = use_subtitle_time
959
+ self.use_frame_time = use_frame_time
960
+ self.dataset_name = dataset
961
+ lmu_root = LMUDataRoot()
962
+ self.clue_frame_root = osp.join(lmu_root, "clue_images", dataset)
963
+
964
+ @classmethod
965
+ def supported_datasets(cls):
966
+ return ["CG-Bench_MCQ_Grounding"]
967
+
968
+ def clue_frame_paths(self, qid, num_frames=8):
969
+ frame_root = osp.join(self.clue_frame_root, qid)
970
+ os.makedirs(frame_root, exist_ok=True)
971
+ return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
972
+
973
+ def clue_frame_paths_fps(self, qid, num_frames=8, fps=-1):
974
+ frame_root = osp.join(self.clue_frame_root, qid)
975
+ os.makedirs(frame_root, exist_ok=True)
976
+ return [osp.join(frame_root, self.frame_tmpl_fps.format(i, num_frames, fps)) for i in range(1, num_frames + 1)]
977
+
978
+ def get_subtitles(self, subtitle_path, frame_indices=None, fps=None, sub_time=False):
979
+
980
+ subtitles = []
981
+
982
+ srt_path = osp.join(self.data_root, subtitle_path)
983
+ assert osp.exists(srt_path)
984
+ import pysubs2
985
+
986
+ subs = pysubs2.load(srt_path, encoding="utf-8")
987
+ if not frame_indices:
988
+ for sub in subs:
989
+ sub_text = sub.text.replace("\\N", " ")
990
+ if sub_time:
991
+ start_time = milliseconds_to_seconds(sub.start)
992
+ end_time = milliseconds_to_seconds(sub.end)
993
+ sub_text = f"[{start_time}, {end_time}] {sub_text}"
994
+ if sub_text.strip() and sub_text not in subtitles:
995
+ subtitles.append(sub_text)
996
+ else:
997
+ for selected_frame_id in frame_indices:
998
+ cur_time = pysubs2.make_time(fps=fps, frames=selected_frame_id)
999
+ for sub in subs:
1000
+ if sub.start < cur_time and sub.end > cur_time:
1001
+ sub_text = sub.text.replace("\\N", " ")
1002
+ if sub_time:
1003
+ start_time = milliseconds_to_seconds(sub.start)
1004
+ end_time = milliseconds_to_seconds(sub.end)
1005
+ sub_text = f"[{start_time}, {end_time}] {sub_text}"
1006
+ if sub_text.strip() and sub_text not in subtitles:
1007
+ subtitles.append(sub_text)
1008
+
1009
+ if subtitles:
1010
+ subtitles_str = '\n'.join(subtitles)
1011
+ return f"The subtitles of the video are as follows:\n\n{subtitles_str}\n\n"
1012
+ else:
1013
+ return ""
1014
+
1015
+ def prepare_dataset(self, dataset_name="CG-Bench_MCQ_Grounding", repo_id="CG-Bench/CG-Bench"):
1016
+
1017
+ def check_integrity(pth):
1018
+ data_file = osp.join(pth, f"{dataset_name}.tsv")
1019
+
1020
+ if not os.path.exists(data_file):
1021
+ return False
1022
+
1023
+ if md5(data_file) != self.MD5:
1024
+ return False
1025
+ data = load(data_file)
1026
+ for video_pth in data["video"]:
1027
+ if not osp.exists(osp.join(pth, video_pth)):
1028
+ return False
1029
+
1030
+ for clue_video_pth in data["clue_video_path"]:
1031
+ if clue_video_pth and not (isinstance(clue_video_pth, float) and np.isnan(clue_video_pth)):
1032
+ if not osp.exists(osp.join(pth, clue_video_pth)):
1033
+ return False
1034
+
1035
+ return True
1036
+
1037
+ cache_path = get_cache_path(repo_id)
1038
+
1039
+ if cache_path is not None and check_integrity(cache_path):
1040
+ dataset_path = cache_path
1041
+ else:
1042
+
1043
+ def generate_tsv(pth):
1044
+
1045
+ tsv_file = osp.join(pth, f"{dataset_name}.tsv")
1046
+
1047
+ task_modes = ["long_acc", "clue_acc", "miou"]
1048
+ all_data = []
1049
+ for task_mode in task_modes:
1050
+ with open(osp.join(pth, "cgbench.json"), "r") as f:
1051
+ data_file = pd.DataFrame(json.load(f))
1052
+
1053
+ data_file = data_file.assign(index=range(len(data_file)))
1054
+ data_file["video"] = data_file["video_uid"].apply(lambda x: f"cg_videos_720p/{x}.mp4")
1055
+ data_file["subtitle_path"] = data_file["video_uid"].apply(
1056
+ lambda x: (
1057
+ f"cg_subtitles/{x}.srt"
1058
+ if osp.exists(osp.join(dataset_path, f"cg_subtitles/{x}.srt"))
1059
+ else ""
1060
+ )
1061
+ )
1062
+
1063
+ data_file["clue_video_path"] = ""
1064
+
1065
+ if task_mode in ["clue_acc"]:
1066
+ data_file["clue_video_path"] = data_file["clue_video_path"] = data_file.apply(
1067
+ lambda row: f"cg_clue_videos/{row['qid']}.mp4", axis=1
1068
+ )
1069
+
1070
+ data_file["task_mode"] = task_mode
1071
+
1072
+ if task_mode in ["clue_acc", "long_acc"]:
1073
+ data_file["answer"] = data_file["right_answer"]
1074
+
1075
+ if task_mode == "miou":
1076
+ data_file["answer"] = data_file["clue_intervals"]
1077
+
1078
+ if task_mode in ["long_acc", "miou"]:
1079
+ data_file["clue_intervals"] = ""
1080
+
1081
+ data_file = data_file[
1082
+ [
1083
+ "index",
1084
+ "video_uid",
1085
+ "video",
1086
+ "duration",
1087
+ "domain",
1088
+ "choices",
1089
+ "sub_category",
1090
+ "subtitle_path",
1091
+ "question",
1092
+ "answer",
1093
+ "task_mode",
1094
+ "clue_intervals",
1095
+ "qid",
1096
+ "clue_video_path",
1097
+ ]
1098
+ ]
1099
+
1100
+ all_data.append(data_file)
1101
+
1102
+ final_data = pd.concat(all_data, ignore_index=True)
1103
+ final_data["index"] = range(len(final_data))
1104
+ final_data.to_csv(tsv_file, sep="\t", index=False)
1105
+
1106
+ if modelscope_flag_set():
1107
+ from modelscope import dataset_snapshot_download
1108
+
1109
+ dataset_path = dataset_snapshot_download(dataset_id=repo_id)
1110
+ else:
1111
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
1112
+
1113
+ unzip_hf_zip(dataset_path)
1114
+ generate_tsv(dataset_path)
1115
+
1116
+ tsv_file = osp.join(dataset_path, f"{dataset_name}.tsv")
1117
+
1118
+ return dict(data_file=tsv_file, root=dataset_path)
1119
+
1120
+ def build_prompt(self, line, video_llm):
1121
+
1122
+ if isinstance(line, int):
1123
+ assert line < len(self)
1124
+ line = self.data.iloc[line]
1125
+
1126
+ task_mode = line["task_mode"]
1127
+
1128
+ message = []
1129
+
1130
+ origin_use_subtitle_time = self.use_subtitle_time
1131
+
1132
+ try:
1133
+ if task_mode in ["long_acc", "clue_acc"]:
1134
+ system_prompt = self.SYS[task_mode]
1135
+ elif task_mode == "miou":
1136
+ if self.use_frame_time and not video_llm:
1137
+ system_prompt = self.SYS[task_mode]
1138
+ else:
1139
+ system_prompt = self.SYS["miou_wo_frame_time"]
1140
+ if self.use_subtitle_time is True:
1141
+ self.use_subtitle_time = False
1142
+
1143
+ user_prompt = ""
1144
+
1145
+ if task_mode in ["long_acc", "miou"]:
1146
+ video_path = line["video"]
1147
+
1148
+ if video_llm:
1149
+ message.append(dict(type="video", value=osp.join(self.data_root, video_path)))
1150
+
1151
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
1152
+ if self.nframe:
1153
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
1154
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
1155
+ )
1156
+ user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
1157
+ fps=vid_fps, sub_time=self.use_subtitle_time)
1158
+ else:
1159
+ user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
1160
+ else:
1161
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
1162
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
1163
+ )
1164
+ message.extend(dict(type="image", value=im) for im in image_paths)
1165
+
1166
+ if self.use_frame_time:
1167
+ user_prompt += get_timestampes(frame_indices, vid_fps)
1168
+
1169
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
1170
+ user_prompt += self.get_subtitles(
1171
+ line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
1172
+ sub_time=self.use_subtitle_time
1173
+ )
1174
+
1175
+ elif task_mode == "clue_acc":
1176
+ clue_video_path = line["clue_video_path"]
1177
+ video_path = line["video"]
1178
+
1179
+ if video_llm:
1180
+ message.append(dict(type="video", value=osp.join(self.data_root, clue_video_path)))
1181
+ print(message)
1182
+
1183
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
1184
+ if self.nframe:
1185
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
1186
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
1187
+ )
1188
+ user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
1189
+ fps=vid_fps, sub_time=self.use_subtitle_time)
1190
+ else:
1191
+ user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
1192
+ else:
1193
+ if self.nframe > 32:
1194
+ self.nframe = 32
1195
+ print("The maximum number of frames is 32 when evaluating clue-based mcq in CG-Bench !")
1196
+
1197
+ clue_intervals = eval(line["clue_intervals"])
1198
+
1199
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
1200
+ video_path, uid=line["qid"], clue_intervals=clue_intervals, num_frames=self.nframe, fps=self.fps
1201
+ )
1202
+
1203
+ message.extend(dict(type="image", value=im) for im in image_paths)
1204
+
1205
+ if self.use_frame_time:
1206
+ user_prompt += get_timestampes(frame_indices, vid_fps)
1207
+
1208
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
1209
+ user_prompt += self.get_subtitles(
1210
+ line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
1211
+ sub_time=self.use_subtitle_time
1212
+ )
1213
+
1214
+ question = line["question"]
1215
+ user_prompt += f"Question: {question}\n\n"
1216
+
1217
+ choices = eval(line["choices"])
1218
+ labels = [chr(ord("A") + i) for i in range(len(choices))]
1219
+ user_prompt += "\n".join([f"{label}:{value}" for label, value in zip(labels, choices)]) + "\n\n"
1220
+
1221
+ message.append(dict(type="text", value=system_prompt + user_prompt))
1222
+
1223
+ return message
1224
+
1225
+ finally:
1226
+ # Ensure that `use_subtitle_time` is always restored to its original value
1227
+ self.use_subtitle_time = origin_use_subtitle_time
1228
+
1229
+ def save_video_frames(self, video, uid, clue_intervals=None, num_frames=8, fps=-1):
1230
+
1231
+ if type(uid) is not str:
1232
+ uid = str(uid)
1233
+
1234
+ vid_path = osp.join(self.data_root, video)
1235
+ vid = decord.VideoReader(vid_path)
1236
+ vid_fps = vid.get_avg_fps()
1237
+ n_frames = len(vid)
1238
+
1239
+ if clue_intervals is not None:
1240
+ merged_intervals = merge_intervals(clue_intervals)
1241
+
1242
+ if num_frames > 0 and fps < 0:
1243
+ indices = sample_frames_clue_average(merged_intervals, num_frames, vid_fps)
1244
+ frame_paths = self.clue_frame_paths(uid, len(indices))
1245
+
1246
+ elif fps > 0:
1247
+ frame_indices = []
1248
+ for start, end in merged_intervals:
1249
+ start_frame = int(start * vid_fps)
1250
+ end_frame = int(end * vid_fps)
1251
+ step = vid_fps / fps
1252
+ interval_indices = [
1253
+ int(start_frame + i * step) for i in range(int((end_frame - start_frame) / step))
1254
+ ]
1255
+ frame_indices.extend(interval_indices)
1256
+
1257
+ if len(frame_indices) < 32:
1258
+ indices = sample_frames_clue_average(merged_intervals, 32, vid_fps)
1259
+ else:
1260
+ indices = frame_indices
1261
+ frame_paths = self.clue_frame_paths_fps(uid, len(indices), fps)
1262
+
1263
+ else:
1264
+ if num_frames > 0 and fps < 0:
1265
+ step_size = len(vid) / (num_frames + 1)
1266
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
1267
+
1268
+ frame_paths = self.frame_paths(uid)
1269
+ elif fps > 0:
1270
+ total_duration = n_frames / vid_fps
1271
+ required_frames = int(total_duration * fps)
1272
+ step_size = vid_fps / fps
1273
+ indices = [int(i * step_size) for i in range(required_frames)]
1274
+ frame_paths = self.frame_paths_fps(uid, len(indices))
1275
+
1276
+ # Save and validate frames
1277
+ valid_paths = []
1278
+ valid_indices = []
1279
+
1280
+ if not np.all([osp.exists(p) for p in frame_paths]):
1281
+ images = [vid[i].asnumpy() for i in indices]
1282
+ for i, (img_array, path) in enumerate(zip(images, frame_paths)):
1283
+ if osp.exists(path):
1284
+ try:
1285
+ with Image.open(path) as img:
1286
+ img.verify()
1287
+ valid_paths.append(path)
1288
+ valid_indices.append(indices[i])
1289
+ except Exception:
1290
+ continue
1291
+ else:
1292
+ try:
1293
+ img = Image.fromarray(img_array)
1294
+ img.save(path)
1295
+ img.verify()
1296
+ valid_paths.append(path)
1297
+ valid_indices.append(indices[i])
1298
+ except Exception:
1299
+ continue
1300
+ else:
1301
+ for i, path in enumerate(frame_paths):
1302
+ try:
1303
+ with Image.open(path) as img:
1304
+ img.verify()
1305
+ valid_paths.append(path)
1306
+ valid_indices.append(indices[i])
1307
+ except Exception:
1308
+ continue
1309
+
1310
+ return valid_paths, valid_indices, vid_fps
1311
+
1312
+ def evaluate(self, eval_file, **judge_kwargs):
1313
+
1314
+ assert eval_file.endswith(".xlsx"), "data file should be an xlsx file"
1315
+
1316
+ tgt_file = eval_file.replace(".xlsx", "_rating.json")
1317
+ score_file = eval_file.replace(".xlsx", "_score.xlsx")
1318
+
1319
+ data = load(eval_file)
1320
+
1321
+ data_un = data[~pd.isna(data["prediction"])]
1322
+ data_pred_na = data[pd.isna(data["prediction"])]
1323
+
1324
+ data_pred_na["score"] = -1
1325
+
1326
+ data_un["score"] = data_un.apply(
1327
+ lambda row: post_process(
1328
+ response=row["prediction"],
1329
+ right_answer=row["answer"],
1330
+ task_mode=row["task_mode"],
1331
+ duration=row["duration"],
1332
+ ),
1333
+ axis=1,
1334
+ )
1335
+
1336
+ data = pd.concat([data_pred_na, data_un])
1337
+
1338
+ rejected_count = (data["score"] == -1).sum()
1339
+
1340
+ print(
1341
+ f"Among {len(data)} questions, "
1342
+ f"failed to obtain prediction for {len(data_pred_na)} questions, "
1343
+ f"failed to obtain the score for {rejected_count - len(data_pred_na)} questions. "
1344
+ f"Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating."
1345
+ )
1346
+
1347
+ dump(data, score_file)
1348
+
1349
+ rating = get_dimention_rating_mcq_grouding(score_file)
1350
+
1351
+ dump(rating, tgt_file)
1352
+
1353
+ return rating
1354
+
1355
+
1356
+ # 评估时,step_2 评估时,给出 [prompt] + image_paths 就行
1357
+ class CGBench_OpenEnded(VideoBaseDataset):
1358
+
1359
+ TYPE = "Video-OpenEnded"
1360
+
1361
+ dataset = "CG-Bench_OpenEnded"
1362
+
1363
+ MD5 = "796035eda0b1e916c517cdc1bc145cfc"
1364
+
1365
+ SYS = (
1366
+ "You will be provided with sampled frames from a video, along with a "
1367
+ "question.\n"
1368
+ "Your task is to analyze the provided frames and infer the most plausible "
1369
+ "answer based on the visual information.\n"
1370
+ "If the visual information is ambiguous or insufficient, use the available "
1371
+ "context to reason your answer.\n"
1372
+ "Only output the answer in the following format:\n\n"
1373
+ '```json\n{"result": "answer"}\n```\n\n'
1374
+ 'The "answer" can be a word, phrase, or sentence that directly responds to '
1375
+ "the question.\n\n"
1376
+ )
1377
+
1378
+ def __init__(
1379
+ self,
1380
+ dataset="CG-Bench_OpenEnded",
1381
+ use_subtitle=False,
1382
+ use_subtitle_time=False,
1383
+ use_frame_time=False,
1384
+ nframe=0,
1385
+ fps=-1,
1386
+ ):
1387
+ super().__init__(dataset=dataset, nframe=nframe, fps=fps)
1388
+ self.use_subtitle = use_subtitle
1389
+ self.use_subtitle_time = use_subtitle_time
1390
+ self.use_frame_time = use_frame_time
1391
+ self.dataset_name = dataset
1392
+ lmu_root = LMUDataRoot()
1393
+ self.clue_frame_root = osp.join(lmu_root, "clue_images", dataset)
1394
+
1395
+ @classmethod
1396
+ def supported_datasets(cls):
1397
+ return ["CG-Bench_OpenEnded"]
1398
+
1399
+ def get_subtitles(self, subtitle_path, frame_indices=None, fps=None, sub_time=False):
1400
+
1401
+ subtitles = []
1402
+
1403
+ srt_path = osp.join(self.data_root, subtitle_path)
1404
+ assert osp.exists(srt_path)
1405
+ import pysubs2
1406
+
1407
+ subs = pysubs2.load(srt_path, encoding="utf-8")
1408
+ if not frame_indices:
1409
+ for sub in subs:
1410
+ sub_text = sub.text.replace("\\N", " ")
1411
+ if sub_time:
1412
+ start_time = milliseconds_to_seconds(sub.start)
1413
+ end_time = milliseconds_to_seconds(sub.end)
1414
+ sub_text = f"[{start_time}, {end_time}] {sub_text}"
1415
+ if sub_text.strip() and sub_text not in subtitles:
1416
+ subtitles.append(sub_text)
1417
+ else:
1418
+ for selected_frame_id in frame_indices:
1419
+ cur_time = pysubs2.make_time(fps=fps, frames=selected_frame_id)
1420
+ for sub in subs:
1421
+ if sub.start < cur_time and sub.end > cur_time:
1422
+ sub_text = sub.text.replace("\\N", " ")
1423
+ if sub_time:
1424
+ start_time = milliseconds_to_seconds(sub.start)
1425
+ end_time = milliseconds_to_seconds(sub.end)
1426
+ sub_text = f"[{start_time}, {end_time}] {sub_text}"
1427
+ if sub_text.strip() and sub_text not in subtitles:
1428
+ subtitles.append(sub_text)
1429
+
1430
+ if subtitles:
1431
+ subtitles_str = '\n'.join(subtitles)
1432
+ return f"The subtitles of the video are as follows:\n\n{subtitles_str}\n\n"
1433
+ else:
1434
+ return ""
1435
+
1436
+ def prepare_dataset(self, dataset_name="CG-Bench_OpenEnded", repo_id="CG-Bench/CG-Bench"):
1437
+
1438
+ def check_integrity(pth):
1439
+ data_file = osp.join(pth, f"{dataset_name}.tsv")
1440
+
1441
+ if not os.path.exists(data_file):
1442
+ return False
1443
+
1444
+ if md5(data_file) != self.MD5:
1445
+ return False
1446
+ data = load(data_file)
1447
+ for video_pth in data["video"]:
1448
+ if not osp.exists(osp.join(pth, video_pth)):
1449
+ return False
1450
+
1451
+ return True
1452
+
1453
+ cache_path = get_cache_path(repo_id)
1454
+
1455
+ if cache_path is not None and check_integrity(cache_path):
1456
+ dataset_path = cache_path
1457
+ else:
1458
+
1459
+ def generate_tsv(pth):
1460
+
1461
+ tsv_file = osp.join(pth, f"{dataset_name}.tsv")
1462
+
1463
+ with open(osp.join(pth, "cgbench.json"), "r") as f:
1464
+ data_file = pd.DataFrame(json.load(f))
1465
+
1466
+ data_file = data_file.assign(index=range(len(data_file)))
1467
+ data_file["video"] = data_file["video_uid"].apply(lambda x: f"cg_videos_720p/{x}.mp4")
1468
+ data_file["subtitle_path"] = data_file["video_uid"].apply(
1469
+ lambda x: f"cg_subtitles/{x}.srt" if osp.exists(osp.join(pth, f"cg_subtitles/{x}.srt")) else ""
1470
+ )
1471
+
1472
+ data_file = data_file[
1473
+ [
1474
+ "index",
1475
+ "video_uid",
1476
+ "video",
1477
+ "duration",
1478
+ "domain",
1479
+ "sub_category",
1480
+ "subtitle_path",
1481
+ "question",
1482
+ "answer",
1483
+ "clue_intervals",
1484
+ "qid",
1485
+ ]
1486
+ ]
1487
+
1488
+ data_file.to_csv(tsv_file, sep="\t", index=False)
1489
+
1490
+ if modelscope_flag_set():
1491
+ from modelscope import dataset_snapshot_download
1492
+ dataset_path = dataset_snapshot_download(dataset_id=repo_id)
1493
+ else:
1494
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
1495
+
1496
+ unzip_hf_zip(dataset_path)
1497
+ generate_tsv(dataset_path)
1498
+
1499
+ tsv_file = osp.join(dataset_path, f"{dataset_name}.tsv")
1500
+
1501
+ return dict(data_file=tsv_file, root=dataset_path)
1502
+
1503
+ def build_prompt(self, line, video_llm):
1504
+
1505
+ if isinstance(line, int):
1506
+ assert line < len(self)
1507
+ line = self.data.iloc[line]
1508
+
1509
+ message = []
1510
+
1511
+ sys_prompt = self.SYS
1512
+
1513
+ user_prompt = ""
1514
+
1515
+ video_path = line["video"]
1516
+
1517
+ if video_llm:
1518
+ message.append(dict(type="video", value=osp.join(self.data_root, video_path)))
1519
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
1520
+ if self.nframe:
1521
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
1522
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
1523
+ )
1524
+ user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
1525
+ fps=vid_fps, sub_time=self.use_subtitle_time)
1526
+ else:
1527
+ user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
1528
+ else:
1529
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
1530
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
1531
+ )
1532
+ message.extend(dict(type="image", value=im) for im in image_paths)
1533
+
1534
+ if self.use_frame_time:
1535
+ user_prompt += get_timestampes(frame_indices, vid_fps)
1536
+
1537
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
1538
+ user_prompt += self.get_subtitles(
1539
+ line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
1540
+ sub_time=self.use_subtitle_time
1541
+ )
1542
+
1543
+ question = line["question"]
1544
+ user_prompt += f"Question: {question}\n\n"
1545
+
1546
+ message.append(dict(type="text", value=sys_prompt + user_prompt))
1547
+
1548
+ return message
1549
+
1550
+ def clue_frame_paths(self, qid, num_frames=8):
1551
+ frame_root = osp.join(self.clue_frame_root, qid)
1552
+ os.makedirs(frame_root, exist_ok=True)
1553
+ return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
1554
+
1555
+ def save_video_frames(self, video, uid, clue_intervals=None, num_frames=8, fps=-1):
1556
+
1557
+ if type(uid) is not str:
1558
+ uid = str(uid)
1559
+
1560
+ vid_path = osp.join(self.data_root, video)
1561
+ vid = decord.VideoReader(vid_path)
1562
+ vid_fps = vid.get_avg_fps()
1563
+ n_frames = len(vid)
1564
+
1565
+ if clue_intervals is not None:
1566
+ merged_intervals = merge_intervals(clue_intervals)
1567
+
1568
+ if num_frames > 0 and fps < 0:
1569
+ indices = sample_frames_clue_average(merged_intervals, num_frames, vid_fps)
1570
+ frame_paths = self.clue_frame_paths(uid, len(indices))
1571
+
1572
+ elif fps > 0:
1573
+ frame_indices = []
1574
+ for start, end in merged_intervals:
1575
+ start_frame = int(start * vid_fps)
1576
+ end_frame = int(end * vid_fps)
1577
+ step = vid_fps / fps
1578
+ interval_indices = [
1579
+ int(start_frame + i * step) for i in range(int((end_frame - start_frame) / step))
1580
+ ]
1581
+ frame_indices.extend(interval_indices)
1582
+
1583
+ if len(frame_indices) < 32:
1584
+ indices = sample_frames_clue_average(merged_intervals, 32, vid_fps)
1585
+ else:
1586
+ indices = frame_indices
1587
+ frame_paths = self.clue_frame_paths_fps(uid, len(indices), fps)
1588
+
1589
+ else:
1590
+ if num_frames > 0 and fps < 0:
1591
+ step_size = len(vid) / (num_frames + 1)
1592
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
1593
+ frame_paths = self.frame_paths(uid)
1594
+ elif fps > 0:
1595
+ total_duration = n_frames / vid_fps
1596
+ required_frames = int(total_duration * fps)
1597
+ step_size = vid_fps / fps
1598
+ indices = [int(i * step_size) for i in range(required_frames)]
1599
+ frame_paths = self.frame_paths_fps(uid, len(indices))
1600
+
1601
+ valid_paths = []
1602
+ valid_indices = []
1603
+
1604
+ if not np.all([osp.exists(p) for p in frame_paths]):
1605
+ images = [vid[i].asnumpy() for i in indices]
1606
+ for i, (img_array, path) in enumerate(zip(images, frame_paths)):
1607
+ if osp.exists(path):
1608
+ try:
1609
+ with Image.open(path) as img:
1610
+ img.verify()
1611
+ valid_paths.append(path)
1612
+ valid_indices.append(indices[i])
1613
+ except Exception:
1614
+ continue
1615
+ else:
1616
+ try:
1617
+ img = Image.fromarray(img_array)
1618
+ img.save(path)
1619
+ img.verify()
1620
+ valid_paths.append(path)
1621
+ valid_indices.append(indices[i])
1622
+ except Exception:
1623
+ continue
1624
+ else:
1625
+ for i, path in enumerate(frame_paths):
1626
+ try:
1627
+ with Image.open(path) as img:
1628
+ img.verify()
1629
+ valid_paths.append(path)
1630
+ valid_indices.append(indices[i])
1631
+ except Exception:
1632
+ continue
1633
+
1634
+ return valid_paths, valid_indices, vid_fps
1635
+
1636
+ def evaluate(self, eval_file, **judge_kwargs):
1637
+
1638
+ from .utils.cgbench import get_dimention_rating_open_ended, post_process_open
1639
+
1640
+ assert eval_file.endswith(".xlsx"), "data file should be an xlsx file"
1641
+
1642
+ tgt_file = eval_file.replace(".xlsx", "_rating.json")
1643
+ score_file = eval_file.replace(".xlsx", "_score.xlsx")
1644
+ step_1_tmp_file = eval_file.replace(".xlsx", "_step_1.pkl")
1645
+ step_2_tmp_file = eval_file.replace(".xlsx", "_step_2.pkl")
1646
+
1647
+ data = load(eval_file)
1648
+
1649
+ data_pred_no_na = data[~pd.isna(data["prediction"])]
1650
+ data_pred_na = data[pd.isna(data["prediction"])]
1651
+
1652
+ data_pred_na["model_result"] = -1
1653
+ data_pred_na["step_1_result"] = -1
1654
+ data_pred_na["step_2_result"] = -1
1655
+ data_pred_na["score"] = -1
1656
+
1657
+ data_pred_no_na["model_result"] = data_pred_no_na.apply(
1658
+ lambda row: post_process_open(
1659
+ response=row["prediction"],
1660
+ ),
1661
+ axis=1,
1662
+ )
1663
+
1664
+ if judge_kwargs.get("model", None) != "gpt-4o-0806":
1665
+ judge_kwargs["model"] = "gpt-4o-0806"
1666
+ print("The judge model in cg-bench is gpt-4o-0806!")
1667
+
1668
+ data_no_model_result = data_pred_no_na[data_pred_no_na["model_result"] == -1]
1669
+ data_step_1 = data_pred_no_na[data_pred_no_na["model_result"] != -1]
1670
+
1671
+ model_step_1 = build_judge(system_prompt=sys_prompt_open_eval_step_1, **judge_kwargs)
1672
+ nproc = judge_kwargs.pop('nproc', 32)
1673
+
1674
+ lines_step_1 = data_step_1.to_dict("records")
1675
+ tups_step_1 = [(model_step_1, line) for line in lines_step_1]
1676
+
1677
+ keys_step_1 = {line["qid"] for line in lines_step_1}
1678
+
1679
+ ans = {}
1680
+ if osp.exists(step_1_tmp_file):
1681
+ ans = load(step_1_tmp_file)
1682
+ tups_step_1 = [x for x, i in zip(tups_step_1, keys_step_1) if i not in ans]
1683
+ keys_step_1 = [i for i in keys_step_1 if i not in ans]
1684
+
1685
+ _ = track_progress_rich(
1686
+ eval_open_first,
1687
+ tups_step_1,
1688
+ nproc=nproc,
1689
+ keys=keys_step_1,
1690
+ save=step_1_tmp_file,
1691
+ )
1692
+
1693
+ step_1_results = load(step_1_tmp_file)
1694
+ data_step_1 = save_step_1_steps(data_step_1, step_1_results) # -1, 0, 1, 2
1695
+
1696
+ data_no_step_1_results = data_step_1[data_step_1["step_1_result"] == -1]
1697
+ data_step_1_over = data_step_1[data_step_1["step_1_result"].isin([0, 1])]
1698
+ data_step_2 = data_step_1[data_step_1["step_1_result"] == 2]
1699
+
1700
+ model_step_2 = build_judge(system_prompt=sys_prompt_open_eval_step_2, **judge_kwargs)
1701
+
1702
+ lines_step_2 = data_step_2.to_dict("records")
1703
+
1704
+ tups_step_2 = []
1705
+
1706
+ for line in tqdm(lines_step_2):
1707
+ clue_intervals = eval(line["clue_intervals"])
1708
+ lmu_root = LMUDataRoot()
1709
+ clue_frame_root = osp.join(lmu_root, "clue_images", self.dataset)
1710
+ data_root = self.data_root
1711
+ frame_paths, _, _ = save_clue_video_frames(
1712
+ data_root,
1713
+ clue_frame_root,
1714
+ video=line["video"],
1715
+ uid=line["qid"],
1716
+ clue_intervals=clue_intervals,
1717
+ num_frames=32,
1718
+ )
1719
+ tups_step_2.append((model_step_2, line, frame_paths))
1720
+
1721
+ keys_step_2 = {line["qid"] for line in lines_step_2}
1722
+
1723
+ ans = {}
1724
+ if osp.exists(step_2_tmp_file):
1725
+ ans = load(step_2_tmp_file)
1726
+ tups_step_2 = [x for x, i in zip(tups_step_2, keys_step_2) if i not in ans]
1727
+ keys_step_2 = [i for i in keys_step_2 if i not in ans]
1728
+
1729
+ _ = track_progress_rich(
1730
+ eval_open_second,
1731
+ tups_step_2,
1732
+ nproc=nproc,
1733
+ keys=keys_step_2,
1734
+ save=step_2_tmp_file,
1735
+ )
1736
+
1737
+ step_2_results = load(step_2_tmp_file)
1738
+ data_step_2 = save_step_2_steps(data_step_2, step_2_results)
1739
+
1740
+ data_no_step_2_results = data_step_2[data_step_2["score"] == -1]
1741
+ data_step_2_over = data_step_2[data_step_2["score"].isin([0, 1])]
1742
+
1743
+ data = pd.concat(
1744
+ [
1745
+ data_pred_na,
1746
+ data_no_model_result,
1747
+ data_no_step_1_results,
1748
+ data_step_1_over,
1749
+ data_no_step_2_results,
1750
+ data_step_2_over,
1751
+ ]
1752
+ )
1753
+
1754
+ dump(data, score_file)
1755
+
1756
+ rating = get_dimention_rating_open_ended(score_file)
1757
+
1758
+ dump(rating, tgt_file)
1759
+
1760
+ return rating
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/dataset/cmmmu.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .image_base import ImageBaseDataset
2
+ import random
3
+ from collections import Counter
4
+ import os
5
+ import re
6
+ import tempfile
7
+ from ..smp import *
8
+
9
+
10
+ def get_multi_choice_prediction(response, all_choices, index2ans):
11
+ for char in [',', '.', '!', '?', ';', ':', "'"]:
12
+ response = response.strip(char)
13
+ response = " " + response + " " # add space to avoid partial match
14
+
15
+ candidates = []
16
+
17
+ for choice in all_choices: # (A) (B) (C) (D)
18
+ # Add the choice to candidates each time it appears in the response
19
+ candidates.extend([choice for _ in range(response.count(f'({choice})'))])
20
+
21
+ if len(candidates) == 0:
22
+ for choice in all_choices: # A B C D
23
+ # Similarly, add the choice for each occurrence
24
+ candidates.extend([choice for _ in range(response.count(f'{choice}'))])
25
+
26
+ if len(candidates) == 0 and len(response.split()) >= 1:
27
+ for index, ans in index2ans.items():
28
+ # Add index for each occurrence of ans in response
29
+ candidates.extend([index for _ in range(response.count(ans))])
30
+
31
+ # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
32
+ if len(candidates) == 0 and len(response.split()) >= 1:
33
+ for index, ans in index2ans.items():
34
+ if ans in response:
35
+ candidates.append(index)
36
+ # index_ans = False # it's content ans.
37
+
38
+ if len(candidates) == 0: # still not get answer, randomly choose one.
39
+ return random.choice(all_choices)
40
+ # return ''
41
+ else:
42
+ # Count the occurrence of each candidate
43
+ candidate_counts = Counter(candidates)
44
+
45
+ # Select the most frequent candidates
46
+ max_count = max(candidate_counts.values())
47
+ most_frequent_candidates = [c for c in all_choices if candidate_counts.get(c, 0) == max_count]
48
+
49
+ # Combine the most frequent candidates in ABCD order
50
+ return ''.join(most_frequent_candidates)
51
+
52
+
53
+ def extract_numbers(string):
54
+ # Pattern for numbers with Chinese commas
55
+ pattern_commas = r'-?\d{1,3}(?:,\d{3})+'
56
+ # Pattern for scientific notation
57
+ pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+'
58
+ # Pattern for simple numbers without Chinese commas
59
+ pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+)(?![eE][+-]?\d+)(?!,\d)'
60
+
61
+ # Extract numbers with Chinese commas
62
+ numbers_with_commas = re.findall(pattern_commas, string)
63
+ # Extract numbers in scientific notation
64
+ numbers_scientific = re.findall(pattern_scientific, string)
65
+ # Extract simple numbers without Chinese commas
66
+ numbers_simple = re.findall(pattern_simple, string)
67
+
68
+ # Combine all extracted numbers
69
+ all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
70
+ return all_numbers
71
+
72
+
73
+ def check_is_number(string):
74
+ try:
75
+ float(string.replace(',', ''))
76
+ return True
77
+ except ValueError:
78
+ # check if there's comma inside
79
+ return False
80
+
81
+
82
+ def count_letters(string):
83
+ return sum(c.isalpha() and 'a' <= c <= 'z' or 'A' <= c <= 'Z' for c in string)
84
+
85
+
86
+ def normalize_str(string, answer):
87
+ # check if characters in the string
88
+
89
+ # if number, numerize it.
90
+ if string is None:
91
+ return [string]
92
+ string = string.strip()
93
+
94
+ is_number = check_is_number(string)
95
+
96
+ if is_number:
97
+ string = string.replace(',', '')
98
+ string = float(string)
99
+ # leave 2 decimal
100
+ string = round(string, 2)
101
+ return [string]
102
+ else: # it's likely to be a string
103
+ if len(string) > len(answer) + 20 or count_letters(string) > count_letters(answer) + 2:
104
+ return []
105
+ return [string]
106
+
107
+
108
+ def get_fill_blank_prediction(response, answer):
109
+ """get the prediction from the generated response,
110
+ return a list of predicted strings or numbers"""
111
+
112
+ def get_key_subresponses(response):
113
+ response = response.strip("。").strip()
114
+ sub_responses = re.split(r'。|\n', response)
115
+ indicators_of_keys = ['是', '为', '所以', '等于', '方案', '选择',
116
+ '正确答案', '因此', '最后', '答案', '结果']
117
+ key_responses = []
118
+ for index, resp in enumerate(sub_responses):
119
+ # if last one, accept it's an equation (the entire response can be just one sentence with equation)
120
+ if index == len(sub_responses) - 1:
121
+ indicators_of_keys.extend(['='])
122
+ shortest_key_response = None
123
+ # the shortest response that may contain the answer (tail part of the response)
124
+ for indicator in indicators_of_keys:
125
+ if indicator in resp:
126
+ if not shortest_key_response:
127
+ shortest_key_response = resp.split(indicator)[-1].strip()
128
+ else:
129
+ if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
130
+ shortest_key_response = resp.split(indicator)[-1].strip()
131
+
132
+ if shortest_key_response:
133
+ # and it's not trivial
134
+ if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
135
+ key_responses.append(shortest_key_response)
136
+ if len(key_responses) == 0: # did not found any
137
+ return [response]
138
+ return key_responses
139
+
140
+ key_responses = get_key_subresponses(response)
141
+
142
+ pred_list = key_responses.copy() # keep the original string response
143
+ for resp in key_responses:
144
+ pred_list.extend(extract_numbers(resp))
145
+
146
+ tmp_pred_list = []
147
+ for i in range(len(pred_list)):
148
+ tmp_pred_list.extend(normalize_str(pred_list[i], answer))
149
+ pred_list = tmp_pred_list
150
+
151
+ # remove duplicates
152
+ pred_list = list(set(pred_list))
153
+
154
+ return pred_list
155
+
156
+
157
+ def get_TF_prediction(response):
158
+ """get the prediction from the generated response,
159
+ return a list of predicted strings or numbers"""
160
+
161
+ def get_key_subresponses(response):
162
+ response = response.strip("。").strip()
163
+ sub_responses = re.split(r'。|\n', response)
164
+ indicators_of_keys = ['是', '为', '所以', '判断',
165
+ '陈述', '说法', '表达', '答案', '结果']
166
+ key_responses = []
167
+ for index, resp in enumerate(sub_responses):
168
+ shortest_key_response = None
169
+ # the shortest response that may contain the answer (tail part of the response)
170
+ for indicator in indicators_of_keys:
171
+ if indicator in resp:
172
+ if not shortest_key_response:
173
+ shortest_key_response = resp.split(indicator)[-1].strip()
174
+ else:
175
+ if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
176
+ shortest_key_response = resp.split(indicator)[-1].strip()
177
+
178
+ if shortest_key_response:
179
+ # and it's not trivial
180
+ if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
181
+ key_responses.append(shortest_key_response)
182
+ if len(key_responses) == 0: # did not found any
183
+ return [response]
184
+ return key_responses
185
+
186
+ key_responses = get_key_subresponses(response)
187
+
188
+ pred_list = key_responses.copy() # keep the original string response
189
+ # remove duplicates
190
+ pred_list = list(set(pred_list))
191
+
192
+ return pred_list
193
+
194
+
195
+ class CMMMU(ImageBaseDataset):
196
+ TYPE = 'VQA'
197
+
198
+ DATASET_URL = {
199
+ 'CMMMU_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/CMMMU_VAL.tsv'
200
+ }
201
+
202
+ DATASET_MD5 = {
203
+ 'CMMMU_VAL': 'b4727e2fce2415bf646379e60c11a726'
204
+ }
205
+
206
+ def dump_image(self, line):
207
+ os.makedirs(self.img_root, exist_ok=True)
208
+
209
+ tgt_path_z = []
210
+ if isinstance(line['image'], list):
211
+ for i in range(len(line['image'])):
212
+ tgt_path = osp.join(self.img_root, f"{line['index']}--{i + 1}.jpg")
213
+ if not read_ok(tgt_path):
214
+ decode_base64_to_image_file(line['image'][i], tgt_path)
215
+ tgt_path_z.append(tgt_path)
216
+ else:
217
+ tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
218
+ if not read_ok(tgt_path):
219
+ decode_base64_to_image_file(line['image'], tgt_path)
220
+ tgt_path_z.append(tgt_path)
221
+ return tgt_path_z
222
+
223
+ @classmethod
224
+ def evaluate(self, eval_file, **judge_kwargs):
225
+
226
+ suffix = eval_file.split('.')[-1]
227
+ result_file = eval_file.replace(f'.{suffix}', '_acc.csv')
228
+
229
+ if not osp.exists(result_file):
230
+ data = load(eval_file)
231
+ assert 'answer' in data and 'prediction' in data
232
+ data['prediction'] = [str(x) for x in data['prediction']]
233
+ data['answer'] = [str(x) for x in data['answer']]
234
+
235
+ correct_count = 0
236
+ correct_category = {
237
+ '技术与工程': [0, 0],
238
+ '科学': [0, 0],
239
+ '健康与医学': [0, 0],
240
+ '商业': [0, 0],
241
+ '艺术与设计': [0, 0],
242
+ '人文社会科学': [0, 0],
243
+ }
244
+
245
+ for i in tqdm(data.iterrows()):
246
+ line = i[1]
247
+ correct_category[line['category']][0] += 1
248
+
249
+ # Options
250
+ if line['type'] == '选择':
251
+ index2ans = {
252
+ 'A': line['option1'],
253
+ 'B': line['option2'],
254
+ 'C': line['option3'],
255
+ 'D': line['option4']
256
+ }
257
+ fact_option = get_multi_choice_prediction(line['prediction'], ['A', 'B', 'C', 'D'], index2ans)
258
+ if fact_option == line['answer']:
259
+ correct_count += 1
260
+ correct_category[line['category']][1] += 1
261
+
262
+ # Binary
263
+ elif line['type'] == '判断':
264
+ positive_keywords = ['正确', '对', '准确', '肯定', '对的']
265
+ negative_keywords = ['不对', '错误', '不正确', '不准确', '不合适', '否定', '错的', '错']
266
+ ambiguous_keywords = ['对错', '是否正确', '否正确', '或者', '是否', '正确性', '对不']
267
+
268
+ def judge_similarity(pred_list, positive_keywords, negative_keywords):
269
+ positive_count = 0
270
+ negative_count = 0
271
+
272
+ for pred in pred_list:
273
+ if any(pos_word in pred for pos_word in positive_keywords):
274
+ positive_count += 1
275
+ elif any(neg_word in pred for neg_word in negative_keywords):
276
+ negative_count += 1
277
+
278
+ if positive_count > negative_count:
279
+ return "对"
280
+ elif negative_count > positive_count:
281
+ return "错"
282
+ else:
283
+ return random.choice(['对', '错'])
284
+
285
+ answer = get_TF_prediction(line['prediction'])
286
+ answer = [word for word in answer if not any(ambiguous in word for ambiguous in ambiguous_keywords)]
287
+ fact_answer = judge_similarity(answer, positive_keywords, negative_keywords)
288
+ if fact_answer == line['answer']:
289
+ correct_count += 1
290
+ correct_category[line['category']][1] += 1
291
+
292
+ # Fill the Blank
293
+ else:
294
+ norm_answers = normalize_str(line['answer'], line['answer'])
295
+ predicted_answer = get_fill_blank_prediction(line['prediction'], line['answer'])
296
+
297
+ for pred in predicted_answer:
298
+ # already normalized
299
+ if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
300
+ for norm_ans in norm_answers:
301
+ # only see if the string answer in the string pred
302
+ # print(norm_ans, pred)
303
+ if isinstance(norm_ans, str) and norm_ans in pred:
304
+ correct_count += 1
305
+ correct_category[line['category']][1] += 1
306
+ else: # it's a number
307
+ if pred in norm_answers:
308
+ correct_count += 1
309
+ correct_category[line['category']][1] += 1
310
+
311
+ accuracyz = {}
312
+ accuracyz['总准确率'] = correct_count / len(data)
313
+ for i in correct_category.keys():
314
+ accuracyz[i] = correct_category[i][1] / correct_category[i][0]
315
+
316
+ accuracyz = d2df(accuracyz)
317
+ accuracyz.round(10)
318
+ dump(accuracyz, result_file)
319
+
320
+ result = pd.read_csv(result_file)
321
+ return result
322
+
323
+ def build_prompt(self, line):
324
+ if line['type'] == '选择':
325
+ tgt_path = self.dump_image(line)
326
+ question = line['question']
327
+ options_prompt = 'Options:\n'
328
+
329
+ for i in [['A', '1'], ['B', '2'], ['C', '3'], ['D', '4']]:
330
+ options_prompt += i[0] + '. ' + line['option' + i[1]] + '\n'
331
+
332
+ prompt = (f'问题: {question}\n' + options_prompt
333
+ + '请回答上述多项选择题,并选出正确选项。这些题目可能包括单选和多选题型。如果所提供的信息不足以确定一个明确的答案,那么请根据可用的数据和你的判断来选择最可能正确的选项。')
334
+
335
+ msgs = []
336
+ if isinstance(tgt_path, list):
337
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
338
+ else:
339
+ msgs = [dict(type='image', value=tgt_path)]
340
+ msgs.append(dict(type='text', value=prompt))
341
+
342
+ return msgs
343
+
344
+ elif line['type'] == '判断':
345
+ msgs = super().build_prompt(line)
346
+ assert msgs[-1]['type'] == 'text'
347
+ msgs[-1]['value'] += '\n请回答上述判断题,并根据题目描述和所给的信息来判断问题中陈述的对错。如果信息不完整或不足以作出绝对判断,请运用你的逻辑推理和现有信息来做出最可能的判断。'
348
+ return msgs
349
+
350
+ else:
351
+ msgs = super().build_prompt(line)
352
+ assert msgs[-1]['type'] == 'text'
353
+ msgs[-1]['value'] += '\n请回答上述填空题,并根据题目的要求和所提供的信息来给出最恰当的答案。如果信息不足以确切回答,那么请依据现有的数据和你的推理能力来填写最合理的答案。'
354
+ return msgs
r1-a/response_generation/qwenomni.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import uuid # For generating unique filenames
5
+ import time
6
+ import re # For parsing history
7
+ from io import BytesIO
8
+ import random
9
+ import concurrent.futures # <-- For ThreadPoolExecutor
10
+ from tqdm import tqdm # <-- For progress bar
11
+ import threading # <-- For potential thread-local data or locks if needed later
12
+ import traceback # <-- For detailed error printing
13
+
14
+ import numpy as np
15
+ import soundfile as sf
16
+ from openai import OpenAI
17
+ from datasets import load_from_disk, Dataset, Features, Value # Ensure Features is imported
18
+ from dotenv import load_dotenv
19
+
20
+ # --- Configuration ---
21
+ load_dotenv()
22
+
23
+ # 1. API Client Setup & Model Rotation Setup
24
+ QWEN_MODEL_LIST = [
25
+ "qwen-omni-turbo",
26
+ "qwen-omni-turbo-latest",
27
+ "qwen-omni-turbo-2025-03-26",
28
+ "qwen-omni-turbo-2025-01-19",
29
+ ]
30
+ NUM_MODELS = len(QWEN_MODEL_LIST)
31
+ print(f"Using Qwen models in rotation: {QWEN_MODEL_LIST}")
32
+
33
+ client = OpenAI(
34
+ # api_key=os.getenv("DASHSCOPE_API_KEY"),
35
+ api_key="sk-368bc96f5be74b9bbc880cc6161ab64b", # Replace with your actual key or os.getenv("DASHSCOPE_API_KEY")
36
+ base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
37
+ )
38
+
39
+ # 2. Dataset Paths
40
+ INPUT_DATASET_DIR = "/home/chenyifu/audio-r1/r1-a/dataset/preference_tasks_fully_merged_with_audio/train/final_dataset"
41
+ OUTPUT_DATASET_DIR = "/root/autodl-tmp/audio-r1/r1-a/dataset/preference_tasks_with_qwen_rotated" # <-- Adjusted name
42
+
43
+ # 3. Output Audio Configuration
44
+ OUTPUT_AUDIO_ROOT_DIR = "/root/autodl-tmp/audio-r1/r1-a/generated_audio/qwen_omni_rotated" # <-- Adjusted name
45
+ OUTPUT_AUDIO_FORMAT = "wav"
46
+ AVAILABLE_QWEN_VOICES = ["Cherry", "Serena", "Ethan", "Chelsie"]
47
+ OUTPUT_AUDIO_SAMPLERATE = 24000
48
+
49
+ # 4. API Call Settings
50
+ API_RETRY_DELAY = 5
51
+ API_MAX_RETRIES = 3
52
+ MAX_WORKERS = 10 # <-- Set desired number of threads (Be mindful of rate limits!)
53
+
54
+ # 5. Checkpoint Saving Configuration
55
+ CHECKPOINT_INTERVAL = 50 # Save every 500 completed tasks
56
+
57
+ # --- Helper Functions ---
58
+
59
+ def encode_audio_base64(audio_path):
60
+ if not audio_path or not os.path.exists(audio_path):
61
+ print(f"Warning: Input audio file not found or path is empty: {audio_path}")
62
+ return None
63
+ try:
64
+ with open(audio_path, "rb") as audio_file:
65
+ return base64.b64encode(audio_file.read()).decode("utf-8")
66
+ except Exception as e:
67
+ print(f"Error encoding audio file {audio_path}: {e}")
68
+ return None
69
+
70
+ def parse_ultra_history(history_str):
71
+ messages = []
72
+ pattern = re.compile(r"\[(USER|ASSISTANT)\]\s*([\s\S]*?)(?=\s*\[(?:USER|ASSISTANT)\]|$)")
73
+ matches = pattern.findall(history_str)
74
+ if not matches and history_str and history_str.strip():
75
+ if history_str.lower().startswith("user:") or history_str.lower().startswith("[user]"):
76
+ role = "user"
77
+ content = re.sub(r"^(user:|\[user\])\s*", "", history_str, flags=re.IGNORECASE).strip()
78
+ if content: messages.append({"role": role, "content": content})
79
+ elif history_str.lower().startswith("assistant:") or history_str.lower().startswith("[assistant]"):
80
+ role = "assistant"
81
+ content = re.sub(r"^(assistant:|\[assistant\])\s*", "", history_str, flags=re.IGNORECASE).strip()
82
+ if content: messages.append({"role": role, "content": content})
83
+ else:
84
+ return []
85
+ else:
86
+ for role_tag, content in matches:
87
+ role = role_tag.lower()
88
+ cleaned_content = content.strip()
89
+ if cleaned_content:
90
+ messages.append({"role": role, "content": cleaned_content})
91
+ return messages
92
+
93
+ # --- API Call Worker Function (Takes model_name) ---
94
+ def call_qwen_omni_api_worker(task_info):
95
+ """
96
+ Worker function to call Qwen API for a single task using a specific model.
97
+ Returns results including the model used.
98
+ """
99
+ row_idx = task_info["row_idx"]
100
+ slot_idx = task_info["slot_idx"]
101
+ model_to_use = task_info["model_name"]
102
+ history_messages = task_info["history_messages"]
103
+ prompt_text = task_info["prompt_text"]
104
+ question_audio_path = task_info["question_audio_path"]
105
+ output_audio_filepath = task_info["output_audio_filepath"]
106
+
107
+ retries = 0
108
+ selected_voice = random.choice(AVAILABLE_QWEN_VOICES)
109
+
110
+ while retries < API_MAX_RETRIES:
111
+ try:
112
+ base64_audio_data = encode_audio_base64(question_audio_path)
113
+ if not base64_audio_data:
114
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Skipping API call due to missing input audio: {question_audio_path}")
115
+ # Return the model name even on error for potential logging
116
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": "[ERROR: Missing input audio]", "saved_audio_path": None, "model_used": model_to_use}
117
+
118
+ input_audio_format = os.path.splitext(question_audio_path)[1].lstrip('.') or 'wav'
119
+
120
+ user_content = []
121
+ user_content.append({
122
+ "type": "input_audio",
123
+ "input_audio": {
124
+ "data": f"data:audio/{input_audio_format};base64,{base64_audio_data}",
125
+ "format": input_audio_format,
126
+ },
127
+ })
128
+ user_content.append({"type": "text", "text": prompt_text})
129
+ messages = history_messages + [{"role": "user", "content": user_content}]
130
+
131
+ completion = client.chat.completions.create(
132
+ model=model_to_use,
133
+ messages=messages,
134
+ modalities=["text", "audio"],
135
+ audio={"voice": selected_voice, "format": OUTPUT_AUDIO_FORMAT},
136
+ stream=True,
137
+ stream_options={"include_usage": True},
138
+ )
139
+
140
+ collected_text = ""
141
+ audio_base64_string = ""
142
+ usage_info = None
143
+
144
+ for chunk in completion:
145
+ if chunk.choices and len(chunk.choices) > 0:
146
+ delta = chunk.choices[0].delta
147
+ if hasattr(delta, 'content') and delta.content:
148
+ collected_text += delta.content
149
+ if hasattr(delta, "audio") and delta.audio:
150
+ if "data" in delta.audio and delta.audio["data"]:
151
+ audio_base64_string += delta.audio["data"]
152
+ if "transcript" in delta.audio and delta.audio["transcript"]:
153
+ collected_text += delta.audio["transcript"]
154
+ elif hasattr(chunk, "usage") and chunk.usage:
155
+ usage_info = chunk.usage
156
+
157
+ if audio_base64_string:
158
+ try:
159
+ wav_bytes = base64.b64decode(audio_base64_string)
160
+ if len(wav_bytes) == 0:
161
+ print(f"Warning (Row {row_idx}, Slot {slot_idx}, Model: {model_to_use}): Decoded audio bytes are empty.")
162
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None, "model_used": model_to_use}
163
+ if len(wav_bytes) % 2 != 0:
164
+ wav_bytes = wav_bytes[:-1] # Truncate for int16
165
+ if len(wav_bytes) == 0:
166
+ print(f"Warning (Row {row_idx}, Slot {slot_idx}, Model: {model_to_use}): Audio bytes became empty after truncation.")
167
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None, "model_used": model_to_use}
168
+
169
+ audio_np = np.frombuffer(wav_bytes, dtype=np.int16)
170
+ os.makedirs(os.path.dirname(output_audio_filepath), exist_ok=True)
171
+ sf.write(output_audio_filepath, audio_np, OUTPUT_AUDIO_SAMPLERATE, format=OUTPUT_AUDIO_FORMAT.upper())
172
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": output_audio_filepath, "model_used": model_to_use}
173
+
174
+ except base64.binascii.Error as b64_e:
175
+ print(f"Error (Row {row_idx}, Slot {slot_idx}, Model: {model_to_use}): Decoding base64 failed: {b64_e}")
176
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None, "model_used": model_to_use}
177
+ except ValueError as val_e:
178
+ print(f"Error (Row {row_idx}, Slot {slot_idx}, Model: {model_to_use}): Interpreting buffer as int16 failed: {val_e} (Bytes: {len(wav_bytes)})")
179
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None, "model_used": model_to_use}
180
+ except Exception as e:
181
+ print(f"Error (Row {row_idx}, Slot {slot_idx}, Model: {model_to_use}): Processing/saving audio bytes failed: {e}")
182
+ traceback.print_exc()
183
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None, "model_used": model_to_use}
184
+ else:
185
+ print(f"Warning (Row {row_idx}, Slot {slot_idx}, Model: {model_to_use}): No audio data received in the stream.")
186
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": collected_text.strip(), "saved_audio_path": None, "model_used": model_to_use}
187
+
188
+ except Exception as e:
189
+ retries += 1
190
+ print(f"Error (Row {row_idx}, Slot {slot_idx}, Model: {model_to_use}): API Call Attempt {retries}/{API_MAX_RETRIES} failed: {e}")
191
+ if "rate limit" in str(e).lower() or "too many requests" in str(e).lower():
192
+ print("Rate limit likely hit. Consider reducing MAX_WORKERS or increasing delays.")
193
+ time.sleep(API_RETRY_DELAY * 2)
194
+ elif retries < API_MAX_RETRIES:
195
+ time.sleep(API_RETRY_DELAY)
196
+ else:
197
+ print(f"Error (Row {row_idx}, Slot {slot_idx}, Model: {model_to_use}): Max retries reached. Giving up.")
198
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": f"[API ERROR: Max retries on {model_to_use}]", "saved_audio_path": None, "model_used": model_to_use}
199
+
200
+ return {"row_idx": row_idx, "slot_idx": slot_idx, "response_text": f"[UNEXPECTED ERROR on {model_to_use}]", "saved_audio_path": None, "model_used": model_to_use}
201
+
202
+
203
+ # --- Checkpoint Saving Function (Strictly using original_features) --- # <-- MODIFIED
204
+ def save_checkpoint(data_to_save, output_dir, dataset_features):
205
+ """Saves the current state of the data list as a Hugging Face Dataset,
206
+ strictly adhering to the provided dataset_features."""
207
+ if not data_to_save:
208
+ print("Checkpoint: No data available to save.")
209
+ return
210
+
211
+ print(f"\nCheckpoint: Saving {len(data_to_save)} rows to {output_dir}...")
212
+ try:
213
+ # --- REMOVED logic to add 'model_used' feature ---
214
+
215
+ # Create dataset using the original features passed to the function
216
+ # This will raise an error if data_to_save contains keys not in dataset_features
217
+ # or if data types are incompatible after processing.
218
+ # Ensure data_to_save only contains keys present in dataset_features.
219
+ # Filter data_to_save to only include keys present in the original features
220
+ feature_keys = set(dataset_features.keys())
221
+ filtered_data_to_save = []
222
+ for item in data_to_save:
223
+ filtered_item = {k: v for k, v in item.items() if k in feature_keys}
224
+ # Optional: Fill missing keys with None if required by schema, though from_list handles this.
225
+ # for key in feature_keys:
226
+ # if key not in filtered_item:
227
+ # filtered_item[key] = None
228
+ filtered_data_to_save.append(filtered_item)
229
+
230
+ checkpoint_dataset = Dataset.from_list(filtered_data_to_save, features=dataset_features)
231
+
232
+ os.makedirs(output_dir, exist_ok=True)
233
+ checkpoint_dataset.save_to_disk(output_dir)
234
+ print(f"Checkpoint: Saved successfully to {output_dir}")
235
+
236
+ except Exception as ckpt_save_e:
237
+ print(f"Error saving checkpoint dataset using datasets lib: {ckpt_save_e}")
238
+ print("Detailed error:", traceback.format_exc()) # Print full traceback for save errors
239
+ # Fallback to JSON Lines (does not strictly enforce schema)
240
+ output_jsonl_path = output_dir + "_checkpoint.jsonl"
241
+ print(f"Attempting to save checkpoint as JSON lines to {output_jsonl_path}...")
242
+ try:
243
+ # Save the original unfiltered data to JSONL for debugging if needed
244
+ with open(output_jsonl_path, 'w', encoding='utf-8') as f:
245
+ for item in data_to_save: # Use original data for JSON fallback
246
+ serializable_item = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in item.items()}
247
+ f.write(json.dumps(serializable_item, ensure_ascii=False) + '\n')
248
+ print(f"Checkpoint: Fallback save successful to {output_jsonl_path}")
249
+ except Exception as json_save_e:
250
+ print(f"Error saving checkpoint as JSON lines: {json_save_e}")
251
+
252
+
253
+ # --- Main Processing Logic ---
254
+
255
+ print("Checking for existing checkpoint/output dataset...")
256
+ dataset = None
257
+ original_features = None
258
+
259
+ try:
260
+ potential_checkpoint_info = os.path.join(OUTPUT_DATASET_DIR, "dataset_info.json")
261
+ potential_checkpoint_state = os.path.join(OUTPUT_DATASET_DIR, "state.json")
262
+
263
+ if os.path.exists(OUTPUT_DATASET_DIR) and \
264
+ (os.path.exists(potential_checkpoint_info) or os.path.exists(potential_checkpoint_state)):
265
+ print(f"Attempting to load existing data from output directory: {OUTPUT_DATASET_DIR}")
266
+ try:
267
+ dataset = load_from_disk(OUTPUT_DATASET_DIR)
268
+ original_features = dataset.features
269
+ print(f"Successfully resumed from {OUTPUT_DATASET_DIR}. Loaded {len(dataset)} rows.")
270
+ print(f"Resumed features: {original_features}") # Log the features
271
+ except Exception as load_ckpt_e:
272
+ print(f"Warning: Failed to load from {OUTPUT_DATASET_DIR}: {load_ckpt_e}")
273
+ dataset = None
274
+ else:
275
+ print(f"No valid existing data found in {OUTPUT_DATASET_DIR}.")
276
+
277
+ if dataset is None:
278
+ print(f"Loading original dataset from {INPUT_DATASET_DIR}...")
279
+ if not os.path.exists(INPUT_DATASET_DIR):
280
+ print(f"FATAL: Original input dataset directory not found at {INPUT_DATASET_DIR}")
281
+ exit(1)
282
+ dataset = load_from_disk(INPUT_DATASET_DIR)
283
+ original_features = dataset.features
284
+ print(f"Original dataset loaded successfully with {len(dataset)} rows.")
285
+ print(f"Original features: {original_features}") # Log the features
286
+
287
+ except Exception as initial_load_e:
288
+ print(f"FATAL: Error during initial dataset loading: {initial_load_e}")
289
+ traceback.print_exc()
290
+ exit(1)
291
+ breakpoint()
292
+ # Ensure original_features is loaded
293
+ if original_features is None:
294
+ print("FATAL: Failed to load dataset features. Exiting.")
295
+ exit(1)
296
+
297
+ os.makedirs(OUTPUT_AUDIO_ROOT_DIR, exist_ok=True)
298
+
299
+ # --- Pre-calculation Step (Assign Models Round-Robin) ---
300
+ print("Pre-calculating tasks and assigning models...")
301
+ tasks_to_process = []
302
+ updated_data = list(dataset) # Use mutable list of dicts
303
+ task_creation_counter = 0
304
+
305
+ for idx, row in enumerate(tqdm(updated_data, desc="Scanning dataset")):
306
+ needs_processing_in_row = False
307
+ qwen_tasks_in_row = []
308
+ for i in range(1, 4):
309
+ model_key = f"model_{i}"
310
+ response_text_key = f"response_text_{i}"
311
+ model_assigned = row.get(model_key)
312
+ response_text_exists = row.get(response_text_key) is not None
313
+ if model_assigned == "qwen_omni" and not response_text_exists:
314
+ needs_processing_in_row = True
315
+ qwen_tasks_in_row.append(i)
316
+
317
+ if needs_processing_in_row:
318
+ slot_to_process = qwen_tasks_in_row[0]
319
+ i = slot_to_process
320
+ prompt_text_key = f"prompt_text_{i}"
321
+ response_audio_key = f"response_audio_path_{i}" # Define key for clarity
322
+
323
+ question_audio_path = row.get('question_audio')
324
+ if not question_audio_path or not os.path.exists(question_audio_path):
325
+ print(f"Warning (Row {idx}, Slot {i}): Skipping task creation - Missing or non-existent 'question_audio': {question_audio_path}")
326
+ # Ensure error state is marked in updated_data if skipping task creation
327
+ response_text_key_for_error = f"response_text_{i}"
328
+ response_audio_key_for_error = f"response_audio_path_{i}"
329
+ if 0 <= idx < len(updated_data):
330
+ updated_data[idx][response_text_key_for_error] = "[SKIPPED: Missing input audio]"
331
+ updated_data[idx][response_audio_key_for_error] = None
332
+ continue
333
+
334
+ metadata_str = row.get('metadata', "{}")
335
+ source_dataset = row.get('source_dataset')
336
+ metadata = {}
337
+ try:
338
+ if metadata_str and isinstance(metadata_str, str): metadata = json.loads(metadata_str)
339
+ elif isinstance(metadata_str, dict): metadata = metadata_str
340
+ except (json.JSONDecodeError, TypeError): pass
341
+
342
+ history_messages = []
343
+ if source_dataset == 'ultra':
344
+ history_str = metadata.get('history', '')
345
+ if history_str: history_messages = parse_ultra_history(history_str)
346
+
347
+ model_to_use_for_this_task = QWEN_MODEL_LIST[task_creation_counter % NUM_MODELS]
348
+ task_creation_counter += 1
349
+
350
+ unique_id = str(uuid.uuid4()).replace("-", "")
351
+ output_audio_filename = f"qwen_r{idx}_s{i}_{unique_id}.{OUTPUT_AUDIO_FORMAT}"
352
+ output_audio_filepath = os.path.join(OUTPUT_AUDIO_ROOT_DIR, output_audio_filename)
353
+
354
+ task_info = {
355
+ "row_idx": idx,
356
+ "slot_idx": i,
357
+ "model_name": model_to_use_for_this_task,
358
+ "history_messages": history_messages,
359
+ "prompt_text": row.get(prompt_text_key, ""),
360
+ "question_audio_path": question_audio_path,
361
+ "output_audio_filepath": output_audio_filepath,
362
+ }
363
+ tasks_to_process.append(task_info)
364
+
365
+ total_tasks = len(tasks_to_process)
366
+ if total_tasks == 0:
367
+ print("No Qwen tasks found needing processing in the loaded dataset.")
368
+ exit(0)
369
+
370
+ print(f"Found {total_tasks} Qwen tasks to process.")
371
+ model_counts = {model: 0 for model in QWEN_MODEL_LIST}
372
+ for task in tasks_to_process: model_counts[task['model_name']] += 1
373
+ print("Task distribution per model:", model_counts)
374
+
375
+ # --- Threaded Execution with Checkpointing ---
376
+ print(f"Starting processing with up to {MAX_WORKERS} worker threads...")
377
+ start_total_time = time.time()
378
+ tasks_completed = 0
379
+ tasks_failed = 0
380
+ completed_since_last_save = 0
381
+
382
+ # --- REMOVED code block that updated original_features ---
383
+
384
+ with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
385
+ future_to_task = {executor.submit(call_qwen_omni_api_worker, task): task for task in tasks_to_process}
386
+
387
+ for future in tqdm(concurrent.futures.as_completed(future_to_task), total=total_tasks, desc="Processing tasks"):
388
+ task_info = future_to_task[future]
389
+ row_idx = task_info["row_idx"]
390
+ slot_idx = task_info["slot_idx"]
391
+ result = None
392
+
393
+ try:
394
+ result = future.result()
395
+ response_text_key = f"response_text_{slot_idx}"
396
+ response_audio_key = f"response_audio_path_{slot_idx}"
397
+ # --- REMOVED model_used_key and assignment ---
398
+
399
+ if 0 <= row_idx < len(updated_data):
400
+ updated_data[row_idx][response_text_key] = result["response_text"]
401
+ updated_data[row_idx][response_audio_key] = result["saved_audio_path"]
402
+ # --- REMOVED assignment to updated_data[row_idx][model_used_key] ---
403
+
404
+ if result["saved_audio_path"] is None or "ERROR" in result["response_text"]:
405
+ tasks_failed += 1
406
+ else:
407
+ print(f"Warning: Invalid row index {row_idx} encountered during result merge. Skipping update.")
408
+ tasks_failed += 1
409
+
410
+ tasks_completed += 1
411
+ completed_since_last_save += 1
412
+
413
+ if completed_since_last_save >= CHECKPOINT_INTERVAL:
414
+ # Pass the unmodified original_features
415
+ save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
416
+ completed_since_last_save = 0
417
+
418
+ except Exception as exc:
419
+ print(f"Error (Row {row_idx}, Slot {slot_idx}): Task generated an exception: {exc}")
420
+ traceback.print_exc()
421
+ response_text_key = f"response_text_{slot_idx}"
422
+ response_audio_key = f"response_audio_path_{slot_idx}"
423
+ # --- REMOVED model_used_key ---
424
+
425
+ if 0 <= row_idx < len(updated_data):
426
+ updated_data[row_idx][response_text_key] = f"[ERROR: Task Exception {type(exc).__name__}]"
427
+ updated_data[row_idx][response_audio_key] = None
428
+ # --- REMOVED assignment to updated_data[row_idx][model_used_key] ---
429
+
430
+ tasks_failed += 1
431
+ tasks_completed += 1
432
+ completed_since_last_save += 1
433
+
434
+ if completed_since_last_save >= CHECKPOINT_INTERVAL:
435
+ # Pass the unmodified original_features
436
+ save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
437
+ completed_since_last_save = 0
438
+
439
+
440
+ end_total_time = time.time()
441
+ print("\n--- Processing Complete ---")
442
+ print(f"Total tasks submitted: {total_tasks}")
443
+ print(f"Total tasks processed (returned): {tasks_completed} (Succeeded-ish: {tasks_completed - tasks_failed}, Failed: {tasks_failed})")
444
+ print(f"Total processing time: {(end_total_time - start_total_time)/60:.2f} minutes")
445
+
446
+ # --- Final Save ---
447
+ print("\nPerforming final save...")
448
+ # Pass the unmodified original_features
449
+ save_checkpoint(updated_data, OUTPUT_DATASET_DIR, original_features)
450
+
451
+ print("\nScript finished.")