Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- 4JOB/.ipynb_checkpoints/overlapOriginal-checkpoint.json +0 -0
- 4JOB/.ipynb_checkpoints/silence-checkpoint.json +0 -0
- 4JOB/overlap.json +0 -0
- 4JOB/overlap/overlap.json +0 -0
- 4JOB/overlap/trimmed_dialogues_pause_200_400_output.json +0 -0
- 4JOB/overlap/trimmed_dialogues_pause_400_600_output.json +0 -0
- 4JOB/overlapOriginal.json +0 -0
- 4JOB/overlap_filtered_output/overlap.json +0 -0
- 4JOB/overlap_filtered_output/trimmed_dialogues_pause_400_600_output.json +0 -0
- 4JOB/silence.json +0 -0
- 4JOB/silence/.ipynb_checkpoints/silence-checkpoint.json +0 -0
- 4JOB/silence/silence.json +0 -0
- 4JOB/silence/trimmed_dialogues_pause_0_100_output.json +0 -0
- 4JOB/silence/trimmed_dialogues_pause_200_300_output.json +0 -0
- 4JOB/silence/trimmed_dialogues_pause_300_400_output.json +0 -0
- 4JOB/silence/trimmed_dialogues_pause_400_500_output.json +0 -0
- 4JOB/train/.ipynb_checkpoints/overlap_overlapgap-checkpoint.json +0 -0
- 4JOB/train/.ipynb_checkpoints/silence_silencegap-checkpoint.json +0 -0
- 4JOB/train/overlap_overlapgap.json +0 -0
- 4JOB/train/overlap_silencegap.json +0 -0
- 4JOB/train/overlap_transcription.json +0 -0
- 4JOB/train/silence_overlapgap.json +0 -0
- 4JOB/train/silence_silencegap.json +0 -0
- 4JOB/train/silence_speaker.json +0 -0
- 4JOB/train/silence_transcription.json +0 -0
- GRPO/.ipynb_checkpoints/Reward-checkpoint.py +87 -0
- GRPO/.ipynb_checkpoints/formatReward-checkpoint.py +26 -0
- GRPO/__pycache__/Reward.cpython-310.pyc +0 -0
- GRPO/__pycache__/Reward.cpython-312.pyc +0 -0
- GRPO/formatReward.py +26 -0
- asset/discord_qr.jpg +0 -0
- asset/wechat.png +0 -0
- cotSFT/.ipynb_checkpoints/add_streoaudio-checkpoint.py +34 -0
- cotSFT/.ipynb_checkpoints/onlyAudios_longdelay_add_silence-checkpoint.json +0 -0
- cotSFT/.ipynb_checkpoints/overlaps-checkpoint.json +0 -0
- cotSFT/Results/.ipynb_checkpoints/isoverlapresults-checkpoint.json +0 -0
- cotSFT/Results/.ipynb_checkpoints/issilenceresults-checkpoint.json +0 -0
- cotSFT/Results/.ipynb_checkpoints/texterror_results-checkpoint.json +0 -0
- cotSFT/Results/correctresults.json +0 -0
- cotSFT/Results/isoverlapresults.json +0 -0
- cotSFT/Results/issilenceresults.json +0 -0
- cotSFT/Results/texterror_results.json +0 -0
- cotSFT/add_streoaudio.py +34 -0
- cotSFT/gemini-correct/.ipynb_checkpoints/correctresults-checkpoint.json +0 -0
- cotSFT/gemini-correct/.ipynb_checkpoints/gemini2.5_metainfo-checkpoint.py +317 -0
- cotSFT/gemini-correct/.ipynb_checkpoints/run_gemini_meta-checkpoint.sh +1 -0
- cotSFT/gemini-correct/.ipynb_checkpoints/thinkSFT_correct_transcriptions-checkpoint.json +0 -0
- cotSFT/gemini-correct/checkpoints_test_text/checkpoint_100_20250612_182923.json +0 -0
- cotSFT/gemini-correct/checkpoints_test_text/checkpoint_1100_20250613_023551.json +0 -0
- cotSFT/gemini-correct/checkpoints_test_text/checkpoint_200_20250612_191111.json +0 -0
4JOB/.ipynb_checkpoints/overlapOriginal-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/.ipynb_checkpoints/silence-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/overlap.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/overlap/overlap.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/overlap/trimmed_dialogues_pause_200_400_output.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/overlap/trimmed_dialogues_pause_400_600_output.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/overlapOriginal.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/overlap_filtered_output/overlap.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/overlap_filtered_output/trimmed_dialogues_pause_400_600_output.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/silence.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/silence/.ipynb_checkpoints/silence-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/silence/silence.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/silence/trimmed_dialogues_pause_0_100_output.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/silence/trimmed_dialogues_pause_200_300_output.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/silence/trimmed_dialogues_pause_300_400_output.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/silence/trimmed_dialogues_pause_400_500_output.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/train/.ipynb_checkpoints/overlap_overlapgap-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/train/.ipynb_checkpoints/silence_silencegap-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/train/overlap_overlapgap.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/train/overlap_silencegap.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/train/overlap_transcription.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/train/silence_overlapgap.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/train/silence_silencegap.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/train/silence_speaker.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
4JOB/train/silence_transcription.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
GRPO/.ipynb_checkpoints/Reward-checkpoint.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import math
|
| 4 |
+
import json
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from swift.plugin import ORM,orms
|
| 7 |
+
from typing import Dict, List, Union
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MultiModalAccuracyORM(ORM):
|
| 11 |
+
def __call__(self, completions, solution, **kwargs) -> List[float]:
|
| 12 |
+
"""
|
| 13 |
+
Reward function that checks if the completion is correct.
|
| 14 |
+
Args:
|
| 15 |
+
completions (list[str]): Generated outputs
|
| 16 |
+
solution (list[str]): Ground Truths.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
list[float]: Reward scores
|
| 20 |
+
"""
|
| 21 |
+
rewards = []
|
| 22 |
+
#completion_contents = [completion[0]["content"] for completion in completions]
|
| 23 |
+
for content, gt_score_orig in zip(completions, solution):
|
| 24 |
+
score_match = re.search(r"<overall score>(\d+)</overall score>", content)
|
| 25 |
+
#score_match = re.search(r"<score>(\d+)</score>", content)
|
| 26 |
+
pred_score = None
|
| 27 |
+
gt_score = None
|
| 28 |
+
# breakpoint()
|
| 29 |
+
# print(content)
|
| 30 |
+
# print(score_match)
|
| 31 |
+
if score_match:
|
| 32 |
+
try:
|
| 33 |
+
pred_score = int(score_match.group(1))
|
| 34 |
+
if not (1 <= pred_score <= 2):
|
| 35 |
+
pred_score = None
|
| 36 |
+
except:
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
gt_score = int(gt_score_orig[0])
|
| 41 |
+
|
| 42 |
+
if not (1 <= gt_score <= 2):
|
| 43 |
+
gt_score = None
|
| 44 |
+
except:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
# 分段奖励逻辑
|
| 48 |
+
if pred_score is not None and gt_score is not None:
|
| 49 |
+
if pred_score == gt_score:
|
| 50 |
+
reward = 5.0
|
| 51 |
+
elif abs(pred_score - gt_score) <= 1:
|
| 52 |
+
reward = 1.0
|
| 53 |
+
else:
|
| 54 |
+
reward = 0.0
|
| 55 |
+
else:
|
| 56 |
+
reward = 0.0
|
| 57 |
+
|
| 58 |
+
rewards.append(reward)
|
| 59 |
+
return rewards
|
| 60 |
+
class MultiModalFormatAccuracyORM(ORM):
|
| 61 |
+
def __call__(self, completions, **kwargs) -> List[float]:
|
| 62 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 63 |
+
rewards = []
|
| 64 |
+
response_pattern = r"<response think>.*?</response think>"
|
| 65 |
+
react_pattern = r"<fluency think>.*?</fluency think>"
|
| 66 |
+
score_pattern = r"[*\s]*<overall score>(\d+)</overall score>[\s*]*"
|
| 67 |
+
#completion_contents = [completion[0]["content"] for completion in completions]
|
| 68 |
+
for content in completions:
|
| 69 |
+
# breakpoint()
|
| 70 |
+
# print(content)
|
| 71 |
+
has_response = bool(re.search(response_pattern, content, re.DOTALL))
|
| 72 |
+
#print(has_response)
|
| 73 |
+
has_react = bool(re.search(react_pattern, content, re.DOTALL))
|
| 74 |
+
#print(has_react)
|
| 75 |
+
has_score = bool(re.search(score_pattern, content, re.DOTALL))
|
| 76 |
+
#print(has_score)
|
| 77 |
+
if has_response and has_react and has_score:
|
| 78 |
+
rewards.append(5.0)
|
| 79 |
+
# elif has_score and (has_response or has_react):
|
| 80 |
+
# rewards.append(3.0)
|
| 81 |
+
# elif has_response or has_react:
|
| 82 |
+
# rewards.append(1.0)
|
| 83 |
+
else:
|
| 84 |
+
rewards.append(0)
|
| 85 |
+
return rewards
|
| 86 |
+
orms['external_r1v_format_acc'] = MultiModalFormatAccuracyORM
|
| 87 |
+
orms['external_r1v_acc'] = MultiModalAccuracyORM
|
GRPO/.ipynb_checkpoints/formatReward-checkpoint.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class MultiModalFormatAccuracyORM(ORM):
|
| 2 |
+
def __call__(self, completions, **kwargs) -> List[float]:
|
| 3 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 4 |
+
rewards = []
|
| 5 |
+
response_pattern = r"<response think>.*?</response think>"
|
| 6 |
+
react_pattern = r"<fluency think>.*?</fluency think>"
|
| 7 |
+
score_pattern = r"[*\s]*<overall score>(\d+)</overall score>[\s*]*"
|
| 8 |
+
for content in completion_contents:
|
| 9 |
+
breakpoint()
|
| 10 |
+
print(content)
|
| 11 |
+
has_response = bool(re.search(response_pattern, content, re.DOTALL))
|
| 12 |
+
print(has_response)
|
| 13 |
+
has_react = bool(re.search(react_pattern, content, re.DOTALL))
|
| 14 |
+
print(has_react)
|
| 15 |
+
has_score = bool(re.search(score_pattern, content, re.DOTALL))
|
| 16 |
+
print(has_score)
|
| 17 |
+
if has_response and has_react and has_score:
|
| 18 |
+
rewards.append(5.0)
|
| 19 |
+
# elif has_score and (has_response or has_react):
|
| 20 |
+
# rewards.append(3.0)
|
| 21 |
+
# elif has_response or has_react:
|
| 22 |
+
# rewards.append(1.0)
|
| 23 |
+
else:
|
| 24 |
+
rewards.append(0)
|
| 25 |
+
return rewards
|
| 26 |
+
orms['external_r1v_format_acc'] = MultiModalFormatAccuracyORM
|
GRPO/__pycache__/Reward.cpython-310.pyc
ADDED
|
Binary file (2.23 kB). View file
|
|
|
GRPO/__pycache__/Reward.cpython-312.pyc
ADDED
|
Binary file (3.13 kB). View file
|
|
|
GRPO/formatReward.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class MultiModalFormatAccuracyORM(ORM):
|
| 2 |
+
def __call__(self, completions, **kwargs) -> List[float]:
|
| 3 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 4 |
+
rewards = []
|
| 5 |
+
response_pattern = r"<response think>.*?</response think>"
|
| 6 |
+
react_pattern = r"<fluency think>.*?</fluency think>"
|
| 7 |
+
score_pattern = r"[*\s]*<overall score>(\d+)</overall score>[\s*]*"
|
| 8 |
+
for content in completion_contents:
|
| 9 |
+
breakpoint()
|
| 10 |
+
print(content)
|
| 11 |
+
has_response = bool(re.search(response_pattern, content, re.DOTALL))
|
| 12 |
+
print(has_response)
|
| 13 |
+
has_react = bool(re.search(react_pattern, content, re.DOTALL))
|
| 14 |
+
print(has_react)
|
| 15 |
+
has_score = bool(re.search(score_pattern, content, re.DOTALL))
|
| 16 |
+
print(has_score)
|
| 17 |
+
if has_response and has_react and has_score:
|
| 18 |
+
rewards.append(5.0)
|
| 19 |
+
# elif has_score and (has_response or has_react):
|
| 20 |
+
# rewards.append(3.0)
|
| 21 |
+
# elif has_response or has_react:
|
| 22 |
+
# rewards.append(1.0)
|
| 23 |
+
else:
|
| 24 |
+
rewards.append(0)
|
| 25 |
+
return rewards
|
| 26 |
+
orms['external_r1v_format_acc'] = MultiModalFormatAccuracyORM
|
asset/discord_qr.jpg
ADDED
|
asset/wechat.png
ADDED
|
cotSFT/.ipynb_checkpoints/add_streoaudio-checkpoint.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
# 加载JSON数据的工具函数
|
| 4 |
+
def load_json(file_path):
|
| 5 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 6 |
+
return json.load(f)
|
| 7 |
+
|
| 8 |
+
# 保存JSON数据
|
| 9 |
+
def save_json(data, file_path):
|
| 10 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
| 11 |
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
| 12 |
+
|
| 13 |
+
# 主逻辑
|
| 14 |
+
def add_stereo_audio_field(array_json_path, dict_json_path, output_path):
|
| 15 |
+
array_data = load_json(array_json_path) # 加载JSON数组
|
| 16 |
+
dict_data = load_json(dict_json_path) # 加载JSON字典
|
| 17 |
+
|
| 18 |
+
# 遍历数组中的每一项,根据"key"从字典中查找对应的"stereo_audio"
|
| 19 |
+
for item in array_data:
|
| 20 |
+
key = item.get("key")
|
| 21 |
+
if key in dict_data:
|
| 22 |
+
stereo_audio = dict_data[key].get("stereo_audio")
|
| 23 |
+
item["audio_url"] = stereo_audio # 添加字段
|
| 24 |
+
|
| 25 |
+
# 保存结果
|
| 26 |
+
save_json(array_data, output_path)
|
| 27 |
+
|
| 28 |
+
# 示例调用
|
| 29 |
+
if __name__ == "__main__":
|
| 30 |
+
add_stereo_audio_field(
|
| 31 |
+
array_json_path="./test.json",
|
| 32 |
+
dict_json_path="./overlaps.json",
|
| 33 |
+
output_path="test.json"
|
| 34 |
+
)
|
cotSFT/.ipynb_checkpoints/onlyAudios_longdelay_add_silence-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cotSFT/.ipynb_checkpoints/overlaps-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cotSFT/Results/.ipynb_checkpoints/isoverlapresults-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cotSFT/Results/.ipynb_checkpoints/issilenceresults-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cotSFT/Results/.ipynb_checkpoints/texterror_results-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cotSFT/Results/correctresults.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cotSFT/Results/isoverlapresults.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cotSFT/Results/issilenceresults.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cotSFT/Results/texterror_results.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cotSFT/add_streoaudio.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
# 加载JSON数据的工具函数
|
| 4 |
+
def load_json(file_path):
|
| 5 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 6 |
+
return json.load(f)
|
| 7 |
+
|
| 8 |
+
# 保存JSON数据
|
| 9 |
+
def save_json(data, file_path):
|
| 10 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
| 11 |
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
| 12 |
+
|
| 13 |
+
# 主逻辑
|
| 14 |
+
def add_stereo_audio_field(array_json_path, dict_json_path, output_path):
|
| 15 |
+
array_data = load_json(array_json_path) # 加载JSON数组
|
| 16 |
+
dict_data = load_json(dict_json_path) # 加载JSON字典
|
| 17 |
+
|
| 18 |
+
# 遍历数组中的每一项,根据"key"从字典中查找对应的"stereo_audio"
|
| 19 |
+
for item in array_data:
|
| 20 |
+
key = item.get("key")
|
| 21 |
+
if key in dict_data:
|
| 22 |
+
stereo_audio = dict_data[key].get("stereo_audio")
|
| 23 |
+
item["audio_url"] = stereo_audio # 添加字段
|
| 24 |
+
|
| 25 |
+
# 保存结果
|
| 26 |
+
save_json(array_data, output_path)
|
| 27 |
+
|
| 28 |
+
# 示例调用
|
| 29 |
+
if __name__ == "__main__":
|
| 30 |
+
add_stereo_audio_field(
|
| 31 |
+
array_json_path="./test.json",
|
| 32 |
+
dict_json_path="./overlaps.json",
|
| 33 |
+
output_path="test.json"
|
| 34 |
+
)
|
cotSFT/gemini-correct/.ipynb_checkpoints/correctresults-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cotSFT/gemini-correct/.ipynb_checkpoints/gemini2.5_metainfo-checkpoint.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
import requests
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
import glob
|
| 8 |
+
from requests.exceptions import Timeout
|
| 9 |
+
import argparse
|
| 10 |
+
|
| 11 |
+
prompt_template = (
|
| 12 |
+
"# Interactional Dialogue Evaluation\n\n"
|
| 13 |
+
"**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\n"
|
| 14 |
+
"Evaluate the quality of the interaction in the given dialogue transcript, focusing on:\n"
|
| 15 |
+
"**Response Relevance:** \n"
|
| 16 |
+
"**logical consistency, topic coherence**\n"
|
| 17 |
+
"**Interactional Fluency:**\n"
|
| 18 |
+
"**Detect and evaluate extended overlaps in conversation.**\n"
|
| 19 |
+
"**Detect and evaluate long pauses between speaker turns.\n\n**"
|
| 20 |
+
"**Note**: Small pauses and brief overlaps in conversation are acceptable, while prolonged pauses and overlapping turns are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n"
|
| 21 |
+
"## Scoring Criteria\n"
|
| 22 |
+
"Assign a single holistic score based on the combined evaluation:\n"
|
| 23 |
+
"`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n"
|
| 24 |
+
"`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n"
|
| 25 |
+
"## Evaluation Output Format:\n"
|
| 26 |
+
"Strictly follow this template:\n"
|
| 27 |
+
"<response think>\n"
|
| 28 |
+
"[Analysing Response Relevance and giving reasons for scoring...]\n"
|
| 29 |
+
"</response think>\n"
|
| 30 |
+
"<fluency think>\n"
|
| 31 |
+
"[Analysing Interactional Fluency and giving reasons for scoring.]\n"
|
| 32 |
+
"</fluency think>\n"
|
| 33 |
+
"<overall score>X</overall score>\n"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# API configuration
|
| 37 |
+
url = "https://api2.aigcbest.top/v1/chat/completions"
|
| 38 |
+
headers = {
|
| 39 |
+
"Authorization": "Bearer sk-yAIqUaGzzVNSesHq4mRPaCbt53MMFRJIMB97cS4FkRy6idwN",
|
| 40 |
+
"Content-Type": "application/json",
|
| 41 |
+
"Accept": "application/json"
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
def parse_args():
|
| 45 |
+
parser = argparse.ArgumentParser(description='Process text evaluation with Gemini model')
|
| 46 |
+
parser.add_argument('--input_file', type=str, required=True,
|
| 47 |
+
help='Input JSON file containing text data')
|
| 48 |
+
parser.add_argument('--output_file', type=str, default='texterror_gemini.json',
|
| 49 |
+
help='Output JSON file for results')
|
| 50 |
+
parser.add_argument('--error_file', type=str, default='texterror_gemini_error.json',
|
| 51 |
+
help='Output JSON file for errors')
|
| 52 |
+
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints_test_text',
|
| 53 |
+
help='Directory for storing checkpoints')
|
| 54 |
+
parser.add_argument('--max_retries', type=int, default=3,
|
| 55 |
+
help='Maximum number of retries for failed predictions')
|
| 56 |
+
parser.add_argument('--checkpoint_interval', type=int, default=20,
|
| 57 |
+
help='Number of items to process before saving checkpoint')
|
| 58 |
+
return parser.parse_args()
|
| 59 |
+
|
| 60 |
+
def extract_overall_score(output_str):
|
| 61 |
+
"""Extract <overall score>X</overall score> from model output."""
|
| 62 |
+
score_pattern = r"<overall score>(\d+)</overall score>"
|
| 63 |
+
match = re.search(score_pattern, output_str)
|
| 64 |
+
if match:
|
| 65 |
+
try:
|
| 66 |
+
return int(match.group(1))
|
| 67 |
+
except ValueError:
|
| 68 |
+
pass
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
def validate_model_output(output_str):
|
| 72 |
+
"""Validate that the model output contains all required tags"""
|
| 73 |
+
required_tags = [
|
| 74 |
+
"<response think>",
|
| 75 |
+
"</response think>",
|
| 76 |
+
"<fluency think>",
|
| 77 |
+
"</fluency think>",
|
| 78 |
+
"<overall score>",
|
| 79 |
+
"</overall score>"
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
for tag in required_tags:
|
| 83 |
+
if tag not in output_str:
|
| 84 |
+
return False
|
| 85 |
+
return True
|
| 86 |
+
|
| 87 |
+
def extract_tag_content(output_str, tag_name):
|
| 88 |
+
"""Extract content between opening and closing tags"""
|
| 89 |
+
start_tag = f"<{tag_name}>"
|
| 90 |
+
end_tag = f"</{tag_name}>"
|
| 91 |
+
try:
|
| 92 |
+
start_idx = output_str.find(start_tag) + len(start_tag)
|
| 93 |
+
end_idx = output_str.find(end_tag)
|
| 94 |
+
if start_idx == -1 or end_idx == -1:
|
| 95 |
+
return None
|
| 96 |
+
return output_str[start_idx:end_idx].strip()
|
| 97 |
+
except:
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
def format_model_output(output_str):
|
| 101 |
+
"""Extract and format content from all required tags"""
|
| 102 |
+
response_content = extract_tag_content(output_str, "response think")
|
| 103 |
+
fluency_content = extract_tag_content(output_str, "fluency think")
|
| 104 |
+
score_content = extract_tag_content(output_str, "overall score")
|
| 105 |
+
|
| 106 |
+
if not all([response_content, fluency_content, score_content]):
|
| 107 |
+
return None
|
| 108 |
+
|
| 109 |
+
formatted_output = (
|
| 110 |
+
f"<response think>\n{response_content}\n</response think>\n\n"
|
| 111 |
+
f"<fluency think>\n{fluency_content}\n</fluency think>\n\n"
|
| 112 |
+
f"<overall score>{score_content}</overall score>"
|
| 113 |
+
)
|
| 114 |
+
return formatted_output
|
| 115 |
+
|
| 116 |
+
def make_api_call(text_input, retry_count=0, max_retries=5):
|
| 117 |
+
"""Make API call with retry logic for API errors"""
|
| 118 |
+
try:
|
| 119 |
+
print(f"Attempting API call (attempt {retry_count + 1}/{max_retries + 1})")
|
| 120 |
+
data_req = {
|
| 121 |
+
"model": "gemini-2.5-flash-preview-05-20-thinking",
|
| 122 |
+
"messages": [
|
| 123 |
+
{
|
| 124 |
+
"role": "user",
|
| 125 |
+
"content": [
|
| 126 |
+
{
|
| 127 |
+
"type": "text",
|
| 128 |
+
"text": prompt_template
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"type": "text",
|
| 132 |
+
"text": text_input
|
| 133 |
+
},
|
| 134 |
+
]
|
| 135 |
+
}
|
| 136 |
+
],
|
| 137 |
+
"temperature": 1,
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
response = requests.post(url, headers=headers, json=data_req, timeout=(200, 200))
|
| 141 |
+
print(f"API response received with status code: {response.status_code}")
|
| 142 |
+
|
| 143 |
+
if response.status_code == 200:
|
| 144 |
+
model_output = response.json()['choices'][0]['message']['content']
|
| 145 |
+
if not validate_model_output(model_output):
|
| 146 |
+
print("Model output missing required tags, retrying...")
|
| 147 |
+
return None, None
|
| 148 |
+
|
| 149 |
+
formatted_output = format_model_output(model_output)
|
| 150 |
+
if formatted_output is None:
|
| 151 |
+
print("Failed to extract content from tags, retrying...")
|
| 152 |
+
return None, None
|
| 153 |
+
|
| 154 |
+
pred_score = extract_overall_score(model_output)
|
| 155 |
+
return formatted_output, pred_score
|
| 156 |
+
else:
|
| 157 |
+
print(f"API returned error status {response.status_code}: {response.text}")
|
| 158 |
+
if retry_count >= max_retries:
|
| 159 |
+
raise Exception(f"POST error {response.status_code}: {response.text}")
|
| 160 |
+
return None, None
|
| 161 |
+
except requests.exceptions.ConnectTimeout:
|
| 162 |
+
print(f"Connection timeout (>10s)")
|
| 163 |
+
if retry_count >= max_retries:
|
| 164 |
+
raise Exception("Connection timeout")
|
| 165 |
+
return None, None
|
| 166 |
+
except requests.exceptions.ReadTimeout:
|
| 167 |
+
print(f"Read timeout (>30s)")
|
| 168 |
+
if retry_count >= max_retries:
|
| 169 |
+
raise Exception("Read timeout")
|
| 170 |
+
return None, None
|
| 171 |
+
except Exception as e:
|
| 172 |
+
print(f"Unexpected error during API call: {str(e)}")
|
| 173 |
+
if retry_count >= max_retries:
|
| 174 |
+
raise e
|
| 175 |
+
return None, None
|
| 176 |
+
|
| 177 |
+
def get_latest_checkpoint(checkpoint_dir):
|
| 178 |
+
"""Get the latest checkpoint file and its processed count"""
|
| 179 |
+
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint_*.json"))
|
| 180 |
+
if not checkpoint_files:
|
| 181 |
+
return None, 0
|
| 182 |
+
|
| 183 |
+
latest_checkpoint = None
|
| 184 |
+
max_count = 0
|
| 185 |
+
for checkpoint in checkpoint_files:
|
| 186 |
+
try:
|
| 187 |
+
count = int(os.path.basename(checkpoint).split('_')[1])
|
| 188 |
+
if count > max_count:
|
| 189 |
+
max_count = count
|
| 190 |
+
latest_checkpoint = checkpoint
|
| 191 |
+
except (ValueError, IndexError):
|
| 192 |
+
continue
|
| 193 |
+
|
| 194 |
+
return latest_checkpoint, max_count
|
| 195 |
+
|
| 196 |
+
def save_checkpoint(results, processed_count, checkpoint_dir):
|
| 197 |
+
"""Save results to a checkpoint file"""
|
| 198 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 199 |
+
checkpoint_file = os.path.join(checkpoint_dir, f"checkpoint_{processed_count}_{timestamp}.json")
|
| 200 |
+
with open(checkpoint_file, "w", encoding="utf-8") as f:
|
| 201 |
+
json.dump(results, f, indent=2, ensure_ascii=False)
|
| 202 |
+
print(f"Checkpoint saved: {checkpoint_file}")
|
| 203 |
+
|
| 204 |
+
def main():
|
| 205 |
+
args = parse_args()
|
| 206 |
+
|
| 207 |
+
# Initialize results storage
|
| 208 |
+
results = []
|
| 209 |
+
save_file_name = args.output_file
|
| 210 |
+
error_file_name = args.error_file
|
| 211 |
+
|
| 212 |
+
# Create checkpoints directory
|
| 213 |
+
checkpoint_dir = args.checkpoint_dir
|
| 214 |
+
if not os.path.exists(checkpoint_dir):
|
| 215 |
+
os.makedirs(checkpoint_dir)
|
| 216 |
+
|
| 217 |
+
# Load test data
|
| 218 |
+
all_data_file = args.input_file
|
| 219 |
+
with open(all_data_file, 'r', encoding='utf-8') as f:
|
| 220 |
+
all_data = json.load(f)
|
| 221 |
+
|
| 222 |
+
# Initialize error tracking
|
| 223 |
+
error_results = []
|
| 224 |
+
|
| 225 |
+
# Load checkpoint if exists
|
| 226 |
+
latest_checkpoint, checkpoint_count = get_latest_checkpoint(checkpoint_dir)
|
| 227 |
+
if latest_checkpoint:
|
| 228 |
+
print(f"Found latest checkpoint with {checkpoint_count} processed items: {latest_checkpoint}")
|
| 229 |
+
try:
|
| 230 |
+
with open(latest_checkpoint, 'r', encoding='utf-8') as f:
|
| 231 |
+
results = json.load(f)
|
| 232 |
+
print(f"Resumed from checkpoint: processed {len(results)} items")
|
| 233 |
+
except Exception as e:
|
| 234 |
+
print(f"Warning: Failed to load checkpoint {latest_checkpoint}: {e}")
|
| 235 |
+
results = []
|
| 236 |
+
else:
|
| 237 |
+
print("No checkpoint found, starting from scratch")
|
| 238 |
+
results = []
|
| 239 |
+
|
| 240 |
+
max_prediction_retries = args.max_retries
|
| 241 |
+
total_count = 0
|
| 242 |
+
|
| 243 |
+
for item in tqdm(all_data, desc="Processing texts"):
|
| 244 |
+
key = item.get('key')
|
| 245 |
+
text_input = item.get('model_output')
|
| 246 |
+
|
| 247 |
+
if not text_input:
|
| 248 |
+
print(f"No text input found for key {key}, skipping...")
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
print(f"Processing text for key={key}")
|
| 252 |
+
|
| 253 |
+
prediction_retry_count = 0
|
| 254 |
+
success = False
|
| 255 |
+
|
| 256 |
+
while prediction_retry_count < max_prediction_retries and not success:
|
| 257 |
+
try:
|
| 258 |
+
print(f"\nProcessing attempt {prediction_retry_count + 1}")
|
| 259 |
+
model_output, pred_score = make_api_call(text_input)
|
| 260 |
+
|
| 261 |
+
if model_output is None or pred_score is None:
|
| 262 |
+
print("API call failed, retrying...")
|
| 263 |
+
prediction_retry_count += 1
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
print(f"Received prediction: {pred_score}")
|
| 267 |
+
|
| 268 |
+
if pred_score == 2:
|
| 269 |
+
success = True
|
| 270 |
+
print("Prediction score is 2, accepting result")
|
| 271 |
+
else:
|
| 272 |
+
prediction_retry_count += 1
|
| 273 |
+
print(f"Prediction score is not 2 (attempt {prediction_retry_count}/{max_prediction_retries})")
|
| 274 |
+
if prediction_retry_count >= max_prediction_retries:
|
| 275 |
+
print("Max retries reached, accepting last prediction")
|
| 276 |
+
success = True
|
| 277 |
+
else:
|
| 278 |
+
continue
|
| 279 |
+
|
| 280 |
+
results.append({
|
| 281 |
+
"key": key,
|
| 282 |
+
"text_input": text_input,
|
| 283 |
+
"model_output": model_output,
|
| 284 |
+
"predicted_score": pred_score,
|
| 285 |
+
"prediction_attempts": prediction_retry_count + 1
|
| 286 |
+
})
|
| 287 |
+
|
| 288 |
+
with open(save_file_name, "w", encoding="utf-8") as f:
|
| 289 |
+
json.dump(results, f, indent=2, ensure_ascii=False)
|
| 290 |
+
|
| 291 |
+
total_count += 1
|
| 292 |
+
|
| 293 |
+
if total_count % args.checkpoint_interval == 0:
|
| 294 |
+
save_checkpoint(results, total_count, checkpoint_dir)
|
| 295 |
+
|
| 296 |
+
except Exception as e:
|
| 297 |
+
error_msg = str(e)
|
| 298 |
+
print(f"Failed to process text for key {key}: {error_msg}")
|
| 299 |
+
error_results.append({
|
| 300 |
+
"key": key,
|
| 301 |
+
"text_input": text_input,
|
| 302 |
+
"error": f"Exception: {error_msg}"
|
| 303 |
+
})
|
| 304 |
+
break
|
| 305 |
+
|
| 306 |
+
with open(error_file_name, "w", encoding="utf-8") as f:
|
| 307 |
+
json.dump(error_results, f, indent=2, ensure_ascii=False)
|
| 308 |
+
|
| 309 |
+
# Save final results
|
| 310 |
+
with open(save_file_name, "w", encoding="utf-8") as f:
|
| 311 |
+
json.dump(results, f, indent=2, ensure_ascii=False)
|
| 312 |
+
|
| 313 |
+
print(f"Results saved to {save_file_name}")
|
| 314 |
+
print(f"Total processed items: {total_count}")
|
| 315 |
+
|
| 316 |
+
if __name__ == "__main__":
|
| 317 |
+
main()
|
cotSFT/gemini-correct/.ipynb_checkpoints/run_gemini_meta-checkpoint.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python gemini2.5_metainfo.py --input_file thinkSFT_correct_transcriptions.json --output_file results.json --error_file errors.json --max_retries 5 --checkpoint_interval 100
|
cotSFT/gemini-correct/.ipynb_checkpoints/thinkSFT_correct_transcriptions-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cotSFT/gemini-correct/checkpoints_test_text/checkpoint_100_20250612_182923.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cotSFT/gemini-correct/checkpoints_test_text/checkpoint_1100_20250613_023551.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cotSFT/gemini-correct/checkpoints_test_text/checkpoint_200_20250612_191111.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|