Student0809 commited on
Commit
3438cdb
·
verified ·
1 Parent(s): b6a70f8

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .ipynb_checkpoints/COT_TRAIN-checkpoint.jsonl +0 -0
  2. .ipynb_checkpoints/GRPO_TRAIN-checkpoint.jsonl +0 -0
  3. .ipynb_checkpoints/test-checkpoint.sh +6 -0
  4. 4JOB/.ipynb_checkpoints/filter-checkpoint.py +132 -0
  5. 4JOB/.ipynb_checkpoints/process_silence-checkpoint.py +84 -0
  6. 4JOB/.ipynb_checkpoints/process_speaker-checkpoint.py +74 -0
  7. 4JOB/.ipynb_checkpoints/process_transcription-checkpoint.py +80 -0
  8. 4JOB/filter_logs/.ipynb_checkpoints/removed_entries_20250618_162013-checkpoint.log +72 -0
  9. 4JOB/filter_logs/removed_entries_20250618_162013.log +72 -0
  10. 4JOB/filter_logs/removed_entries_20250618_162341.log +92 -0
  11. 4JOB/overlap/.ipynb_checkpoints/mergeAll-checkpoint.py +44 -0
  12. 4JOB/overlap/mergeAll.py +44 -0
  13. 4JOB/overlap/trimmed_dialogues_pause_0_200_output.json +0 -0
  14. 4JOB/overlap/trimmed_dialogues_pause_600_800_output.json +0 -0
  15. 4JOB/overlap_filtered_output/trimmed_dialogues_pause_0_200_output.json +0 -0
  16. 4JOB/overlap_filtered_output/trimmed_dialogues_pause_200_400_output.json +0 -0
  17. 4JOB/overlap_filtered_output/trimmed_dialogues_pause_600_800_output.json +0 -0
  18. 4JOB/silence/mergeAll.py +44 -0
  19. 4JOB/silence/trimmed_dialogues_pause_100_200_output.json +0 -0
  20. 4JOB/silenceOringal.json +0 -0
  21. 4JOB/train/overlap_speaker.json +0 -0
  22. GRPO/Reward.py +87 -0
  23. cotSFT/gemini-text/.ipynb_checkpoints/texterror_results-checkpoint.json +0 -0
  24. cotSFT/gemini-text/texterror_results.json +0 -0
  25. cotSFT/train/.ipynb_checkpoints/correctresults_with_audio-checkpoint.json +0 -0
  26. cotSFT/train/correctresults_with_audio.json +0 -0
  27. cotSFT_new/.ipynb_checkpoints/correct_output_transcription-checkpoint.json +0 -0
  28. cotSFT_new/.ipynb_checkpoints/delay_output-checkpoint.json +0 -0
  29. cotSFT_new/.ipynb_checkpoints/gemini2.5_metainfo-checkpoint.py +317 -0
  30. cotSFT_new/.ipynb_checkpoints/overlaps1_output-checkpoint.json +0 -0
  31. cotSFT_new/.ipynb_checkpoints/process_transcription-checkpoint.py +80 -0
  32. cotSFT_new/cotSFT_10data/.ipynb_checkpoints/dataset_real_sft-checkpoint.jsonl +0 -0
  33. cotSFT_new/cotSFT_10data/dataset_real_sft.jsonl +0 -0
  34. cotSFT_new/cotSFT_10data/gemini2.5_metainfo.py +329 -0
  35. cotSFT_new/cotSFT_gemini.json +0 -0
  36. cotSFT_new/delay_output.json +0 -0
  37. cotSFT_new/filtered_output/.ipynb_checkpoints/delay_output-checkpoint.json +0 -0
  38. cotSFT_new/filtered_output/.ipynb_checkpoints/process_transcription-checkpoint.py +80 -0
  39. cotSFT_new/filtered_output/.ipynb_checkpoints/texterror_output_transcription_gemini-checkpoint.json +0 -0
  40. cotSFT_new/filtered_output/alltrain/.ipynb_checkpoints/correct_output_transcription_merged_output_990-checkpoint.json +0 -0
  41. cotSFT_new/filtered_output/alltrain/correct_output_transcription_merged_output_990.json +0 -0
  42. cotSFT_new/filtered_output/alltrain/overlaps1_gemini_merged_output.json +0 -0
  43. cotSFT_new/filtered_output/alltrain/texterror_output_transcription_merged_output.json +0 -0
  44. cotSFT_new/filtered_output/correc/.ipynb_checkpoints/correct_output_transcription_gemini_error-checkpoint.json +1 -0
  45. cotSFT_new/filtered_output/correc/correct_output_transcription.json +0 -0
  46. cotSFT_new/filtered_output/correc/correct_output_transcription_gemini_chunk2.json +0 -0
  47. cotSFT_new/filtered_output/correc/correct_output_transcription_gemini_chunk3.json +0 -0
  48. cotSFT_new/filtered_output/correc/correct_output_transcription_gemini_chunk4.json +0 -0
  49. cotSFT_new/filtered_output/correc/correct_output_transcription_gemini_chunk6.json +0 -0
  50. cotSFT_new/filtered_output/correc/correct_output_transcription_gemini_chunk7.json +0 -0
