| | import argparse
|
| | import requests
|
| | import json
|
| | from pathlib import Path
|
| | import logging
|
| |
|
| | logger = logging.getLogger("compare-logprobs")
|
| | logging.basicConfig(level=logging.INFO)
|
| |
|
| |
|
| | DESCRIPTION = """
|
| | Compare logits between llama.cpp and another inference engine using OpenAI-compatible server endpoints.
|
| |
|
| | Unlike compare-logits.py, it allows dumping logits from a hosted API endpoint. Useful when it's not possible to run both models locally.
|
| |
|
| | Example usage:
|
| | Step 1: Dump logits from two different servers
|
| | python scripts/compare-logprobs.py dump logits_llama.log http://localhost:8080/v1/completions
|
| | python scripts/compare-logprobs.py dump logits_other.log http://other-engine:8000/v1/completions
|
| |
|
| | (optionally, you can add --api-key <key> if the endpoint requires authentication)
|
| |
|
| | Step 2: Compare the dumped logits
|
| | python scripts/compare-logprobs.py compare logits_llama.log logits_other.log report.md
|
| | """
|
| |
|
| |
|
| | def get_remote_corpus(url: str, length: int) -> list[str]:
|
| | response = requests.get(url)
|
| | response.raise_for_status()
|
| | corpus = response.text
|
| | words = [w.strip() for w in corpus.strip().split(" ")]
|
| | words = [w for w in words if "<" not in w]
|
| | words = [w for w in words if len(w) > 0]
|
| | while len(words) < length:
|
| | words += words
|
| | return words[:length]
|
| |
|
| |
|
| | def dump_logits(
|
| | endpoint: str,
|
| | output_path: Path,
|
| | input_words: list[str],
|
| | pattern: list[tuple[bool, int]],
|
| | api_key=None,
|
| | ):
|
| | logger.info(f"Dumping logits to {output_path} from endpoint {endpoint}...")
|
| | words = input_words
|
| | curr_text = ""
|
| | n_total = sum(n for get, n in pattern if get)
|
| | n_done = 0
|
| | i_cur = 0
|
| | i_total = len(words)
|
| | with output_path.open("w") as f:
|
| | for get, n in pattern:
|
| | if not get:
|
| |
|
| | for i in range(n):
|
| | curr_text += words.pop(0) + " "
|
| | i_cur += 1
|
| | continue
|
| |
|
| | for i in range(n):
|
| | curr_text += words.pop(0) + " "
|
| | payload = {
|
| | "prompt": curr_text.strip(),
|
| | "temperature": 0.0,
|
| | "top_k": 1,
|
| | "max_tokens": 1,
|
| | "logprobs": 1,
|
| | "stream": False,
|
| | }
|
| | response = requests.post(
|
| | endpoint,
|
| | json=payload,
|
| | headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
|
| | )
|
| | response.raise_for_status()
|
| | data = response.json()
|
| | data["__index"] = i_cur
|
| | data = json.dumps(data)
|
| | f.write(f"{data}\n")
|
| | n_done += 1
|
| | i_cur += 1
|
| | logger.info(
|
| | f"\n\n{data}\n\n[Step: {n_done}/{n_total} | Word: {i_cur}/{i_total}]"
|
| | )
|
| | logger.info(f"Logits dumped to {output_path}")
|
| |
|
| |
|
| | def get_token_logprobs(data: dict):
|
| | logprobs = data["choices"][0]["logprobs"]
|
| | if "content" in logprobs:
|
| |
|
| | top = logprobs["content"][0]["top_logprobs"][0]
|
| | return top["token"], top["logprob"]
|
| | else:
|
| |
|
| | tokens = logprobs["tokens"]
|
| | token_logprobs = logprobs["token_logprobs"]
|
| | return tokens[0], token_logprobs[0]
|
| |
|
| |
|
| | def clean_text(text: str) -> str:
|
| | return (
|
| | "'"
|
| | + text.replace("\n", "\\n")
|
| | .replace("\t", "\\t")
|
| | .replace("\r", "\\r")
|
| | .replace("|", "\\|")
|
| | + "'"
|
| | )
|
| |
|
| |
|
| | def compare_logits(input1: Path, input2: Path, output_path: Path):
|
| | with input1.open("r") as f1, input2.open("r") as f2, output_path.open("w") as fout:
|
| | lines1 = f1.readlines()
|
| | lines2 = f2.readlines()
|
| |
|
| | tab_header = [
|
| | "idx",
|
| | input1.name,
|
| | "logprob_1",
|
| | input2.name,
|
| | "logprob_2",
|
| | "diff (abs)",
|
| | ]
|
| | tab_entries = []
|
| | tab_max_widths = [len(h) for h in tab_header]
|
| |
|
| | assert len(lines1) == len(
|
| | lines2
|
| | ), "Input files must have the same number of lines."
|
| |
|
| | fout.write("# Logits Comparison Report\n\n")
|
| | for i, (line1, line2) in enumerate(zip(lines1, lines2)):
|
| | if not line1.strip() or not line2.strip():
|
| | continue
|
| |
|
| | data1 = json.loads(line1)
|
| | data2 = json.loads(line2)
|
| |
|
| | idx1 = data1.get("__index", -1)
|
| | idx2 = data2.get("__index", -1)
|
| | if idx1 != idx2:
|
| | logger.warning(
|
| | f"Warning: Mismatched indices at line {i}: {idx1} vs {idx2}"
|
| | )
|
| |
|
| | token1, logprob1 = get_token_logprobs(data1)
|
| | token2, logprob2 = get_token_logprobs(data2)
|
| |
|
| | token1 = clean_text(token1)
|
| | token2 = clean_text(token2)
|
| | abs_diff = abs(logprob1 - logprob2)
|
| |
|
| | tab_entries.append(
|
| | (
|
| | str(idx1 + 1),
|
| | token1,
|
| | f"{logprob1:.4f}",
|
| | token2,
|
| | f"{logprob2:.4f}",
|
| | f"{(abs_diff):.4f}",
|
| | )
|
| | )
|
| |
|
| | for i in range(len(tab_entries)):
|
| | for j in range(len(tab_header)):
|
| | tab_max_widths[j] = max(tab_max_widths[j], len(tab_entries[i][j]))
|
| |
|
| | output = ""
|
| | for j in range(len(tab_header)):
|
| | output += f"| {tab_header[j]:<{tab_max_widths[j]}} "
|
| | output += "|\n"
|
| | for j in range(len(tab_header)):
|
| | output += f"|{'-' * (tab_max_widths[j] + 2)}"
|
| | output += "|\n"
|
| | for entry in tab_entries:
|
| | for j in range(len(tab_header)):
|
| | output += f"| {entry[j]:<{tab_max_widths[j]}} "
|
| | output += "|\n"
|
| |
|
| | logger.info("\n" + output)
|
| | fout.write(output)
|
| | logger.info(f"Report written to {output_path}")
|
| |
|
| |
|
| | def parse_pattern(pattern: str) -> list[tuple[bool, int]]:
|
| | parts = pattern.split(",")
|
| | result = []
|
| | for i, part in enumerate(parts):
|
| | n = int(part)
|
| | if i % 2 == 0:
|
| | result.append((True, n))
|
| | else:
|
| | result.append((False, n))
|
| | return result
|
| |
|
| |
|
| | def parse_args() -> argparse.Namespace:
|
| | parser = argparse.ArgumentParser(
|
| | description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter
|
| | )
|
| | subparsers = parser.add_subparsers(
|
| | dest="verb", required=True, help="action to perform"
|
| | )
|
| |
|
| |
|
| | parser_dump = subparsers.add_parser("dump", help="dump logits from an endpoint")
|
| | parser_dump.add_argument(
|
| | "output", type=Path, help="output path for dumped logits (.log)"
|
| | )
|
| | parser_dump.add_argument(
|
| | "endpoint", type=str, help="OAI-compat /completions endpoint"
|
| | )
|
| | parser_dump.add_argument(
|
| | "--api-key",
|
| | type=str,
|
| | default=None,
|
| | help="API key for authentication (if required)",
|
| | )
|
| | parser_dump.add_argument(
|
| | "--file",
|
| | type=str,
|
| | default="https://raw.githubusercontent.com/ggml-org/llama.cpp/eaba92c3dcc980ebe753348855d4a5d75c069997/tools/server/README.md",
|
| | help="File containing prompt to use instead of the default (can also be an URL)",
|
| | )
|
| | parser_dump.add_argument(
|
| | "--pattern",
|
| | type=str,
|
| | default="10,1000,10,4000,10",
|
| | help="Pattern n_get,n_skip,... where n_get is number of words to get and n_skip is number of words to skip (num of words, NOT num of tokens)",
|
| | )
|
| |
|
| |
|
| | parser_compare = subparsers.add_parser(
|
| | "compare", help="compare two dumped logits files"
|
| | )
|
| | parser_compare.add_argument("input1", type=Path, help="first input file (.log)")
|
| | parser_compare.add_argument("input2", type=Path, help="second input file (.log)")
|
| | parser_compare.add_argument(
|
| | "output", type=Path, help="output path for comparison report (.md)"
|
| | )
|
| |
|
| | try:
|
| | return parser.parse_args()
|
| | except Exception as e:
|
| | parser.print_help()
|
| | raise e
|
| |
|
| |
|
| | def main():
|
| | args = parse_args()
|
| |
|
| | if args.verb == "dump":
|
| | pattern = parse_pattern(args.pattern)
|
| | required_words = sum(n for _, n in pattern)
|
| | if args.file.startswith("http"):
|
| | input_words = get_remote_corpus(args.file, required_words)
|
| | logger.info(f"Fetched {len(input_words)} words from remote {args.file}")
|
| | else:
|
| | with open(args.file, "r") as f:
|
| | input_words = f.read().strip().split(" ")
|
| | input_words = [w for w in input_words if len(w) > 0]
|
| | if len(input_words) < required_words:
|
| | raise ValueError(
|
| | f"Input file has only {len(input_words)} words, but pattern requires at least {required_words} words."
|
| | )
|
| | logger.info(f"Using {len(input_words)} words")
|
| | dump_logits(args.endpoint, args.output, input_words, pattern, args.api_key)
|
| | elif args.verb == "compare":
|
| | compare_logits(args.input1, args.input2, args.output)
|
| | else:
|
| | raise ValueError(f"Unknown verb: {args.verb}")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|