import os import json import re import requests from tqdm import tqdm from datetime import datetime import glob from requests.exceptions import Timeout import argparse prompt_template = ( "# Interactional Dialogue Evaluation\n\n" "**IMPORTANT**: Evaluation must include `` and `` analysis and `` rating.\n" "Evaluate the quality of the interaction in the given dialogue transcript, focusing on:\n" "**Response Relevance:** \n" "**logical consistency, topic coherence**\n" "**Interactional Fluency:**\n" "**Detect and evaluate extended overlaps in conversation.**\n" "**Detect and evaluate long pauses between speaker turns.\n\n**" "**Note**: Small pauses and brief overlaps in conversation are acceptable, while prolonged pauses and overlapping turns are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n" "## Scoring Criteria\n" "Assign a single holistic score based on the combined evaluation:\n" "`1` (Poor): Significant issues in either **Response Relevance ** or **Interactional Fluency. **\n" "`2` (Excellent): Both **Response Relevance ** and **Interactional Fluency ** are consistently appropriate and natural.\n" "## Evaluation Output Format:\n" "Strictly follow this template:\n" "\n" "[Analysing Response Relevance and giving reasons for scoring...]\n" "\n" "\n" "[Analysing Interactional Fluency and giving reasons for scoring.]\n" "\n" "X\n" ) # API configuration url = "https://api2.aigcbest.top/v1/chat/completions" headers = { "Authorization": "Bearer sk-yAIqUaGzzVNSesHq4mRPaCbt53MMFRJIMB97cS4FkRy6idwN", "Content-Type": "application/json", "Accept": "application/json" } def parse_args(): parser = argparse.ArgumentParser(description='Process text evaluation with Gemini model') parser.add_argument('--input_file', type=str, required=True, help='Input JSON file containing text data') parser.add_argument('--output_file', type=str, default='texterror_gemini.json', help='Output JSON file for results') parser.add_argument('--error_file', type=str, default='texterror_gemini_error.json', help='Output JSON file for errors') parser.add_argument('--checkpoint_dir', type=str, default='checkpoints_test_text', help='Directory for storing checkpoints') parser.add_argument('--max_retries', type=int, default=3, help='Maximum number of retries for failed predictions') parser.add_argument('--checkpoint_interval', type=int, default=20, help='Number of items to process before saving checkpoint') return parser.parse_args() def extract_overall_score(output_str): """Extract X from model output.""" score_pattern = r"(\d+)" match = re.search(score_pattern, output_str) if match: try: return int(match.group(1)) except ValueError: pass return None def validate_model_output(output_str): """Validate that the model output contains all required tags""" required_tags = [ "", "", "", "", "", "" ] for tag in required_tags: if tag not in output_str: return False return True def extract_tag_content(output_str, tag_name): """Extract content between opening and closing tags""" start_tag = f"<{tag_name}>" end_tag = f"" try: start_idx = output_str.find(start_tag) + len(start_tag) end_idx = output_str.find(end_tag) if start_idx == -1 or end_idx == -1: return None return output_str[start_idx:end_idx].strip() except: return None def format_model_output(output_str): """Extract and format content from all required tags""" response_content = extract_tag_content(output_str, "response think") fluency_content = extract_tag_content(output_str, "fluency think") score_content = extract_tag_content(output_str, "overall score") if not all([response_content, fluency_content, score_content]): return None formatted_output = ( f"\n{response_content}\n\n\n" f"\n{fluency_content}\n\n\n" f"{score_content}" ) return formatted_output def make_api_call(text_input, retry_count=0, max_retries=5): """Make API call with retry logic for API errors""" try: print(f"Attempting API call (attempt {retry_count + 1}/{max_retries + 1})") data_req = { "model": "gemini-2.5-flash-preview-05-20-thinking", "messages": [ { "role": "user", "content": [ { "type": "text", "text": prompt_template }, { "type": "text", "text": text_input }, ] } ], "temperature": 1, } response = requests.post(url, headers=headers, json=data_req, timeout=(200, 200)) print(f"API response received with status code: {response.status_code}") if response.status_code == 200: model_output = response.json()['choices'][0]['message']['content'] if not validate_model_output(model_output): print("Model output missing required tags, retrying...") return None, None formatted_output = format_model_output(model_output) if formatted_output is None: print("Failed to extract content from tags, retrying...") return None, None pred_score = extract_overall_score(model_output) return formatted_output, pred_score else: print(f"API returned error status {response.status_code}: {response.text}") if retry_count >= max_retries: raise Exception(f"POST error {response.status_code}: {response.text}") return None, None except requests.exceptions.ConnectTimeout: print(f"Connection timeout (>10s)") if retry_count >= max_retries: raise Exception("Connection timeout") return None, None except requests.exceptions.ReadTimeout: print(f"Read timeout (>30s)") if retry_count >= max_retries: raise Exception("Read timeout") return None, None except Exception as e: print(f"Unexpected error during API call: {str(e)}") if retry_count >= max_retries: raise e return None, None def get_latest_checkpoint(checkpoint_dir): """Get the latest checkpoint file and its processed count""" checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint_*.json")) if not checkpoint_files: return None, 0 latest_checkpoint = None max_count = 0 for checkpoint in checkpoint_files: try: count = int(os.path.basename(checkpoint).split('_')[1]) if count > max_count: max_count = count latest_checkpoint = checkpoint except (ValueError, IndexError): continue return latest_checkpoint, max_count def save_checkpoint(results, processed_count, checkpoint_dir): """Save results to a checkpoint file""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") checkpoint_file = os.path.join(checkpoint_dir, f"checkpoint_{processed_count}_{timestamp}.json") with open(checkpoint_file, "w", encoding="utf-8") as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"Checkpoint saved: {checkpoint_file}") def main(): args = parse_args() # Initialize results storage results = [] save_file_name = args.output_file error_file_name = args.error_file # Create checkpoints directory checkpoint_dir = args.checkpoint_dir if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) # Load test data all_data_file = args.input_file with open(all_data_file, 'r', encoding='utf-8') as f: all_data = json.load(f) # Initialize error tracking error_results = [] # Load checkpoint if exists latest_checkpoint, checkpoint_count = get_latest_checkpoint(checkpoint_dir) if latest_checkpoint: print(f"Found latest checkpoint with {checkpoint_count} processed items: {latest_checkpoint}") try: with open(latest_checkpoint, 'r', encoding='utf-8') as f: results = json.load(f) print(f"Resumed from checkpoint: processed {len(results)} items") except Exception as e: print(f"Warning: Failed to load checkpoint {latest_checkpoint}: {e}") results = [] else: print("No checkpoint found, starting from scratch") results = [] max_prediction_retries = args.max_retries total_count = 0 for item in tqdm(all_data, desc="Processing texts"): key = item.get('key') text_input = item.get('model_output') if not text_input: print(f"No text input found for key {key}, skipping...") continue print(f"Processing text for key={key}") prediction_retry_count = 0 success = False while prediction_retry_count < max_prediction_retries and not success: try: print(f"\nProcessing attempt {prediction_retry_count + 1}") model_output, pred_score = make_api_call(text_input) if model_output is None or pred_score is None: print("API call failed, retrying...") prediction_retry_count += 1 continue print(f"Received prediction: {pred_score}") if pred_score == 1: success = True print("Prediction score is 1, accepting result") else: prediction_retry_count += 1 print(f"Prediction score is not 1 (attempt {prediction_retry_count}/{max_prediction_retries})") if prediction_retry_count >= max_prediction_retries: print("Max retries reached, accepting last prediction") success = True else: continue results.append({ "key": key, "text_input": text_input, "model_output": model_output, "predicted_score": pred_score, "prediction_attempts": prediction_retry_count + 1 }) with open(save_file_name, "w", encoding="utf-8") as f: json.dump(results, f, indent=2, ensure_ascii=False) total_count += 1 if total_count % args.checkpoint_interval == 0: save_checkpoint(results, total_count, checkpoint_dir) except Exception as e: error_msg = str(e) print(f"Failed to process text for key {key}: {error_msg}") error_results.append({ "key": key, "text_input": text_input, "error": f"Exception: {error_msg}" }) break with open(error_file_name, "w", encoding="utf-8") as f: json.dump(error_results, f, indent=2, ensure_ascii=False) # Save final results with open(save_file_name, "w", encoding="utf-8") as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"Results saved to {save_file_name}") print(f"Total processed items: {total_count}") if __name__ == "__main__": main()