v120rc_exp4 / runA.py
Linksome's picture
Upload folder using huggingface_hub
a17a5a9 verified
import json
import os
import hashlib
from typing import Any, Dict, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import requests
from loguru import logger
working_dir = "/workspace/v119rc_exp4/A8_LoRA"
particular = "" # tune to "" if no specific file to repair
checkpoints = [
# 2000, 4000, 6000, 8000,
# 10000, 12000, 14000, 16000,
18000, 20000
]
base_port = 8002
models = {
f"http://localhost:{base_port + i}/v1/chat/completions": checkpoints[i]
for i in range(len(checkpoints))
}
MAX_WORKERS = min(16, max(1, len(models)))
def thought_generator_with_local_LLM_requests(
message,
LLM_model,
LLM_max_new_tokens=8196,
n=1,
API_URL="http://localhost:8000/v1/chat/completions",
timeout_sec=600,
stream=True,
) -> str | list[Any] | Any:
assert n == 1 if stream else True, "You can't set n>1 when using stream"
if not stream:
payload = {
"model": LLM_model,
"messages": message,
"n": n,
"max_tokens": LLM_max_new_tokens,
}
r = requests.post(
API_URL,
json=payload,
headers={"Content-Type": "application/json", "Authorization": "Bearer 0"},
timeout=timeout_sec,
)
if r.status_code != 200:
logger.error(f"LLM API error {r.status_code}: {r.text}")
raise RuntimeError(f"LLM API returned {r.status_code}")
data = r.json()
if n == 1:
return data["choices"][0]["message"]["content"]
else:
return [c["message"]["content"] for c in data["choices"]]
if n != 1:
raise ValueError("Streaming only supports n=1")
payload = {
"model": LLM_model,
"messages": message,
"n": 1,
"max_tokens": LLM_max_new_tokens,
"stream": True,
}
full_text = ""
try:
with requests.post(
API_URL,
json=payload,
headers={"Content-Type": "application/json", "Authorization": "Bearer 0"},
timeout=timeout_sec,
stream=True,
) as r:
if r.status_code != 200:
logger.error(f"LLM streaming API error {r.status_code}: {r.text}")
raise RuntimeError(f"LLM streaming API returned {r.status_code}")
for line in r.iter_lines(decode_unicode=True):
if not line:
continue
if not line.startswith("data: "):
continue
data = line[6:]
if data.strip() == "[DONE]":
break
try:
j = json.loads(data)
except json.JSONDecodeError:
continue
choices = j.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if not content:
continue
full_text += content
return full_text
except Exception as e:
logger.error(f"Local LLM streaming request failed: {e}")
raise
def extract_label(response: str) -> str:
has_yes = "<Yes>" in response
has_no = "<No>" in response
if has_yes and not has_no:
return "<Yes>"
if has_no and not has_yes:
return "<No>"
return ""
def call_one_model(
model_url: str,
ckpt: int,
msgs,
gold_label: str,
) -> Tuple[int, Dict[str, Any]]:
try:
response = thought_generator_with_local_LLM_requests(
message=msgs,
LLM_model="custom-model",
LLM_max_new_tokens=1024,
n=1,
API_URL=model_url,
timeout_sec=300,
stream=False,
)
except Exception as e:
logger.error(f"Error getting response from model at {model_url}: {e}")
response = ""
label = extract_label(response)
return ckpt, {
"label": label,
"output": response,
"full_output": response,
"accuracy": 1 if label == gold_label else 0,
}
# --------- NEW: resume/checkpoint cache helpers ---------
def entry_uid(system: str, prompt: str, gold_label: str, gold_output: str) -> str:
payload = {"system": system, "prompt": prompt, "gold_label": gold_label, "gold_output": gold_output}
s = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
return hashlib.sha1(s.encode("utf-8")).hexdigest()
def load_cache(path: str) -> Dict[str, Dict[str, Any]]:
"""Return {uid: cached_entry_dict}."""
if not os.path.exists(path):
return {}
try:
with open(path, "r") as f:
data = json.load(f)
cache = {}
for e in data:
uid = entry_uid(e.get("system", ""), e.get("prompt", ""), e.get("gold_label", ""), e.get("gold_output", ""))
cache[uid] = e
logger.info(f"Loaded cache from {path}: {len(cache)} entries")
return cache
except Exception as ex:
logger.warning(f"Failed to load cache from {path} (starting fresh): {ex}")
return {}
def should_run_step(o_entry: Dict[str, Any], ckpt: int) -> bool:
"""Compute this ckpt if missing or empty."""
key = f"step_{ckpt}"
if key not in o_entry:
return True
v = o_entry.get(key) or {}
# Treat missing/empty output as incomplete
out = v.get("output", "")
return not isinstance(out, str) or out.strip() == ""
def atomic_write_json(path: str, obj: Any) -> None:
tmp = path + ".tmp"
with open(tmp, "w") as f:
json.dump(obj, f, indent=2)
os.replace(tmp, path)
if __name__ == "__main__":
for original_eval_log_file in os.listdir(working_dir):
if not original_eval_log_file.startswith("eval_log_") or not original_eval_log_file.endswith(f"{particular}.json") or original_eval_log_file.endswith("_cps.json"):
continue
original_eval_file = os.path.join(working_dir, original_eval_log_file)
output_eval_file = original_eval_file.replace(".json", "_cps.json")
with open(original_eval_file, "r") as f:
eval_data = json.load(f)
# NEW: load existing output cache (if present)
cache_map = load_cache(output_eval_file)
output_eval_data = []
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
for idx, entry in enumerate(tqdm(eval_data)):
system = entry["system"]
prompt = entry["prompt"]
gold_label = entry["gold_label"]
gold_output = entry["gold_output"]
uid = entry_uid(system, prompt, gold_label, gold_output)
# Start from cached entry if available, otherwise create a fresh one
o_entry = cache_map.get(uid, {})
# Ensure base fields are up to date/present
o_entry.update({
"system": system,
"prompt": prompt,
"gold_label": gold_label,
"gold_output": gold_output,
})
msgs = [
{"role": "system", "content": system},
{"role": "user", "content": prompt},
]
futures = []
for model_url, ckpt in models.items():
if should_run_step(o_entry, ckpt):
futures.append(
executor.submit(call_one_model, model_url, ckpt, msgs, gold_label)
)
for fut in as_completed(futures):
ckpt, result = fut.result()
o_entry[f"step_{ckpt}"] = result
output_eval_data.append(o_entry)
# OPTIONAL: periodically flush progress so you can Ctrl-C and resume safely
if (idx + 1) % 50 == 0:
atomic_write_json(output_eval_file, output_eval_data)
atomic_write_json(output_eval_file, output_eval_data)
print("Evaluation with checkpoints completed.")