GreenStar24's picture
small
67eefaa
import json
import logging
import os
import re
import threading
import time
import traceback
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timezone
import gradio as gr
import gspread
from google.oauth2.service_account import Credentials
from openai import OpenAI
SCOPES = [
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/drive",
]
DEFAULT_MODEL_NAME = "gpt-5-mini"
DEFAULT_EVAL_MAX_WORKERS = 4
DEFAULT_LLM2_MAX_WORKERS = 8
DEFAULT_OPENAI_MAX_CONCURRENT_REQUESTS = 16
EVAL_REPEAT_COUNT = 3
SUBMISSION_TYPES = {"practice", "final"}
IRRELEVANT_TAG_START = "<user_message>"
IRRELEVANT_TAG_END = "</user_message>"
SHEET_SECTION_DIVIDER = "=" * 50
STAGE_ONE_OUTPUT_MAX_CHARS = 4000
STAGE_TWO_OUTPUT_MAX_CHARS = 2000
STAGE_THREE_OUTPUT_MAX_CHARS = 4000
PROMPT_CELL_MAX_CHARS = 8000
EXPECTED_OUTPUT_MAX_CHARS = 4000
SHEET_CELL_MAX_CHARS = 49000
COUNSEL_FIELDS = ("buyer_counsel", "seller_counsel", "third_party_counsel")
TARGET_FIRM_FIELD = "target_firm"
USER_QUESTION_FIELD = "user_question"
FINAL_SCHEMA_KEYS = (*COUNSEL_FIELDS, USER_QUESTION_FIELD)
CITATION_PATTERN = re.compile(r"\[\^(\d+)\]")
BOOLEAN_ANSWER_PATTERN = re.compile(r"^(true|false|yes|no)\b$", re.IGNORECASE)
LOCAL_ATTEMPTS = {}
_OPENAI_CLIENT = None
_OPENAI_REQUEST_SEMAPHORE = None
_OPENAI_REQUEST_SEMAPHORE_LIMIT = None
_OPENAI_REQUEST_SEMAPHORE_LOCK = threading.Lock()
_LOG_LEVEL_NAME = os.environ.get("LOG_LEVEL", "INFO").upper()
logging.basicConfig(level=getattr(logging, _LOG_LEVEL_NAME, logging.INFO))
logger = logging.getLogger(__name__)
SHEET_COLUMNS = [
("timestamp", "Submitted At"),
("submission_type", "Submission Type"),
("name", "Candidate Name"),
("email", "Candidate Email"),
("results", "Results"),
("results_without_llm_1", "Results LLM3"),
("llm_score_breakdown", "LLM Score Breakdown"),
("llm_output_vs_expected", "LLM Output vs Expected"),
("llm_2_output", "LLM 2 Output"),
("prompts", "Prompts"),
]
SHEET_HEADERS = [label for _, label in SHEET_COLUMNS]
APP_CSS = """
.submission-note {
background: var(--submission-note-bg, #fff7e6);
border: 1px solid var(--submission-note-border, #ffe5b4);
border-radius: 8px;
color: var(--submission-note-text, #3d2b00);
margin-bottom: 1em;
padding: 16px;
}
@media (prefers-color-scheme: dark) {
.submission-note {
--submission-note-bg: #2b2111;
--submission-note-border: #6b4b18;
--submission-note-text: #f7e8c5;
}
}
.dark .submission-note {
--submission-note-bg: #2b2111;
--submission-note-border: #6b4b18;
--submission-note-text: #f7e8c5;
}
"""
def get_model_name():
return os.environ.get("OPENAI_MODEL_NAME", DEFAULT_MODEL_NAME)
def get_positive_int_env(name, default):
raw_value = os.environ.get(name, str(default))
try:
value = int(raw_value)
except ValueError:
logger.warning("Invalid integer for %s=%r; using default %s.", name, raw_value, default)
return default
return max(1, value)
def get_eval_max_workers():
if "EVAL_CASE_MAX_WORKERS" in os.environ:
return get_positive_int_env("EVAL_CASE_MAX_WORKERS", DEFAULT_EVAL_MAX_WORKERS)
return get_positive_int_env("EVAL_MAX_WORKERS", DEFAULT_EVAL_MAX_WORKERS)
def get_llm2_max_workers():
return get_positive_int_env("LLM2_MAX_WORKERS", DEFAULT_LLM2_MAX_WORKERS)
def get_openai_max_concurrent_requests():
return get_positive_int_env(
"OPENAI_MAX_CONCURRENT_REQUESTS",
DEFAULT_OPENAI_MAX_CONCURRENT_REQUESTS,
)
def get_openai_request_semaphore():
global _OPENAI_REQUEST_SEMAPHORE
global _OPENAI_REQUEST_SEMAPHORE_LIMIT
limit = get_openai_max_concurrent_requests()
with _OPENAI_REQUEST_SEMAPHORE_LOCK:
if _OPENAI_REQUEST_SEMAPHORE is None or _OPENAI_REQUEST_SEMAPHORE_LIMIT != limit:
_OPENAI_REQUEST_SEMAPHORE = threading.BoundedSemaphore(limit)
_OPENAI_REQUEST_SEMAPHORE_LIMIT = limit
return _OPENAI_REQUEST_SEMAPHORE
def get_openai_client():
global _OPENAI_CLIENT
if _OPENAI_CLIENT is None:
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY is not set.")
_OPENAI_CLIENT = OpenAI(api_key=api_key)
return _OPENAI_CLIENT
def get_google_sheet(ensure_headers=True):
creds_json = os.environ.get("GOOGLE_CREDS_JSON")
spreadsheet_id = os.environ.get("SPREADSHEET_ID")
if not creds_json or not spreadsheet_id:
raise RuntimeError("GOOGLE_CREDS_JSON and SPREADSHEET_ID must be set.")
creds = Credentials.from_service_account_info(
json.loads(creds_json),
scopes=SCOPES,
)
gc = gspread.authorize(creds)
spreadsheet = gc.open_by_key(spreadsheet_id)
sheet = spreadsheet.worksheet("Submissions")
if ensure_headers:
ensure_submission_sheet_headers(sheet)
return sheet
def load_eval_set(prefix):
questions_str = os.environ.get(f"{prefix}_QUESTIONS_JSON")
documents_str = os.environ.get(f"{prefix}_DOCUMENTS_JSON")
expected_str = os.environ.get(f"{prefix}_EXPECTED_JSON")
if not questions_str or not documents_str or not expected_str:
return {"cases": []}
questions = json.loads(questions_str)
documents = json.loads(documents_str)
expected = json.loads(expected_str)
if not (
isinstance(questions, list)
and isinstance(documents, list)
and isinstance(expected, list)
):
raise ValueError(f"{prefix} dataset must be a JSON list for all fields.")
if len(questions) != len(documents) or len(questions) != len(expected):
raise ValueError(
f"{prefix} dataset lengths do not match for questions, documents, and expected answers."
)
cases = []
for case_index, (question, docs_entry, expected_entry) in enumerate(
zip(questions, documents, expected),
start=1,
):
normalized_docs = normalize_documents_entry(docs_entry)
cases.append(
{
"question": str(question),
"docs": normalized_docs,
"expected": normalize_expected_entry(
expected_entry,
prefix,
case_index,
question=str(question),
docs=normalized_docs,
),
}
)
return {"cases": cases}
def build_submission_response(
ok,
message,
notice="",
disable_practice=False,
disable_final=False,
):
return {
"ok": ok,
"message": message,
"notice": notice,
"disable_practice": disable_practice,
"disable_final": disable_final,
}
def build_internal_error_message(submission_type):
label = "practice run" if submission_type == "practice" else "submission"
return (
f"We hit an internal error while processing your {label}. "
"Please try again in a minute."
)
def build_stage_error_message(submission_type, stage):
label = "practice run" if submission_type == "practice" else "submission"
return (
f"We hit an internal error while processing your {label}. "
f"Stage: {stage}."
)
def build_save_error_message(submission_type, stage="sheet_write"):
label = "practice run" if submission_type == "practice" else "submission"
return (
f"We could not save your {label} right now. "
f"Stage: {stage}. Please check the Space logs and spreadsheet permissions, then try again."
)
def build_bypass_save_notice():
return "Debug mode is active: this run was graded but not saved to Google Sheets."
def should_bypass_sheet_save():
return os.environ.get("BYPASS_SHEET_SAVE", "").strip().lower() == "true"
def log_stage_exception(message, *args):
logger.exception(message, *args)
print(traceback.format_exc(), flush=True)
def safe_evaluate_case(case, prompts, case_index):
try:
return evaluate_case(case, prompts, case_index=case_index)
except Exception as exc:
logger.exception("Case %s failed during evaluation.", case_index)
raise RuntimeError(f"Evaluation failed for case {case_index}.") from exc
def normalize_documents_entry(entry):
if isinstance(entry, list):
documents = [str(item).strip() for item in entry if str(item).strip()]
return documents
text = str(entry).strip()
if not text:
return []
if "---" in text:
parts = [part.strip() for part in text.split("---")]
return [part for part in parts if part]
return [text]
def get_sheet_column_letter(index):
letters = ""
while index > 0:
index, remainder = divmod(index - 1, 26)
letters = chr(65 + remainder) + letters
return letters
def ensure_submission_sheet_headers(sheet):
expected_headers = SHEET_HEADERS
end_column = get_sheet_column_letter(len(expected_headers))
sheet.update(f"A1:{end_column}1", [expected_headers])
def get_next_submission_row_index(sheet):
rows = sheet.get_all_values()
last_data_row = 1
for row_index, row in enumerate(rows, start=1):
if row_index == 1:
continue
if any(str(cell).strip() for cell in row):
last_data_row = row_index
return last_data_row + 1
def serialize_sheet_value(value):
if isinstance(value, (dict, list)):
return serialize_json(value)
return "" if value is None else str(value)
def normalize_expected_mapping(entry, prefix, case_index):
raw_target_firm = entry.get(TARGET_FIRM_FIELD)
if raw_target_firm is None:
raise ValueError(
f"{prefix} case {case_index} is missing a non-empty '{TARGET_FIRM_FIELD}'."
)
target_firm = normalize_whitespace(raw_target_firm)
if not target_firm or target_firm.lower() == "none":
raise ValueError(
f"{prefix} case {case_index} is missing a non-empty '{TARGET_FIRM_FIELD}'."
)
normalized = {TARGET_FIRM_FIELD: target_firm}
for field_name in COUNSEL_FIELDS:
field_value = entry.get(field_name, "unknown")
if field_value is None:
normalized[field_name] = "unknown"
continue
normalized[field_name] = str(field_value).strip() or "unknown"
return normalized
def extract_target_candidate_from_question(question):
stripped = normalize_whitespace(question)
patterns = [
r"^Is\s+(.+?)\s+present\s+in\s+the\s+agreement\??$",
r"^Is\s+(.+?)\s+mentioned\s+in\s+the\s+agreement\??$",
r"^Is\s+(.+?)\s+acting\s+as\s+counsel.*\??$",
r"^Is\s+(.+?)\s+anywhere\s+in\s+this\s+Asset\s+Purchase\s+Agreement\??$",
r"^Is\s+(.+?)\s+in\s+the\s+agreement\??$",
]
for pattern in patterns:
match = re.match(pattern, stripped, re.IGNORECASE)
if match:
return normalize_whitespace(match.group(1))
match = re.match(r"^Is\s+(.+?)\??$", stripped, re.IGNORECASE)
if match:
return normalize_whitespace(match.group(1))
return ""
def extract_entity_candidates(text):
pattern = re.compile(
r"\b([A-Z][A-Za-z0-9'.,-]*(?:\s+(?:&|[A-Z][A-Za-z0-9'.,-]*))*\s+"
r"(?:LLP|LLC|LP|Inc\.?|Corporation|Corp\.?|Ltd\.?))\b"
)
return [normalize_whitespace(match.group(1)) for match in pattern.finditer(str(text))]
def build_legacy_expected_mapping(entry, question, docs):
normalized = {
"buyer_counsel": normalize_whitespace(entry.get("buyer_firm", "unknown")) or "unknown",
"seller_counsel": normalize_whitespace(entry.get("seller_firm", "unknown")) or "unknown",
"third_party_counsel": normalize_whitespace(entry.get("third_party", "unknown")) or "unknown",
}
target_candidate = extract_target_candidate_from_question(question)
target_norm = normalize_counsel_value(target_candidate)
for candidate in normalized.values():
candidate_text = normalize_whitespace(candidate)
if not candidate_text:
continue
if target_norm and target_norm in normalize_counsel_value(candidate_text):
normalized[TARGET_FIRM_FIELD] = candidate_text
break
else:
doc_candidates = []
for doc in docs:
doc_candidates.extend(extract_entity_candidates(doc))
for candidate_text in doc_candidates:
if target_norm and target_norm in normalize_counsel_value(candidate_text):
normalized[TARGET_FIRM_FIELD] = candidate_text
break
else:
normalized[TARGET_FIRM_FIELD] = target_candidate or "unknown target"
return normalized
def normalize_expected_entry(entry, prefix, case_index, question="", docs=None):
docs = docs or []
if isinstance(entry, dict):
legacy_keys = {"buyer_firm", "seller_firm", "third_party", "contains_target_firm"}
if TARGET_FIRM_FIELD not in entry and legacy_keys.intersection(entry.keys()):
return build_legacy_expected_mapping(entry, question, docs)
return normalize_expected_mapping(entry, prefix, case_index)
if not isinstance(entry, str):
raise ValueError(
f"{prefix} case {case_index} expected entry must be a string or object."
)
stripped = entry.strip()
if stripped.startswith("{") and stripped.endswith("}"):
try:
parsed = json.loads(stripped)
except json.JSONDecodeError as exc:
raise ValueError(
f"{prefix} case {case_index} contains invalid JSON in expected entry: {exc}"
) from exc
if isinstance(parsed, dict):
legacy_keys = {"buyer_firm", "seller_firm", "third_party", "contains_target_firm"}
if TARGET_FIRM_FIELD not in parsed and legacy_keys.intersection(parsed.keys()):
return build_legacy_expected_mapping(parsed, question, docs)
return normalize_expected_mapping(parsed, prefix, case_index)
return entry
def sanitize_input(text, max_length=500):
clean_text = re.sub(r"[^a-zA-Z0-9\s.,!?@:\-+'&()/]", "", text)
return clean_text.strip()[:max_length]
def sanitize_prompt(text):
return text.strip()[:8000]
def normalize_email(email):
return email.strip().lower()
def validate_email(email):
email_regex = r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$"
return re.match(email_regex, email) is not None
def extract_user_message(text):
if not text:
return None
pattern = re.compile(
rf"{re.escape(IRRELEVANT_TAG_START)}(.*?){re.escape(IRRELEVANT_TAG_END)}",
re.IGNORECASE | re.DOTALL,
)
match = pattern.search(text)
if not match:
return None
return match.group(1).strip()
def normalize_whitespace(text):
return re.sub(r"\s+", " ", str(text)).strip()
def normalize_counsel_value(value):
if value is None:
return ""
text = normalize_whitespace(value).lower()
text = text.replace("&", " and ")
text = re.sub(r"[^a-z0-9\s]", " ", text)
return " ".join(text.split())
def value_is_unknown(value):
return normalize_whitespace(value).lower() == "unknown"
def expected_is_unknown(expected_value):
return value_is_unknown(expected_value) or normalize_whitespace(expected_value) == ""
def truncate_text(text, max_chars=2000):
text = str(text or "")
if len(text) <= max_chars:
return text
return f"{text[: max_chars - 3]}..."
def serialize_json(value):
return json.dumps(value, ensure_ascii=True, sort_keys=True)
def normalize_newlines(text):
return str(text or "").replace("\r\n", "\n").replace("\r", "\n")
def normalize_sheet_block_text(text):
normalized = normalize_newlines(text)
lines = [line.rstrip() for line in normalized.split("\n")]
return "\n".join(lines).strip()
def format_score(value):
formatted = f"{float(value):.4f}".rstrip("0").rstrip(".")
if "." not in formatted:
formatted += ".0"
return formatted
def join_sheet_sections(blocks):
cleaned_blocks = [block for block in blocks if block]
return f"\n\n{SHEET_SECTION_DIVIDER}\n\n".join(cleaned_blocks)
def build_case_run_header(case_index, run_index):
return f"CASE {case_index} | RUN {run_index}"
def format_run_indices(run_indices):
labels = [str(run_index) for run_index in run_indices]
if not labels:
return ""
if len(labels) == 1:
return labels[0]
if len(labels) == 2:
return f"{labels[0]} AND {labels[1]}"
return f"{', '.join(labels[:-1])} AND {labels[-1]}"
def build_grouped_case_run_header(case_index, run_indices):
return f"CASE {case_index} | RUN {format_run_indices(run_indices)}"
def truncate_sheet_block(text, max_chars):
normalized = normalize_sheet_block_text(text)
return truncate_text(normalized, max_chars=max_chars)
def format_pretty_expected_answer(expected):
if isinstance(expected, str):
return "Irrelevant case. No final JSON should be produced because LLM 1 should stop the pipeline."
return truncate_text(
json.dumps(expected, ensure_ascii=True, sort_keys=True, indent=2),
max_chars=EXPECTED_OUTPUT_MAX_CHARS,
)
def format_llm1_expected_text(expected):
if isinstance(expected, str):
return "Irrelevant case: output must be wrapped in <user_message>...</user_message>."
return "Relevant case: output must be exactly TARGET_FIRM: <canonical firm name>."
def format_llm2_user_output(llm2_outputs):
if not llm2_outputs:
return "Not run."
blocks = []
for snippet_id, snippet_output in llm2_outputs.items():
blocks.append(
"\n".join(
[
f"{snippet_id}:",
truncate_sheet_block(snippet_output, max_chars=STAGE_TWO_OUTPUT_MAX_CHARS),
]
)
)
return "\n\n".join(blocks)
def format_teacher_forced_note(run_result):
if run_result["used_teacher_forced_step1"]:
return (
"LLM 1 did not match. LLM 2 and LLM 3 were re-run with the teacher-forced "
f"Step 1 context: {run_result['effective_step1_context']}"
)
return f"LLM 2 and LLM 3 used the submitted Step 1 context: {run_result['effective_step1_context']}"
def parse_cited_value(value):
raw_text = normalize_whitespace(value)
citations = [int(match.group(1)) for match in CITATION_PATTERN.finditer(raw_text)]
malformed_citation = False
for token in re.findall(r"\[\^[^\]]*\]", raw_text):
if not CITATION_PATTERN.fullmatch(token):
malformed_citation = True
break
if "[^" in raw_text and not re.findall(r"\[\^[^\]]*\]", raw_text):
malformed_citation = True
base_value = normalize_whitespace(CITATION_PATTERN.sub("", raw_text))
return {
"raw": raw_text,
"base_value": base_value,
"citations": citations,
"malformed_citation": malformed_citation,
}
def snippet_supports_value(snippet_text, value):
normalized_value = normalize_counsel_value(value)
if not normalized_value:
return False
return normalized_value in normalize_counsel_value(snippet_text)
def get_supporting_snippet_numbers(docs, value):
matches = []
for index, doc in enumerate(docs, start=1):
if snippet_supports_value(doc, value):
matches.append(index)
return matches
def append_citations(value, snippet_numbers):
if value_is_unknown(value) or not snippet_numbers:
return value
citations = " ".join(f"[^{snippet_number}]" for snippet_number in snippet_numbers)
return f"{value} {citations}".strip()
def dedupe_preserving_order(values):
seen = set()
ordered = []
for value in values:
normalized = normalize_counsel_value(value)
if not normalized or normalized in seen:
continue
seen.add(normalized)
ordered.append(normalize_whitespace(value))
return ordered
def find_contradicting_entity(expected, docs, target_firm):
target_norm = normalize_counsel_value(target_firm)
for field_name in COUNSEL_FIELDS:
candidate = normalize_whitespace(expected.get(field_name, "unknown"))
if expected_is_unknown(candidate):
continue
if normalize_counsel_value(candidate) == target_norm:
continue
snippet_numbers = get_supporting_snippet_numbers(docs, candidate)
if snippet_numbers:
return candidate, snippet_numbers
doc_candidates = []
for doc in docs:
doc_candidates.extend(extract_entity_candidates(doc))
for candidate in dedupe_preserving_order(doc_candidates):
if normalize_counsel_value(candidate) == target_norm:
continue
snippet_numbers = get_supporting_snippet_numbers(docs, candidate)
if snippet_numbers:
return candidate, snippet_numbers
return "", []
def build_expected_user_question_spec(expected, docs, target_firm):
target_snippet_numbers = get_supporting_snippet_numbers(docs, target_firm)
if target_snippet_numbers:
return {
"verdict": "true",
"company": normalize_whitespace(target_firm),
"citations": target_snippet_numbers,
}
contradicting_company, contradicting_snippet_numbers = find_contradicting_entity(
expected,
docs,
target_firm,
)
if contradicting_company and contradicting_snippet_numbers:
return {
"verdict": "false",
"company": contradicting_company,
"citations": contradicting_snippet_numbers,
}
return {
"verdict": "unknown",
"company": "",
"citations": [],
}
def format_expected_user_question(spec):
verdict = spec["verdict"]
citations = spec.get("citations", [])
if verdict == "unknown":
return "unknown"
return append_citations(verdict, citations)
def parse_user_question_answer(value):
parsed = parse_cited_value(value)
base_value = parsed["base_value"]
if value_is_unknown(base_value):
return {
**parsed,
"verdict": "unknown",
"evidence": "",
}
match = BOOLEAN_ANSWER_PATTERN.match(base_value)
if match:
token = match.group(1).lower()
verdict = "true" if token in {"true", "yes"} else "false"
return {
**parsed,
"verdict": verdict,
}
return {
**parsed,
"verdict": None,
}
def validate_citations_for_value(parsed_value, docs, expected_value):
bad_references = []
unsupported_citations = []
for citation_number in parsed_value["citations"]:
if citation_number < 1 or citation_number > len(docs):
bad_references.append(citation_number)
continue
if not snippet_supports_value(docs[citation_number - 1], expected_value):
unsupported_citations.append(citation_number)
return bad_references, unsupported_citations
def build_expected_step1_output(expected):
if isinstance(expected, str):
return f"{IRRELEVANT_TAG_START}irrelevant{IRRELEVANT_TAG_END}"
return f"TARGET_FIRM: {expected[TARGET_FIRM_FIELD]}"
def build_expected_llm3_answer(expected, docs, target_firm):
if isinstance(expected, str):
return expected
formatted = {}
for field_name in COUNSEL_FIELDS:
expected_value = str(expected.get(field_name, "unknown"))
snippet_numbers = get_supporting_snippet_numbers(docs, expected_value)
formatted[field_name] = append_citations(expected_value, snippet_numbers)
formatted[USER_QUESTION_FIELD] = format_expected_user_question(
build_expected_user_question_spec(expected, docs, target_firm)
)
return formatted
def format_prompts_cell(prompts):
blocks = []
for label, prompt_key in (
("LLM 1", "prompt_1"),
("LLM 2", "prompt_2"),
("LLM 3", "prompt_3"),
):
blocks.append(
"\n".join(
[
label,
"Prompt:",
truncate_sheet_block(prompts[prompt_key], max_chars=PROMPT_CELL_MAX_CHARS),
]
)
)
return truncate_text(
join_sheet_sections(blocks),
max_chars=SHEET_CELL_MAX_CHARS,
)
def parse_relevant_stage_one_output(text):
normalized = normalize_sheet_block_text(text)
match = re.fullmatch(r"TARGET_FIRM:\s*(.+)", normalized)
if not match:
return None
target_firm = match.group(1).strip()
if not target_firm:
return None
return {
"target_firm": target_firm,
}
def is_error_output(text, stage_name):
return str(text).startswith(f"Error during {stage_name} call:")
def run_chat_completion(system_prompt, user_prompt):
semaphore = get_openai_request_semaphore()
semaphore.acquire()
try:
response = get_openai_client().chat.completions.create(
model=get_model_name(),
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
)
return response.choices[0].message.content.strip()
finally:
semaphore.release()
def run_step_one(prompts, question):
try:
llm1_output = run_chat_completion(prompts["prompt_1"], question)
except Exception as exc:
llm1_output = f"Error during LLM1 call: {exc}"
user_message = extract_user_message(llm1_output)
return {
"llm1_output": llm1_output,
"is_irrelevant": user_message is not None,
"user_message": user_message,
}
def build_llm2_user_prompt(step1_context, snippet_id, doc):
return (
f"Target firm context:\n{step1_context}\n\n"
f"Snippet ID: {snippet_id}\n"
f"Snippet Text:\n{doc}"
)
def run_llm2_snippet(prompts, step1_context, snippet_id, doc):
try:
return snippet_id, run_chat_completion(
prompts["prompt_2"],
build_llm2_user_prompt(step1_context, snippet_id, doc),
)
except Exception as exc:
return snippet_id, f"Error during LLM2 call: {exc}"
def run_downstream_stages(prompts, docs, step1_context):
llm2_outputs = {f"S{index}": "" for index in range(1, len(docs) + 1)}
llm3_output = ""
if docs:
max_workers = min(get_llm2_max_workers(), len(docs))
if max_workers == 1:
for index, doc in enumerate(docs, start=1):
snippet_id = f"S{index}"
_, snippet_output = run_llm2_snippet(prompts, step1_context, snippet_id, doc)
llm2_outputs[snippet_id] = snippet_output
else:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(
run_llm2_snippet,
prompts,
step1_context,
f"S{index}",
doc,
)
for index, doc in enumerate(docs, start=1)
]
for future in as_completed(futures):
snippet_id, snippet_output = future.result()
llm2_outputs[snippet_id] = snippet_output
serialized_llm2 = []
for snippet_id, snippet_output in llm2_outputs.items():
serialized_llm2.append(f"{snippet_id}\n{snippet_output}")
try:
llm3_output = run_chat_completion(
prompts["prompt_3"],
(
f"Target firm context:\n{step1_context}\n\n"
"Per-snippet analyses:\n"
f"{chr(10).join(serialized_llm2)}"
),
)
except Exception as exc:
llm3_output = f"Error during LLM3 call: {exc}"
return {
"snippet_count": len(docs),
"llm2_outputs": llm2_outputs,
"llm3_output": llm3_output,
}
def parse_final_answer(answer):
try:
parsed = json.loads(answer)
except json.JSONDecodeError as exc:
return None, f"Invalid JSON: {exc}"
if not isinstance(parsed, dict):
return None, "Final answer must be a JSON object."
return parsed, None
def build_field_result(correct, issues=None, penalty=None):
if penalty is None:
penalty = 0.0 if correct else 0.2
return {
"correct": correct,
"issues": issues or [],
"penalty": round(float(penalty), 4),
}
def build_prefixed_failure_tags(field_name, issues):
return [f"{issue}_{field_name}" for issue in issues]
def grade_counsel_field(field_name, parsed_answer, expected_answer, docs):
if field_name not in parsed_answer:
return build_field_result(False, ["missing_field"])
parsed_value = parse_cited_value(parsed_answer.get(field_name))
expected_value = str(expected_answer.get(field_name, "unknown"))
expected_unknown = expected_is_unknown(expected_value)
actual_unknown = value_is_unknown(parsed_value["base_value"])
issues = []
if parsed_value["malformed_citation"]:
issues.append("malformed_citation")
if actual_unknown and expected_unknown:
if parsed_value["citations"]:
issues.append("unexpected_citation")
return build_field_result(not issues, issues)
if actual_unknown and not expected_unknown:
issues.append("used_unknown")
return build_field_result(False, issues)
if not actual_unknown and expected_unknown:
issues.append("hallucinated_value")
return build_field_result(False, issues)
if normalize_counsel_value(parsed_value["base_value"]) != normalize_counsel_value(expected_value):
issues.append("wrong_value")
if not parsed_value["citations"]:
issues.append("missing_citation")
else:
bad_references = []
unsupported_citations = []
for citation_number in parsed_value["citations"]:
if citation_number < 1 or citation_number > len(docs):
bad_references.append(citation_number)
continue
if not snippet_supports_value(docs[citation_number - 1], parsed_value["base_value"]):
unsupported_citations.append(citation_number)
if bad_references:
issues.append("bad_citation_reference")
if unsupported_citations:
issues.append("unsupported_citation")
return build_field_result(not issues, issues)
def grade_user_question_field(parsed_answer, expected_user_question, docs, target_firm):
if USER_QUESTION_FIELD not in parsed_answer:
return build_field_result(False, ["missing_field"], penalty=0.2)
parsed_value = parse_user_question_answer(parsed_answer.get(USER_QUESTION_FIELD))
issues = []
penalty = 0.0
if parsed_value["malformed_citation"]:
issues.append("malformed_citation")
expected_verdict = expected_user_question["verdict"]
expected_company = normalize_whitespace(expected_user_question.get("company", ""))
if expected_verdict == "unknown":
if parsed_value["verdict"] != "unknown":
issues.append("wrong_value")
penalty += 0.1
if parsed_value["citations"]:
issues.append("unexpected_citation")
penalty += 0.1
return build_field_result(not issues, issues, penalty=min(0.2, penalty))
if parsed_value["verdict"] == "unknown":
issues.append("used_unknown")
return build_field_result(False, issues, penalty=0.2)
if parsed_value["verdict"] != expected_verdict:
issues.append("wrong_value")
penalty += 0.1
if not parsed_value["citations"]:
issues.append("missing_citation")
penalty += 0.1
else:
citation_target = expected_company or target_firm
bad_references, unsupported_citations = validate_citations_for_value(
parsed_value,
docs,
citation_target,
)
if bad_references:
issues.append("bad_citation_reference")
if unsupported_citations:
issues.append("unsupported_citation")
if parsed_value["malformed_citation"] or bad_references or unsupported_citations:
penalty += 0.1
return build_field_result(not issues, issues, penalty=min(0.2, penalty))
def grade_case(case_run, expected):
case_result = {"case_score": 0.0, "failure_tags": [], "field_results": {}}
if isinstance(expected, str):
if case_run["is_irrelevant"]:
case_result["case_score"] = 1.0
else:
case_result["failure_tags"] = ["relevance_false_positive"]
case_result["passed"] = case_result["case_score"] == 1.0
return case_result
parsed_answer, parse_error = parse_final_answer(case_run["llm3_output"])
if parse_error:
case_result["failure_tags"] = ["invalid_json"]
case_result["field_results"] = {
field_name: build_field_result(False, ["invalid_json"])
for field_name in FINAL_SCHEMA_KEYS
}
case_result["passed"] = False
return case_result
total_penalty = 0.0
failure_tags = []
field_results = {}
for field_name in COUNSEL_FIELDS:
field_result = grade_counsel_field(field_name, parsed_answer, expected, case_run["docs"])
field_results[field_name] = field_result
total_penalty += field_result["penalty"]
if field_result["issues"]:
failure_tags.extend(build_prefixed_failure_tags(field_name, field_result["issues"]))
user_question_result = grade_user_question_field(
parsed_answer,
build_expected_user_question_spec(expected, case_run["docs"], case_run["effective_target_firm"]),
case_run["docs"],
case_run["effective_target_firm"],
)
field_results[USER_QUESTION_FIELD] = user_question_result
total_penalty += user_question_result["penalty"]
if user_question_result["issues"]:
failure_tags.extend(
build_prefixed_failure_tags(USER_QUESTION_FIELD, user_question_result["issues"])
)
case_result["case_score"] = round(max(0.0, 1.0 - total_penalty), 4)
case_result["failure_tags"] = sorted(set(failure_tags))
case_result["field_results"] = field_results
case_result["passed"] = total_penalty == 0.0
return case_result
def grade_llm1_stage(llm1_output, expected):
is_irrelevant = extract_user_message(llm1_output) is not None
parsed_output = parse_relevant_stage_one_output(llm1_output)
if isinstance(expected, str):
if is_irrelevant:
return {
"score": 1.0,
"reasoning": "matched",
"failure_tags": [],
"parsed_output": None,
}
reasoning = "call_failed" if is_error_output(llm1_output, "LLM1") else "wrong_route"
return {
"score": 0.0,
"reasoning": reasoning,
"failure_tags": [f"llm1_{reasoning}"],
"parsed_output": None,
}
if is_irrelevant:
return {
"score": 0.0,
"reasoning": "wrong_route",
"failure_tags": ["llm1_wrong_route"],
"parsed_output": None,
}
if parsed_output is None:
if is_error_output(llm1_output, "LLM1"):
reasoning = "call_failed"
score = 0.0
else:
expected_target = normalize_counsel_value(expected[TARGET_FIRM_FIELD])
output_norm = normalize_counsel_value(llm1_output)
if expected_target and expected_target in output_norm:
reasoning = "right_answer_wrong_format"
score = 0.5
else:
reasoning = "wrong_answer_wrong_format"
score = 0.0
return {
"score": score,
"reasoning": reasoning,
"failure_tags": [f"llm1_{reasoning}"],
"parsed_output": None,
}
if normalize_counsel_value(parsed_output["target_firm"]) != normalize_counsel_value(
expected[TARGET_FIRM_FIELD]
):
return {
"score": 0.5,
"reasoning": "wrong_answer_right_format",
"failure_tags": ["llm1_wrong_answer_right_format"],
"parsed_output": parsed_output,
}
return {
"score": 1.0,
"reasoning": "matched",
"failure_tags": [],
"parsed_output": parsed_output,
}
def summarize_llm2_stage(case_run, expected):
if isinstance(expected, str):
return "Not run. Irrelevant case."
total_snippets = case_run["snippet_count"]
if total_snippets == 0:
return "No snippets were provided."
error_ids = []
successful_snippets = 0
for index in range(1, total_snippets + 1):
snippet_id = f"S{index}"
snippet_output = case_run["llm2_outputs"].get(snippet_id, "")
if not snippet_output or is_error_output(snippet_output, "LLM2"):
error_ids.append(snippet_id)
else:
successful_snippets += 1
summary = f"{successful_snippets}/{total_snippets} snippet calls completed successfully."
if error_ids:
summary += f" Errored snippets: {', '.join(error_ids)}."
return summary
def grade_llm3_stage(case_run, expected, semantic_result):
if semantic_result["failure_tags"]:
reasoning = ", ".join(semantic_result["failure_tags"])
elif isinstance(expected, str) and case_run["is_irrelevant"]:
reasoning = "skipped_irrelevant"
else:
reasoning = "clean"
return {
"score": semantic_result["case_score"],
"reasoning": reasoning,
}
def summarize_field_accuracy(case_results):
totals = {
field_name: {"correct": 0, "total": 0}
for field_name in FINAL_SCHEMA_KEYS
}
for case_result in case_results:
for field_name, field_summary in case_result["field_results"].items():
totals[field_name]["total"] += 1
if field_summary["correct"]:
totals[field_name]["correct"] += 1
summary = {}
for field_name, counts in totals.items():
total = counts["total"]
accuracy = round((counts["correct"] / total) * 100, 2) if total else 0.0
summary[field_name] = {
"correct": counts["correct"],
"total": total,
"accuracy": accuracy,
}
return summary
def average_run_score(runs, key):
if not runs:
return 0.0
return round(sum(run[key] for run in runs) / len(runs), 4)
def total_run_score(runs, key):
if not runs:
return 0.0
return round(sum(run[key] for run in runs), 4)
def summarize_reasoning_counts(runs, key):
counts = Counter(run[key] for run in runs)
ordered_labels = []
for run in runs:
label = run[key]
if label not in ordered_labels:
ordered_labels.append(label)
return "; ".join(f"{label} x{counts[label]}" for label in ordered_labels)
def explain_reasoning_label(label):
explanations = {
"matched": "The model followed the required instruction and produced the expected type of output.",
"wrong_format": "The model answered the right task but did not use the exact output format the evaluator requires.",
"right_answer_wrong_format": "The model identified the right answer but did not use the exact required Step 1 format.",
"wrong_answer_right_format": "The model used the required Step 1 format but identified the wrong target firm.",
"wrong_answer_wrong_format": "The model gave neither the right answer nor the required Step 1 format.",
"wrong_route": "The model chose the wrong kind of response for this question, such as treating an irrelevant query as relevant or the reverse.",
"call_failed": "The model call failed, so no valid answer was produced for this step.",
"invalid_json": "The Step 3 answer was not valid JSON in the required schema.",
"relevance_false_positive": "The system treated a question as relevant even though it should have been rejected as irrelevant.",
"skipped_irrelevant": "This step was skipped because the question was correctly identified as irrelevant.",
"clean": "The answer passed the evaluator checks for this step.",
}
if ", " in label:
parts = [part.strip() for part in label.split(",") if part.strip()]
explained_parts = [explanations.get(part, part.replace("_", " ")) for part in parts]
return "; ".join(explained_parts)
return explanations.get(label, label.replace("_", " "))
def summarize_run_groups(runs, key):
grouped = []
for run_index, run in enumerate(runs, start=1):
value = run[key]
if grouped and grouped[-1]["value"] == value:
grouped[-1]["run_indices"].append(run_index)
continue
grouped.append({"value": value, "run_indices": [run_index]})
return [
f"RUN {format_run_indices(group['run_indices'])}: {group['value']}"
for group in grouped
]
def build_explained_run_groups(runs, key):
grouped = []
for run_index, run in enumerate(runs, start=1):
value = run[key]
if grouped and grouped[-1]["value"] == value:
grouped[-1]["run_indices"].append(run_index)
continue
grouped.append({"value": value, "run_indices": [run_index]})
return [
"\n".join(
[
f"RUN {format_run_indices(group['run_indices'])}",
f"Label: {group['value']}",
f"Meaning: {explain_reasoning_label(group['value'])}",
]
)
for group in grouped
]
def build_score_breakdown_cell(case_results):
blocks = []
for case_index, case_result in enumerate(case_results, start=1):
all_runs = list(range(1, len(case_result["runs"]) + 1))
llm1_reasoning_lines = build_explained_run_groups(case_result["runs"], "llm1_reasoning")
llm3_reasoning_lines = build_explained_run_groups(case_result["runs"], "llm3_reasoning")
block_lines = [
f"CASE {case_index}",
f"RUNS: {format_run_indices(all_runs)}",
"",
"LLM 1",
f"Score: {format_score(total_run_score(case_result['runs'], 'llm1_score'))} / {len(all_runs)}",
"Reasoning:",
*llm1_reasoning_lines,
"",
"LLM 3",
f"Score: {format_score(total_run_score(case_result['runs'], 'llm3_score'))} / {len(all_runs)}",
"Reasoning:",
*llm3_reasoning_lines,
]
blocks.append("\n".join(block_lines))
return truncate_text(
join_sheet_sections(blocks),
max_chars=SHEET_CELL_MAX_CHARS,
)
def build_llm2_output_cell(case_results):
blocks = []
for case_index, case_result in enumerate(case_results, start=1):
for run_index, run_result in enumerate(case_result["runs"], start=1):
blocks.append(
"\n".join(
[
build_case_run_header(case_index, run_index),
format_teacher_forced_note(run_result)
if not isinstance(case_result["expected"], str)
else "Downstream stages were skipped because this was graded as irrelevant.",
"",
f"LLM 2 Summary: {run_result['llm2_reasoning']}",
"",
"LLM 2 Output:",
format_llm2_user_output(run_result["llm2_outputs"]),
]
)
)
return truncate_text(
join_sheet_sections(blocks),
max_chars=SHEET_CELL_MAX_CHARS,
)
def build_output_vs_expected_cell(case_results):
blocks = []
for case_index, case_result in enumerate(case_results, start=1):
llm1_expected = format_llm1_expected_text(case_result["expected"])
grouped_blocks = []
for run_index, run_result in enumerate(case_result["runs"], start=1):
llm3_output = run_result["llm3_output"] or "Not run."
block_body = "\n".join(
[
"LLM 1",
f"Status: {run_result['llm1_reasoning']}",
"User Output:",
truncate_sheet_block(run_result["llm1_output"], max_chars=STAGE_ONE_OUTPUT_MAX_CHARS),
"",
"Expected:",
llm1_expected,
"",
(
format_teacher_forced_note(run_result)
if not isinstance(case_result["expected"], str)
else "Downstream stages were skipped because this case was irrelevant."
),
"",
"LLM 3",
f"Reasoning: {run_result['llm3_reasoning']}",
"User Output:",
truncate_sheet_block(llm3_output, max_chars=STAGE_THREE_OUTPUT_MAX_CHARS),
"",
"Expected:",
format_pretty_expected_answer(run_result["llm3_expected"]),
]
)
normalized_body = normalize_sheet_block_text(block_body)
for grouped_block in grouped_blocks:
if grouped_block["body"] == normalized_body:
grouped_block["run_indices"].append(run_index)
break
else:
grouped_blocks.append({"body": normalized_body, "run_indices": [run_index]})
for grouped_block in grouped_blocks:
blocks.append(
"\n".join(
[
build_grouped_case_run_header(case_index, grouped_block["run_indices"]),
grouped_block["body"],
]
)
)
return truncate_text(
join_sheet_sections(blocks),
max_chars=SHEET_CELL_MAX_CHARS,
)
def estimate_case_llm_call_count(case):
if isinstance(case["expected"], str):
return EVAL_REPEAT_COUNT
return EVAL_REPEAT_COUNT * (len(case["docs"]) + 2)
def estimate_submission_llm_call_count(cases):
return sum(estimate_case_llm_call_count(case) for case in cases)
def evaluate_case(case, prompts, case_index=None):
started_at = time.perf_counter()
runs = []
for _ in range(EVAL_REPEAT_COUNT):
step_one_result = run_step_one(prompts, case["question"])
llm1_result = grade_llm1_stage(step_one_result["llm1_output"], case["expected"])
expected_relevant = not isinstance(case["expected"], str)
used_teacher_forced_step1 = False
effective_step1_context = step_one_result["llm1_output"]
if expected_relevant:
if llm1_result["score"] < 1.0:
effective_step1_context = build_expected_step1_output(case["expected"])
used_teacher_forced_step1 = True
downstream_result = run_downstream_stages(
prompts,
case["docs"],
effective_step1_context,
)
else:
downstream_result = {
"snippet_count": len(case["docs"]),
"llm2_outputs": {},
"llm3_output": "",
}
effective_target_firm = ""
if expected_relevant:
effective_target_firm = parse_relevant_stage_one_output(effective_step1_context)["target_firm"]
case_run = {
"question": case["question"],
"docs": case["docs"],
"snippet_count": downstream_result["snippet_count"],
"llm1_output": step_one_result["llm1_output"],
"llm2_outputs": downstream_result["llm2_outputs"],
"llm3_output": downstream_result["llm3_output"],
"is_irrelevant": step_one_result["is_irrelevant"],
"user_message": step_one_result["user_message"],
"effective_step1_context": effective_step1_context,
"effective_target_firm": effective_target_firm,
"used_teacher_forced_step1": used_teacher_forced_step1,
}
semantic_result = grade_case(case_run, case["expected"])
llm3_result = grade_llm3_stage(case_run, case["expected"], semantic_result)
llm2_reasoning = summarize_llm2_stage(case_run, case["expected"])
llm3_expected = build_expected_llm3_answer(
case["expected"],
case["docs"],
effective_target_firm,
)
failure_tags = sorted(set(llm1_result["failure_tags"] + semantic_result["failure_tags"]))
runs.append(
{
"llm1_output": truncate_sheet_block(
step_one_result["llm1_output"],
max_chars=STAGE_ONE_OUTPUT_MAX_CHARS,
),
"llm2_outputs": {
snippet_id: truncate_sheet_block(
snippet_output,
max_chars=STAGE_TWO_OUTPUT_MAX_CHARS,
)
for snippet_id, snippet_output in case_run["llm2_outputs"].items()
},
"llm3_output": truncate_sheet_block(
case_run["llm3_output"],
max_chars=STAGE_THREE_OUTPUT_MAX_CHARS,
),
"is_irrelevant": step_one_result["is_irrelevant"],
"failure_tags": failure_tags,
"field_results": semantic_result["field_results"],
"passed": llm1_result["score"] == 1.0 and llm3_result["score"] == 1.0,
"llm1_score": llm1_result["score"],
"llm1_reasoning": llm1_result["reasoning"],
"llm2_reasoning": llm2_reasoning,
"llm3_score": llm3_result["score"],
"llm3_reasoning": llm3_result["reasoning"],
"used_teacher_forced_step1": used_teacher_forced_step1,
"effective_step1_context": truncate_sheet_block(
effective_step1_context,
max_chars=STAGE_ONE_OUTPUT_MAX_CHARS,
),
"effective_target_firm": effective_target_firm,
"llm3_expected": llm3_expected,
"total_points": round(llm1_result["score"] + llm3_result["score"], 4),
"total_points_without_llm1": round(llm3_result["score"], 4),
}
)
average_llm1_score = average_run_score(runs, "llm1_score")
average_llm3_score = average_run_score(runs, "llm3_score")
average_total_points = round(average_llm1_score + average_llm3_score, 4)
result = {
"expected": case["expected"],
"runs": runs,
"average_llm1_score": average_llm1_score,
"average_llm3_score": average_llm3_score,
"average_total_points": average_total_points,
"case_score": average_total_points,
"passed": abs(average_total_points - 2.0) < 1e-9,
}
elapsed_seconds = time.perf_counter() - started_at
logger.info(
"Completed case %s: relevant=%s snippets=%s repeat_count=%s "
"expected_llm_calls=%s elapsed_seconds=%.2f",
case_index if case_index is not None else "unknown",
not isinstance(case["expected"], str),
len(case["docs"]),
EVAL_REPEAT_COUNT,
estimate_case_llm_call_count(case),
elapsed_seconds,
)
return result
def grade_submission(eval_set, prompts, mode):
case_results = []
cases = eval_set["cases"]
if not cases:
return {
"submission_type": mode,
"score_percent": 0.0,
"passed_cases": 0,
"total_cases": 0,
"failure_summary": {},
"field_accuracy": summarize_field_accuracy([]),
"case_results": [],
"results": 0.0,
"results_without_llm_1": 0.0,
}
max_workers = min(get_eval_max_workers(), len(cases))
started_at = time.perf_counter()
expected_llm_calls = estimate_submission_llm_call_count(cases)
snippet_counts = [len(case["docs"]) for case in cases]
logger.info(
"Starting %s grading: cases=%s repeat_count=%s snippet_counts=%s "
"expected_llm_calls=%s case_workers=%s llm2_workers=%s openai_global_limit=%s",
mode,
len(cases),
EVAL_REPEAT_COUNT,
snippet_counts,
expected_llm_calls,
max_workers,
get_llm2_max_workers(),
get_openai_max_concurrent_requests(),
)
ordered_results = [None] * len(cases)
if max_workers == 1:
for index, case in enumerate(cases):
ordered_results[index] = safe_evaluate_case(case, prompts, index + 1)
else:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_index = {
executor.submit(safe_evaluate_case, case, prompts, index + 1): index
for index, case in enumerate(cases)
}
for future in as_completed(future_to_index):
index = future_to_index[future]
ordered_results[index] = future.result()
case_results = ordered_results
total_cases = len(case_results)
total_results = round(
sum(case_result["average_total_points"] for case_result in case_results),
4,
)
total_results_without_llm_1 = round(
sum(case_result["average_llm3_score"] for case_result in case_results),
4,
)
run_results = [
run_result
for case_result in case_results
for run_result in case_result["runs"]
]
passed_cases = sum(1 for case_result in case_results if case_result["passed"])
max_points = total_cases * 2.0
score_percent = round((total_results / max_points) * 100, 2) if max_points else 0.0
failure_summary = dict(
sorted(
Counter(
tag
for run_result in run_results
for tag in run_result["failure_tags"]
).items()
)
)
field_accuracy = summarize_field_accuracy(run_results)
elapsed_seconds = time.perf_counter() - started_at
logger.info(
"Completed %s grading: cases=%s repeat_count=%s expected_llm_calls=%s "
"elapsed_seconds=%.2f",
mode,
total_cases,
EVAL_REPEAT_COUNT,
expected_llm_calls,
elapsed_seconds,
)
return {
"submission_type": mode,
"score_percent": score_percent,
"passed_cases": passed_cases,
"total_cases": total_cases,
"failure_summary": failure_summary,
"field_accuracy": field_accuracy,
"case_results": case_results,
"results": total_results,
"results_without_llm_1": total_results_without_llm_1,
}
def get_attempt_status(email):
normalized_email = normalize_email(email)
status = {"practice_used": False, "final_used": False}
local_status = LOCAL_ATTEMPTS.get(normalized_email, {})
if local_status.get("practice_used"):
status["practice_used"] = True
if local_status.get("final_used"):
status["final_used"] = True
sheet = get_google_sheet(ensure_headers=False)
rows = sheet.get("B:D")
for row in rows[1:]:
row_email = None
row_type = None
if len(row) >= 3 and row[0] in SUBMISSION_TYPES:
row_type = row[0]
row_email = normalize_email(row[2])
elif row:
row_type = "final"
row_email = normalize_email(row[0])
if row_email != normalized_email:
continue
if row_type == "practice":
status["practice_used"] = True
if row_type == "final":
status["final_used"] = True
return status
def save_attempt(record):
sheet = get_google_sheet()
row = [serialize_sheet_value(record.get(key, "")) for key, _ in SHEET_COLUMNS]
sheet.append_row(
row,
value_input_option="RAW",
insert_data_option="INSERT_ROWS",
table_range="A1",
)
def record_local_attempt(email, submission_type):
LOCAL_ATTEMPTS[email] = {
"practice_used": LOCAL_ATTEMPTS.get(email, {}).get("practice_used", False)
or submission_type == "practice",
"final_used": LOCAL_ATTEMPTS.get(email, {}).get("final_used", False)
or submission_type == "final",
}
def build_attempt_record(name, email, prompts, submission_result):
return {
"timestamp": datetime.now(timezone.utc).isoformat(),
"submission_type": submission_result["submission_type"],
"name": name,
"email": email,
"results": submission_result["results"],
"results_without_llm_1": submission_result["results_without_llm_1"],
"llm_score_breakdown": build_score_breakdown_cell(submission_result["case_results"]),
"llm_output_vs_expected": build_output_vs_expected_cell(
submission_result["case_results"]
),
"llm_2_output": build_llm2_output_cell(submission_result["case_results"]),
"prompts": format_prompts_cell(prompts),
}
def format_field_accuracy(field_accuracy):
labels = {
"buyer_counsel": "Buyer counsel",
"seller_counsel": "Seller counsel",
"third_party_counsel": "Third-party counsel",
"user_question": "User question",
}
lines = []
for field_name in FINAL_SCHEMA_KEYS:
summary = field_accuracy[field_name]
lines.append(
(
f"- {labels[field_name]}: {summary['accuracy']}% "
f"({summary['correct']}/{summary['total']})"
)
)
return "\n".join(lines)
def build_user_facing_reasoning_message(label, llm_name):
if llm_name == "LLM 1":
messages = {
"matched": "No issue in these runs.",
"right_answer_wrong_format": "The model identified the right answer but did not use the required Step 1 format.",
"wrong_answer_right_format": "The model used the required Step 1 format but identified the wrong target firm.",
"wrong_answer_wrong_format": "The model gave neither the right output nor the right format.",
"wrong_route": "The model chose the wrong route for this question and should have handled relevance differently.",
"call_failed": "The Step 1 model call failed, so no valid answer was produced.",
}
return messages.get(label, explain_reasoning_label(label))
if label == "matched":
return "No issue in these runs."
if label == "clean":
return "The final JSON was valid and passed the evaluator checks."
if label == "skipped_irrelevant":
return "This case was correctly rejected as irrelevant before Step 3."
if label == "invalid_json":
return "The final answer was not valid JSON in the required schema."
if label == "relevance_false_positive":
return (
"This question should have been rejected as irrelevant earlier, but the pipeline continued as if it were relevant."
)
parts = [part.strip() for part in str(label).split(",") if part.strip()]
explained = []
for part in parts:
detail = explain_output_issue_tag(part)
if detail:
explained.append(detail)
else:
explained.append(explain_reasoning_label(part))
return "; ".join(explained) if explained else "The final output had evaluator-detected issues."
def split_field_issue_tag(tag):
for field_name in sorted(FINAL_SCHEMA_KEYS, key=len, reverse=True):
suffix = f"_{field_name}"
if tag.endswith(suffix):
return tag[: -len(suffix)], field_name
return None, None
def explain_output_issue_tag(tag):
field_labels = {
"buyer_counsel": "Buyer counsel",
"seller_counsel": "Seller counsel",
"third_party_counsel": "Third-party counsel",
"user_question": "User question",
}
issue_name, field_name = split_field_issue_tag(tag)
if not issue_name or not field_name:
return None
field_label = field_labels[field_name]
issue_messages = {
"missing_field": "the field was missing from the final JSON.",
"malformed_citation": "the citation format was malformed.",
"unexpected_citation": "a citation was included where it should not have been.",
"used_unknown": 'the answer said "unknown" even though the evidence supported a conclusion.',
"hallucinated_value": "a value was given even though the expected answer was unknown.",
"wrong_value": "the answer value was wrong.",
"missing_citation": "the answer needed a citation but did not include one.",
"bad_citation_reference": "the citation pointed to a snippet number that does not exist.",
"unsupported_citation": "the citation pointed to a snippet that does not support the answer.",
}
message = issue_messages.get(issue_name)
if not message:
return None
return f"{field_label}: {message}"
def build_practice_run_issue_lines(runs, key, llm_name):
grouped = []
for run_index, run in enumerate(runs, start=1):
value = run[key]
if value in {"matched", "clean", "skipped_irrelevant"}:
continue
if grouped and grouped[-1]["value"] == value:
grouped[-1]["run_indices"].append(run_index)
continue
grouped.append({"value": value, "run_indices": [run_index]})
if not grouped:
return ["- No issues detected."]
return [
f"- RUN {format_run_indices(group['run_indices'])}: "
f"{build_user_facing_reasoning_message(group['value'], llm_name)}"
for group in grouped
]
def build_practice_case_feedback(case_results):
blocks = []
for case_index, case_result in enumerate(case_results, start=1):
run_count = len(case_result["runs"])
llm1_score = total_run_score(case_result["runs"], "llm1_score")
llm3_score = total_run_score(case_result["runs"], "llm3_score")
blocks.append(
"\n".join(
[
f"CASE {case_index}",
"",
"LLM 1",
f"Score: {format_score(llm1_score)} / {run_count}",
"Issues:",
*build_practice_run_issue_lines(case_result["runs"], "llm1_reasoning", "LLM 1"),
"",
"LLM 3",
f"Score: {format_score(llm3_score)} / {run_count}",
"Issues:",
*build_practice_run_issue_lines(case_result["runs"], "llm3_reasoning", "LLM 3"),
]
)
)
return join_sheet_sections(blocks)
def build_practice_scoring_note(submission_result):
total_cases = submission_result["total_cases"]
return (
f"This practice score is based on {total_cases} hidden calibration cases. "
"Each case is run 3 times to check prompt consistency."
)
def format_practice_feedback(name, submission_result):
case_results = submission_result["case_results"]
return (
f"Practice run complete for {name}.\n\n"
f"Score: {submission_result['score_percent']}%\n"
f"{build_practice_scoring_note(submission_result)}\n\n"
f"{build_practice_case_feedback(case_results)}\n\n"
"Your final submission is still available."
)
def load_eval_set_for_mode(submission_type):
prefix = "PRACTICE" if submission_type == "practice" else "FINAL"
return load_eval_set(prefix)
def validate_submission_inputs(email, name):
if not validate_email(email):
return "Invalid email address. Please enter a valid email."
if not name.strip():
return "Please enter your full name."
return None
def submit_attempt(submission_type, email, name, prompt_1, prompt_2, prompt_3):
if submission_type not in SUBMISSION_TYPES:
raise ValueError(f"Unsupported submission type: {submission_type}")
validation_error = validate_submission_inputs(email, name)
if validation_error:
return build_submission_response(False, validation_error)
normalized_email = normalize_email(email)
clean_name = sanitize_input(name)
prompts = {
"prompt_1": sanitize_prompt(prompt_1),
"prompt_2": sanitize_prompt(prompt_2),
"prompt_3": sanitize_prompt(prompt_3),
}
bypass_sheet_save = should_bypass_sheet_save()
try:
attempt_status = get_attempt_status(normalized_email)
except Exception:
log_stage_exception(
"Unexpected error while reading sheet state for %s submission for %s.",
submission_type,
normalized_email,
)
return build_submission_response(
False,
build_stage_error_message(submission_type, "sheet_read"),
)
if submission_type == "practice":
if attempt_status["final_used"]:
return build_submission_response(
False,
(
f"A final submission has already been received for {normalized_email}. "
"Practice is no longer available."
),
disable_practice=True,
disable_final=True,
)
if attempt_status["practice_used"]:
return build_submission_response(
False,
(
f"Practice has already been used for {normalized_email}. "
"Your final submission is still available."
),
disable_practice=True,
disable_final=False,
)
if submission_type == "final" and attempt_status["final_used"]:
return build_submission_response(
False,
(
f"Final submission already received for {normalized_email}. "
"You can only submit the final once."
),
disable_practice=True,
disable_final=True,
)
try:
eval_set = load_eval_set_for_mode(submission_type)
except Exception:
log_stage_exception(
"Unexpected error while loading %s dataset for %s.",
submission_type,
normalized_email,
)
return build_submission_response(
False,
build_stage_error_message(submission_type, "dataset_load"),
)
if not eval_set["cases"]:
return build_submission_response(
False,
f"No hidden cases configured for the {submission_type} dataset.",
)
try:
submission_result = grade_submission(eval_set, prompts, submission_type)
except Exception:
log_stage_exception(
"Unexpected error while grading %s submission for %s.",
submission_type,
normalized_email,
)
return build_submission_response(
False,
build_stage_error_message(submission_type, "grading"),
)
try:
record = build_attempt_record(clean_name, normalized_email, prompts, submission_result)
except Exception:
log_stage_exception(
"Unexpected error while building %s record for %s.",
submission_type,
normalized_email,
)
return build_submission_response(
False,
build_stage_error_message(submission_type, "record_build"),
)
bypass_notice = build_bypass_save_notice() if bypass_sheet_save else ""
if not bypass_sheet_save:
try:
save_attempt(record)
record_local_attempt(normalized_email, submission_type)
except Exception:
log_stage_exception(
"Unexpected error while saving %s submission for %s.",
submission_type,
normalized_email,
)
return build_submission_response(
False,
build_save_error_message(submission_type, "sheet_write"),
)
else:
print(
f"BYPASS_SHEET_SAVE enabled for {submission_type} submission {normalized_email}.",
flush=True,
)
if submission_type == "practice":
notice = "Practice run saved. Your final submission is still available."
if bypass_notice:
notice = bypass_notice
return build_submission_response(
True,
format_practice_feedback(clean_name, submission_result),
notice=notice,
disable_practice=not bypass_sheet_save,
disable_final=False,
)
notice = "Final submission received. You can close the page."
if bypass_notice:
notice = bypass_notice
return build_submission_response(
True,
f"Thank you for your submission, {clean_name}!",
notice=notice,
disable_practice=not bypass_sheet_save,
disable_final=not bypass_sheet_save,
)
def build_submission_callback_result(result):
return (
result["message"],
gr.update(interactive=not result["disable_practice"]),
gr.update(interactive=not result["disable_final"]),
gr.update(value=result["notice"], visible=bool(result["notice"])),
)
def handle_submission(submission_type, email, name, s1, s2, s3):
try:
result = submit_attempt(submission_type, email, name, s1, s2, s3)
except Exception:
log_stage_exception("Unhandled callback failure for %s submission.", submission_type)
result = build_submission_response(
False,
build_stage_error_message(submission_type, "callback"),
)
return build_submission_callback_result(result)
def build_interface():
with gr.Blocks(css=APP_CSS) as demo:
gr.Markdown(
"""
# Applicant Task: Target Company & Law Firm Identification
Draft prompts for a strict 3-step legal review pipeline over snippets from Asset Purchase Agreements [SEC Agreement Example](https://www.sec.gov/Archives/edgar/data/28452/000119312505012401/dex101.htm)
> This evaluation system uses (default: `gpt-5-mini`) for all LLM steps.
## What you need to do
- Decide whether the query is relevant and normalize the target firm name.
- Inspect each snippet independently and record only what that snippet actually supports.
- Combine those snippet-level findings into one final JSON answer with citations, including a cited answer to the original user question.
## The 3-step pipeline
### Step 1: Relevance and target-firm normalization
Decide whether the question belongs in this deal at all, and if it does, standardize the firm name you will track through the rest of the pipeline.
- If the query is irrelevant, return `<user_message>...</user_message>`.
- If the query is relevant, return only:
```text
TARGET_FIRM: <canonical firm name>
```
Please ensure your final output uses the exact key `TARGET_FIRM:` as shown above, alongside the firm name.
### Step 2: Snippet-by-snippet analysis
Step 2 is not the final answer; it is a working note for one snippet at a time.
- Your Step 2 prompt runs independently on each evidence unit, so the model only sees one snippet per call.
- The app passes:
- the Step 1 output
- `Snippet ID: S1..Sn`
- the snippet text
- A good Step 2 prompt says what this snippet supports, what it does not support, and where the evidence is still uncertain.
### Step 3: Reconciliation and final answer
Step 3 receives the Step 1 output plus all Step 2 notes and must turn them into one answer that could survive a hostile redline, including a direct answer to the original user question.
- Step 3 receives the Step 1 output plus all Step 2 outputs.
- It must return valid JSON with this exact schema:
```json
{
"buyer_counsel": "string with citations like \"Firm Name [^2]\" or \"unknown\"",
"seller_counsel": "string with citations like \"Firm Name [^4]\" or \"unknown\"",
"third_party_counsel": "string with citations like \"Firm Name [^1]\" or \"unknown\"",
"user_question": "string with citations like \"true [^2]\", \"false [^4]\", or \"unknown\""
}
```
Use `"unknown"` when a counsel field or the user-question answer cannot be supported by the evidence. When you provide a firm name, include snippet citations such as `[^2]` that point to the relevant Step 2 snippet IDs. The `user_question` field should answer the original query using only `true`, `false`, or `unknown`: if the answer is `true` or `false`, include supporting snippet citations and do not add any extra text.
"""
)
gr.Markdown(
"""
## Workflow At A Glance
This visual shows how the query moves from `LLM1` to `LLM2` to `LLM3`, and how the final JSON is assembled from the APA snippets.
"""
)
gr.Image(
value="mermaid_diagram.png",
label="Pipeline overview",
show_label=True,
interactive=False,
)
with gr.Accordion("Example Workflow", open=False):
gr.Markdown(
"""
**User query**
```text
Is Kirkland & Ellis LLP acting as counsel anywhere in this Asset Purchase Agreement?
```
**Step 1 relevant output**
```text
TARGET_FIRM: Kirkland & Ellis LLP
```
**Example evidence units**
- `S1`: the opening paragraph names the parties and mentions Kirkland & Ellis LLP as transaction counsel to the buyer.
- `S2`: the notices section identifies buyer counsel as Kirkland & Ellis LLP and seller counsel as Wachtell, Lipton, Rosen & Katz.
- `S3`: a boilerplate clause contains no counsel information and should not drive the final answer.
- `S4`: a representative provision states that Gibson, Dunn & Crutcher LLP advises the securityholders' representative.
- `S5`: a later notice block again confirms seller counsel as Wachtell, Lipton, Rosen & Katz.
**Final JSON shape**
```json
{
"buyer_counsel": "Kirkland & Ellis LLP [^1] [^2]",
"seller_counsel": "Wachtell, Lipton, Rosen & Katz [^2] [^5]",
"third_party_counsel": "Gibson, Dunn & Crutcher LLP [^4]",
"user_question": "true [^1] [^2]"
}
```
"""
)
with gr.Accordion("Practice and Final Submission", open=False):
gr.Markdown(
"""
- You may use **one optional practice run** per email to test your prompts against a hidden calibration set.
- The practice run uses 3 hidden calibration cases.
- Each case is run 3 times to check prompt consistency.
- For each run, LLM 1 can earn up to 1 point for correct routing and target-firm normalization, and LLM 3 can earn up to 1 point for a correct final JSON answer with supported citations.
- Step 2 is not scored directly, but it strongly affects the LLM 3 score because Step 3 relies on the snippet-level analysis.
- Practice returns aggregate feedback only: score percentage, an LLM 1 summary, and an LLM 3 summary.
- You may then revise your prompts or keep them as they are.
- You may submit **one final submission** per email against a separate hidden holdout set.
- After the final submission, practice is no longer available.
- No structured decoding is used for you, so your prompts must make Step 3 produce reliable JSON on their own.
"""
)
gr.Markdown(
"""
Enter your name and email exactly as listed in your CV. Both buttons below use the same three prompt boxes.
You have **one** chance to run the practice set and get feedback, and **one** chance to run the final set. After you click a button, wait for the results to load before clicking again or refreshing the page.
**Good Luck!**
"""
)
email_input = gr.Textbox(label="Email", placeholder="your.email@example.com")
name_input = gr.Textbox(label="First Name, Last Name", placeholder="John Smith")
system_prompt_input_1 = gr.Textbox(
label="System Prompt for Step 1",
placeholder="Enter your Step 1 prompt here...",
lines=6,
)
system_prompt_input_2 = gr.Textbox(
label="System Prompt for Step 2",
placeholder="Enter your Step 2 prompt here...",
lines=10,
)
system_prompt_input_3 = gr.Textbox(
label="System Prompt for Step 3",
placeholder="Enter your Step 3 prompt here...",
lines=6,
)
gr.Markdown(
"""
<div class="submission-note">
<b>Please note:</b><br>
Each run may take a couple of minutes.<br>
After you click a button, wait for the result and do not click it again.
</div>
"""
)
with gr.Row():
practice_button = gr.Button("Practice Run")
final_button = gr.Button("Submit Final")
output_text = gr.Textbox(label="Results", lines=18)
feedback_md = gr.Markdown("", visible=False)
def practice_submit_and_update(email, name, s1, s2, s3):
return handle_submission("practice", email, name, s1, s2, s3)
def final_submit_and_update(email, name, s1, s2, s3):
return handle_submission("final", email, name, s1, s2, s3)
practice_button.click(
fn=practice_submit_and_update,
inputs=[
email_input,
name_input,
system_prompt_input_1,
system_prompt_input_2,
system_prompt_input_3,
],
outputs=[output_text, practice_button, final_button, feedback_md],
)
final_button.click(
fn=final_submit_and_update,
inputs=[
email_input,
name_input,
system_prompt_input_1,
system_prompt_input_2,
system_prompt_input_3,
],
outputs=[output_text, practice_button, final_button, feedback_md],
)
return demo
if __name__ == "__main__":
interface = build_interface()
interface.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)