.ipynb_checkpoints/COT_TRAIN-checkpoint.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
.ipynb_checkpoints/GRPO_TRAIN-checkpoint.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
.ipynb_checkpoints/test-checkpoint.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0 \
2
+ swift infer \
3
+ --adapters /root/autodl-tmp/output_7B_SFT/v0-20250605-155458/checkpoint-1095 \
4
+ --stream true \
5
+ --temperature 0 \
6
+ --max_new_tokens 2048
4JOB/.ipynb_checkpoints/filter-checkpoint.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from datetime import datetime
4
+
5
+ def filter_by_duration(input_file, output_file, min_duration=30, max_duration=90):
6
+ """
7
+ 过滤JSON文件,只保留total_duration在[min_duration, max_duration]范围内的条目
8
+ 并记录被删除的文件信息到日志文件
9
+
10
+ :param input_file: 输入JSON文件路径
11
+ :param output_file: 输出JSON文件路径
12
+ :param min_duration: 最小持续时间(秒)
13
+ :param max_duration: 最大持续时间(秒)
14
+ """
15
+ # 创建日志目录
16
+ log_dir = os.path.join(os.path.dirname(output_file), "filter_logs")
17
+ if not os.path.exists(log_dir):
18
+ os.makedirs(log_dir)
19
+
20
+ # 创建日志文件(以当前时间命名)
21
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
22
+ log_file = os.path.join(log_dir, f"removed_entries_{timestamp}.log")
23
+
24
+ # 加载原始JSON文件
25
+ with open(input_file, 'r', encoding='utf-8') as f:
26
+ data = json.load(f)
27
+
28
+ # 初始化过滤结果和删除列表
29
+ filtered_data = {}
30
+ removed_entries = []
31
+
32
+ # 过滤数据并记录被删除的条目
33
+ for key, value in data.items():
34
+ if 'total_duration' in value and min_duration <= value['total_duration'] <= max_duration:
35
+ filtered_data[key] = value
36
+ else:
37
+ duration = value.get('total_duration', 'N/A')
38
+ removed_entries.append({
39
+ 'key': key,
40
+ 'duration': duration,
41
+ 'original_dialog_id': value.get('original_dialog_id', 'N/A'),
42
+ 'reason': 'too_short' if isinstance(duration, (int, float)) and duration < min_duration
43
+ else 'too_long' if isinstance(duration, (int, float)) and duration > max_duration
44
+ else 'missing_or_invalid'
45
+ })
46
+
47
+ # 保存过滤后的结果
48
+ with open(output_file, 'w', encoding='utf-8') as f:
49
+ json.dump(filtered_data, f, indent=2, ensure_ascii=False)
50
+
51
+ # 保存删除日志
52
+ with open(log_file, 'w', encoding='utf-8') as f:
53
+ f.write(f"Filtering log - {timestamp}\n")
54
+ f.write(f"Input file: {input_file}\n")
55
+ f.write(f"Output file: {output_file}\n")
56
+ f.write(f"Duration range: {min_duration}s to {max_duration}s\n\n")
57
+ f.write("Removed Entries:\n")
58
+ f.write("="*50 + "\n")
59
+ for entry in removed_entries:
60
+ f.write(f"Key: {entry['key']}\n")
61
+ f.write(f"Original Dialog ID: {entry['original_dialog_id']}\n")
62
+ f.write(f"Duration: {entry['duration']}s\n")
63
+ f.write(f"Reason: {entry['reason']}\n")
64
+ f.write("-"*50 + "\n")
65
+
66
+ print(f"\n处理结果: {os.path.basename(input_file)}")
67
+ print(f"原始条目数: {len(data)}")
68
+ print(f"过滤后条目数: {len(filtered_data)}")
69
+ print(f"已删除 {len(removed_entries)} 个不符合时长要求的条目")
70
+ print(f"过滤后的数据已保存到: {output_file}")
71
+ print(f"删除条目日志已保存到: {log_file}")
72
+
73
+ def process_directory(input_dir, output_dir, min_duration=30, max_duration=90):
74
+ """
75
+ 处理目录中的所有JSON文件
76
+ """
77
+ if not os.path.exists(output_dir):
78
+ os.makedirs(output_dir)
79
+
80
+ # 创建总日志文件
81
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
82
+ summary_log = os.path.join(output_dir, f"summary_removed_entries_{timestamp}.log")
83
+
84
+ total_removed = 0
85
+ total_processed = 0
86
+
87
+ with open(summary_log, 'w', encoding='utf-8') as summary_f:
88
+ summary_f.write(f"Summary Filtering Log - {timestamp}\n")
89
+ summary_f.write(f"Input directory: {input_dir}\n")
90
+ summary_f.write(f"Output directory: {output_dir}\n")
91
+ summary_f.write(f"Duration range: {min_duration}s to {max_duration}s\n\n")
92
+
93
+ for filename in os.listdir(input_dir):
94
+ if filename.endswith('.json'):
95
+ input_path = os.path.join(input_dir, filename)
96
+ output_path = os.path.join(output_dir, filename)
97
+
98
+ print(f"\n处理文件: {filename}")
99
+ filter_by_duration(input_path, output_path, min_duration, max_duration)
100
+
101
+ # 读取单个文件日志以获取统计信息
102
+ log_dir = os.path.join(output_dir, "filter_logs")
103
+ latest_log = max(
104
+ [f for f in os.listdir(log_dir) if f.startswith('removed_entries')],
105
+ key=lambda f: os.path.getmtime(os.path.join(log_dir, f)))
106
+
107
+ with open(os.path.join(log_dir, latest_log), 'r', encoding='utf-8') as log_f:
108
+ log_content = log_f.read()
109
+ removed_count = log_content.count("Key: ")
110
+
111
+ summary_f.write(f"\nFile: {filename}\n")
112
+ summary_f.write(f"Removed entries: {removed_count}\n")
113
+ summary_f.write("-"*40 + "\n")
114
+
115
+ total_removed += removed_count
116
+ total_processed += 1
117
+
118
+ summary_f.write(f"\nTotal files processed: {total_processed}\n")
119
+ summary_f.write(f"Total entries removed: {total_removed}\n")
120
+
121
+ print(f"\n处理完成!所有文件的总日志已保存到: {summary_log}")
122
+
123
+ if __name__ == "__main__":
124
+ # 使用示例 - 处理单个文件
125
+ input_json = "silence.json" # 替换为你的输入文件路径
126
+ output_json = "silence_filtered_output.json" # 输出文件路径
127
+ filter_by_duration(input_json, output_json)
128
+
129
+ # 使用示例 - 处理整个目录
130
+ # input_directory = "./input_4JOB_overlap" # 替换为你的输入目录
131
+ # output_directory = "./filtered_output" # 替换为你的输出目录
132
+ # process_directory(input_directory, output_directory)
4JOB/.ipynb_checkpoints/process_silence-checkpoint.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ def seconds_to_mmss(seconds):
6
+ minutes = int(seconds // 60)
7
+ seconds = int(seconds % 60)
8
+ return f"{minutes:02d}:{seconds:02d}"
9
+
10
+ # Templates for silence gap descriptions
11
+ SILENCE_TEMPLATES = [
12
+ "Silence gaps longer than 3 seconds occur at: {gaps}",
13
+ "The conversation contains significant pauses at: {gaps}",
14
+ "There are silent periods of more than 3 seconds at: {gaps}",
15
+ "The dialogue features extended pauses at: {gaps}",
16
+ "Silent intervals exceeding 3 seconds are found at: {gaps}",
17
+ "The conversation includes notable gaps at: {gaps}",
18
+ "Extended periods of silence occur at: {gaps}",
19
+ "The dialogue has significant breaks at: {gaps}",
20
+ "Silent segments longer than 3 seconds appear at: {gaps}",
21
+ "The conversation shows substantial pauses at: {gaps}"
22
+ ]
23
+
24
+ # Templates for no silence case
25
+ NO_SILENCE_TEMPLATES = [
26
+ "No silence gaps longer than 3 seconds were found in this conversation.",
27
+ "The conversation flows continuously without significant pauses.",
28
+ "No extended periods of silence were detected in this dialogue.",
29
+ "The conversation maintains a steady pace without notable gaps.",
30
+ "No silent intervals exceeding 3 seconds were identified.",
31
+ "The dialogue proceeds without substantial pauses.",
32
+ "No significant breaks in conversation were observed.",
33
+ "The conversation shows no extended silent periods.",
34
+ "No notable gaps in speech were detected.",
35
+ "The dialogue continues without significant silent intervals."
36
+ ]
37
+ file = "silence"
38
+ def process_silence_gaps():
39
+ # Read the overlap_5s_716.json file
40
+ with open(f'{file}.json', 'r', encoding='utf-8') as f:
41
+ silence_data = json.load(f)
42
+
43
+ # List to store results for all conversations
44
+ results = []
45
+
46
+ # Process each conversation
47
+ for conversation_id, conversation in silence_data.items():
48
+ segments = conversation.get('segments', [])
49
+ audio_path = conversation.get('stereo_audio', [])
50
+ silence_gaps = []
51
+
52
+ # Find silence gaps > 3s between segments
53
+ for i in range(len(segments) - 1):
54
+ current_end = segments[i]['end_time']
55
+ next_start = segments[i + 1]['start_time']
56
+ gap_duration = next_start - current_end
57
+
58
+ if gap_duration > 3:
59
+ silence_gaps.append(f"{seconds_to_mmss(current_end)}-{seconds_to_mmss(next_start)}")
60
+
61
+ # Create result entry with random template
62
+ if silence_gaps:
63
+ template = random.choice(SILENCE_TEMPLATES)
64
+ model_output = template.format(gaps=', '.join(silence_gaps))
65
+ else:
66
+ model_output = random.choice(NO_SILENCE_TEMPLATES)
67
+
68
+ result = {
69
+ "key": conversation_id,
70
+ "audio_url": audio_path,
71
+ "model_output": model_output
72
+ }
73
+ results.append(result)
74
+
75
+ # Save the results to a JSON file
76
+ output_file = f'{file}_silencegap.json'
77
+ with open(output_file, 'w', encoding='utf-8') as f:
78
+ json.dump(results, f, indent=2, ensure_ascii=False)
79
+
80
+ print(f"Processed {len(results)} conversations")
81
+ print(f"Results written to {output_file}")
82
+
83
+ if __name__ == "__main__":
84
+ process_silence_gaps()
4JOB/.ipynb_checkpoints/process_speaker-checkpoint.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+
4
+ def seconds_to_mmss(seconds):
5
+ minutes = int(seconds // 60)
6
+ seconds = int(seconds % 60)
7
+ return f"{minutes:02d}:{seconds:02d}"
8
+
9
+ # Templates for speaker segment descriptions
10
+ SPEAKER_TEMPLATES = [
11
+ "Speaker {speaker} speaks during the following periods: {times}",
12
+ "Speaker {speaker}'s speaking segments occur at: {times}",
13
+ "Speaker {speaker} is active in the conversation at: {times}",
14
+ "The following time segments belong to Speaker {speaker}: {times}",
15
+ "Speaker {speaker} participates in the dialogue at: {times}",
16
+ "Speaker {speaker} contributes to the conversation during: {times}",
17
+ "Speaking turns for Speaker {speaker} are at: {times}",
18
+ "Speaker {speaker} takes the floor at: {times}",
19
+ "The voice of Speaker {speaker} is heard at: {times}",
20
+ "Speaker {speaker} engages in the discussion during: {times}"
21
+ ]
22
+ file = "silence"
23
+ def process_speaker_segments():
24
+ # Read the overlap_5s_716.json file
25
+ with open(f'{file}.json', 'r', encoding='utf-8') as f:
26
+ data = json.load(f)
27
+
28
+ # List to store results for all conversations
29
+ results = []
30
+
31
+ # Process each conversation
32
+ for conversation_id, conversation in data.items():
33
+ segments = conversation.get('segments', [])
34
+ audio_path = conversation.get('stereo_audio', [])
35
+ # Dictionary to store speaking times for each speaker
36
+ speaker_times = {}
37
+
38
+ # Process each segment
39
+ for segment in segments:
40
+ speaker = segment['speaker']
41
+ start_time = segment['start_time'] # Keep as float for accurate conversion
42
+ end_time = segment['end_time'] # Keep as float for accurate conversion
43
+
44
+ # Initialize list for this speaker if not exists
45
+ if speaker not in speaker_times:
46
+ speaker_times[speaker] = []
47
+
48
+ # Add this speaking interval
49
+ speaker_times[speaker].append((start_time, end_time))
50
+
51
+ # Format the output string
52
+ output_lines = []
53
+ for speaker in sorted(speaker_times.keys()):
54
+ times = speaker_times[speaker]
55
+ time_ranges = [f"{seconds_to_mmss(start)}-{seconds_to_mmss(end)}" for start, end in times]
56
+ # Randomly select a template for each speaker
57
+ template = random.choice(SPEAKER_TEMPLATES)
58
+ output_lines.append(template.format(speaker=speaker, times=', '.join(time_ranges)))
59
+
60
+ # Create result entry
61
+ result = {
62
+ "key": conversation_id,
63
+ "audio_url": audio_path,
64
+ "model_output": "\n".join(output_lines)
65
+ }
66
+ results.append(result)
67
+
68
+ # Save the results to a JSON file
69
+ output_file = f'{file}_speaker.json'
70
+ with open(output_file, 'w', encoding='utf-8') as f:
71
+ json.dump(results, f, indent=2, ensure_ascii=False)
72
+
73
+ if __name__ == "__main__":
74
+ process_speaker_segments()
4JOB/.ipynb_checkpoints/process_transcription-checkpoint.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ def seconds_to_mmss(seconds):
4
+ minutes = int(seconds // 60)
5
+ seconds = int(seconds % 60)
6
+ return f"{minutes:02d}:{seconds:02d}"
7
+
8
+ filename = "silence"
9
+ def is_overlapping(current_segment, other_segments):
10
+ """Check if the current segment overlaps with any other segment."""
11
+ current_start = current_segment['start_time']
12
+ current_end = current_segment['end_time']
13
+
14
+ for segment in other_segments:
15
+ if segment == current_segment:
16
+ continue
17
+
18
+ other_start = segment['start_time']
19
+ other_end = segment['end_time']
20
+
21
+ # Check if there's an overlap
22
+ if (current_start < other_end and current_end > other_start):
23
+ return True
24
+
25
+ return False
26
+
27
+ def process_transcriptions():
28
+ # Read the overlap_5s_716.json file
29
+ with open(f'./{filename}.json', 'r', encoding='utf-8') as f:
30
+ data = json.load(f)
31
+
32
+ # List to store results for all conversations
33
+ results = []
34
+
35
+ # Process each conversation
36
+ for conversation_id, conversation in data.items():
37
+ segments = conversation.get('segments', [])
38
+ audio_path = conversation.get('stereo_audio', [])
39
+ # Sort segments by start time
40
+ segments.sort(key=lambda x: x['start_time'])
41
+
42
+ # Process each segment and create transcription lines
43
+ transcription_lines = []
44
+
45
+ for segment in segments:
46
+ speaker = segment['speaker']
47
+ start_time = segment['start_time']
48
+ end_time = segment['end_time']
49
+ text = segment['text']
50
+ original_text = segment['original_text']
51
+ original_text = original_text.replace("[interrupt] ", "").strip()
52
+ # Format timestamp
53
+ timestamp = f"[{seconds_to_mmss(start_time)} - {seconds_to_mmss(end_time)}]"
54
+
55
+ # Check if this segment overlaps with any other segment
56
+ has_overlap = is_overlapping(segment, segments)
57
+
58
+ # Format the line
59
+ if has_overlap:
60
+ line = f"{timestamp} Speaker {speaker}: {original_text}"
61
+ else:
62
+ line = f"{timestamp} Speaker {speaker}: {text}"
63
+
64
+ transcription_lines.append(line)
65
+
66
+ # Create result entry
67
+ result = {
68
+ "key": conversation_id,
69
+ "audio_url": audio_path,
70
+ "model_output": "\n".join(transcription_lines)
71
+ }
72
+ results.append(result)
73
+
74
+ # Save the results to a JSON file
75
+ output_file = f'./{filename}_transcription.json'
76
+ with open(output_file, 'w', encoding='utf-8') as f:
77
+ json.dump(results, f, indent=2, ensure_ascii=False)
78
+
79
+ if __name__ == "__main__":
80
+ process_transcriptions()
4JOB/filter_logs/.ipynb_checkpoints/removed_entries_20250618_162013-checkpoint.log ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Filtering log - 20250618_162013
2
+ Input file: overlap.json
3
+ Output file: overlap_filtered_output.json
4
+ Duration range: 30s to 90s
5
+
6
+ Removed Entries:
7
+ ==================================================
8
+ Key: 165
9
+ Original Dialog ID: DialogSum--train--713
10
+ Duration: 94.65241666666667s
11
+ Reason: too_long
12
+ --------------------------------------------------
13
+ Key: 131
14
+ Original Dialog ID: DialogSum--train--674
15
+ Duration: 91.20929166666667s
16
+ Reason: too_long
17
+ --------------------------------------------------
18
+ Key: 185_1
19
+ Original Dialog ID: DialogSum--train--988
20
+ Duration: 90.738625s
21
+ Reason: too_long
22
+ --------------------------------------------------
23
+ Key: 63_1
24
+ Original Dialog ID: DialogSum--train--837
25
+ Duration: 93.28445833333333s
26
+ Reason: too_long
27
+ --------------------------------------------------
28
+ Key: 74
29
+ Original Dialog ID: DialogSum--train--850
30
+ Duration: 90.21241666666667s
31
+ Reason: too_long
32
+ --------------------------------------------------
33
+ Key: 129_2
34
+ Original Dialog ID: DialogSum--train--153
35
+ Duration: 29.000666666666667s
36
+ Reason: too_short
37
+ --------------------------------------------------
38
+ Key: 174_1
39
+ Original Dialog ID: DialogSum--train--215
40
+ Duration: 92.936125s
41
+ Reason: too_long
42
+ --------------------------------------------------
43
+ Key: 119_2
44
+ Original Dialog ID: DialogSum--train--142
45
+ Duration: 91.69558333333333s
46
+ Reason: too_long
47
+ --------------------------------------------------
48
+ Key: 25_2
49
+ Original Dialog ID: DialogSum--train--29
50
+ Duration: 93.67675s
51
+ Reason: too_long
52
+ --------------------------------------------------
53
+ Key: 34_2
54
+ Original Dialog ID: DialogSum--train--40
55
+ Duration: 91.93370833333333s
56
+ Reason: too_long
57
+ --------------------------------------------------
58
+ Key: 21_3
59
+ Original Dialog ID: DialogSum--train--278
60
+ Duration: 29.887916666666666s
61
+ Reason: too_short
62
+ --------------------------------------------------
63
+ Key: 39_2
64
+ Original Dialog ID: DialogSum--train--300
65
+ Duration: 27.758333333333333s
66
+ Reason: too_short
67
+ --------------------------------------------------
68
+ Key: 146_3
69
+ Original Dialog ID: DialogSum--train--439
70
+ Duration: 97.25554166666667s
71
+ Reason: too_long
72
+ --------------------------------------------------
4JOB/filter_logs/removed_entries_20250618_162013.log ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Filtering log - 20250618_162013
2
+ Input file: overlap.json
3
+ Output file: overlap_filtered_output.json
4
+ Duration range: 30s to 90s
5
+
6
+ Removed Entries:
7
+ ==================================================
8
+ Key: 165
9
+ Original Dialog ID: DialogSum--train--713
10
+ Duration: 94.65241666666667s
11
+ Reason: too_long
12
+ --------------------------------------------------
13
+ Key: 131
14
+ Original Dialog ID: DialogSum--train--674
15
+ Duration: 91.20929166666667s
16
+ Reason: too_long
17
+ --------------------------------------------------
18
+ Key: 185_1
19
+ Original Dialog ID: DialogSum--train--988
20
+ Duration: 90.738625s
21
+ Reason: too_long
22
+ --------------------------------------------------
23
+ Key: 63_1
24
+ Original Dialog ID: DialogSum--train--837
25
+ Duration: 93.28445833333333s
26
+ Reason: too_long
27
+ --------------------------------------------------
28
+ Key: 74
29
+ Original Dialog ID: DialogSum--train--850
30
+ Duration: 90.21241666666667s
31
+ Reason: too_long
32
+ --------------------------------------------------
33
+ Key: 129_2
34
+ Original Dialog ID: DialogSum--train--153
35
+ Duration: 29.000666666666667s
36
+ Reason: too_short
37
+ --------------------------------------------------
38
+ Key: 174_1
39
+ Original Dialog ID: DialogSum--train--215
40
+ Duration: 92.936125s
41
+ Reason: too_long
42
+ --------------------------------------------------
43
+ Key: 119_2
44
+ Original Dialog ID: DialogSum--train--142
45
+ Duration: 91.69558333333333s
46
+ Reason: too_long
47
+ --------------------------------------------------
48
+ Key: 25_2
49
+ Original Dialog ID: DialogSum--train--29
50
+ Duration: 93.67675s
51
+ Reason: too_long
52
+ --------------------------------------------------
53
+ Key: 34_2
54
+ Original Dialog ID: DialogSum--train--40
55
+ Duration: 91.93370833333333s
56
+ Reason: too_long
57
+ --------------------------------------------------
58
+ Key: 21_3
59
+ Original Dialog ID: DialogSum--train--278
60
+ Duration: 29.887916666666666s
61
+ Reason: too_short
62
+ --------------------------------------------------
63
+ Key: 39_2
64
+ Original Dialog ID: DialogSum--train--300
65
+ Duration: 27.758333333333333s
66
+ Reason: too_short
67
+ --------------------------------------------------
68
+ Key: 146_3
69
+ Original Dialog ID: DialogSum--train--439
70
+ Duration: 97.25554166666667s
71
+ Reason: too_long
72
+ --------------------------------------------------
4JOB/filter_logs/removed_entries_20250618_162341.log ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Filtering log - 20250618_162341
2
+ Input file: silence.json
3
+ Output file: silence_filtered_output.json
4
+ Duration range: 30s to 90s
5
+
6
+ Removed Entries:
7
+ ==================================================
8
+ Key: 83
9
+ Original Dialog ID: SODA_PROCESSED--train--214477
10
+ Duration: 99.83525s
11
+ Reason: too_long
12
+ --------------------------------------------------
13
+ Key: 15
14
+ Original Dialog ID: SODA_PROCESSED--train--972977
15
+ Duration: 94.94791666666667s
16
+ Reason: too_long
17
+ --------------------------------------------------
18
+ Key: 18
19
+ Original Dialog ID: SODA_PROCESSED--train--795181
20
+ Duration: 92.62829166666667s
21
+ Reason: too_long
22
+ --------------------------------------------------
23
+ Key: 31_1
24
+ Original Dialog ID: SODA_PROCESSED--train--1113674
25
+ Duration: 95.01416666666667s
26
+ Reason: too_long
27
+ --------------------------------------------------
28
+ Key: 53_1
29
+ Original Dialog ID: SODA_PROCESSED--train--484021
30
+ Duration: 98.07645833333333s
31
+ Reason: too_long
32
+ --------------------------------------------------
33
+ Key: 74_1
34
+ Original Dialog ID: SODA_PROCESSED--train--1047480
35
+ Duration: 91.74375s
36
+ Reason: too_long
37
+ --------------------------------------------------
38
+ Key: 17_1
39
+ Original Dialog ID: SODA_PROCESSED--train--166191
40
+ Duration: 97.76666666666667s
41
+ Reason: too_long
42
+ --------------------------------------------------
43
+ Key: 46_2
44
+ Original Dialog ID: SODA_PROCESSED--train--727552
45
+ Duration: 91.58875s
46
+ Reason: too_long
47
+ --------------------------------------------------
48
+ Key: 84_2
49
+ Original Dialog ID: SODA_PROCESSED--train--286623
50
+ Duration: 94.22970833333333s
51
+ Reason: too_long
52
+ --------------------------------------------------
53
+ Key: 55_2
54
+ Original Dialog ID: SODA_PROCESSED--train--317784
55
+ Duration: 96.18079166666666s
56
+ Reason: too_long
57
+ --------------------------------------------------
58
+ Key: 35_2
59
+ Original Dialog ID: SODA_PROCESSED--train--1190867
60
+ Duration: 99.861s
61
+ Reason: too_long
62
+ --------------------------------------------------
63
+ Key: 99_2
64
+ Original Dialog ID: SODA_PROCESSED--train--304811
65
+ Duration: 91.12975s
66
+ Reason: too_long
67
+ --------------------------------------------------
68
+ Key: 44_3
69
+ Original Dialog ID: SODA_PROCESSED--train--1084179
70
+ Duration: 90.02725s
71
+ Reason: too_long
72
+ --------------------------------------------------
73
+ Key: 24_3
74
+ Original Dialog ID: SODA_PROCESSED--train--209436
75
+ Duration: 94.00129166666666s
76
+ Reason: too_long
77
+ --------------------------------------------------
78
+ Key: 10_3
79
+ Original Dialog ID: SODA_PROCESSED--train--606362
80
+ Duration: 95.01458333333333s
81
+ Reason: too_long
82
+ --------------------------------------------------
83
+ Key: 11_3
84
+ Original Dialog ID: SODA_PROCESSED--train--33760
85
+ Duration: 91.81675s
86
+ Reason: too_long
87
+ --------------------------------------------------
88
+ Key: 73_4
89
+ Original Dialog ID: SODA_PROCESSED--train--873625
90
+ Duration: 92.01975s
91
+ Reason: too_long
92
+ --------------------------------------------------
4JOB/overlap/.ipynb_checkpoints/mergeAll-checkpoint.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ def load_json(file_path):
5
+ with open(file_path, 'r', encoding='utf-8') as f:
6
+ return json.load(f)
7
+
8
+ def get_unique_key(base_key, existing_keys):
9
+ """在已有 key 中查找唯一 key,例如 key, key_1, key_2..."""
10
+ if base_key not in existing_keys:
11
+ return base_key
12
+ i = 1
13
+ while f"{base_key}_{i}" in existing_keys:
14
+ i += 1
15
+ return f"{base_key}_{i}"
16
+
17
+ def merge_all_jsons_in_folder(folder_path='.', output_path="merged_all_unique.json"):
18
+ merged_data = {}
19
+
20
+ for filename in os.listdir(folder_path):
21
+ if filename.endswith(".json"):
22
+ file_path = os.path.join(folder_path, filename)
23
+ try:
24
+ data = load_json(file_path)
25
+ if not isinstance(data, dict):
26
+ print(f"{filename} 不是一个字典,跳过。")
27
+ continue
28
+
29
+ for key, value in data.items():
30
+ unique_key = get_unique_key(key, merged_data)
31
+ if unique_key != key:
32
+ print(f"键 '{key}' 重复,已改为 '{unique_key}' 来合并。")
33
+ merged_data[unique_key] = value
34
+
35
+ except Exception as e:
36
+ print(f"加载 {filename} 出错:{e}")
37
+
38
+ with open(output_path, 'w', encoding='utf-8') as f:
39
+ json.dump(merged_data, f, ensure_ascii=False, indent=2)
40
+
41
+ print(f"\n合并完成,共 {len(merged_data)} 条记录写入 {output_path}")
42
+
43
+ if __name__ == "__main__":
44
+ merge_all_jsons_in_folder(folder_path='.', output_path="merged_all_unique.json")
4JOB/overlap/mergeAll.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ def load_json(file_path):
5
+ with open(file_path, 'r', encoding='utf-8') as f:
6
+ return json.load(f)
7
+
8
+ def get_unique_key(base_key, existing_keys):
9
+ """在已有 key 中查找唯一 key,例如 key, key_1, key_2..."""
10
+ if base_key not in existing_keys:
11
+ return base_key
12
+ i = 1
13
+ while f"{base_key}_{i}" in existing_keys:
14
+ i += 1
15
+ return f"{base_key}_{i}"
16
+
17
+ def merge_all_jsons_in_folder(folder_path='.', output_path="merged_all_unique.json"):
18
+ merged_data = {}
19
+
20
+ for filename in os.listdir(folder_path):
21
+ if filename.endswith(".json"):
22
+ file_path = os.path.join(folder_path, filename)
23
+ try:
24
+ data = load_json(file_path)
25
+ if not isinstance(data, dict):
26
+ print(f"{filename} 不是一个字典,跳过。")
27
+ continue
28
+
29
+ for key, value in data.items():
30
+ unique_key = get_unique_key(key, merged_data)
31
+ if unique_key != key:
32
+ print(f"键 '{key}' 重复,已改为 '{unique_key}' 来合并。")
33
+ merged_data[unique_key] = value
34
+
35
+ except Exception as e:
36
+ print(f"加载 {filename} 出错:{e}")
37
+
38
+ with open(output_path, 'w', encoding='utf-8') as f:
39
+ json.dump(merged_data, f, ensure_ascii=False, indent=2)
40
+
41
+ print(f"\n合并完成,共 {len(merged_data)} 条记录写入 {output_path}")
42
+
43
+ if __name__ == "__main__":
44
+ merge_all_jsons_in_folder(folder_path='.', output_path="merged_all_unique.json")
4JOB/overlap/trimmed_dialogues_pause_0_200_output.json ADDED
The diff for this file is too large to render. See raw diff
 
4JOB/overlap/trimmed_dialogues_pause_600_800_output.json ADDED
The diff for this file is too large to render. See raw diff
 
4JOB/overlap_filtered_output/trimmed_dialogues_pause_0_200_output.json ADDED
The diff for this file is too large to render. See raw diff
 
4JOB/overlap_filtered_output/trimmed_dialogues_pause_200_400_output.json ADDED
The diff for this file is too large to render. See raw diff
 
4JOB/overlap_filtered_output/trimmed_dialogues_pause_600_800_output.json ADDED
The diff for this file is too large to render. See raw diff
 
4JOB/silence/mergeAll.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ def load_json(file_path):
5
+ with open(file_path, 'r', encoding='utf-8') as f:
6
+ return json.load(f)
7
+
8
+ def get_unique_key(base_key, existing_keys):
9
+ """在已有 key 中查找唯一 key,例如 key, key_1, key_2..."""
10
+ if base_key not in existing_keys:
11
+ return base_key
12
+ i = 1
13
+ while f"{base_key}_{i}" in existing_keys:
14
+ i += 1
15
+ return f"{base_key}_{i}"
16
+
17
+ def merge_all_jsons_in_folder(folder_path='.', output_path="merged_all_unique.json"):
18
+ merged_data = {}
19
+
20
+ for filename in os.listdir(folder_path):
21
+ if filename.endswith(".json"):
22
+ file_path = os.path.join(folder_path, filename)
23
+ try:
24
+ data = load_json(file_path)
25
+ if not isinstance(data, dict):
26
+ print(f"{filename} 不是一个字典,跳过。")
27
+ continue
28
+
29
+ for key, value in data.items():
30
+ unique_key = get_unique_key(key, merged_data)
31
+ if unique_key != key:
32
+ print(f"键 '{key}' 重复,已改为 '{unique_key}' 来合并。")
33
+ merged_data[unique_key] = value
34
+
35
+ except Exception as e:
36
+ print(f"加载 {filename} 出错:{e}")
37
+
38
+ with open(output_path, 'w', encoding='utf-8') as f:
39
+ json.dump(merged_data, f, ensure_ascii=False, indent=2)
40
+
41
+ print(f"\n合并完成,共 {len(merged_data)} 条记录写入 {output_path}")
42
+
43
+ if __name__ == "__main__":
44
+ merge_all_jsons_in_folder(folder_path='.', output_path="merged_all_unique.json")
4JOB/silence/trimmed_dialogues_pause_100_200_output.json ADDED
The diff for this file is too large to render. See raw diff
 
4JOB/silenceOringal.json ADDED
The diff for this file is too large to render. See raw diff
 
4JOB/train/overlap_speaker.json ADDED
The diff for this file is too large to render. See raw diff
 
GRPO/Reward.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import math
4
+ import json
5
+ from datetime import datetime
6
+ from swift.plugin import ORM,orms
7
+ from typing import Dict, List, Union
8
+
9
+
10
+ class MultiModalAccuracyORM(ORM):
11
+ def __call__(self, completions, solution, **kwargs) -> List[float]:
12
+ """
13
+ Reward function that checks if the completion is correct.
14
+ Args:
15
+ completions (list[str]): Generated outputs
16
+ solution (list[str]): Ground Truths.
17
+
18
+ Returns:
19
+ list[float]: Reward scores
20
+ """
21
+ rewards = []
22
+ #completion_contents = [completion[0]["content"] for completion in completions]
23
+ for content, gt_score_orig in zip(completions, solution):
24
+ score_match = re.search(r"<overall score>(\d+)</overall score>", content)
25
+ #score_match = re.search(r"<score>(\d+)</score>", content)
26
+ pred_score = None
27
+ gt_score = None
28
+ # breakpoint()
29
+ # print(content)
30
+ # print(score_match)
31
+ if score_match:
32
+ try:
33
+ pred_score = int(score_match.group(1))
34
+ if not (1 <= pred_score <= 2):
35
+ pred_score = None
36
+ except:
37
+ pass
38
+
39
+ try:
40
+ gt_score = int(gt_score_orig[0])
41
+
42
+ if not (1 <= gt_score <= 2):
43
+ gt_score = None
44
+ except:
45
+ pass
46
+
47
+ # 分段奖励逻辑
48
+ if pred_score is not None and gt_score is not None:
49
+ if pred_score == gt_score:
50
+ reward = 5.0
51
+ elif abs(pred_score - gt_score) <= 1:
52
+ reward = 1.0
53
+ else:
54
+ reward = 0.0
55
+ else:
56
+ reward = 0.0
57
+
58
+ rewards.append(reward)
59
+ return rewards
60
+ class MultiModalFormatAccuracyORM(ORM):
61
+ def __call__(self, completions, **kwargs) -> List[float]:
62
+ """Reward function that checks if the completion has a specific format."""
63
+ rewards = []
64
+ response_pattern = r"<response think>.*?</response think>"
65
+ react_pattern = r"<fluency think>.*?</fluency think>"
66
+ score_pattern = r"[*\s]*<overall score>(\d+)</overall score>[\s*]*"
67
+ #completion_contents = [completion[0]["content"] for completion in completions]
68
+ for content in completions:
69
+ # breakpoint()
70
+ # print(content)
71
+ has_response = bool(re.search(response_pattern, content, re.DOTALL))
72
+ #print(has_response)
73
+ has_react = bool(re.search(react_pattern, content, re.DOTALL))
74
+ #print(has_react)
75
+ has_score = bool(re.search(score_pattern, content, re.DOTALL))
76
+ #print(has_score)
77
+ if has_response and has_react and has_score:
78
+ rewards.append(5.0)
79
+ # elif has_score and (has_response or has_react):
80
+ # rewards.append(3.0)
81
+ # elif has_response or has_react:
82
+ # rewards.append(1.0)
83
+ else:
84
+ rewards.append(0)
85
+ return rewards
86
+ orms['external_r1v_format_acc'] = MultiModalFormatAccuracyORM
87
+ orms['external_r1v_acc'] = MultiModalAccuracyORM
cotSFT/gemini-text/.ipynb_checkpoints/texterror_results-checkpoint.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT/gemini-text/texterror_results.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT/train/.ipynb_checkpoints/correctresults_with_audio-checkpoint.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT/train/correctresults_with_audio.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/.ipynb_checkpoints/correct_output_transcription-checkpoint.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/.ipynb_checkpoints/delay_output-checkpoint.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/.ipynb_checkpoints/gemini2.5_metainfo-checkpoint.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import re
4
+ import requests
5
+ from tqdm import tqdm
6
+ from datetime import datetime
7
+ import glob
8
+ from requests.exceptions import Timeout
9
+ import argparse
10
+
11
+ prompt_template = (
12
+ "# Interactional Dialogue Evaluation\n\n"
13
+ "**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\n"
14
+ "Evaluate the quality of the interaction in the given dialogue transcript, focusing on:\n"
15
+ "**Response Relevance:** \n"
16
+ "**logical consistency, topic coherence**\n"
17
+ "**Interactional Fluency:**\n"
18
+ "**Detect and evaluate extended overlaps in conversation.**\n"
19
+ "**Detect and evaluate long pauses between speaker turns.\n\n**"
20
+ "**Note**: Small pauses and brief overlaps in conversation are acceptable, while prolonged pauses and overlapping turns are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n"
21
+ "## Scoring Criteria\n"
22
+ "Assign a single holistic score based on the combined evaluation:\n"
23
+ "`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n"
24
+ "`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n"
25
+ "## Evaluation Output Format:\n"
26
+ "Strictly follow this template:\n"
27
+ "<response think>\n"
28
+ "[Analysing Response Relevance and giving reasons for scoring...]\n"
29
+ "</response think>\n"
30
+ "<fluency think>\n"
31
+ "[Analysing Interactional Fluency and giving reasons for scoring.]\n"
32
+ "</fluency think>\n"
33
+ "<overall score>X</overall score>\n"
34
+ )
35
+
36
+ # API configuration
37
+ url = "https://api2.aigcbest.top/v1/chat/completions"
38
+ headers = {
39
+ "Authorization": "Bearer sk-yAIqUaGzzVNSesHq4mRPaCbt53MMFRJIMB97cS4FkRy6idwN",
40
+ "Content-Type": "application/json",
41
+ "Accept": "application/json"
42
+ }
43
+
44
+ def parse_args():
45
+ parser = argparse.ArgumentParser(description='Process text evaluation with Gemini model')
46
+ parser.add_argument('--input_file', type=str, required=True,
47
+ help='Input JSON file containing text data')
48
+ parser.add_argument('--output_file', type=str, default='texterror_gemini.json',
49
+ help='Output JSON file for results')
50
+ parser.add_argument('--error_file', type=str, default='texterror_gemini_error.json',
51
+ help='Output JSON file for errors')
52
+ parser.add_argument('--checkpoint_dir', type=str, default='checkpoints_test_text',
53
+ help='Directory for storing checkpoints')
54
+ parser.add_argument('--max_retries', type=int, default=3,
55
+ help='Maximum number of retries for failed predictions')
56
+ parser.add_argument('--checkpoint_interval', type=int, default=20,
57
+ help='Number of items to process before saving checkpoint')
58
+ return parser.parse_args()
59
+
60
+ def extract_overall_score(output_str):
61
+ """Extract <overall score>X</overall score> from model output."""
62
+ score_pattern = r"<overall score>(\d+)</overall score>"
63
+ match = re.search(score_pattern, output_str)
64
+ if match:
65
+ try:
66
+ return int(match.group(1))
67
+ except ValueError:
68
+ pass
69
+ return None
70
+
71
+ def validate_model_output(output_str):
72
+ """Validate that the model output contains all required tags"""
73
+ required_tags = [
74
+ "<response think>",
75
+ "</response think>",
76
+ "<fluency think>",
77
+ "</fluency think>",
78
+ "<overall score>",
79
+ "</overall score>"
80
+ ]
81
+
82
+ for tag in required_tags:
83
+ if tag not in output_str:
84
+ return False
85
+ return True
86
+
87
+ def extract_tag_content(output_str, tag_name):
88
+ """Extract content between opening and closing tags"""
89
+ start_tag = f"<{tag_name}>"
90
+ end_tag = f"</{tag_name}>"
91
+ try:
92
+ start_idx = output_str.find(start_tag) + len(start_tag)
93
+ end_idx = output_str.find(end_tag)
94
+ if start_idx == -1 or end_idx == -1:
95
+ return None
96
+ return output_str[start_idx:end_idx].strip()
97
+ except:
98
+ return None
99
+
100
+ def format_model_output(output_str):
101
+ """Extract and format content from all required tags"""
102
+ response_content = extract_tag_content(output_str, "response think")
103
+ fluency_content = extract_tag_content(output_str, "fluency think")
104
+ score_content = extract_tag_content(output_str, "overall score")
105
+
106
+ if not all([response_content, fluency_content, score_content]):
107
+ return None
108
+
109
+ formatted_output = (
110
+ f"<response think>\n{response_content}\n</response think>\n\n"
111
+ f"<fluency think>\n{fluency_content}\n</fluency think>\n\n"
112
+ f"<overall score>{score_content}</overall score>"
113
+ )
114
+ return formatted_output
115
+
116
+ def make_api_call(text_input, retry_count=0, max_retries=5):
117
+ """Make API call with retry logic for API errors"""
118
+ try:
119
+ print(f"Attempting API call (attempt {retry_count + 1}/{max_retries + 1})")
120
+ data_req = {
121
+ "model": "gemini-2.5-flash-preview-05-20-thinking",
122
+ "messages": [
123
+ {
124
+ "role": "user",
125
+ "content": [
126
+ {
127
+ "type": "text",
128
+ "text": prompt_template
129
+ },
130
+ {
131
+ "type": "text",
132
+ "text": text_input
133
+ },
134
+ ]
135
+ }
136
+ ],
137
+ "temperature": 1,
138
+ }
139
+
140
+ response = requests.post(url, headers=headers, json=data_req, timeout=(200, 200))
141
+ print(f"API response received with status code: {response.status_code}")
142
+
143
+ if response.status_code == 200:
144
+ model_output = response.json()['choices'][0]['message']['content']
145
+ if not validate_model_output(model_output):
146
+ print("Model output missing required tags, retrying...")
147
+ return None, None
148
+
149
+ formatted_output = format_model_output(model_output)
150
+ if formatted_output is None:
151
+ print("Failed to extract content from tags, retrying...")
152
+ return None, None
153
+
154
+ pred_score = extract_overall_score(model_output)
155
+ return formatted_output, pred_score
156
+ else:
157
+ print(f"API returned error status {response.status_code}: {response.text}")
158
+ if retry_count >= max_retries:
159
+ raise Exception(f"POST error {response.status_code}: {response.text}")
160
+ return None, None
161
+ except requests.exceptions.ConnectTimeout:
162
+ print(f"Connection timeout (>10s)")
163
+ if retry_count >= max_retries:
164
+ raise Exception("Connection timeout")
165
+ return None, None
166
+ except requests.exceptions.ReadTimeout:
167
+ print(f"Read timeout (>30s)")
168
+ if retry_count >= max_retries:
169
+ raise Exception("Read timeout")
170
+ return None, None
171
+ except Exception as e:
172
+ print(f"Unexpected error during API call: {str(e)}")
173
+ if retry_count >= max_retries:
174
+ raise e
175
+ return None, None
176
+
177
+ def get_latest_checkpoint(checkpoint_dir):
178
+ """Get the latest checkpoint file and its processed count"""
179
+ checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint_*.json"))
180
+ if not checkpoint_files:
181
+ return None, 0
182
+
183
+ latest_checkpoint = None
184
+ max_count = 0
185
+ for checkpoint in checkpoint_files:
186
+ try:
187
+ count = int(os.path.basename(checkpoint).split('_')[1])
188
+ if count > max_count:
189
+ max_count = count
190
+ latest_checkpoint = checkpoint
191
+ except (ValueError, IndexError):
192
+ continue
193
+
194
+ return latest_checkpoint, max_count
195
+
196
+ def save_checkpoint(results, processed_count, checkpoint_dir):
197
+ """Save results to a checkpoint file"""
198
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
199
+ checkpoint_file = os.path.join(checkpoint_dir, f"checkpoint_{processed_count}_{timestamp}.json")
200
+ with open(checkpoint_file, "w", encoding="utf-8") as f:
201
+ json.dump(results, f, indent=2, ensure_ascii=False)
202
+ print(f"Checkpoint saved: {checkpoint_file}")
203
+
204
+ def main():
205
+ args = parse_args()
206
+
207
+ # Initialize results storage
208
+ results = []
209
+ save_file_name = args.output_file
210
+ error_file_name = args.error_file
211
+
212
+ # Create checkpoints directory
213
+ checkpoint_dir = args.checkpoint_dir
214
+ if not os.path.exists(checkpoint_dir):
215
+ os.makedirs(checkpoint_dir)
216
+
217
+ # Load test data
218
+ all_data_file = args.input_file
219
+ with open(all_data_file, 'r', encoding='utf-8') as f:
220
+ all_data = json.load(f)
221
+
222
+ # Initialize error tracking
223
+ error_results = []
224
+
225
+ # Load checkpoint if exists
226
+ latest_checkpoint, checkpoint_count = get_latest_checkpoint(checkpoint_dir)
227
+ if latest_checkpoint:
228
+ print(f"Found latest checkpoint with {checkpoint_count} processed items: {latest_checkpoint}")
229
+ try:
230
+ with open(latest_checkpoint, 'r', encoding='utf-8') as f:
231
+ results = json.load(f)
232
+ print(f"Resumed from checkpoint: processed {len(results)} items")
233
+ except Exception as e:
234
+ print(f"Warning: Failed to load checkpoint {latest_checkpoint}: {e}")
235
+ results = []
236
+ else:
237
+ print("No checkpoint found, starting from scratch")
238
+ results = []
239
+
240
+ max_prediction_retries = args.max_retries
241
+ total_count = 0
242
+
243
+ for item in tqdm(all_data, desc="Processing texts"):
244
+ key = item.get('key')
245
+ text_input = item.get('model_output')
246
+
247
+ if not text_input:
248
+ print(f"No text input found for key {key}, skipping...")
249
+ continue
250
+
251
+ print(f"Processing text for key={key}")
252
+
253
+ prediction_retry_count = 0
254
+ success = False
255
+
256
+ while prediction_retry_count < max_prediction_retries and not success:
257
+ try:
258
+ print(f"\nProcessing attempt {prediction_retry_count + 1}")
259
+ model_output, pred_score = make_api_call(text_input)
260
+
261
+ if model_output is None or pred_score is None:
262
+ print("API call failed, retrying...")
263
+ prediction_retry_count += 1
264
+ continue
265
+
266
+ print(f"Received prediction: {pred_score}")
267
+
268
+ if pred_score == 1:
269
+ success = True
270
+ print("Prediction score is 1, accepting result")
271
+ else:
272
+ prediction_retry_count += 1
273
+ print(f"Prediction score is not 1 (attempt {prediction_retry_count}/{max_prediction_retries})")
274
+ if prediction_retry_count >= max_prediction_retries:
275
+ print("Max retries reached, accepting last prediction")
276
+ success = True
277
+ else:
278
+ continue
279
+
280
+ results.append({
281
+ "key": key,
282
+ "text_input": text_input,
283
+ "model_output": model_output,
284
+ "predicted_score": pred_score,
285
+ "prediction_attempts": prediction_retry_count + 1
286
+ })
287
+
288
+ with open(save_file_name, "w", encoding="utf-8") as f:
289
+ json.dump(results, f, indent=2, ensure_ascii=False)
290
+
291
+ total_count += 1
292
+
293
+ if total_count % args.checkpoint_interval == 0:
294
+ save_checkpoint(results, total_count, checkpoint_dir)
295
+
296
+ except Exception as e:
297
+ error_msg = str(e)
298
+ print(f"Failed to process text for key {key}: {error_msg}")
299
+ error_results.append({
300
+ "key": key,
301
+ "text_input": text_input,
302
+ "error": f"Exception: {error_msg}"
303
+ })
304
+ break
305
+
306
+ with open(error_file_name, "w", encoding="utf-8") as f:
307
+ json.dump(error_results, f, indent=2, ensure_ascii=False)
308
+
309
+ # Save final results
310
+ with open(save_file_name, "w", encoding="utf-8") as f:
311
+ json.dump(results, f, indent=2, ensure_ascii=False)
312
+
313
+ print(f"Results saved to {save_file_name}")
314
+ print(f"Total processed items: {total_count}")
315
+
316
+ if __name__ == "__main__":
317
+ main()
cotSFT_new/.ipynb_checkpoints/overlaps1_output-checkpoint.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/.ipynb_checkpoints/process_transcription-checkpoint.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ def seconds_to_mmss(seconds):
4
+ minutes = int(seconds // 60)
5
+ seconds = int(seconds % 60)
6
+ return f"{minutes:02d}:{seconds:02d}"
7
+
8
+ filename = "correct_output"
9
+ def is_overlapping(current_segment, other_segments):
10
+ """Check if the current segment overlaps with any other segment."""
11
+ current_start = current_segment['start_time']
12
+ current_end = current_segment['end_time']
13
+
14
+ for segment in other_segments:
15
+ if segment == current_segment:
16
+ continue
17
+
18
+ other_start = segment['start_time']
19
+ other_end = segment['end_time']
20
+
21
+ # Check if there's an overlap
22
+ if (current_start < other_end and current_end > other_start):
23
+ return True
24
+
25
+ return False
26
+
27
+ def process_transcriptions():
28
+ # Read the overlap_5s_716.json file
29
+ with open(f'./{filename}.json', 'r', encoding='utf-8') as f:
30
+ data = json.load(f)
31
+
32
+ # List to store results for all conversations
33
+ results = []
34
+
35
+ # Process each conversation
36
+ for conversation_id, conversation in data.items():
37
+ segments = conversation.get('segments', [])
38
+ audio_path = conversation.get('stereo_audio', [])
39
+ # Sort segments by start time
40
+ segments.sort(key=lambda x: x['start_time'])
41
+
42
+ # Process each segment and create transcription lines
43
+ transcription_lines = []
44
+
45
+ for segment in segments:
46
+ speaker = segment['speaker']
47
+ start_time = segment['start_time']
48
+ end_time = segment['end_time']
49
+ text = segment['text']
50
+ original_text = segment['original_text']
51
+ original_text = original_text.replace("[interrupt] ", "").strip()
52
+ # Format timestamp
53
+ timestamp = f"[{seconds_to_mmss(start_time)} - {seconds_to_mmss(end_time)}]"
54
+
55
+ # Check if this segment overlaps with any other segment
56
+ has_overlap = is_overlapping(segment, segments)
57
+
58
+ # Format the line
59
+ if has_overlap:
60
+ line = f"{timestamp} Speaker {speaker}: {original_text}"
61
+ else:
62
+ line = f"{timestamp} Speaker {speaker}: {text}"
63
+
64
+ transcription_lines.append(line)
65
+
66
+ # Create result entry
67
+ result = {
68
+ "key": conversation_id,
69
+ "audio_url": audio_path,
70
+ "model_output": "\n".join(transcription_lines)
71
+ }
72
+ results.append(result)
73
+
74
+ # Save the results to a JSON file
75
+ output_file = f'./{filename}_transcription.json'
76
+ with open(output_file, 'w', encoding='utf-8') as f:
77
+ json.dump(results, f, indent=2, ensure_ascii=False)
78
+
79
+ if __name__ == "__main__":
80
+ process_transcriptions()
cotSFT_new/cotSFT_10data/.ipynb_checkpoints/dataset_real_sft-checkpoint.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/cotSFT_10data/dataset_real_sft.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/cotSFT_10data/gemini2.5_metainfo.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import re
4
+ import requests
5
+ from tqdm import tqdm
6
+ from datetime import datetime
7
+ import glob
8
+ from requests.exceptions import Timeout
9
+ import argparse
10
+ import multiprocessing
11
+
12
+ prompt_template = (
13
+ "# Interactional Dialogue Evaluation\n\n"
14
+ "**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\n"
15
+ "Evaluate the quality of the interaction in the given dialogue transcript, focusing on:\n"
16
+ "**Response Relevance:** \n"
17
+ "**logical consistency, topic coherence**\n"
18
+ "**Interactional Fluency:**\n"
19
+ "**Detect and evaluate extended overlaps in conversation.**\n"
20
+ "**Detect and evaluate long pauses between speaker turns.\n\n**"
21
+ "**Note**: Small pauses and brief overlaps in conversation are acceptable, while prolonged pauses and overlapping turns are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n"
22
+ "## Scoring Criteria\n"
23
+ "Assign a single holistic score based on the combined evaluation:\n"
24
+ "`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n"
25
+ "`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n"
26
+ "## Evaluation Output Format:\n"
27
+ "Strictly follow this template:\n"
28
+ "<response think>\n"
29
+ "[Analysing Response Relevance and giving reasons for scoring...]\n"
30
+ "</response think>\n"
31
+ "<fluency think>\n"
32
+ "[Analysing Interactional Fluency and giving reasons for scoring.]\n"
33
+ "</fluency think>\n"
34
+ "<overall score>X</overall score>\n"
35
+ )
36
+
37
+ # API configuration
38
+ url = "https://api2.aigcbest.top/v1/chat/completions"
39
+ headers = {
40
+ "Authorization": "Bearer sk-yAIqUaGzzVNSesHq4mRPaCbt53MMFRJIMB97cS4FkRy6idwN",
41
+ "Content-Type": "application/json",
42
+ "Accept": "application/json"
43
+ }
44
+
45
+ def parse_args():
46
+ parser = argparse.ArgumentParser(description='Process text evaluation with Gemini model')
47
+ parser.add_argument('--input_file', type=str, default='all_dialogues_processed.json',
48
+ help='Input JSON file containing text data')
49
+ parser.add_argument('--output_file', type=str, default='cotSFT_gemini.json',
50
+ help='Output JSON file for results')
51
+ parser.add_argument('--error_file', type=str, default='cotSFT_gemini_error.json',
52
+ help='Output JSON file for errors')
53
+ parser.add_argument('--checkpoint_dir', type=str, default='checkpoints_test_text',
54
+ help='Directory for storing checkpoints')
55
+ parser.add_argument('--max_retries', type=int, default=6,
56
+ help='Maximum number of retries for failed predictions')
57
+ parser.add_argument('--checkpoint_interval', type=int, default=100,
58
+ help='Number of items to process before saving checkpoint')
59
+ parser.add_argument('--num_processes', type=int, default=5,
60
+ help='Number of parallel processes to use')
61
+ return parser.parse_args()
62
+
63
+ def extract_overall_score(output_str):
64
+ """Extract <overall score>X</overall score> from model output."""
65
+ score_pattern = r"<overall score>(\d+)</overall score>"
66
+ match = re.search(score_pattern, output_str)
67
+ if match:
68
+ try:
69
+ return int(match.group(1))
70
+ except ValueError:
71
+ pass
72
+ return None
73
+
74
+ def validate_model_output(output_str):
75
+ """Validate that the model output contains all required tags"""
76
+ required_tags = [
77
+ "<response think>",
78
+ "</response think>",
79
+ "<fluency think>",
80
+ "</fluency think>",
81
+ "<overall score>",
82
+ "</overall score>"
83
+ ]
84
+
85
+ for tag in required_tags:
86
+ if tag not in output_str:
87
+ return False
88
+ return True
89
+
90
+ def extract_tag_content(output_str, tag_name):
91
+ """Extract content between opening and closing tags"""
92
+ start_tag = f"<{tag_name}>"
93
+ end_tag = f"</{tag_name}>"
94
+ try:
95
+ start_idx = output_str.find(start_tag) + len(start_tag)
96
+ end_idx = output_str.find(end_tag)
97
+ if start_idx == -1 or end_idx == -1:
98
+ return None
99
+ return output_str[start_idx:end_idx].strip()
100
+ except:
101
+ return None
102
+
103
+ def format_model_output(output_str):
104
+ """Extract and format content from all required tags"""
105
+ response_content = extract_tag_content(output_str, "response think")
106
+ fluency_content = extract_tag_content(output_str, "fluency think")
107
+ score_content = extract_tag_content(output_str, "overall score")
108
+
109
+ if not all([response_content, fluency_content, score_content]):
110
+ return None
111
+
112
+ formatted_output = (
113
+ f"<response think>\n{response_content}\n</response think>\n\n"
114
+ f"<fluency think>\n{fluency_content}\n</fluency think>\n\n"
115
+ f"<overall score>{score_content}</overall score>"
116
+ )
117
+ return formatted_output
118
+
119
+ def make_api_call(text_input, retry_count=0, max_retries=5):
120
+ """Make API call with retry logic for API errors"""
121
+ try:
122
+ print(f"Attempting API call (attempt {retry_count + 1}/{max_retries + 1})")
123
+ data_req = {
124
+ "model": "gemini-2.5-pro-preview-06-05-thinking",
125
+ "messages": [
126
+ {
127
+ "role": "user",
128
+ "content": [
129
+ {
130
+ "type": "text",
131
+ "text": prompt_template
132
+ },
133
+ {
134
+ "type": "text",
135
+ "text": f"The correct overall score is: 2\n"
136
+ },
137
+ {
138
+ "type": "text",
139
+ "text": text_input
140
+ },
141
+ ]
142
+ }
143
+ ],
144
+ "temperature": 1,
145
+
146
+ }
147
+
148
+ response = requests.post(url, headers=headers, json=data_req, timeout=(200, 200))
149
+ print(f"API response received with status code: {response.status_code}")
150
+
151
+ if response.status_code == 200:
152
+ model_output = response.json()['choices'][0]['message']['content']
153
+ if not validate_model_output(model_output):
154
+ print("Model output missing required tags, retrying...")
155
+ return None, None
156
+
157
+ formatted_output = format_model_output(model_output)
158
+ if formatted_output is None:
159
+ print("Failed to extract content from tags, retrying...")
160
+ return None, None
161
+
162
+ pred_score = extract_overall_score(model_output)
163
+ return formatted_output, pred_score
164
+ else:
165
+ print(f"API returned error status {response.status_code}: {response.text}")
166
+ if retry_count >= max_retries:
167
+ raise Exception(f"POST error {response.status_code}: {response.text}")
168
+ return None, None
169
+ except requests.exceptions.ConnectTimeout:
170
+ print(f"Connection timeout (>10s)")
171
+ if retry_count >= max_retries:
172
+ raise Exception("Connection timeout")
173
+ return None, None
174
+ except requests.exceptions.ReadTimeout:
175
+ print(f"Read timeout (>30s)")
176
+ if retry_count >= max_retries:
177
+ raise Exception("Read timeout")
178
+ return None, None
179
+ except Exception as e:
180
+ print(f"Unexpected error during API call: {str(e)}")
181
+ if retry_count >= max_retries:
182
+ raise e
183
+ return None, None
184
+
185
+ def get_latest_checkpoint(checkpoint_dir):
186
+ """Get the latest checkpoint file and its processed count"""
187
+ checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint_*.json"))
188
+ if not checkpoint_files:
189
+ return None, 0
190
+
191
+ latest_checkpoint = None
192
+ max_count = 0
193
+ for checkpoint in checkpoint_files:
194
+ try:
195
+ count = int(os.path.basename(checkpoint).split('_')[1])
196
+ if count > max_count:
197
+ max_count = count
198
+ latest_checkpoint = checkpoint
199
+ except (ValueError, IndexError):
200
+ continue
201
+
202
+ return latest_checkpoint, max_count
203
+
204
+ def save_checkpoint(results, processed_count, checkpoint_dir):
205
+ """Save results to a checkpoint file"""
206
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
207
+ checkpoint_file = os.path.join(checkpoint_dir, f"checkpoint_{processed_count}_{timestamp}.json")
208
+ with open(checkpoint_file, "w", encoding="utf-8") as f:
209
+ json.dump(results, f, indent=2, ensure_ascii=False)
210
+ print(f"Checkpoint saved: {checkpoint_file}")
211
+
212
+ def split_data(data, num_chunks):
213
+ # Split data into num_chunks as evenly as possible
214
+ chunk_size = len(data) // num_chunks
215
+ remainder = len(data) % num_chunks
216
+ chunks = []
217
+ start = 0
218
+ for i in range(num_chunks):
219
+ end = start + chunk_size + (1 if i < remainder else 0)
220
+ chunks.append(data[start:end])
221
+ start = end
222
+ return chunks
223
+
224
+ def process_chunk(args_tuple):
225
+ chunk_data, chunk_idx, args = args_tuple
226
+ results = []
227
+ error_results = []
228
+ save_file_name = f"{os.path.splitext(args.output_file)[0]}_chunk{chunk_idx}.json"
229
+ error_file_name = f"{os.path.splitext(args.error_file)[0]}_chunk{chunk_idx}.json"
230
+ checkpoint_dir = f"{args.checkpoint_dir}_chunk{chunk_idx}"
231
+ if not os.path.exists(checkpoint_dir):
232
+ os.makedirs(checkpoint_dir)
233
+ max_prediction_retries = args.max_retries
234
+ total_count = 0
235
+ for item in tqdm(chunk_data, desc=f"Processing chunk {chunk_idx}"):
236
+ key = item.get('key')
237
+ text_input = item.get('process_dialogue') # 使用process_dialogue字段
238
+ if not text_input:
239
+ print(f"No text input found for key {key}, skipping...")
240
+ continue
241
+ prediction_retry_count = 0
242
+ success = False
243
+ while prediction_retry_count < max_prediction_retries and not success:
244
+ try:
245
+ model_output, pred_score = make_api_call(text_input)
246
+ if model_output is None or pred_score is None:
247
+ prediction_retry_count += 1
248
+ print(f"API call failed for key {key}, retry {prediction_retry_count}/{max_prediction_retries}")
249
+ continue
250
+
251
+ # 只有当预测分数为2时才保存结果
252
+ if pred_score == 2:
253
+ success = True
254
+ results.append({
255
+ "key": key,
256
+ "text_input": text_input,
257
+ "model_output": model_output,
258
+ "predicted_score": pred_score,
259
+ "prediction_attempts": prediction_retry_count + 1
260
+ })
261
+ print(f"Success! Predicted score 2 for key {key} after {prediction_retry_count + 1} attempts")
262
+ else:
263
+ prediction_retry_count += 1
264
+ print(f"Predicted score {pred_score} for key {key}, retry {prediction_retry_count}/{max_prediction_retries}")
265
+ if prediction_retry_count >= max_prediction_retries:
266
+ print(f"Max retries reached for key {key}, saving with score {pred_score}")
267
+ results.append({
268
+ "key": key,
269
+ "text_input": text_input,
270
+ "model_output": model_output,
271
+ "predicted_score": pred_score,
272
+ "prediction_attempts": prediction_retry_count
273
+ })
274
+ success = True
275
+ continue
276
+
277
+ # 保存当前结果
278
+ with open(save_file_name, "w", encoding="utf-8") as f:
279
+ json.dump(results, f, indent=2, ensure_ascii=False)
280
+ total_count += 1
281
+ if total_count % args.checkpoint_interval == 0:
282
+ save_checkpoint(results, total_count, checkpoint_dir)
283
+ except Exception as e:
284
+ error_msg = str(e)
285
+ print(f"Exception for key {key}: {error_msg}")
286
+ error_results.append({
287
+ "key": key,
288
+ "text_input": text_input,
289
+ "error": f"Exception: {error_msg}"
290
+ })
291
+ break
292
+ # 保存错误结果
293
+ with open(error_file_name, "w", encoding="utf-8") as f:
294
+ json.dump(error_results, f, indent=2, ensure_ascii=False)
295
+ # 最终保存结果
296
+ with open(save_file_name, "w", encoding="utf-8") as f:
297
+ json.dump(results, f, indent=2, ensure_ascii=False)
298
+ return save_file_name, error_file_name
299
+
300
+ def merge_json_files(file_list, output_file):
301
+ merged = []
302
+ for fname in file_list:
303
+ if os.path.exists(fname):
304
+ with open(fname, 'r', encoding='utf-8') as f:
305
+ merged.extend(json.load(f))
306
+ with open(output_file, 'w', encoding='utf-8') as f:
307
+ json.dump(merged, f, indent=2, ensure_ascii=False)
308
+
309
+ def main():
310
+ args = parse_args()
311
+ with open(args.input_file, 'r', encoding='utf-8') as f:
312
+ all_data = json.load(f)
313
+ num_chunks = args.num_processes
314
+ chunks = split_data(all_data, num_chunks)
315
+ pool = multiprocessing.Pool(num_chunks)
316
+ chunk_args = [(chunks[i], i, args) for i in range(num_chunks)]
317
+ results = pool.map(process_chunk, chunk_args)
318
+ pool.close()
319
+ pool.join()
320
+ # 合并所有chunk输出文件
321
+ output_files = [r[0] for r in results]
322
+ error_files = [r[1] for r in results]
323
+ merge_json_files(output_files, args.output_file)
324
+ merge_json_files(error_files, args.error_file)
325
+ print(f"Results saved to {args.output_file}")
326
+ print(f"Errors saved to {args.error_file}")
327
+
328
+ if __name__ == "__main__":
329
+ main()
cotSFT_new/cotSFT_gemini.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/delay_output.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/filtered_output/.ipynb_checkpoints/delay_output-checkpoint.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/filtered_output/.ipynb_checkpoints/process_transcription-checkpoint.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ def seconds_to_mmss(seconds):
4
+ minutes = int(seconds // 60)
5
+ seconds = int(seconds % 60)
6
+ return f"{minutes:02d}:{seconds:02d}"
7
+
8
+ filename = "texterror_output"
9
+ def is_overlapping(current_segment, other_segments):
10
+ """Check if the current segment overlaps with any other segment."""
11
+ current_start = current_segment['start_time']
12
+ current_end = current_segment['end_time']
13
+
14
+ for segment in other_segments:
15
+ if segment == current_segment:
16
+ continue
17
+
18
+ other_start = segment['start_time']
19
+ other_end = segment['end_time']
20
+
21
+ # Check if there's an overlap
22
+ if (current_start < other_end and current_end > other_start):
23
+ return True
24
+
25
+ return False
26
+
27
+ def process_transcriptions():
28
+ # Read the overlap_5s_716.json file
29
+ with open(f'./{filename}.json', 'r', encoding='utf-8') as f:
30
+ data = json.load(f)
31
+
32
+ # List to store results for all conversations
33
+ results = []
34
+
35
+ # Process each conversation
36
+ for conversation_id, conversation in data.items():
37
+ segments = conversation.get('segments', [])
38
+ audio_path = conversation.get('stereo_audio', [])
39
+ # Sort segments by start time
40
+ segments.sort(key=lambda x: x['start_time'])
41
+
42
+ # Process each segment and create transcription lines
43
+ transcription_lines = []
44
+
45
+ for segment in segments:
46
+ speaker = segment['speaker']
47
+ start_time = segment['start_time']
48
+ end_time = segment['end_time']
49
+ text = segment['text']
50
+ original_text = segment['original_text']
51
+ original_text = original_text.replace("[interrupt] ", "").strip()
52
+ # Format timestamp
53
+ timestamp = f"[{seconds_to_mmss(start_time)} - {seconds_to_mmss(end_time)}]"
54
+
55
+ # Check if this segment overlaps with any other segment
56
+ has_overlap = is_overlapping(segment, segments)
57
+
58
+ # Format the line
59
+ if has_overlap:
60
+ line = f"{timestamp} Speaker {speaker}: {original_text}"
61
+ else:
62
+ line = f"{timestamp} Speaker {speaker}: {text}"
63
+
64
+ transcription_lines.append(line)
65
+
66
+ # Create result entry
67
+ result = {
68
+ "key": conversation_id,
69
+ "audio_url": audio_path,
70
+ "model_output": "\n".join(transcription_lines)
71
+ }
72
+ results.append(result)
73
+
74
+ # Save the results to a JSON file
75
+ output_file = f'./{filename}_transcription.json'
76
+ with open(output_file, 'w', encoding='utf-8') as f:
77
+ json.dump(results, f, indent=2, ensure_ascii=False)
78
+
79
+ if __name__ == "__main__":
80
+ process_transcriptions()
cotSFT_new/filtered_output/.ipynb_checkpoints/texterror_output_transcription_gemini-checkpoint.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/filtered_output/alltrain/.ipynb_checkpoints/correct_output_transcription_merged_output_990-checkpoint.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/filtered_output/alltrain/correct_output_transcription_merged_output_990.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/filtered_output/alltrain/overlaps1_gemini_merged_output.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/filtered_output/alltrain/texterror_output_transcription_merged_output.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/filtered_output/correc/.ipynb_checkpoints/correct_output_transcription_gemini_error-checkpoint.json ADDED
@@ -0,0 +1 @@
 
 
1
+ []
cotSFT_new/filtered_output/correc/correct_output_transcription.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/filtered_output/correc/correct_output_transcription_gemini_chunk2.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/filtered_output/correc/correct_output_transcription_gemini_chunk3.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/filtered_output/correc/correct_output_transcription_gemini_chunk4.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/filtered_output/correc/correct_output_transcription_gemini_chunk6.json ADDED
The diff for this file is too large to render. See raw diff
 
cotSFT_new/filtered_output/correc/correct_output_transcription_gemini_chunk7.json ADDED
The diff for this file is too large to render. See raw diff