| import json |
| import os |
| import hashlib |
| from typing import Any, Dict, Tuple |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from tqdm import tqdm |
| import requests |
| from loguru import logger |
|
|
| working_dir = "/workspace/v119rc_exp4/A8_LoRA" |
| particular = "" |
|
|
| checkpoints = [ |
| |
| |
| 18000, 20000 |
| ] |
|
|
| base_port = 8002 |
| models = { |
| f"http://localhost:{base_port + i}/v1/chat/completions": checkpoints[i] |
| for i in range(len(checkpoints)) |
| } |
|
|
| MAX_WORKERS = min(16, max(1, len(models))) |
|
|
|
|
| def thought_generator_with_local_LLM_requests( |
| message, |
| LLM_model, |
| LLM_max_new_tokens=8196, |
| n=1, |
| API_URL="http://localhost:8000/v1/chat/completions", |
| timeout_sec=600, |
| stream=True, |
| ) -> str | list[Any] | Any: |
| assert n == 1 if stream else True, "You can't set n>1 when using stream" |
|
|
| if not stream: |
| payload = { |
| "model": LLM_model, |
| "messages": message, |
| "n": n, |
| "max_tokens": LLM_max_new_tokens, |
| } |
|
|
| r = requests.post( |
| API_URL, |
| json=payload, |
| headers={"Content-Type": "application/json", "Authorization": "Bearer 0"}, |
| timeout=timeout_sec, |
| ) |
|
|
| if r.status_code != 200: |
| logger.error(f"LLM API error {r.status_code}: {r.text}") |
| raise RuntimeError(f"LLM API returned {r.status_code}") |
|
|
| data = r.json() |
| if n == 1: |
| return data["choices"][0]["message"]["content"] |
| else: |
| return [c["message"]["content"] for c in data["choices"]] |
|
|
| if n != 1: |
| raise ValueError("Streaming only supports n=1") |
|
|
| payload = { |
| "model": LLM_model, |
| "messages": message, |
| "n": 1, |
| "max_tokens": LLM_max_new_tokens, |
| "stream": True, |
| } |
|
|
| full_text = "" |
| try: |
| with requests.post( |
| API_URL, |
| json=payload, |
| headers={"Content-Type": "application/json", "Authorization": "Bearer 0"}, |
| timeout=timeout_sec, |
| stream=True, |
| ) as r: |
| if r.status_code != 200: |
| logger.error(f"LLM streaming API error {r.status_code}: {r.text}") |
| raise RuntimeError(f"LLM streaming API returned {r.status_code}") |
|
|
| for line in r.iter_lines(decode_unicode=True): |
| if not line: |
| continue |
| if not line.startswith("data: "): |
| continue |
|
|
| data = line[6:] |
| if data.strip() == "[DONE]": |
| break |
|
|
| try: |
| j = json.loads(data) |
| except json.JSONDecodeError: |
| continue |
|
|
| choices = j.get("choices", []) |
| if not choices: |
| continue |
|
|
| delta = choices[0].get("delta", {}) |
| content = delta.get("content", "") |
| if not content: |
| continue |
|
|
| full_text += content |
|
|
| return full_text |
|
|
| except Exception as e: |
| logger.error(f"Local LLM streaming request failed: {e}") |
| raise |
|
|
|
|
| def extract_label(response: str) -> str: |
| has_yes = "<Yes>" in response |
| has_no = "<No>" in response |
| if has_yes and not has_no: |
| return "<Yes>" |
| if has_no and not has_yes: |
| return "<No>" |
| return "" |
|
|
|
|
| def call_one_model( |
| model_url: str, |
| ckpt: int, |
| msgs, |
| gold_label: str, |
| ) -> Tuple[int, Dict[str, Any]]: |
| try: |
| response = thought_generator_with_local_LLM_requests( |
| message=msgs, |
| LLM_model="custom-model", |
| LLM_max_new_tokens=1024, |
| n=1, |
| API_URL=model_url, |
| timeout_sec=300, |
| stream=False, |
| ) |
| except Exception as e: |
| logger.error(f"Error getting response from model at {model_url}: {e}") |
| response = "" |
|
|
| label = extract_label(response) |
| return ckpt, { |
| "label": label, |
| "output": response, |
| "full_output": response, |
| "accuracy": 1 if label == gold_label else 0, |
| } |
|
|
|
|
| |
|
|
| def entry_uid(system: str, prompt: str, gold_label: str, gold_output: str) -> str: |
| payload = {"system": system, "prompt": prompt, "gold_label": gold_label, "gold_output": gold_output} |
| s = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":")) |
| return hashlib.sha1(s.encode("utf-8")).hexdigest() |
|
|
|
|
| def load_cache(path: str) -> Dict[str, Dict[str, Any]]: |
| """Return {uid: cached_entry_dict}.""" |
| if not os.path.exists(path): |
| return {} |
| try: |
| with open(path, "r") as f: |
| data = json.load(f) |
| cache = {} |
| for e in data: |
| uid = entry_uid(e.get("system", ""), e.get("prompt", ""), e.get("gold_label", ""), e.get("gold_output", "")) |
| cache[uid] = e |
| logger.info(f"Loaded cache from {path}: {len(cache)} entries") |
| return cache |
| except Exception as ex: |
| logger.warning(f"Failed to load cache from {path} (starting fresh): {ex}") |
| return {} |
|
|
|
|
| def should_run_step(o_entry: Dict[str, Any], ckpt: int) -> bool: |
| """Compute this ckpt if missing or empty.""" |
| key = f"step_{ckpt}" |
| if key not in o_entry: |
| return True |
| v = o_entry.get(key) or {} |
| |
| out = v.get("output", "") |
| return not isinstance(out, str) or out.strip() == "" |
|
|
|
|
| def atomic_write_json(path: str, obj: Any) -> None: |
| tmp = path + ".tmp" |
| with open(tmp, "w") as f: |
| json.dump(obj, f, indent=2) |
| os.replace(tmp, path) |
|
|
|
|
| if __name__ == "__main__": |
| for original_eval_log_file in os.listdir(working_dir): |
| if not original_eval_log_file.startswith("eval_log_") or not original_eval_log_file.endswith(f"{particular}.json") or original_eval_log_file.endswith("_cps.json"): |
| continue |
|
|
| original_eval_file = os.path.join(working_dir, original_eval_log_file) |
| output_eval_file = original_eval_file.replace(".json", "_cps.json") |
|
|
| with open(original_eval_file, "r") as f: |
| eval_data = json.load(f) |
|
|
| |
| cache_map = load_cache(output_eval_file) |
|
|
| output_eval_data = [] |
|
|
| with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: |
| for idx, entry in enumerate(tqdm(eval_data)): |
| system = entry["system"] |
| prompt = entry["prompt"] |
| gold_label = entry["gold_label"] |
| gold_output = entry["gold_output"] |
|
|
| uid = entry_uid(system, prompt, gold_label, gold_output) |
|
|
| |
| o_entry = cache_map.get(uid, {}) |
| |
| o_entry.update({ |
| "system": system, |
| "prompt": prompt, |
| "gold_label": gold_label, |
| "gold_output": gold_output, |
| }) |
|
|
| msgs = [ |
| {"role": "system", "content": system}, |
| {"role": "user", "content": prompt}, |
| ] |
|
|
| futures = [] |
| for model_url, ckpt in models.items(): |
| if should_run_step(o_entry, ckpt): |
| futures.append( |
| executor.submit(call_one_model, model_url, ckpt, msgs, gold_label) |
| ) |
|
|
| for fut in as_completed(futures): |
| ckpt, result = fut.result() |
| o_entry[f"step_{ckpt}"] = result |
|
|
| output_eval_data.append(o_entry) |
|
|
| |
| if (idx + 1) % 50 == 0: |
| atomic_write_json(output_eval_file, output_eval_data) |
|
|
| atomic_write_json(output_eval_file, output_eval_data) |
|
|
| print("Evaluation with checkpoints completed.") |
|
|
|
|