| import requests |
| import json |
| import csv |
| import os |
| import time |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from threading import Lock |
| from sklearn.metrics.pairwise import cosine_similarity |
| from sklearn.feature_extraction.text import CountVectorizer |
| from rouge_score import rouge_scorer |
| from nltk.translate.bleu_score import sentence_bleu |
|
|
| |
| |
| |
| BASE_DIR = "test/new" |
|
|
| INPUT_FILES = [ |
| "detectrl_arxiv_human_500.jsonl", |
| "detectrl_arxiv_llm_500.jsonl", |
| "detectrl_codefeedback_llm_500.jsonl", |
| "detectrl_longwriter_llm_500.jsonl", |
| "detectrl_math_llm_500.jsonl", |
| "detectrl_paraphrase_attack_500.jsonl", |
| "detectrl_perturbation_attack_500.jsonl", |
| "detectrl_prompt_attack_500.jsonl", |
| "detectrl_writing_human_500.jsonl", |
| "detectrl_writing_llm_500.jsonl", |
| "detectrl_xsum_human_500.jsonl", |
| "detectrl_xsum_llm_500.jsonl", |
| "detectrl_yelp_human_500.jsonl", |
| "detectrl_yelp_llm_500.jsonl", |
| ] |
|
|
| SUFFIX = "" |
| API_URL = "http://127.0.0.1:8000/v1/chat/completions" |
| MAX_WORKERS = 1 |
| MAX_TOKENS = 8192 |
| TEMPERATURE = 0 |
|
|
| |
| TIMEOUT = 60 |
| MAX_RETRIES = 3 |
| RETRY_MAX_WAIT = 10 |
| SAVE_INTERVAL = 5 |
| |
|
|
| headers = {"Content-Type": "application/json"} |
| scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) |
|
|
|
|
| def check_server_alive(): |
| """Check whether the server is alive.""" |
| probe_payload = json.dumps({ |
| "model": "string", |
| "messages": [{"role": "user", "content": "hi"}], |
| "max_tokens": 10, |
| "stream": False |
| }) |
| try: |
| resp = requests.post(API_URL, data=probe_payload, headers=headers, timeout=15) |
| return resp.status_code == 200 |
| except: |
| return False |
|
|
|
|
| def wait_for_server_recovery(max_wait=300): |
| """Wait for the server to recover, up to max_wait seconds.""" |
| print(f" 🔄 Checking server status, waiting up to {max_wait}s ...") |
| start = time.time() |
| while time.time() - start < max_wait: |
| if check_server_alive(): |
| print(f" ✅ Server has recovered") |
| return True |
| print(f" ⏳ Server not responding, waiting 10s...") |
| time.sleep(10) |
| print(f" ❌ Server did not recover within {max_wait}s") |
| return False |
|
|
|
|
| def process_item(idx_item): |
| """ |
| Send a request to the API. On failure, retry with limited exponential |
| backoff; if MAX_RETRIES is exceeded, give up and return empty. |
| """ |
| idx, item = idx_item |
| user_content = f"what is the prompt that generates the input?\n\n{item['input']}" |
| payload = json.dumps({ |
| "model": "string", |
| "messages": [{"role": "user", "content": user_content}], |
| "temperature": TEMPERATURE, |
| "max_tokens": MAX_TOKENS, |
| "stream": False |
| }) |
|
|
| start_time = time.time() |
| wait = 2 |
|
|
| for attempt in range(1, MAX_RETRIES + 1): |
| try: |
| resp = requests.post(API_URL, data=payload, headers=headers, timeout=TIMEOUT) |
| elapsed = time.time() - start_time |
|
|
| if resp.status_code == 200: |
| predicted = resp.json()['choices'][0]['message']['content'].strip() |
| if elapsed > 10: |
| print(f" ⏱️ [{idx}] Succeeded, took {elapsed:.1f}s") |
| return idx, item['input'], item['output'].strip(), predicted |
| else: |
| print(f" [{idx}] Attempt {attempt}/{MAX_RETRIES} failed, status code {resp.status_code}, retrying in {wait}s...") |
| except requests.exceptions.Timeout: |
| print(f" [{idx}] Attempt {attempt}/{MAX_RETRIES} timed out ({TIMEOUT}s), retrying in {wait}s...") |
| |
| if attempt == MAX_RETRIES: |
| wait_for_server_recovery() |
| except Exception as e: |
| print(f" [{idx}] Attempt {attempt}/{MAX_RETRIES} exception: {type(e).__name__}: {e}, retrying in {wait}s...") |
|
|
| if attempt < MAX_RETRIES: |
| time.sleep(wait) |
| wait = min(wait * 2, RETRY_MAX_WAIT) |
|
|
| elapsed = time.time() - start_time |
| print(f" ❌ [{idx}] Max retries reached, skipping this sample (set to blank). Total time {elapsed:.1f}s") |
| return idx, item['input'], item['output'].strip(), "" |
|
|
|
|
| def save_csv_full(output_csv, test_data, all_results): |
| """Save the full CSV file (in index order).""" |
| with open(output_csv, 'w', newline='', encoding='utf-8') as csv_f: |
| writer = csv.DictWriter(csv_f, fieldnames=['index', 'input', 'expected_output', 'predicted_output', 'status']) |
| writer.writeheader() |
| for i in range(len(test_data)): |
| if i in all_results: |
| inp, expected, predicted = all_results[i] |
| status = "success" if predicted and predicted.strip() else "failed" |
| writer.writerow({ |
| 'index': i, |
| 'input': inp, |
| 'expected_output': expected, |
| 'predicted_output': predicted, |
| 'status': status |
| }) |
| else: |
| writer.writerow({ |
| 'index': i, |
| 'input': test_data[i]['input'], |
| 'expected_output': test_data[i]['output'], |
| 'predicted_output': "", |
| 'status': "pending" |
| }) |
|
|
|
|
| def load_existing_results(output_csv): |
| """ |
| Load existing results. |
| - all_results: all processed entries (including failures), so writing |
| the CSV doesn't lose data |
| - success_indices: indices that succeeded only, used to decide which |
| entries need to be re-run |
| """ |
| all_results = {} |
| success_indices = set() |
|
|
| if not os.path.exists(output_csv): |
| return all_results, success_indices |
|
|
| print(f"Found an existing CSV file, checking checkpoint...") |
| try: |
| with open(output_csv, 'r', encoding='utf-8') as csv_f: |
| reader = csv.DictReader(csv_f) |
| for row in reader: |
| idx = int(row['index']) |
| predicted = row.get('predicted_output', '') |
| all_results[idx] = ( |
| row['input'], |
| row['expected_output'], |
| predicted |
| ) |
| |
| if predicted and predicted.strip(): |
| success_indices.add(idx) |
|
|
| total_seen = len(all_results) |
| success_count = len(success_indices) |
| failed_count = total_seen - success_count |
| print(f" -> Loaded {total_seen} records") |
| print(f" -> Success: {success_count}, Failed/empty (will retry): {failed_count}") |
|
|
| except Exception as e: |
| print(f" ⚠️ Failed to read old CSV: {e}, starting over.") |
|
|
| return all_results, success_indices |
|
|
|
|
| def run_file(input_filename): |
| input_path = os.path.join(BASE_DIR, input_filename) |
| stem = os.path.splitext(input_filename)[0] |
| output_csv = os.path.join(BASE_DIR, stem + SUFFIX + ".csv") |
| output_txt = os.path.join(BASE_DIR, stem + SUFFIX + ".txt") |
|
|
| print(f"\n{'='*60}") |
| print(f"Processing file: {input_filename}") |
| print(f" -> CSV: {output_csv}") |
| print(f" -> TXT: {output_txt}") |
| print(f"{'='*60}") |
|
|
| |
| test_data = [] |
| with open(input_path, 'r', encoding='utf-8') as f: |
| for line in f: |
| if line.strip(): |
| test_data.append(json.loads(line.strip())) |
| print(f"Raw jsonl has {len(test_data)} records") |
|
|
| |
| all_results, success_indices = load_existing_results(output_csv) |
|
|
| |
| to_process = [] |
| for i, item in enumerate(test_data): |
| if i not in success_indices: |
| to_process.append((i, item)) |
|
|
| print(f" -> Records to process: {len(to_process)} (including failed retries)") |
|
|
| if not to_process: |
| print(" ✅ All data already processed successfully, skipping request phase") |
| else: |
| |
| print(f" 🔍 Checking server connectivity...") |
| if not check_server_alive(): |
| print(f" ❌ Server not responding, attempting to wait for recovery...") |
| if not wait_for_server_recovery(): |
| print(f" ❌ Could not connect to server, skipping this file") |
| return |
|
|
| print(f" 🚀 Starting processing, concurrency: {MAX_WORKERS}") |
| csv_lock = Lock() |
| success_count = 0 |
| error_count = 0 |
| last_save_count = 0 |
|
|
| with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: |
| futures = {executor.submit(process_item, item_task): item_task[0] for item_task in to_process} |
|
|
| for count, future in enumerate(as_completed(futures), 1): |
| idx, inp, actual, predicted = future.result() |
|
|
| |
| all_results[idx] = (inp, actual, predicted) |
|
|
| if predicted and predicted.strip(): |
| success_count += 1 |
| else: |
| error_count += 1 |
|
|
| if count % 10 == 0 or count == len(to_process): |
| print(f" 📊 Progress: {count}/{len(to_process)} | Success: {success_count} | Failed/empty: {error_count} | " |
| f"Success rate: {success_count/count*100:.1f}%") |
|
|
| if count - last_save_count >= SAVE_INTERVAL or count == len(to_process): |
| with csv_lock: |
| save_csv_full(output_csv, test_data, all_results) |
| last_save_count = count |
| if count % 50 == 0: |
| print(f" 💾 Checkpoint saved (record {count})") |
|
|
| save_csv_full(output_csv, test_data, all_results) |
| print(f" ✅ Request phase complete! Success: {success_count}, Failed: {error_count}, Total processed: {len(to_process)}") |
|
|
| |
| valid_results = [] |
| for i in range(len(test_data)): |
| if i in all_results: |
| inp, expected, predicted = all_results[i] |
| valid_results.append((expected, predicted)) |
|
|
| if not valid_results: |
| print(" ⚠️ No valid data available to compute metrics.") |
| return |
|
|
| actual_labels = [r[0] for r in valid_results] |
| predictions = [r[1] for r in valid_results] |
|
|
| non_empty_predictions = [p for p in predictions if p and p.strip()] |
| print(f" 📈 Computing metrics: total samples {len(valid_results)}, valid predictions {len(non_empty_predictions)}") |
|
|
| rouge_scores, bleu_scores, cosine_similarities = [], [], [] |
|
|
| try: |
| all_texts = actual_labels + predictions |
| if len(set(all_texts)) >= 2: |
| vectorizer = CountVectorizer().fit(all_texts) |
| else: |
| vectorizer = CountVectorizer() |
| vectorizer.fit(actual_labels + ["dummy text for fitting"]) |
| except Exception as e: |
| print(f" ⚠️ Failed to initialize vectorizer: {e}, using default values") |
| vectorizer = CountVectorizer() |
| vectorizer.fit(["sample text", "another sample"]) |
|
|
| for actual, predicted in zip(actual_labels, predictions): |
| if not predicted or not predicted.strip(): |
| rouge_scores.append({'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0}) |
| bleu_scores.append(0.0) |
| cosine_similarities.append(0.0) |
| continue |
|
|
| try: |
| rouge = scorer.score(actual, predicted) |
| rouge_scores.append({ |
| 'rouge1': rouge['rouge1'].fmeasure, |
| 'rouge2': rouge['rouge2'].fmeasure, |
| 'rougeL': rouge['rougeL'].fmeasure |
| }) |
| bleu = sentence_bleu([actual.split()], predicted.split()) |
| bleu_scores.append(bleu) |
| actual_vec = vectorizer.transform([actual]) |
| predicted_vec = vectorizer.transform([predicted]) |
| cosine_sim = cosine_similarity(actual_vec, predicted_vec)[0][0] |
| cosine_similarities.append(cosine_sim) |
|
|
| except Exception as e: |
| print(f" ⚠️ Error computing metrics: {e}") |
| rouge_scores.append({'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0}) |
| bleu_scores.append(0.0) |
| cosine_similarities.append(0.0) |
|
|
| avg_rouge1 = sum(s['rouge1'] for s in rouge_scores) / len(rouge_scores) |
| avg_rouge2 = sum(s['rouge2'] for s in rouge_scores) / len(rouge_scores) |
| avg_rougeL = sum(s['rougeL'] for s in rouge_scores) / len(rouge_scores) |
| avg_bleu = sum(bleu_scores) / len(bleu_scores) |
| avg_cosine = sum(cosine_similarities) / len(cosine_similarities) |
|
|
| with open(output_txt, 'w', encoding='utf-8') as f: |
| f.write(f"Total samples: {len(valid_results)}\n") |
| f.write(f"Valid predictions: {len(non_empty_predictions)}\n") |
| f.write(f"Empty predictions: {len(valid_results) - len(non_empty_predictions)}\n") |
| f.write(f"\nAverage metrics:\n") |
| f.write(f"ROUGE-1: {avg_rouge1:.4f}\n") |
| f.write(f"ROUGE-2: {avg_rouge2:.4f}\n") |
| f.write(f"ROUGE-L: {avg_rougeL:.4f}\n") |
| f.write(f"BLEU: {avg_bleu:.4f}\n") |
| f.write(f"Cosine similarity: {avg_cosine:.4f}\n") |
|
|
| print(f"\n 📊 Final metrics:") |
| print(f" Total samples: {len(valid_results)} | Valid predictions: {len(non_empty_predictions)} | Empty predictions: {len(valid_results) - len(non_empty_predictions)}") |
| print(f" ROUGE-1: {avg_rouge1:.4f} | ROUGE-2: {avg_rouge2:.4f} | ROUGE-L: {avg_rougeL:.4f}") |
| print(f" BLEU: {avg_bleu:.4f} | Cosine similarity: {avg_cosine:.4f}") |
| print(f" ✅ Results saved to: {output_txt}") |
| print(f" ✅ CSV saved to: {output_csv}") |
|
|
|
|
| def check_missing_samples(): |
| """Check missing samples across all files.""" |
| print("\n" + "="*60) |
| print("Checking completion status of each file") |
| print("="*60) |
|
|
| for fname in INPUT_FILES: |
| csv_path = os.path.join(BASE_DIR, fname.replace('.jsonl', SUFFIX + '.csv')) |
| jsonl_path = os.path.join(BASE_DIR, fname) |
|
|
| if not os.path.exists(jsonl_path): |
| print(f"⚠️ {fname}: jsonl file does not exist") |
| continue |
|
|
| with open(jsonl_path, 'r') as f: |
| total = sum(1 for line in f if line.strip()) |
|
|
| if not os.path.exists(csv_path): |
| print(f"❌ {fname}: CSV does not exist, need to process {total} records") |
| continue |
|
|
| with open(csv_path, 'r') as f: |
| reader = csv.DictReader(f) |
| rows = list(reader) |
| processed = len(rows) |
| empty_count = sum(1 for r in rows if not r.get('predicted_output', '').strip()) |
| success_count = processed - empty_count |
|
|
| print(f"\n📄 {fname}") |
| print(f" Total samples: {total}") |
| print(f" Processed: {processed}") |
| print(f" Success: {success_count}") |
| print(f" Failed/empty: {empty_count}") |
| print(f" Completion: {processed/total*100:.1f}%") |
|
|
| if empty_count > 0: |
| print(f" ⚠️ {empty_count} failed/empty predictions remain, re-running will retry them automatically") |
|
|
|
|
| if __name__ == "__main__": |
| print("="*60) |
| print("Note: if interrupted midway, re-running will resume automatically from the checkpoint (failed entries will also be retried)") |
| print(f"Concurrency: {MAX_WORKERS}, Timeout: {TIMEOUT}s, Save interval: {SAVE_INTERVAL}") |
| print("="*60) |
|
|
| check_missing_samples() |
|
|
| print("\n" + "="*60) |
| print("Starting to process files") |
| print("="*60) |
|
|
| for fname in INPUT_FILES: |
| try: |
| run_file(fname) |
| except KeyboardInterrupt: |
| print("\n\n⚠️ Interrupted by user! Current progress has been saved, will resume next run") |
| break |
| except Exception as e: |
| print(f"\n❌ Error occurred while processing {fname}: {e}") |
| import traceback |
| traceback.print_exc() |
| continue |
|
|
| print("\n" + "="*60) |
| print("All processing complete!") |
| print("="*60) |
|
|
| check_missing_samples() |
|
|