import json import os import hashlib from typing import Any, Dict, Tuple, List from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm import requests import re from loguru import logger def getenv_str(key: str, default: str) -> str: v = os.environ.get(key) return default if v is None else v def getenv_int(key: str, default: int) -> int: v = os.environ.get(key) if v is None or v.strip() == "": return default try: return int(v) except ValueError: raise ValueError(f"Env var {key} must be int, got: {v!r}") def extract_first_int_in_the_string(txt: str): match = re.search(r'\d+', txt) return int(match.group()) if match else None # ---------------------------- # Read config from environment # ---------------------------- CONFIG_DIR = getenv_str("CONFIG_DIR", "") SAVE_DIR = getenv_str("SAVE_DIR", CONFIG_DIR) WORKING_DIR = getenv_str("EVAL_WORKING_DIR", "") WORKING_EVAL_SUBWORD = getenv_str("EVAL_SUBWORD", "") FORBIDDEN_SUBWORDS: List[str] = json.loads(getenv_str("FORBIDDEN_SUBWORDS_JSON", "[]")) PARTICULAR = getenv_str("PARTICULAR", "") BASE_PORT = getenv_int("BASE_PORT", 8002) MAX_TOKEN = getenv_int("MAX_TOKEN", 512) SYSTEM_PROMPT = getenv_str("OVERRIDING_SYSTEM_PROMPT", "") # Prefer explicit URL->ckpt mapping from RUNME.sh MODELS_JSON_ENV = getenv_str("MODELS_JSON", "").strip() if MODELS_JSON_ENV: MODELS: Dict[str, int] = json.loads(MODELS_JSON_ENV) MODELS = {str(k): int(v) for k, v in MODELS.items()} else: # Fallback sequential mapping (rarely used now) checkpoints = json.loads(getenv_str("CKPTS_JSON", "[1000]")) MODELS = {f"http://localhost:{BASE_PORT + i}/v1/chat/completions": int(checkpoints[i]) for i in range(len(checkpoints))} MAX_WORKERS = min(16, max(1, len(MODELS))) def thought_generator_with_local_LLM_requests( message, LLM_model, LLM_max_new_tokens=128, n=1, API_URL="http://localhost:8000/v1/chat/completions", timeout_sec=600, stream=False, ) -> str | list[Any] | Any: # Your eval uses stream=False; keep it simple. payload = { "model": LLM_model, "messages": message, "n": n, "max_tokens": LLM_max_new_tokens, "enable_thinking": False, "stream": stream, "do_sample": False } r = requests.post( API_URL, json=payload, headers={"Content-Type": "application/json", "Authorization": "Bearer 0"}, timeout=timeout_sec, ) if r.status_code != 200: logger.error(f"LLM API error {r.status_code}: {r.text}") raise RuntimeError(f"LLM API returned {r.status_code}") data = r.json() if n == 1: return data["choices"][0]["message"]["content"] return [c["message"]["content"] for c in data["choices"]] def call_one_model( model_url: str, ckpt: int, msgs, gold_label: str, prev_retries: int, max_token: int ) -> Tuple[int, Dict[str, Any]]: try: response = thought_generator_with_local_LLM_requests( message=msgs, LLM_model="custom-model", LLM_max_new_tokens=max_token, n=1, API_URL=model_url, timeout_sec=300, stream=False, ) except Exception as e: logger.error(f"Error getting response from model at {model_url}: {e}") response = "" return ckpt, { "response": "", # "PR": 0, # "NA": 0, "retries": prev_retries + 1, "error": str(e) } if not isinstance(response, str) or response.strip() == "": return ckpt, { "response": "" if not isinstance(response, str) else response, # "PR": 0, # "NA": 0, "retries": prev_retries + 1, "error": "empty_response", } return ckpt, { "response": response, # "PR": 1 if response == gold_label else 0, # "NA": 1 if extract_first_int_in_the_string(response) == extract_first_int_in_the_string(gold_label) else 0, "retries": prev_retries, # keep retry count (or reset if you prefer) } def entry_uid(system: str, prompt: str, gold_label: str) -> str: global SYSTEM_PROMPT payload = {"system": SYSTEM_PROMPT or system, "prompt": prompt, "response": gold_label} s = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":")) return hashlib.sha1(s.encode("utf-8")).hexdigest() def load_cache(path: str) -> Dict[str, Dict[str, Any]]: if not os.path.exists(path): return {} try: with open(path, "r") as f: data = json.load(f) cache = {} for e in data: uid = entry_uid(e.get("system", ""), e.get("prompt", ""), e.get("response", "")) cache[uid] = e logger.info(f"Loaded cache from {path}: {len(cache)} entries") return cache except Exception as ex: logger.warning(f"Failed to load cache from {path} (starting fresh): {ex}") return {} def should_run_step(o_entry: Dict[str, Any], ckpt: int) -> bool: key = f"step_{ckpt}" if key not in o_entry: return True v = o_entry.get(key) or {} retries = int(v.get("retries", 0) or 0) out = v.get("response", "") if (not isinstance(out, str)) or (out.strip() == ""): return retries < 3 return False def atomic_write_json(path: str, obj: Any) -> None: tmp = path + ".tmp" with open(tmp, "w") as f: json.dump(obj, f, indent=2, ensure_ascii=False) os.replace(tmp, path) def should_process_file(filename: str) -> bool: if WORKING_EVAL_SUBWORD and WORKING_EVAL_SUBWORD not in filename: return False if any(sub in filename for sub in FORBIDDEN_SUBWORDS): return False if PARTICULAR and PARTICULAR not in filename: return False return filename.endswith(".json") if __name__ == "__main__": logger.info(f"WORKING_DIR={WORKING_DIR}") logger.info(f"SAVE_DIR={SAVE_DIR}") logger.info(f"MODELS={MODELS}") logger.info(f"MAX_WORKERS={MAX_WORKERS}") if not MODELS: print("No models to evaluate (MODELS is empty). Exiting.") raise SystemExit(0) os.makedirs(SAVE_DIR, exist_ok=True) for original_eval_log_file in os.listdir(WORKING_DIR): if not should_process_file(original_eval_log_file): continue print(f"Working in {original_eval_log_file}") original_eval_file = os.path.join(WORKING_DIR, original_eval_log_file) output_eval_file = os.path.join(SAVE_DIR, original_eval_log_file.replace(".json", "_results.json")) with open(original_eval_file, "r") as f: eval_data: list[dict] = json.load(f) cache_map = load_cache(output_eval_file) # Prebuild a full-length output list in the same order as eval_data output_eval_data: list[dict] = [] uids: list[str] = [] for entry in eval_data: system = entry["system"] prompt = entry["prompt"] gold_label = entry["response"] uid = entry_uid(system, prompt, gold_label) uids.append(uid) # Start from cached entry if present; otherwise new skeleton o_entry = dict(cache_map.get(uid, {})) o_entry.update({"system": SYSTEM_PROMPT or system, "prompt": prompt, "response": gold_label}) output_eval_data.append(o_entry) with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: for idx, entry in enumerate(tqdm(eval_data)): system = entry["system"] prompt = entry["prompt"] gold_label = entry["response"] uid = uids[idx] o_entry = output_eval_data[idx] # already contains cached content + required fields msgs = [{"role": "system", "content": SYSTEM_PROMPT or system}, {"role": "user", "content": prompt}] futures = [] for model_url, ckpt in MODELS.items(): step_key = f"step_{ckpt}" prev = o_entry.get(step_key) or {} prev_retries = int(prev.get("retries", 0) or 0) if should_run_step(o_entry, ckpt): futures.append( executor.submit( call_one_model, model_url, ckpt, msgs, gold_label, prev_retries, MAX_TOKEN, ) ) for fut in as_completed(futures): ckpt, result = fut.result() o_entry[f"step_{ckpt}"] = result # Periodic save now writes ALL entries (including cached tail), so file never shrinks if (idx + 1) % 50 == 0: atomic_write_json(output_eval_file, output_eval_data) # Final save atomic_write_json(output_eval_file, output_eval_data) print("Evaluation with checkpoints completed.")