interactSpeech / cotSFT_new /cotSFT_10data /gemini2.5_metainfo.py
Student0809's picture
Add files using upload-large-folder tool
3438cdb verified
import os
import json
import re
import requests
from tqdm import tqdm
from datetime import datetime
import glob
from requests.exceptions import Timeout
import argparse
import multiprocessing
prompt_template = (
"# Interactional Dialogue Evaluation\n\n"
"**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\n"
"Evaluate the quality of the interaction in the given dialogue transcript, focusing on:\n"
"**Response Relevance:** \n"
"**logical consistency, topic coherence**\n"
"**Interactional Fluency:**\n"
"**Detect and evaluate extended overlaps in conversation.**\n"
"**Detect and evaluate long pauses between speaker turns.\n\n**"
"**Note**: Small pauses and brief overlaps in conversation are acceptable, while prolonged pauses and overlapping turns are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n"
"## Scoring Criteria\n"
"Assign a single holistic score based on the combined evaluation:\n"
"`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n"
"`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n"
"## Evaluation Output Format:\n"
"Strictly follow this template:\n"
"<response think>\n"
"[Analysing Response Relevance and giving reasons for scoring...]\n"
"</response think>\n"
"<fluency think>\n"
"[Analysing Interactional Fluency and giving reasons for scoring.]\n"
"</fluency think>\n"
"<overall score>X</overall score>\n"
)
# API configuration
url = "https://api2.aigcbest.top/v1/chat/completions"
headers = {
"Authorization": "Bearer sk-yAIqUaGzzVNSesHq4mRPaCbt53MMFRJIMB97cS4FkRy6idwN",
"Content-Type": "application/json",
"Accept": "application/json"
}
def parse_args():
parser = argparse.ArgumentParser(description='Process text evaluation with Gemini model')
parser.add_argument('--input_file', type=str, default='all_dialogues_processed.json',
help='Input JSON file containing text data')
parser.add_argument('--output_file', type=str, default='cotSFT_gemini.json',
help='Output JSON file for results')
parser.add_argument('--error_file', type=str, default='cotSFT_gemini_error.json',
help='Output JSON file for errors')
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints_test_text',
help='Directory for storing checkpoints')
parser.add_argument('--max_retries', type=int, default=6,
help='Maximum number of retries for failed predictions')
parser.add_argument('--checkpoint_interval', type=int, default=100,
help='Number of items to process before saving checkpoint')
parser.add_argument('--num_processes', type=int, default=5,
help='Number of parallel processes to use')
return parser.parse_args()
def extract_overall_score(output_str):
"""Extract <overall score>X</overall score> from model output."""
score_pattern = r"<overall score>(\d+)</overall score>"
match = re.search(score_pattern, output_str)
if match:
try:
return int(match.group(1))
except ValueError:
pass
return None
def validate_model_output(output_str):
"""Validate that the model output contains all required tags"""
required_tags = [
"<response think>",
"</response think>",
"<fluency think>",
"</fluency think>",
"<overall score>",
"</overall score>"
]
for tag in required_tags:
if tag not in output_str:
return False
return True
def extract_tag_content(output_str, tag_name):
"""Extract content between opening and closing tags"""
start_tag = f"<{tag_name}>"
end_tag = f"</{tag_name}>"
try:
start_idx = output_str.find(start_tag) + len(start_tag)
end_idx = output_str.find(end_tag)
if start_idx == -1 or end_idx == -1:
return None
return output_str[start_idx:end_idx].strip()
except:
return None
def format_model_output(output_str):
"""Extract and format content from all required tags"""
response_content = extract_tag_content(output_str, "response think")
fluency_content = extract_tag_content(output_str, "fluency think")
score_content = extract_tag_content(output_str, "overall score")
if not all([response_content, fluency_content, score_content]):
return None
formatted_output = (
f"<response think>\n{response_content}\n</response think>\n\n"
f"<fluency think>\n{fluency_content}\n</fluency think>\n\n"
f"<overall score>{score_content}</overall score>"
)
return formatted_output
def make_api_call(text_input, retry_count=0, max_retries=5):
"""Make API call with retry logic for API errors"""
try:
print(f"Attempting API call (attempt {retry_count + 1}/{max_retries + 1})")
data_req = {
"model": "gemini-2.5-pro-preview-06-05-thinking",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt_template
},
{
"type": "text",
"text": f"The correct overall score is: 2\n"
},
{
"type": "text",
"text": text_input
},
]
}
],
"temperature": 1,
}
response = requests.post(url, headers=headers, json=data_req, timeout=(200, 200))
print(f"API response received with status code: {response.status_code}")
if response.status_code == 200:
model_output = response.json()['choices'][0]['message']['content']
if not validate_model_output(model_output):
print("Model output missing required tags, retrying...")
return None, None
formatted_output = format_model_output(model_output)
if formatted_output is None:
print("Failed to extract content from tags, retrying...")
return None, None
pred_score = extract_overall_score(model_output)
return formatted_output, pred_score
else:
print(f"API returned error status {response.status_code}: {response.text}")
if retry_count >= max_retries:
raise Exception(f"POST error {response.status_code}: {response.text}")
return None, None
except requests.exceptions.ConnectTimeout:
print(f"Connection timeout (>10s)")
if retry_count >= max_retries:
raise Exception("Connection timeout")
return None, None
except requests.exceptions.ReadTimeout:
print(f"Read timeout (>30s)")
if retry_count >= max_retries:
raise Exception("Read timeout")
return None, None
except Exception as e:
print(f"Unexpected error during API call: {str(e)}")
if retry_count >= max_retries:
raise e
return None, None
def get_latest_checkpoint(checkpoint_dir):
"""Get the latest checkpoint file and its processed count"""
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint_*.json"))
if not checkpoint_files:
return None, 0
latest_checkpoint = None
max_count = 0
for checkpoint in checkpoint_files:
try:
count = int(os.path.basename(checkpoint).split('_')[1])
if count > max_count:
max_count = count
latest_checkpoint = checkpoint
except (ValueError, IndexError):
continue
return latest_checkpoint, max_count
def save_checkpoint(results, processed_count, checkpoint_dir):
"""Save results to a checkpoint file"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_file = os.path.join(checkpoint_dir, f"checkpoint_{processed_count}_{timestamp}.json")
with open(checkpoint_file, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"Checkpoint saved: {checkpoint_file}")
def split_data(data, num_chunks):
# Split data into num_chunks as evenly as possible
chunk_size = len(data) // num_chunks
remainder = len(data) % num_chunks
chunks = []
start = 0
for i in range(num_chunks):
end = start + chunk_size + (1 if i < remainder else 0)
chunks.append(data[start:end])
start = end
return chunks
def process_chunk(args_tuple):
chunk_data, chunk_idx, args = args_tuple
results = []
error_results = []
save_file_name = f"{os.path.splitext(args.output_file)[0]}_chunk{chunk_idx}.json"
error_file_name = f"{os.path.splitext(args.error_file)[0]}_chunk{chunk_idx}.json"
checkpoint_dir = f"{args.checkpoint_dir}_chunk{chunk_idx}"
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
max_prediction_retries = args.max_retries
total_count = 0
for item in tqdm(chunk_data, desc=f"Processing chunk {chunk_idx}"):
key = item.get('key')
text_input = item.get('process_dialogue') # 使用process_dialogue字段
if not text_input:
print(f"No text input found for key {key}, skipping...")
continue
prediction_retry_count = 0
success = False
while prediction_retry_count < max_prediction_retries and not success:
try:
model_output, pred_score = make_api_call(text_input)
if model_output is None or pred_score is None:
prediction_retry_count += 1
print(f"API call failed for key {key}, retry {prediction_retry_count}/{max_prediction_retries}")
continue
# 只有当预测分数为2时才保存结果
if pred_score == 2:
success = True
results.append({
"key": key,
"text_input": text_input,
"model_output": model_output,
"predicted_score": pred_score,
"prediction_attempts": prediction_retry_count + 1
})
print(f"Success! Predicted score 2 for key {key} after {prediction_retry_count + 1} attempts")
else:
prediction_retry_count += 1
print(f"Predicted score {pred_score} for key {key}, retry {prediction_retry_count}/{max_prediction_retries}")
if prediction_retry_count >= max_prediction_retries:
print(f"Max retries reached for key {key}, saving with score {pred_score}")
results.append({
"key": key,
"text_input": text_input,
"model_output": model_output,
"predicted_score": pred_score,
"prediction_attempts": prediction_retry_count
})
success = True
continue
# 保存当前结果
with open(save_file_name, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
total_count += 1
if total_count % args.checkpoint_interval == 0:
save_checkpoint(results, total_count, checkpoint_dir)
except Exception as e:
error_msg = str(e)
print(f"Exception for key {key}: {error_msg}")
error_results.append({
"key": key,
"text_input": text_input,
"error": f"Exception: {error_msg}"
})
break
# 保存错误结果
with open(error_file_name, "w", encoding="utf-8") as f:
json.dump(error_results, f, indent=2, ensure_ascii=False)
# 最终保存结果
with open(save_file_name, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
return save_file_name, error_file_name
def merge_json_files(file_list, output_file):
merged = []
for fname in file_list:
if os.path.exists(fname):
with open(fname, 'r', encoding='utf-8') as f:
merged.extend(json.load(f))
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(merged, f, indent=2, ensure_ascii=False)
def main():
args = parse_args()
with open(args.input_file, 'r', encoding='utf-8') as f:
all_data = json.load(f)
num_chunks = args.num_processes
chunks = split_data(all_data, num_chunks)
pool = multiprocessing.Pool(num_chunks)
chunk_args = [(chunks[i], i, args) for i in range(num_chunks)]
results = pool.map(process_chunk, chunk_args)
pool.close()
pool.join()
# 合并所有chunk输出文件
output_files = [r[0] for r in results]
error_files = [r[1] for r in results]
merge_json_files(output_files, args.output_file)
merge_json_files(error_files, args.error_file)
print(f"Results saved to {args.output_file}")
print(f"Errors saved to {args.error_file}")
if __name__ == "__main__":
main()