| | import json |
| | import os |
| | import hashlib |
| | from typing import Any, Dict, Tuple, List |
| | from concurrent.futures import ThreadPoolExecutor, as_completed |
| |
|
| | from tqdm import tqdm |
| | import requests |
| | import re |
| | from loguru import logger |
| |
|
| |
|
| | def getenv_str(key: str, default: str) -> str: |
| | v = os.environ.get(key) |
| | return default if v is None else v |
| |
|
| |
|
| | def getenv_int(key: str, default: int) -> int: |
| | v = os.environ.get(key) |
| | if v is None or v.strip() == "": |
| | return default |
| | try: |
| | return int(v) |
| | except ValueError: |
| | raise ValueError(f"Env var {key} must be int, got: {v!r}") |
| |
|
| | def extract_first_int_in_the_string(txt: str): |
| | match = re.search(r'\d+', txt) |
| | return int(match.group()) if match else None |
| |
|
| |
|
| | |
| | |
| | |
| | CONFIG_DIR = getenv_str("CONFIG_DIR", "") |
| | SAVE_DIR = getenv_str("SAVE_DIR", CONFIG_DIR) |
| |
|
| | WORKING_DIR = getenv_str("EVAL_WORKING_DIR", "") |
| | WORKING_EVAL_SUBWORD = getenv_str("EVAL_SUBWORD", "") |
| |
|
| | FORBIDDEN_SUBWORDS: List[str] = json.loads(getenv_str("FORBIDDEN_SUBWORDS_JSON", "[]")) |
| | PARTICULAR = getenv_str("PARTICULAR", "") |
| |
|
| | BASE_PORT = getenv_int("BASE_PORT", 8002) |
| | MAX_TOKEN = getenv_int("MAX_TOKEN", 512) |
| |
|
| | SYSTEM_PROMPT = getenv_str("OVERRIDING_SYSTEM_PROMPT", "") |
| |
|
| | |
| | MODELS_JSON_ENV = getenv_str("MODELS_JSON", "").strip() |
| | if MODELS_JSON_ENV: |
| | MODELS: Dict[str, int] = json.loads(MODELS_JSON_ENV) |
| | MODELS = {str(k): int(v) for k, v in MODELS.items()} |
| | else: |
| | |
| | checkpoints = json.loads(getenv_str("CKPTS_JSON", "[1000]")) |
| | MODELS = {f"http://localhost:{BASE_PORT + i}/v1/chat/completions": int(checkpoints[i]) |
| | for i in range(len(checkpoints))} |
| |
|
| | MAX_WORKERS = min(16, max(1, len(MODELS))) |
| |
|
| |
|
| | def thought_generator_with_local_LLM_requests( |
| | message, |
| | LLM_model, |
| | LLM_max_new_tokens=128, |
| | n=1, |
| | API_URL="http://localhost:8000/v1/chat/completions", |
| | timeout_sec=600, |
| | stream=False, |
| | ) -> str | list[Any] | Any: |
| | |
| | payload = { |
| | "model": LLM_model, |
| | "messages": message, |
| | "n": n, |
| | "max_tokens": LLM_max_new_tokens, |
| | "enable_thinking": False, |
| | "stream": stream, |
| | "do_sample": False |
| | } |
| |
|
| | r = requests.post( |
| | API_URL, |
| | json=payload, |
| | headers={"Content-Type": "application/json", "Authorization": "Bearer 0"}, |
| | timeout=timeout_sec, |
| | ) |
| |
|
| | if r.status_code != 200: |
| | logger.error(f"LLM API error {r.status_code}: {r.text}") |
| | raise RuntimeError(f"LLM API returned {r.status_code}") |
| |
|
| | data = r.json() |
| | if n == 1: |
| | return data["choices"][0]["message"]["content"] |
| | return [c["message"]["content"] for c in data["choices"]] |
| |
|
| | def call_one_model( |
| | model_url: str, |
| | ckpt: int, |
| | msgs, |
| | gold_label: str, |
| | prev_retries: int, |
| | max_token: int |
| | ) -> Tuple[int, Dict[str, Any]]: |
| | try: |
| | response = thought_generator_with_local_LLM_requests( |
| | message=msgs, |
| | LLM_model="custom-model", |
| | LLM_max_new_tokens=max_token, |
| | n=1, |
| | API_URL=model_url, |
| | timeout_sec=300, |
| | stream=False, |
| | ) |
| | except Exception as e: |
| | logger.error(f"Error getting response from model at {model_url}: {e}") |
| | response = "" |
| | return ckpt, { |
| | "response": "", |
| | |
| | |
| | "retries": prev_retries + 1, |
| | "error": str(e) |
| | } |
| |
|
| | if not isinstance(response, str) or response.strip() == "": |
| | return ckpt, { |
| | "response": "" if not isinstance(response, str) else response, |
| | |
| | |
| | "retries": prev_retries + 1, |
| | "error": "empty_response", |
| | } |
| |
|
| | return ckpt, { |
| | "response": response, |
| | |
| | |
| | "retries": prev_retries, |
| | } |
| |
|
| | def entry_uid(system: str, prompt: str, gold_label: str) -> str: |
| | global SYSTEM_PROMPT |
| | payload = {"system": SYSTEM_PROMPT or system, "prompt": prompt, "response": gold_label} |
| | s = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":")) |
| | return hashlib.sha1(s.encode("utf-8")).hexdigest() |
| |
|
| |
|
| | def load_cache(path: str) -> Dict[str, Dict[str, Any]]: |
| | if not os.path.exists(path): |
| | return {} |
| | try: |
| | with open(path, "r") as f: |
| | data = json.load(f) |
| | cache = {} |
| | for e in data: |
| | uid = entry_uid(e.get("system", ""), e.get("prompt", ""), e.get("response", "")) |
| | cache[uid] = e |
| | logger.info(f"Loaded cache from {path}: {len(cache)} entries") |
| | return cache |
| | except Exception as ex: |
| | logger.warning(f"Failed to load cache from {path} (starting fresh): {ex}") |
| | return {} |
| |
|
| |
|
| | def should_run_step(o_entry: Dict[str, Any], ckpt: int) -> bool: |
| | key = f"step_{ckpt}" |
| | if key not in o_entry: |
| | return True |
| | v = o_entry.get(key) or {} |
| | retries = int(v.get("retries", 0) or 0) |
| |
|
| | out = v.get("response", "") |
| | if (not isinstance(out, str)) or (out.strip() == ""): |
| | return retries < 3 |
| | return False |
| | |
| |
|
| | def atomic_write_json(path: str, obj: Any) -> None: |
| | tmp = path + ".tmp" |
| | with open(tmp, "w") as f: |
| | json.dump(obj, f, indent=2, ensure_ascii=False) |
| | os.replace(tmp, path) |
| |
|
| |
|
| | def should_process_file(filename: str) -> bool: |
| | if WORKING_EVAL_SUBWORD and WORKING_EVAL_SUBWORD not in filename: |
| | return False |
| | if any(sub in filename for sub in FORBIDDEN_SUBWORDS): |
| | return False |
| | if PARTICULAR and PARTICULAR not in filename: |
| | return False |
| | return filename.endswith(".json") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | logger.info(f"WORKING_DIR={WORKING_DIR}") |
| | logger.info(f"SAVE_DIR={SAVE_DIR}") |
| | logger.info(f"MODELS={MODELS}") |
| | logger.info(f"MAX_WORKERS={MAX_WORKERS}") |
| |
|
| | if not MODELS: |
| | print("No models to evaluate (MODELS is empty). Exiting.") |
| | raise SystemExit(0) |
| |
|
| | os.makedirs(SAVE_DIR, exist_ok=True) |
| |
|
| | for original_eval_log_file in os.listdir(WORKING_DIR): |
| | if not should_process_file(original_eval_log_file): |
| | continue |
| | print(f"Working in {original_eval_log_file}") |
| |
|
| | original_eval_file = os.path.join(WORKING_DIR, original_eval_log_file) |
| | output_eval_file = os.path.join(SAVE_DIR, original_eval_log_file.replace(".json", "_results.json")) |
| |
|
| | with open(original_eval_file, "r") as f: |
| | eval_data: list[dict] = json.load(f) |
| |
|
| | cache_map = load_cache(output_eval_file) |
| | |
| | |
| | output_eval_data: list[dict] = [] |
| | uids: list[str] = [] |
| | |
| | for entry in eval_data: |
| | system = entry["system"] |
| | prompt = entry["prompt"] |
| | gold_label = entry["response"] |
| | |
| | uid = entry_uid(system, prompt, gold_label) |
| | uids.append(uid) |
| | |
| | |
| | o_entry = dict(cache_map.get(uid, {})) |
| | o_entry.update({"system": SYSTEM_PROMPT or system, "prompt": prompt, "response": gold_label}) |
| | |
| | output_eval_data.append(o_entry) |
| | |
| | with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: |
| | for idx, entry in enumerate(tqdm(eval_data)): |
| | system = entry["system"] |
| | prompt = entry["prompt"] |
| | gold_label = entry["response"] |
| | |
| | uid = uids[idx] |
| | o_entry = output_eval_data[idx] |
| | |
| | msgs = [{"role": "system", "content": SYSTEM_PROMPT or system}, {"role": "user", "content": prompt}] |
| | |
| | futures = [] |
| | for model_url, ckpt in MODELS.items(): |
| | step_key = f"step_{ckpt}" |
| | prev = o_entry.get(step_key) or {} |
| | prev_retries = int(prev.get("retries", 0) or 0) |
| | |
| | if should_run_step(o_entry, ckpt): |
| | futures.append( |
| | executor.submit( |
| | call_one_model, |
| | model_url, |
| | ckpt, |
| | msgs, |
| | gold_label, |
| | prev_retries, |
| | MAX_TOKEN, |
| | ) |
| | ) |
| | |
| | for fut in as_completed(futures): |
| | ckpt, result = fut.result() |
| | o_entry[f"step_{ckpt}"] = result |
| | |
| | |
| | if (idx + 1) % 50 == 0: |
| | atomic_write_json(output_eval_file, output_eval_data) |
| | |
| | |
| | atomic_write_json(output_eval_file, output_eval_data) |
| |
|
| | print("Evaluation with checkpoints completed.") |
| |
|