| """ |
| api_test.py |
| =========== |
| Batch-test all .jsonl files under to_be_tested/, call the LlamaFactory API, |
| and capture the yes/no probabilities written by hf_engine to |
| /tmp/llama_yes_no_prob.json. |
| |
| Dependencies: |
| pip install requests |
| |
| Output (one csv per jsonl file): |
| index | instruction | input | output | yes_prob | no_prob | yes_confidence |
| """ |
|
|
| import csv |
| import json |
| import os |
| import time |
| import glob |
| import requests |
| from threading import Lock |
|
|
| |
| |
| |
| BASE_DIR = "test_task1/to_be_tested/task1" |
| API_URL = "http://127.0.0.1:8000/v1/chat/completions" |
| PROB_FILE = "/tmp/llama_yes_no_prob.json" |
|
|
| MAX_TOKENS = 512 |
| TEMPERATURE = 0 |
| TIMEOUT = 60 |
| MAX_RETRIES = 3 |
|
|
| |
| |
| |
| MAX_WORKERS = 1 |
| |
|
|
| headers = {"Content-Type": "application/json"} |
|
|
|
|
| def read_prob_file() -> dict: |
| """Read the probability file written by hf_engine; return empty values on failure.""" |
| try: |
| with open(PROB_FILE, "r") as f: |
| return json.load(f) |
| except Exception: |
| return {"yes_prob": "", "no_prob": "", "yes_confidence": ""} |
|
|
|
|
| def call_api(instruction: str, input_text: str) -> tuple[str, dict]: |
| """ |
| Send one request, return (model_output, prob_dict). |
| Delete the old prob file before sending, then read the new one after. |
| """ |
| |
| if os.path.exists(PROB_FILE): |
| os.remove(PROB_FILE) |
|
|
| |
| user_content = f"{instruction}\n\n{input_text}" |
| payload = json.dumps({ |
| "model": "string", |
| "messages": [{"role": "user", "content": user_content}], |
| "temperature": TEMPERATURE, |
| "max_tokens": MAX_TOKENS, |
| "stream": False |
| }) |
|
|
| wait = 2 |
| for attempt in range(1, MAX_RETRIES + 1): |
| try: |
| resp = requests.post(API_URL, data=payload, headers=headers, timeout=TIMEOUT) |
| if resp.status_code == 200: |
| model_output = resp.json()["choices"][0]["message"]["content"].strip() |
| |
| prob = read_prob_file() |
| return model_output, prob |
| else: |
| print(f" [attempt {attempt}/{MAX_RETRIES}] HTTP {resp.status_code}, retry in {wait}s...") |
| except Exception as e: |
| print(f" [attempt {attempt}/{MAX_RETRIES}] Exception: {e}, retry in {wait}s...") |
|
|
| if attempt < MAX_RETRIES: |
| time.sleep(wait) |
| wait = min(wait * 2, 30) |
|
|
| print(f" ❌ Max retries reached, skipping this item.") |
| return "", {"yes_prob": "", "no_prob": "", "yes_confidence": ""} |
|
|
|
|
| def run_jsonl(jsonl_path: str): |
| """Process a single jsonl file, output a csv with the same stem name.""" |
| stem = os.path.splitext(os.path.basename(jsonl_path))[0] |
| output_csv = os.path.join(os.path.dirname(jsonl_path), stem + "_results.csv") |
|
|
| print(f"\n{'='*60}") |
| print(f"Processing: {jsonl_path}") |
| print(f"Output: {output_csv}") |
| print(f"{'='*60}") |
|
|
| |
| items = [] |
| with open(jsonl_path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if line: |
| items.append(json.loads(line)) |
| print(f"Total {len(items)} records") |
|
|
| |
| existing = {} |
| if os.path.exists(output_csv): |
| try: |
| with open(output_csv, "r", encoding="utf-8") as cf: |
| reader = csv.DictReader(cf) |
| for row in reader: |
| idx = int(row["index"]) |
| |
| if row.get("output", "").strip() or row.get("yes_prob", "").strip(): |
| existing[idx] = row |
| print(f" Found {len(existing)} existing successful records (resuming from checkpoint)") |
| except Exception as e: |
| print(f" Failed to read old CSV: {e}, starting from scratch") |
|
|
| |
| all_rows = {} |
| for i, item in enumerate(items): |
| if i in existing: |
| all_rows[i] = existing[i] |
| continue |
|
|
| instruction = item.get("instruction", "") |
| input_text = item.get("input", "") |
|
|
| print(f" [{i+1}/{len(items)}] Requesting...", end=" ", flush=True) |
| model_output, prob = call_api(instruction, input_text) |
| print(f"done yes={prob.get('yes_confidence', '?')}") |
|
|
| all_rows[i] = { |
| "index": i, |
| "instruction": instruction, |
| "input": input_text, |
| "output": model_output, |
| "yes_prob": prob.get("yes_prob", ""), |
| "no_prob": prob.get("no_prob", ""), |
| "yes_confidence": prob.get("yes_confidence", ""), |
| } |
|
|
| |
| with open(output_csv, "w", newline="", encoding="utf-8") as cf: |
| fieldnames = ["index", "instruction", "input", "output", |
| "yes_prob", "no_prob", "yes_confidence"] |
| writer = csv.DictWriter(cf, fieldnames=fieldnames) |
| writer.writeheader() |
| for idx in range(len(items)): |
| if idx in all_rows: |
| writer.writerow(all_rows[idx]) |
|
|
| print(f" ✓ Done → {output_csv}") |
|
|
|
|
| def main(): |
| |
| pattern = os.path.join(BASE_DIR, "*.jsonl") |
| jsonl_files = sorted(glob.glob(pattern)) |
|
|
| if not jsonl_files: |
| print(f"❌ No .jsonl files found (path: {pattern})") |
| return |
|
|
| print(f"Found {len(jsonl_files)} jsonl files:") |
| for f in jsonl_files: |
| print(f" {f}") |
|
|
| for jsonl_path in jsonl_files: |
| run_jsonl(jsonl_path) |
|
|
| print("\n✅ All processing complete.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|