import json import logging import os import re import threading import time import traceback from collections import Counter from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timezone import gradio as gr import gspread from google.oauth2.service_account import Credentials from openai import OpenAI SCOPES = [ "https://www.googleapis.com/auth/spreadsheets", "https://www.googleapis.com/auth/drive", ] DEFAULT_MODEL_NAME = "gpt-5-mini" DEFAULT_EVAL_MAX_WORKERS = 4 DEFAULT_LLM2_MAX_WORKERS = 8 DEFAULT_OPENAI_MAX_CONCURRENT_REQUESTS = 16 EVAL_REPEAT_COUNT = 3 SUBMISSION_TYPES = {"practice", "final"} IRRELEVANT_TAG_START = "" IRRELEVANT_TAG_END = "" SHEET_SECTION_DIVIDER = "=" * 50 STAGE_ONE_OUTPUT_MAX_CHARS = 4000 STAGE_TWO_OUTPUT_MAX_CHARS = 2000 STAGE_THREE_OUTPUT_MAX_CHARS = 4000 PROMPT_CELL_MAX_CHARS = 8000 EXPECTED_OUTPUT_MAX_CHARS = 4000 SHEET_CELL_MAX_CHARS = 49000 COUNSEL_FIELDS = ("buyer_counsel", "seller_counsel", "third_party_counsel") TARGET_FIRM_FIELD = "target_firm" USER_QUESTION_FIELD = "user_question" FINAL_SCHEMA_KEYS = (*COUNSEL_FIELDS, USER_QUESTION_FIELD) CITATION_PATTERN = re.compile(r"\[\^(\d+)\]") BOOLEAN_ANSWER_PATTERN = re.compile(r"^(true|false|yes|no)\b$", re.IGNORECASE) LOCAL_ATTEMPTS = {} _OPENAI_CLIENT = None _OPENAI_REQUEST_SEMAPHORE = None _OPENAI_REQUEST_SEMAPHORE_LIMIT = None _OPENAI_REQUEST_SEMAPHORE_LOCK = threading.Lock() _LOG_LEVEL_NAME = os.environ.get("LOG_LEVEL", "INFO").upper() logging.basicConfig(level=getattr(logging, _LOG_LEVEL_NAME, logging.INFO)) logger = logging.getLogger(__name__) SHEET_COLUMNS = [ ("timestamp", "Submitted At"), ("submission_type", "Submission Type"), ("name", "Candidate Name"), ("email", "Candidate Email"), ("results", "Results"), ("results_without_llm_1", "Results LLM3"), ("llm_score_breakdown", "LLM Score Breakdown"), ("llm_output_vs_expected", "LLM Output vs Expected"), ("llm_2_output", "LLM 2 Output"), ("prompts", "Prompts"), ] SHEET_HEADERS = [label for _, label in SHEET_COLUMNS] APP_CSS = """ .submission-note { background: var(--submission-note-bg, #fff7e6); border: 1px solid var(--submission-note-border, #ffe5b4); border-radius: 8px; color: var(--submission-note-text, #3d2b00); margin-bottom: 1em; padding: 16px; } @media (prefers-color-scheme: dark) { .submission-note { --submission-note-bg: #2b2111; --submission-note-border: #6b4b18; --submission-note-text: #f7e8c5; } } .dark .submission-note { --submission-note-bg: #2b2111; --submission-note-border: #6b4b18; --submission-note-text: #f7e8c5; } """ def get_model_name(): return os.environ.get("OPENAI_MODEL_NAME", DEFAULT_MODEL_NAME) def get_positive_int_env(name, default): raw_value = os.environ.get(name, str(default)) try: value = int(raw_value) except ValueError: logger.warning("Invalid integer for %s=%r; using default %s.", name, raw_value, default) return default return max(1, value) def get_eval_max_workers(): if "EVAL_CASE_MAX_WORKERS" in os.environ: return get_positive_int_env("EVAL_CASE_MAX_WORKERS", DEFAULT_EVAL_MAX_WORKERS) return get_positive_int_env("EVAL_MAX_WORKERS", DEFAULT_EVAL_MAX_WORKERS) def get_llm2_max_workers(): return get_positive_int_env("LLM2_MAX_WORKERS", DEFAULT_LLM2_MAX_WORKERS) def get_openai_max_concurrent_requests(): return get_positive_int_env( "OPENAI_MAX_CONCURRENT_REQUESTS", DEFAULT_OPENAI_MAX_CONCURRENT_REQUESTS, ) def get_openai_request_semaphore(): global _OPENAI_REQUEST_SEMAPHORE global _OPENAI_REQUEST_SEMAPHORE_LIMIT limit = get_openai_max_concurrent_requests() with _OPENAI_REQUEST_SEMAPHORE_LOCK: if _OPENAI_REQUEST_SEMAPHORE is None or _OPENAI_REQUEST_SEMAPHORE_LIMIT != limit: _OPENAI_REQUEST_SEMAPHORE = threading.BoundedSemaphore(limit) _OPENAI_REQUEST_SEMAPHORE_LIMIT = limit return _OPENAI_REQUEST_SEMAPHORE def get_openai_client(): global _OPENAI_CLIENT if _OPENAI_CLIENT is None: api_key = os.environ.get("OPENAI_API_KEY") if not api_key: raise RuntimeError("OPENAI_API_KEY is not set.") _OPENAI_CLIENT = OpenAI(api_key=api_key) return _OPENAI_CLIENT def get_google_sheet(ensure_headers=True): creds_json = os.environ.get("GOOGLE_CREDS_JSON") spreadsheet_id = os.environ.get("SPREADSHEET_ID") if not creds_json or not spreadsheet_id: raise RuntimeError("GOOGLE_CREDS_JSON and SPREADSHEET_ID must be set.") creds = Credentials.from_service_account_info( json.loads(creds_json), scopes=SCOPES, ) gc = gspread.authorize(creds) spreadsheet = gc.open_by_key(spreadsheet_id) sheet = spreadsheet.worksheet("Submissions") if ensure_headers: ensure_submission_sheet_headers(sheet) return sheet def load_eval_set(prefix): questions_str = os.environ.get(f"{prefix}_QUESTIONS_JSON") documents_str = os.environ.get(f"{prefix}_DOCUMENTS_JSON") expected_str = os.environ.get(f"{prefix}_EXPECTED_JSON") if not questions_str or not documents_str or not expected_str: return {"cases": []} questions = json.loads(questions_str) documents = json.loads(documents_str) expected = json.loads(expected_str) if not ( isinstance(questions, list) and isinstance(documents, list) and isinstance(expected, list) ): raise ValueError(f"{prefix} dataset must be a JSON list for all fields.") if len(questions) != len(documents) or len(questions) != len(expected): raise ValueError( f"{prefix} dataset lengths do not match for questions, documents, and expected answers." ) cases = [] for case_index, (question, docs_entry, expected_entry) in enumerate( zip(questions, documents, expected), start=1, ): normalized_docs = normalize_documents_entry(docs_entry) cases.append( { "question": str(question), "docs": normalized_docs, "expected": normalize_expected_entry( expected_entry, prefix, case_index, question=str(question), docs=normalized_docs, ), } ) return {"cases": cases} def build_submission_response( ok, message, notice="", disable_practice=False, disable_final=False, ): return { "ok": ok, "message": message, "notice": notice, "disable_practice": disable_practice, "disable_final": disable_final, } def build_internal_error_message(submission_type): label = "practice run" if submission_type == "practice" else "submission" return ( f"We hit an internal error while processing your {label}. " "Please try again in a minute." ) def build_stage_error_message(submission_type, stage): label = "practice run" if submission_type == "practice" else "submission" return ( f"We hit an internal error while processing your {label}. " f"Stage: {stage}." ) def build_save_error_message(submission_type, stage="sheet_write"): label = "practice run" if submission_type == "practice" else "submission" return ( f"We could not save your {label} right now. " f"Stage: {stage}. Please check the Space logs and spreadsheet permissions, then try again." ) def build_bypass_save_notice(): return "Debug mode is active: this run was graded but not saved to Google Sheets." def should_bypass_sheet_save(): return os.environ.get("BYPASS_SHEET_SAVE", "").strip().lower() == "true" def log_stage_exception(message, *args): logger.exception(message, *args) print(traceback.format_exc(), flush=True) def safe_evaluate_case(case, prompts, case_index): try: return evaluate_case(case, prompts, case_index=case_index) except Exception as exc: logger.exception("Case %s failed during evaluation.", case_index) raise RuntimeError(f"Evaluation failed for case {case_index}.") from exc def normalize_documents_entry(entry): if isinstance(entry, list): documents = [str(item).strip() for item in entry if str(item).strip()] return documents text = str(entry).strip() if not text: return [] if "---" in text: parts = [part.strip() for part in text.split("---")] return [part for part in parts if part] return [text] def get_sheet_column_letter(index): letters = "" while index > 0: index, remainder = divmod(index - 1, 26) letters = chr(65 + remainder) + letters return letters def ensure_submission_sheet_headers(sheet): expected_headers = SHEET_HEADERS end_column = get_sheet_column_letter(len(expected_headers)) sheet.update(f"A1:{end_column}1", [expected_headers]) def get_next_submission_row_index(sheet): rows = sheet.get_all_values() last_data_row = 1 for row_index, row in enumerate(rows, start=1): if row_index == 1: continue if any(str(cell).strip() for cell in row): last_data_row = row_index return last_data_row + 1 def serialize_sheet_value(value): if isinstance(value, (dict, list)): return serialize_json(value) return "" if value is None else str(value) def normalize_expected_mapping(entry, prefix, case_index): raw_target_firm = entry.get(TARGET_FIRM_FIELD) if raw_target_firm is None: raise ValueError( f"{prefix} case {case_index} is missing a non-empty '{TARGET_FIRM_FIELD}'." ) target_firm = normalize_whitespace(raw_target_firm) if not target_firm or target_firm.lower() == "none": raise ValueError( f"{prefix} case {case_index} is missing a non-empty '{TARGET_FIRM_FIELD}'." ) normalized = {TARGET_FIRM_FIELD: target_firm} for field_name in COUNSEL_FIELDS: field_value = entry.get(field_name, "unknown") if field_value is None: normalized[field_name] = "unknown" continue normalized[field_name] = str(field_value).strip() or "unknown" return normalized def extract_target_candidate_from_question(question): stripped = normalize_whitespace(question) patterns = [ r"^Is\s+(.+?)\s+present\s+in\s+the\s+agreement\??$", r"^Is\s+(.+?)\s+mentioned\s+in\s+the\s+agreement\??$", r"^Is\s+(.+?)\s+acting\s+as\s+counsel.*\??$", r"^Is\s+(.+?)\s+anywhere\s+in\s+this\s+Asset\s+Purchase\s+Agreement\??$", r"^Is\s+(.+?)\s+in\s+the\s+agreement\??$", ] for pattern in patterns: match = re.match(pattern, stripped, re.IGNORECASE) if match: return normalize_whitespace(match.group(1)) match = re.match(r"^Is\s+(.+?)\??$", stripped, re.IGNORECASE) if match: return normalize_whitespace(match.group(1)) return "" def extract_entity_candidates(text): pattern = re.compile( r"\b([A-Z][A-Za-z0-9'.,-]*(?:\s+(?:&|[A-Z][A-Za-z0-9'.,-]*))*\s+" r"(?:LLP|LLC|LP|Inc\.?|Corporation|Corp\.?|Ltd\.?))\b" ) return [normalize_whitespace(match.group(1)) for match in pattern.finditer(str(text))] def build_legacy_expected_mapping(entry, question, docs): normalized = { "buyer_counsel": normalize_whitespace(entry.get("buyer_firm", "unknown")) or "unknown", "seller_counsel": normalize_whitespace(entry.get("seller_firm", "unknown")) or "unknown", "third_party_counsel": normalize_whitespace(entry.get("third_party", "unknown")) or "unknown", } target_candidate = extract_target_candidate_from_question(question) target_norm = normalize_counsel_value(target_candidate) for candidate in normalized.values(): candidate_text = normalize_whitespace(candidate) if not candidate_text: continue if target_norm and target_norm in normalize_counsel_value(candidate_text): normalized[TARGET_FIRM_FIELD] = candidate_text break else: doc_candidates = [] for doc in docs: doc_candidates.extend(extract_entity_candidates(doc)) for candidate_text in doc_candidates: if target_norm and target_norm in normalize_counsel_value(candidate_text): normalized[TARGET_FIRM_FIELD] = candidate_text break else: normalized[TARGET_FIRM_FIELD] = target_candidate or "unknown target" return normalized def normalize_expected_entry(entry, prefix, case_index, question="", docs=None): docs = docs or [] if isinstance(entry, dict): legacy_keys = {"buyer_firm", "seller_firm", "third_party", "contains_target_firm"} if TARGET_FIRM_FIELD not in entry and legacy_keys.intersection(entry.keys()): return build_legacy_expected_mapping(entry, question, docs) return normalize_expected_mapping(entry, prefix, case_index) if not isinstance(entry, str): raise ValueError( f"{prefix} case {case_index} expected entry must be a string or object." ) stripped = entry.strip() if stripped.startswith("{") and stripped.endswith("}"): try: parsed = json.loads(stripped) except json.JSONDecodeError as exc: raise ValueError( f"{prefix} case {case_index} contains invalid JSON in expected entry: {exc}" ) from exc if isinstance(parsed, dict): legacy_keys = {"buyer_firm", "seller_firm", "third_party", "contains_target_firm"} if TARGET_FIRM_FIELD not in parsed and legacy_keys.intersection(parsed.keys()): return build_legacy_expected_mapping(parsed, question, docs) return normalize_expected_mapping(parsed, prefix, case_index) return entry def sanitize_input(text, max_length=500): clean_text = re.sub(r"[^a-zA-Z0-9\s.,!?@:\-+'&()/]", "", text) return clean_text.strip()[:max_length] def sanitize_prompt(text): return text.strip()[:8000] def normalize_email(email): return email.strip().lower() def validate_email(email): email_regex = r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$" return re.match(email_regex, email) is not None def extract_user_message(text): if not text: return None pattern = re.compile( rf"{re.escape(IRRELEVANT_TAG_START)}(.*?){re.escape(IRRELEVANT_TAG_END)}", re.IGNORECASE | re.DOTALL, ) match = pattern.search(text) if not match: return None return match.group(1).strip() def normalize_whitespace(text): return re.sub(r"\s+", " ", str(text)).strip() def normalize_counsel_value(value): if value is None: return "" text = normalize_whitespace(value).lower() text = text.replace("&", " and ") text = re.sub(r"[^a-z0-9\s]", " ", text) return " ".join(text.split()) def value_is_unknown(value): return normalize_whitespace(value).lower() == "unknown" def expected_is_unknown(expected_value): return value_is_unknown(expected_value) or normalize_whitespace(expected_value) == "" def truncate_text(text, max_chars=2000): text = str(text or "") if len(text) <= max_chars: return text return f"{text[: max_chars - 3]}..." def serialize_json(value): return json.dumps(value, ensure_ascii=True, sort_keys=True) def normalize_newlines(text): return str(text or "").replace("\r\n", "\n").replace("\r", "\n") def normalize_sheet_block_text(text): normalized = normalize_newlines(text) lines = [line.rstrip() for line in normalized.split("\n")] return "\n".join(lines).strip() def format_score(value): formatted = f"{float(value):.4f}".rstrip("0").rstrip(".") if "." not in formatted: formatted += ".0" return formatted def join_sheet_sections(blocks): cleaned_blocks = [block for block in blocks if block] return f"\n\n{SHEET_SECTION_DIVIDER}\n\n".join(cleaned_blocks) def build_case_run_header(case_index, run_index): return f"CASE {case_index} | RUN {run_index}" def format_run_indices(run_indices): labels = [str(run_index) for run_index in run_indices] if not labels: return "" if len(labels) == 1: return labels[0] if len(labels) == 2: return f"{labels[0]} AND {labels[1]}" return f"{', '.join(labels[:-1])} AND {labels[-1]}" def build_grouped_case_run_header(case_index, run_indices): return f"CASE {case_index} | RUN {format_run_indices(run_indices)}" def truncate_sheet_block(text, max_chars): normalized = normalize_sheet_block_text(text) return truncate_text(normalized, max_chars=max_chars) def format_pretty_expected_answer(expected): if isinstance(expected, str): return "Irrelevant case. No final JSON should be produced because LLM 1 should stop the pipeline." return truncate_text( json.dumps(expected, ensure_ascii=True, sort_keys=True, indent=2), max_chars=EXPECTED_OUTPUT_MAX_CHARS, ) def format_llm1_expected_text(expected): if isinstance(expected, str): return "Irrelevant case: output must be wrapped in ...." return "Relevant case: output must be exactly TARGET_FIRM: ." def format_llm2_user_output(llm2_outputs): if not llm2_outputs: return "Not run." blocks = [] for snippet_id, snippet_output in llm2_outputs.items(): blocks.append( "\n".join( [ f"{snippet_id}:", truncate_sheet_block(snippet_output, max_chars=STAGE_TWO_OUTPUT_MAX_CHARS), ] ) ) return "\n\n".join(blocks) def format_teacher_forced_note(run_result): if run_result["used_teacher_forced_step1"]: return ( "LLM 1 did not match. LLM 2 and LLM 3 were re-run with the teacher-forced " f"Step 1 context: {run_result['effective_step1_context']}" ) return f"LLM 2 and LLM 3 used the submitted Step 1 context: {run_result['effective_step1_context']}" def parse_cited_value(value): raw_text = normalize_whitespace(value) citations = [int(match.group(1)) for match in CITATION_PATTERN.finditer(raw_text)] malformed_citation = False for token in re.findall(r"\[\^[^\]]*\]", raw_text): if not CITATION_PATTERN.fullmatch(token): malformed_citation = True break if "[^" in raw_text and not re.findall(r"\[\^[^\]]*\]", raw_text): malformed_citation = True base_value = normalize_whitespace(CITATION_PATTERN.sub("", raw_text)) return { "raw": raw_text, "base_value": base_value, "citations": citations, "malformed_citation": malformed_citation, } def snippet_supports_value(snippet_text, value): normalized_value = normalize_counsel_value(value) if not normalized_value: return False return normalized_value in normalize_counsel_value(snippet_text) def get_supporting_snippet_numbers(docs, value): matches = [] for index, doc in enumerate(docs, start=1): if snippet_supports_value(doc, value): matches.append(index) return matches def append_citations(value, snippet_numbers): if value_is_unknown(value) or not snippet_numbers: return value citations = " ".join(f"[^{snippet_number}]" for snippet_number in snippet_numbers) return f"{value} {citations}".strip() def dedupe_preserving_order(values): seen = set() ordered = [] for value in values: normalized = normalize_counsel_value(value) if not normalized or normalized in seen: continue seen.add(normalized) ordered.append(normalize_whitespace(value)) return ordered def find_contradicting_entity(expected, docs, target_firm): target_norm = normalize_counsel_value(target_firm) for field_name in COUNSEL_FIELDS: candidate = normalize_whitespace(expected.get(field_name, "unknown")) if expected_is_unknown(candidate): continue if normalize_counsel_value(candidate) == target_norm: continue snippet_numbers = get_supporting_snippet_numbers(docs, candidate) if snippet_numbers: return candidate, snippet_numbers doc_candidates = [] for doc in docs: doc_candidates.extend(extract_entity_candidates(doc)) for candidate in dedupe_preserving_order(doc_candidates): if normalize_counsel_value(candidate) == target_norm: continue snippet_numbers = get_supporting_snippet_numbers(docs, candidate) if snippet_numbers: return candidate, snippet_numbers return "", [] def build_expected_user_question_spec(expected, docs, target_firm): target_snippet_numbers = get_supporting_snippet_numbers(docs, target_firm) if target_snippet_numbers: return { "verdict": "true", "company": normalize_whitespace(target_firm), "citations": target_snippet_numbers, } contradicting_company, contradicting_snippet_numbers = find_contradicting_entity( expected, docs, target_firm, ) if contradicting_company and contradicting_snippet_numbers: return { "verdict": "false", "company": contradicting_company, "citations": contradicting_snippet_numbers, } return { "verdict": "unknown", "company": "", "citations": [], } def format_expected_user_question(spec): verdict = spec["verdict"] citations = spec.get("citations", []) if verdict == "unknown": return "unknown" return append_citations(verdict, citations) def parse_user_question_answer(value): parsed = parse_cited_value(value) base_value = parsed["base_value"] if value_is_unknown(base_value): return { **parsed, "verdict": "unknown", "evidence": "", } match = BOOLEAN_ANSWER_PATTERN.match(base_value) if match: token = match.group(1).lower() verdict = "true" if token in {"true", "yes"} else "false" return { **parsed, "verdict": verdict, } return { **parsed, "verdict": None, } def validate_citations_for_value(parsed_value, docs, expected_value): bad_references = [] unsupported_citations = [] for citation_number in parsed_value["citations"]: if citation_number < 1 or citation_number > len(docs): bad_references.append(citation_number) continue if not snippet_supports_value(docs[citation_number - 1], expected_value): unsupported_citations.append(citation_number) return bad_references, unsupported_citations def build_expected_step1_output(expected): if isinstance(expected, str): return f"{IRRELEVANT_TAG_START}irrelevant{IRRELEVANT_TAG_END}" return f"TARGET_FIRM: {expected[TARGET_FIRM_FIELD]}" def build_expected_llm3_answer(expected, docs, target_firm): if isinstance(expected, str): return expected formatted = {} for field_name in COUNSEL_FIELDS: expected_value = str(expected.get(field_name, "unknown")) snippet_numbers = get_supporting_snippet_numbers(docs, expected_value) formatted[field_name] = append_citations(expected_value, snippet_numbers) formatted[USER_QUESTION_FIELD] = format_expected_user_question( build_expected_user_question_spec(expected, docs, target_firm) ) return formatted def format_prompts_cell(prompts): blocks = [] for label, prompt_key in ( ("LLM 1", "prompt_1"), ("LLM 2", "prompt_2"), ("LLM 3", "prompt_3"), ): blocks.append( "\n".join( [ label, "Prompt:", truncate_sheet_block(prompts[prompt_key], max_chars=PROMPT_CELL_MAX_CHARS), ] ) ) return truncate_text( join_sheet_sections(blocks), max_chars=SHEET_CELL_MAX_CHARS, ) def parse_relevant_stage_one_output(text): normalized = normalize_sheet_block_text(text) match = re.fullmatch(r"TARGET_FIRM:\s*(.+)", normalized) if not match: return None target_firm = match.group(1).strip() if not target_firm: return None return { "target_firm": target_firm, } def is_error_output(text, stage_name): return str(text).startswith(f"Error during {stage_name} call:") def run_chat_completion(system_prompt, user_prompt): semaphore = get_openai_request_semaphore() semaphore.acquire() try: response = get_openai_client().chat.completions.create( model=get_model_name(), messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], ) return response.choices[0].message.content.strip() finally: semaphore.release() def run_step_one(prompts, question): try: llm1_output = run_chat_completion(prompts["prompt_1"], question) except Exception as exc: llm1_output = f"Error during LLM1 call: {exc}" user_message = extract_user_message(llm1_output) return { "llm1_output": llm1_output, "is_irrelevant": user_message is not None, "user_message": user_message, } def build_llm2_user_prompt(step1_context, snippet_id, doc): return ( f"Target firm context:\n{step1_context}\n\n" f"Snippet ID: {snippet_id}\n" f"Snippet Text:\n{doc}" ) def run_llm2_snippet(prompts, step1_context, snippet_id, doc): try: return snippet_id, run_chat_completion( prompts["prompt_2"], build_llm2_user_prompt(step1_context, snippet_id, doc), ) except Exception as exc: return snippet_id, f"Error during LLM2 call: {exc}" def run_downstream_stages(prompts, docs, step1_context): llm2_outputs = {f"S{index}": "" for index in range(1, len(docs) + 1)} llm3_output = "" if docs: max_workers = min(get_llm2_max_workers(), len(docs)) if max_workers == 1: for index, doc in enumerate(docs, start=1): snippet_id = f"S{index}" _, snippet_output = run_llm2_snippet(prompts, step1_context, snippet_id, doc) llm2_outputs[snippet_id] = snippet_output else: with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [ executor.submit( run_llm2_snippet, prompts, step1_context, f"S{index}", doc, ) for index, doc in enumerate(docs, start=1) ] for future in as_completed(futures): snippet_id, snippet_output = future.result() llm2_outputs[snippet_id] = snippet_output serialized_llm2 = [] for snippet_id, snippet_output in llm2_outputs.items(): serialized_llm2.append(f"{snippet_id}\n{snippet_output}") try: llm3_output = run_chat_completion( prompts["prompt_3"], ( f"Target firm context:\n{step1_context}\n\n" "Per-snippet analyses:\n" f"{chr(10).join(serialized_llm2)}" ), ) except Exception as exc: llm3_output = f"Error during LLM3 call: {exc}" return { "snippet_count": len(docs), "llm2_outputs": llm2_outputs, "llm3_output": llm3_output, } def parse_final_answer(answer): try: parsed = json.loads(answer) except json.JSONDecodeError as exc: return None, f"Invalid JSON: {exc}" if not isinstance(parsed, dict): return None, "Final answer must be a JSON object." return parsed, None def build_field_result(correct, issues=None, penalty=None): if penalty is None: penalty = 0.0 if correct else 0.2 return { "correct": correct, "issues": issues or [], "penalty": round(float(penalty), 4), } def build_prefixed_failure_tags(field_name, issues): return [f"{issue}_{field_name}" for issue in issues] def grade_counsel_field(field_name, parsed_answer, expected_answer, docs): if field_name not in parsed_answer: return build_field_result(False, ["missing_field"]) parsed_value = parse_cited_value(parsed_answer.get(field_name)) expected_value = str(expected_answer.get(field_name, "unknown")) expected_unknown = expected_is_unknown(expected_value) actual_unknown = value_is_unknown(parsed_value["base_value"]) issues = [] if parsed_value["malformed_citation"]: issues.append("malformed_citation") if actual_unknown and expected_unknown: if parsed_value["citations"]: issues.append("unexpected_citation") return build_field_result(not issues, issues) if actual_unknown and not expected_unknown: issues.append("used_unknown") return build_field_result(False, issues) if not actual_unknown and expected_unknown: issues.append("hallucinated_value") return build_field_result(False, issues) if normalize_counsel_value(parsed_value["base_value"]) != normalize_counsel_value(expected_value): issues.append("wrong_value") if not parsed_value["citations"]: issues.append("missing_citation") else: bad_references = [] unsupported_citations = [] for citation_number in parsed_value["citations"]: if citation_number < 1 or citation_number > len(docs): bad_references.append(citation_number) continue if not snippet_supports_value(docs[citation_number - 1], parsed_value["base_value"]): unsupported_citations.append(citation_number) if bad_references: issues.append("bad_citation_reference") if unsupported_citations: issues.append("unsupported_citation") return build_field_result(not issues, issues) def grade_user_question_field(parsed_answer, expected_user_question, docs, target_firm): if USER_QUESTION_FIELD not in parsed_answer: return build_field_result(False, ["missing_field"], penalty=0.2) parsed_value = parse_user_question_answer(parsed_answer.get(USER_QUESTION_FIELD)) issues = [] penalty = 0.0 if parsed_value["malformed_citation"]: issues.append("malformed_citation") expected_verdict = expected_user_question["verdict"] expected_company = normalize_whitespace(expected_user_question.get("company", "")) if expected_verdict == "unknown": if parsed_value["verdict"] != "unknown": issues.append("wrong_value") penalty += 0.1 if parsed_value["citations"]: issues.append("unexpected_citation") penalty += 0.1 return build_field_result(not issues, issues, penalty=min(0.2, penalty)) if parsed_value["verdict"] == "unknown": issues.append("used_unknown") return build_field_result(False, issues, penalty=0.2) if parsed_value["verdict"] != expected_verdict: issues.append("wrong_value") penalty += 0.1 if not parsed_value["citations"]: issues.append("missing_citation") penalty += 0.1 else: citation_target = expected_company or target_firm bad_references, unsupported_citations = validate_citations_for_value( parsed_value, docs, citation_target, ) if bad_references: issues.append("bad_citation_reference") if unsupported_citations: issues.append("unsupported_citation") if parsed_value["malformed_citation"] or bad_references or unsupported_citations: penalty += 0.1 return build_field_result(not issues, issues, penalty=min(0.2, penalty)) def grade_case(case_run, expected): case_result = {"case_score": 0.0, "failure_tags": [], "field_results": {}} if isinstance(expected, str): if case_run["is_irrelevant"]: case_result["case_score"] = 1.0 else: case_result["failure_tags"] = ["relevance_false_positive"] case_result["passed"] = case_result["case_score"] == 1.0 return case_result parsed_answer, parse_error = parse_final_answer(case_run["llm3_output"]) if parse_error: case_result["failure_tags"] = ["invalid_json"] case_result["field_results"] = { field_name: build_field_result(False, ["invalid_json"]) for field_name in FINAL_SCHEMA_KEYS } case_result["passed"] = False return case_result total_penalty = 0.0 failure_tags = [] field_results = {} for field_name in COUNSEL_FIELDS: field_result = grade_counsel_field(field_name, parsed_answer, expected, case_run["docs"]) field_results[field_name] = field_result total_penalty += field_result["penalty"] if field_result["issues"]: failure_tags.extend(build_prefixed_failure_tags(field_name, field_result["issues"])) user_question_result = grade_user_question_field( parsed_answer, build_expected_user_question_spec(expected, case_run["docs"], case_run["effective_target_firm"]), case_run["docs"], case_run["effective_target_firm"], ) field_results[USER_QUESTION_FIELD] = user_question_result total_penalty += user_question_result["penalty"] if user_question_result["issues"]: failure_tags.extend( build_prefixed_failure_tags(USER_QUESTION_FIELD, user_question_result["issues"]) ) case_result["case_score"] = round(max(0.0, 1.0 - total_penalty), 4) case_result["failure_tags"] = sorted(set(failure_tags)) case_result["field_results"] = field_results case_result["passed"] = total_penalty == 0.0 return case_result def grade_llm1_stage(llm1_output, expected): is_irrelevant = extract_user_message(llm1_output) is not None parsed_output = parse_relevant_stage_one_output(llm1_output) if isinstance(expected, str): if is_irrelevant: return { "score": 1.0, "reasoning": "matched", "failure_tags": [], "parsed_output": None, } reasoning = "call_failed" if is_error_output(llm1_output, "LLM1") else "wrong_route" return { "score": 0.0, "reasoning": reasoning, "failure_tags": [f"llm1_{reasoning}"], "parsed_output": None, } if is_irrelevant: return { "score": 0.0, "reasoning": "wrong_route", "failure_tags": ["llm1_wrong_route"], "parsed_output": None, } if parsed_output is None: if is_error_output(llm1_output, "LLM1"): reasoning = "call_failed" score = 0.0 else: expected_target = normalize_counsel_value(expected[TARGET_FIRM_FIELD]) output_norm = normalize_counsel_value(llm1_output) if expected_target and expected_target in output_norm: reasoning = "right_answer_wrong_format" score = 0.5 else: reasoning = "wrong_answer_wrong_format" score = 0.0 return { "score": score, "reasoning": reasoning, "failure_tags": [f"llm1_{reasoning}"], "parsed_output": None, } if normalize_counsel_value(parsed_output["target_firm"]) != normalize_counsel_value( expected[TARGET_FIRM_FIELD] ): return { "score": 0.5, "reasoning": "wrong_answer_right_format", "failure_tags": ["llm1_wrong_answer_right_format"], "parsed_output": parsed_output, } return { "score": 1.0, "reasoning": "matched", "failure_tags": [], "parsed_output": parsed_output, } def summarize_llm2_stage(case_run, expected): if isinstance(expected, str): return "Not run. Irrelevant case." total_snippets = case_run["snippet_count"] if total_snippets == 0: return "No snippets were provided." error_ids = [] successful_snippets = 0 for index in range(1, total_snippets + 1): snippet_id = f"S{index}" snippet_output = case_run["llm2_outputs"].get(snippet_id, "") if not snippet_output or is_error_output(snippet_output, "LLM2"): error_ids.append(snippet_id) else: successful_snippets += 1 summary = f"{successful_snippets}/{total_snippets} snippet calls completed successfully." if error_ids: summary += f" Errored snippets: {', '.join(error_ids)}." return summary def grade_llm3_stage(case_run, expected, semantic_result): if semantic_result["failure_tags"]: reasoning = ", ".join(semantic_result["failure_tags"]) elif isinstance(expected, str) and case_run["is_irrelevant"]: reasoning = "skipped_irrelevant" else: reasoning = "clean" return { "score": semantic_result["case_score"], "reasoning": reasoning, } def summarize_field_accuracy(case_results): totals = { field_name: {"correct": 0, "total": 0} for field_name in FINAL_SCHEMA_KEYS } for case_result in case_results: for field_name, field_summary in case_result["field_results"].items(): totals[field_name]["total"] += 1 if field_summary["correct"]: totals[field_name]["correct"] += 1 summary = {} for field_name, counts in totals.items(): total = counts["total"] accuracy = round((counts["correct"] / total) * 100, 2) if total else 0.0 summary[field_name] = { "correct": counts["correct"], "total": total, "accuracy": accuracy, } return summary def average_run_score(runs, key): if not runs: return 0.0 return round(sum(run[key] for run in runs) / len(runs), 4) def total_run_score(runs, key): if not runs: return 0.0 return round(sum(run[key] for run in runs), 4) def summarize_reasoning_counts(runs, key): counts = Counter(run[key] for run in runs) ordered_labels = [] for run in runs: label = run[key] if label not in ordered_labels: ordered_labels.append(label) return "; ".join(f"{label} x{counts[label]}" for label in ordered_labels) def explain_reasoning_label(label): explanations = { "matched": "The model followed the required instruction and produced the expected type of output.", "wrong_format": "The model answered the right task but did not use the exact output format the evaluator requires.", "right_answer_wrong_format": "The model identified the right answer but did not use the exact required Step 1 format.", "wrong_answer_right_format": "The model used the required Step 1 format but identified the wrong target firm.", "wrong_answer_wrong_format": "The model gave neither the right answer nor the required Step 1 format.", "wrong_route": "The model chose the wrong kind of response for this question, such as treating an irrelevant query as relevant or the reverse.", "call_failed": "The model call failed, so no valid answer was produced for this step.", "invalid_json": "The Step 3 answer was not valid JSON in the required schema.", "relevance_false_positive": "The system treated a question as relevant even though it should have been rejected as irrelevant.", "skipped_irrelevant": "This step was skipped because the question was correctly identified as irrelevant.", "clean": "The answer passed the evaluator checks for this step.", } if ", " in label: parts = [part.strip() for part in label.split(",") if part.strip()] explained_parts = [explanations.get(part, part.replace("_", " ")) for part in parts] return "; ".join(explained_parts) return explanations.get(label, label.replace("_", " ")) def summarize_run_groups(runs, key): grouped = [] for run_index, run in enumerate(runs, start=1): value = run[key] if grouped and grouped[-1]["value"] == value: grouped[-1]["run_indices"].append(run_index) continue grouped.append({"value": value, "run_indices": [run_index]}) return [ f"RUN {format_run_indices(group['run_indices'])}: {group['value']}" for group in grouped ] def build_explained_run_groups(runs, key): grouped = [] for run_index, run in enumerate(runs, start=1): value = run[key] if grouped and grouped[-1]["value"] == value: grouped[-1]["run_indices"].append(run_index) continue grouped.append({"value": value, "run_indices": [run_index]}) return [ "\n".join( [ f"RUN {format_run_indices(group['run_indices'])}", f"Label: {group['value']}", f"Meaning: {explain_reasoning_label(group['value'])}", ] ) for group in grouped ] def build_score_breakdown_cell(case_results): blocks = [] for case_index, case_result in enumerate(case_results, start=1): all_runs = list(range(1, len(case_result["runs"]) + 1)) llm1_reasoning_lines = build_explained_run_groups(case_result["runs"], "llm1_reasoning") llm3_reasoning_lines = build_explained_run_groups(case_result["runs"], "llm3_reasoning") block_lines = [ f"CASE {case_index}", f"RUNS: {format_run_indices(all_runs)}", "", "LLM 1", f"Score: {format_score(total_run_score(case_result['runs'], 'llm1_score'))} / {len(all_runs)}", "Reasoning:", *llm1_reasoning_lines, "", "LLM 3", f"Score: {format_score(total_run_score(case_result['runs'], 'llm3_score'))} / {len(all_runs)}", "Reasoning:", *llm3_reasoning_lines, ] blocks.append("\n".join(block_lines)) return truncate_text( join_sheet_sections(blocks), max_chars=SHEET_CELL_MAX_CHARS, ) def build_llm2_output_cell(case_results): blocks = [] for case_index, case_result in enumerate(case_results, start=1): for run_index, run_result in enumerate(case_result["runs"], start=1): blocks.append( "\n".join( [ build_case_run_header(case_index, run_index), format_teacher_forced_note(run_result) if not isinstance(case_result["expected"], str) else "Downstream stages were skipped because this was graded as irrelevant.", "", f"LLM 2 Summary: {run_result['llm2_reasoning']}", "", "LLM 2 Output:", format_llm2_user_output(run_result["llm2_outputs"]), ] ) ) return truncate_text( join_sheet_sections(blocks), max_chars=SHEET_CELL_MAX_CHARS, ) def build_output_vs_expected_cell(case_results): blocks = [] for case_index, case_result in enumerate(case_results, start=1): llm1_expected = format_llm1_expected_text(case_result["expected"]) grouped_blocks = [] for run_index, run_result in enumerate(case_result["runs"], start=1): llm3_output = run_result["llm3_output"] or "Not run." block_body = "\n".join( [ "LLM 1", f"Status: {run_result['llm1_reasoning']}", "User Output:", truncate_sheet_block(run_result["llm1_output"], max_chars=STAGE_ONE_OUTPUT_MAX_CHARS), "", "Expected:", llm1_expected, "", ( format_teacher_forced_note(run_result) if not isinstance(case_result["expected"], str) else "Downstream stages were skipped because this case was irrelevant." ), "", "LLM 3", f"Reasoning: {run_result['llm3_reasoning']}", "User Output:", truncate_sheet_block(llm3_output, max_chars=STAGE_THREE_OUTPUT_MAX_CHARS), "", "Expected:", format_pretty_expected_answer(run_result["llm3_expected"]), ] ) normalized_body = normalize_sheet_block_text(block_body) for grouped_block in grouped_blocks: if grouped_block["body"] == normalized_body: grouped_block["run_indices"].append(run_index) break else: grouped_blocks.append({"body": normalized_body, "run_indices": [run_index]}) for grouped_block in grouped_blocks: blocks.append( "\n".join( [ build_grouped_case_run_header(case_index, grouped_block["run_indices"]), grouped_block["body"], ] ) ) return truncate_text( join_sheet_sections(blocks), max_chars=SHEET_CELL_MAX_CHARS, ) def estimate_case_llm_call_count(case): if isinstance(case["expected"], str): return EVAL_REPEAT_COUNT return EVAL_REPEAT_COUNT * (len(case["docs"]) + 2) def estimate_submission_llm_call_count(cases): return sum(estimate_case_llm_call_count(case) for case in cases) def evaluate_case(case, prompts, case_index=None): started_at = time.perf_counter() runs = [] for _ in range(EVAL_REPEAT_COUNT): step_one_result = run_step_one(prompts, case["question"]) llm1_result = grade_llm1_stage(step_one_result["llm1_output"], case["expected"]) expected_relevant = not isinstance(case["expected"], str) used_teacher_forced_step1 = False effective_step1_context = step_one_result["llm1_output"] if expected_relevant: if llm1_result["score"] < 1.0: effective_step1_context = build_expected_step1_output(case["expected"]) used_teacher_forced_step1 = True downstream_result = run_downstream_stages( prompts, case["docs"], effective_step1_context, ) else: downstream_result = { "snippet_count": len(case["docs"]), "llm2_outputs": {}, "llm3_output": "", } effective_target_firm = "" if expected_relevant: effective_target_firm = parse_relevant_stage_one_output(effective_step1_context)["target_firm"] case_run = { "question": case["question"], "docs": case["docs"], "snippet_count": downstream_result["snippet_count"], "llm1_output": step_one_result["llm1_output"], "llm2_outputs": downstream_result["llm2_outputs"], "llm3_output": downstream_result["llm3_output"], "is_irrelevant": step_one_result["is_irrelevant"], "user_message": step_one_result["user_message"], "effective_step1_context": effective_step1_context, "effective_target_firm": effective_target_firm, "used_teacher_forced_step1": used_teacher_forced_step1, } semantic_result = grade_case(case_run, case["expected"]) llm3_result = grade_llm3_stage(case_run, case["expected"], semantic_result) llm2_reasoning = summarize_llm2_stage(case_run, case["expected"]) llm3_expected = build_expected_llm3_answer( case["expected"], case["docs"], effective_target_firm, ) failure_tags = sorted(set(llm1_result["failure_tags"] + semantic_result["failure_tags"])) runs.append( { "llm1_output": truncate_sheet_block( step_one_result["llm1_output"], max_chars=STAGE_ONE_OUTPUT_MAX_CHARS, ), "llm2_outputs": { snippet_id: truncate_sheet_block( snippet_output, max_chars=STAGE_TWO_OUTPUT_MAX_CHARS, ) for snippet_id, snippet_output in case_run["llm2_outputs"].items() }, "llm3_output": truncate_sheet_block( case_run["llm3_output"], max_chars=STAGE_THREE_OUTPUT_MAX_CHARS, ), "is_irrelevant": step_one_result["is_irrelevant"], "failure_tags": failure_tags, "field_results": semantic_result["field_results"], "passed": llm1_result["score"] == 1.0 and llm3_result["score"] == 1.0, "llm1_score": llm1_result["score"], "llm1_reasoning": llm1_result["reasoning"], "llm2_reasoning": llm2_reasoning, "llm3_score": llm3_result["score"], "llm3_reasoning": llm3_result["reasoning"], "used_teacher_forced_step1": used_teacher_forced_step1, "effective_step1_context": truncate_sheet_block( effective_step1_context, max_chars=STAGE_ONE_OUTPUT_MAX_CHARS, ), "effective_target_firm": effective_target_firm, "llm3_expected": llm3_expected, "total_points": round(llm1_result["score"] + llm3_result["score"], 4), "total_points_without_llm1": round(llm3_result["score"], 4), } ) average_llm1_score = average_run_score(runs, "llm1_score") average_llm3_score = average_run_score(runs, "llm3_score") average_total_points = round(average_llm1_score + average_llm3_score, 4) result = { "expected": case["expected"], "runs": runs, "average_llm1_score": average_llm1_score, "average_llm3_score": average_llm3_score, "average_total_points": average_total_points, "case_score": average_total_points, "passed": abs(average_total_points - 2.0) < 1e-9, } elapsed_seconds = time.perf_counter() - started_at logger.info( "Completed case %s: relevant=%s snippets=%s repeat_count=%s " "expected_llm_calls=%s elapsed_seconds=%.2f", case_index if case_index is not None else "unknown", not isinstance(case["expected"], str), len(case["docs"]), EVAL_REPEAT_COUNT, estimate_case_llm_call_count(case), elapsed_seconds, ) return result def grade_submission(eval_set, prompts, mode): case_results = [] cases = eval_set["cases"] if not cases: return { "submission_type": mode, "score_percent": 0.0, "passed_cases": 0, "total_cases": 0, "failure_summary": {}, "field_accuracy": summarize_field_accuracy([]), "case_results": [], "results": 0.0, "results_without_llm_1": 0.0, } max_workers = min(get_eval_max_workers(), len(cases)) started_at = time.perf_counter() expected_llm_calls = estimate_submission_llm_call_count(cases) snippet_counts = [len(case["docs"]) for case in cases] logger.info( "Starting %s grading: cases=%s repeat_count=%s snippet_counts=%s " "expected_llm_calls=%s case_workers=%s llm2_workers=%s openai_global_limit=%s", mode, len(cases), EVAL_REPEAT_COUNT, snippet_counts, expected_llm_calls, max_workers, get_llm2_max_workers(), get_openai_max_concurrent_requests(), ) ordered_results = [None] * len(cases) if max_workers == 1: for index, case in enumerate(cases): ordered_results[index] = safe_evaluate_case(case, prompts, index + 1) else: with ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_index = { executor.submit(safe_evaluate_case, case, prompts, index + 1): index for index, case in enumerate(cases) } for future in as_completed(future_to_index): index = future_to_index[future] ordered_results[index] = future.result() case_results = ordered_results total_cases = len(case_results) total_results = round( sum(case_result["average_total_points"] for case_result in case_results), 4, ) total_results_without_llm_1 = round( sum(case_result["average_llm3_score"] for case_result in case_results), 4, ) run_results = [ run_result for case_result in case_results for run_result in case_result["runs"] ] passed_cases = sum(1 for case_result in case_results if case_result["passed"]) max_points = total_cases * 2.0 score_percent = round((total_results / max_points) * 100, 2) if max_points else 0.0 failure_summary = dict( sorted( Counter( tag for run_result in run_results for tag in run_result["failure_tags"] ).items() ) ) field_accuracy = summarize_field_accuracy(run_results) elapsed_seconds = time.perf_counter() - started_at logger.info( "Completed %s grading: cases=%s repeat_count=%s expected_llm_calls=%s " "elapsed_seconds=%.2f", mode, total_cases, EVAL_REPEAT_COUNT, expected_llm_calls, elapsed_seconds, ) return { "submission_type": mode, "score_percent": score_percent, "passed_cases": passed_cases, "total_cases": total_cases, "failure_summary": failure_summary, "field_accuracy": field_accuracy, "case_results": case_results, "results": total_results, "results_without_llm_1": total_results_without_llm_1, } def get_attempt_status(email): normalized_email = normalize_email(email) status = {"practice_used": False, "final_used": False} local_status = LOCAL_ATTEMPTS.get(normalized_email, {}) if local_status.get("practice_used"): status["practice_used"] = True if local_status.get("final_used"): status["final_used"] = True sheet = get_google_sheet(ensure_headers=False) rows = sheet.get("B:D") for row in rows[1:]: row_email = None row_type = None if len(row) >= 3 and row[0] in SUBMISSION_TYPES: row_type = row[0] row_email = normalize_email(row[2]) elif row: row_type = "final" row_email = normalize_email(row[0]) if row_email != normalized_email: continue if row_type == "practice": status["practice_used"] = True if row_type == "final": status["final_used"] = True return status def save_attempt(record): sheet = get_google_sheet() row = [serialize_sheet_value(record.get(key, "")) for key, _ in SHEET_COLUMNS] sheet.append_row( row, value_input_option="RAW", insert_data_option="INSERT_ROWS", table_range="A1", ) def record_local_attempt(email, submission_type): LOCAL_ATTEMPTS[email] = { "practice_used": LOCAL_ATTEMPTS.get(email, {}).get("practice_used", False) or submission_type == "practice", "final_used": LOCAL_ATTEMPTS.get(email, {}).get("final_used", False) or submission_type == "final", } def build_attempt_record(name, email, prompts, submission_result): return { "timestamp": datetime.now(timezone.utc).isoformat(), "submission_type": submission_result["submission_type"], "name": name, "email": email, "results": submission_result["results"], "results_without_llm_1": submission_result["results_without_llm_1"], "llm_score_breakdown": build_score_breakdown_cell(submission_result["case_results"]), "llm_output_vs_expected": build_output_vs_expected_cell( submission_result["case_results"] ), "llm_2_output": build_llm2_output_cell(submission_result["case_results"]), "prompts": format_prompts_cell(prompts), } def format_field_accuracy(field_accuracy): labels = { "buyer_counsel": "Buyer counsel", "seller_counsel": "Seller counsel", "third_party_counsel": "Third-party counsel", "user_question": "User question", } lines = [] for field_name in FINAL_SCHEMA_KEYS: summary = field_accuracy[field_name] lines.append( ( f"- {labels[field_name]}: {summary['accuracy']}% " f"({summary['correct']}/{summary['total']})" ) ) return "\n".join(lines) def build_user_facing_reasoning_message(label, llm_name): if llm_name == "LLM 1": messages = { "matched": "No issue in these runs.", "right_answer_wrong_format": "The model identified the right answer but did not use the required Step 1 format.", "wrong_answer_right_format": "The model used the required Step 1 format but identified the wrong target firm.", "wrong_answer_wrong_format": "The model gave neither the right output nor the right format.", "wrong_route": "The model chose the wrong route for this question and should have handled relevance differently.", "call_failed": "The Step 1 model call failed, so no valid answer was produced.", } return messages.get(label, explain_reasoning_label(label)) if label == "matched": return "No issue in these runs." if label == "clean": return "The final JSON was valid and passed the evaluator checks." if label == "skipped_irrelevant": return "This case was correctly rejected as irrelevant before Step 3." if label == "invalid_json": return "The final answer was not valid JSON in the required schema." if label == "relevance_false_positive": return ( "This question should have been rejected as irrelevant earlier, but the pipeline continued as if it were relevant." ) parts = [part.strip() for part in str(label).split(",") if part.strip()] explained = [] for part in parts: detail = explain_output_issue_tag(part) if detail: explained.append(detail) else: explained.append(explain_reasoning_label(part)) return "; ".join(explained) if explained else "The final output had evaluator-detected issues." def split_field_issue_tag(tag): for field_name in sorted(FINAL_SCHEMA_KEYS, key=len, reverse=True): suffix = f"_{field_name}" if tag.endswith(suffix): return tag[: -len(suffix)], field_name return None, None def explain_output_issue_tag(tag): field_labels = { "buyer_counsel": "Buyer counsel", "seller_counsel": "Seller counsel", "third_party_counsel": "Third-party counsel", "user_question": "User question", } issue_name, field_name = split_field_issue_tag(tag) if not issue_name or not field_name: return None field_label = field_labels[field_name] issue_messages = { "missing_field": "the field was missing from the final JSON.", "malformed_citation": "the citation format was malformed.", "unexpected_citation": "a citation was included where it should not have been.", "used_unknown": 'the answer said "unknown" even though the evidence supported a conclusion.', "hallucinated_value": "a value was given even though the expected answer was unknown.", "wrong_value": "the answer value was wrong.", "missing_citation": "the answer needed a citation but did not include one.", "bad_citation_reference": "the citation pointed to a snippet number that does not exist.", "unsupported_citation": "the citation pointed to a snippet that does not support the answer.", } message = issue_messages.get(issue_name) if not message: return None return f"{field_label}: {message}" def build_practice_run_issue_lines(runs, key, llm_name): grouped = [] for run_index, run in enumerate(runs, start=1): value = run[key] if value in {"matched", "clean", "skipped_irrelevant"}: continue if grouped and grouped[-1]["value"] == value: grouped[-1]["run_indices"].append(run_index) continue grouped.append({"value": value, "run_indices": [run_index]}) if not grouped: return ["- No issues detected."] return [ f"- RUN {format_run_indices(group['run_indices'])}: " f"{build_user_facing_reasoning_message(group['value'], llm_name)}" for group in grouped ] def build_practice_case_feedback(case_results): blocks = [] for case_index, case_result in enumerate(case_results, start=1): run_count = len(case_result["runs"]) llm1_score = total_run_score(case_result["runs"], "llm1_score") llm3_score = total_run_score(case_result["runs"], "llm3_score") blocks.append( "\n".join( [ f"CASE {case_index}", "", "LLM 1", f"Score: {format_score(llm1_score)} / {run_count}", "Issues:", *build_practice_run_issue_lines(case_result["runs"], "llm1_reasoning", "LLM 1"), "", "LLM 3", f"Score: {format_score(llm3_score)} / {run_count}", "Issues:", *build_practice_run_issue_lines(case_result["runs"], "llm3_reasoning", "LLM 3"), ] ) ) return join_sheet_sections(blocks) def build_practice_scoring_note(submission_result): total_cases = submission_result["total_cases"] return ( f"This practice score is based on {total_cases} hidden calibration cases. " "Each case is run 3 times to check prompt consistency." ) def format_practice_feedback(name, submission_result): case_results = submission_result["case_results"] return ( f"Practice run complete for {name}.\n\n" f"Score: {submission_result['score_percent']}%\n" f"{build_practice_scoring_note(submission_result)}\n\n" f"{build_practice_case_feedback(case_results)}\n\n" "Your final submission is still available." ) def load_eval_set_for_mode(submission_type): prefix = "PRACTICE" if submission_type == "practice" else "FINAL" return load_eval_set(prefix) def validate_submission_inputs(email, name): if not validate_email(email): return "Invalid email address. Please enter a valid email." if not name.strip(): return "Please enter your full name." return None def submit_attempt(submission_type, email, name, prompt_1, prompt_2, prompt_3): if submission_type not in SUBMISSION_TYPES: raise ValueError(f"Unsupported submission type: {submission_type}") validation_error = validate_submission_inputs(email, name) if validation_error: return build_submission_response(False, validation_error) normalized_email = normalize_email(email) clean_name = sanitize_input(name) prompts = { "prompt_1": sanitize_prompt(prompt_1), "prompt_2": sanitize_prompt(prompt_2), "prompt_3": sanitize_prompt(prompt_3), } bypass_sheet_save = should_bypass_sheet_save() try: attempt_status = get_attempt_status(normalized_email) except Exception: log_stage_exception( "Unexpected error while reading sheet state for %s submission for %s.", submission_type, normalized_email, ) return build_submission_response( False, build_stage_error_message(submission_type, "sheet_read"), ) if submission_type == "practice": if attempt_status["final_used"]: return build_submission_response( False, ( f"A final submission has already been received for {normalized_email}. " "Practice is no longer available." ), disable_practice=True, disable_final=True, ) if attempt_status["practice_used"]: return build_submission_response( False, ( f"Practice has already been used for {normalized_email}. " "Your final submission is still available." ), disable_practice=True, disable_final=False, ) if submission_type == "final" and attempt_status["final_used"]: return build_submission_response( False, ( f"Final submission already received for {normalized_email}. " "You can only submit the final once." ), disable_practice=True, disable_final=True, ) try: eval_set = load_eval_set_for_mode(submission_type) except Exception: log_stage_exception( "Unexpected error while loading %s dataset for %s.", submission_type, normalized_email, ) return build_submission_response( False, build_stage_error_message(submission_type, "dataset_load"), ) if not eval_set["cases"]: return build_submission_response( False, f"No hidden cases configured for the {submission_type} dataset.", ) try: submission_result = grade_submission(eval_set, prompts, submission_type) except Exception: log_stage_exception( "Unexpected error while grading %s submission for %s.", submission_type, normalized_email, ) return build_submission_response( False, build_stage_error_message(submission_type, "grading"), ) try: record = build_attempt_record(clean_name, normalized_email, prompts, submission_result) except Exception: log_stage_exception( "Unexpected error while building %s record for %s.", submission_type, normalized_email, ) return build_submission_response( False, build_stage_error_message(submission_type, "record_build"), ) bypass_notice = build_bypass_save_notice() if bypass_sheet_save else "" if not bypass_sheet_save: try: save_attempt(record) record_local_attempt(normalized_email, submission_type) except Exception: log_stage_exception( "Unexpected error while saving %s submission for %s.", submission_type, normalized_email, ) return build_submission_response( False, build_save_error_message(submission_type, "sheet_write"), ) else: print( f"BYPASS_SHEET_SAVE enabled for {submission_type} submission {normalized_email}.", flush=True, ) if submission_type == "practice": notice = "Practice run saved. Your final submission is still available." if bypass_notice: notice = bypass_notice return build_submission_response( True, format_practice_feedback(clean_name, submission_result), notice=notice, disable_practice=not bypass_sheet_save, disable_final=False, ) notice = "Final submission received. You can close the page." if bypass_notice: notice = bypass_notice return build_submission_response( True, f"Thank you for your submission, {clean_name}!", notice=notice, disable_practice=not bypass_sheet_save, disable_final=not bypass_sheet_save, ) def build_submission_callback_result(result): return ( result["message"], gr.update(interactive=not result["disable_practice"]), gr.update(interactive=not result["disable_final"]), gr.update(value=result["notice"], visible=bool(result["notice"])), ) def handle_submission(submission_type, email, name, s1, s2, s3): try: result = submit_attempt(submission_type, email, name, s1, s2, s3) except Exception: log_stage_exception("Unhandled callback failure for %s submission.", submission_type) result = build_submission_response( False, build_stage_error_message(submission_type, "callback"), ) return build_submission_callback_result(result) def build_interface(): with gr.Blocks(css=APP_CSS) as demo: gr.Markdown( """ # Applicant Task: Target Company & Law Firm Identification Draft prompts for a strict 3-step legal review pipeline over snippets from Asset Purchase Agreements [SEC Agreement Example](https://www.sec.gov/Archives/edgar/data/28452/000119312505012401/dex101.htm) > This evaluation system uses (default: `gpt-5-mini`) for all LLM steps. ## What you need to do - Decide whether the query is relevant and normalize the target firm name. - Inspect each snippet independently and record only what that snippet actually supports. - Combine those snippet-level findings into one final JSON answer with citations, including a cited answer to the original user question. ## The 3-step pipeline ### Step 1: Relevance and target-firm normalization Decide whether the question belongs in this deal at all, and if it does, standardize the firm name you will track through the rest of the pipeline. - If the query is irrelevant, return `...`. - If the query is relevant, return only: ```text TARGET_FIRM: ``` Please ensure your final output uses the exact key `TARGET_FIRM:` as shown above, alongside the firm name. ### Step 2: Snippet-by-snippet analysis Step 2 is not the final answer; it is a working note for one snippet at a time. - Your Step 2 prompt runs independently on each evidence unit, so the model only sees one snippet per call. - The app passes: - the Step 1 output - `Snippet ID: S1..Sn` - the snippet text - A good Step 2 prompt says what this snippet supports, what it does not support, and where the evidence is still uncertain. ### Step 3: Reconciliation and final answer Step 3 receives the Step 1 output plus all Step 2 notes and must turn them into one answer that could survive a hostile redline, including a direct answer to the original user question. - Step 3 receives the Step 1 output plus all Step 2 outputs. - It must return valid JSON with this exact schema: ```json { "buyer_counsel": "string with citations like \"Firm Name [^2]\" or \"unknown\"", "seller_counsel": "string with citations like \"Firm Name [^4]\" or \"unknown\"", "third_party_counsel": "string with citations like \"Firm Name [^1]\" or \"unknown\"", "user_question": "string with citations like \"true [^2]\", \"false [^4]\", or \"unknown\"" } ``` Use `"unknown"` when a counsel field or the user-question answer cannot be supported by the evidence. When you provide a firm name, include snippet citations such as `[^2]` that point to the relevant Step 2 snippet IDs. The `user_question` field should answer the original query using only `true`, `false`, or `unknown`: if the answer is `true` or `false`, include supporting snippet citations and do not add any extra text. """ ) gr.Markdown( """ ## Workflow At A Glance This visual shows how the query moves from `LLM1` to `LLM2` to `LLM3`, and how the final JSON is assembled from the APA snippets. """ ) gr.Image( value="mermaid_diagram.png", label="Pipeline overview", show_label=True, interactive=False, ) with gr.Accordion("Example Workflow", open=False): gr.Markdown( """ **User query** ```text Is Kirkland & Ellis LLP acting as counsel anywhere in this Asset Purchase Agreement? ``` **Step 1 relevant output** ```text TARGET_FIRM: Kirkland & Ellis LLP ``` **Example evidence units** - `S1`: the opening paragraph names the parties and mentions Kirkland & Ellis LLP as transaction counsel to the buyer. - `S2`: the notices section identifies buyer counsel as Kirkland & Ellis LLP and seller counsel as Wachtell, Lipton, Rosen & Katz. - `S3`: a boilerplate clause contains no counsel information and should not drive the final answer. - `S4`: a representative provision states that Gibson, Dunn & Crutcher LLP advises the securityholders' representative. - `S5`: a later notice block again confirms seller counsel as Wachtell, Lipton, Rosen & Katz. **Final JSON shape** ```json { "buyer_counsel": "Kirkland & Ellis LLP [^1] [^2]", "seller_counsel": "Wachtell, Lipton, Rosen & Katz [^2] [^5]", "third_party_counsel": "Gibson, Dunn & Crutcher LLP [^4]", "user_question": "true [^1] [^2]" } ``` """ ) with gr.Accordion("Practice and Final Submission", open=False): gr.Markdown( """ - You may use **one optional practice run** per email to test your prompts against a hidden calibration set. - The practice run uses 3 hidden calibration cases. - Each case is run 3 times to check prompt consistency. - For each run, LLM 1 can earn up to 1 point for correct routing and target-firm normalization, and LLM 3 can earn up to 1 point for a correct final JSON answer with supported citations. - Step 2 is not scored directly, but it strongly affects the LLM 3 score because Step 3 relies on the snippet-level analysis. - Practice returns aggregate feedback only: score percentage, an LLM 1 summary, and an LLM 3 summary. - You may then revise your prompts or keep them as they are. - You may submit **one final submission** per email against a separate hidden holdout set. - After the final submission, practice is no longer available. - No structured decoding is used for you, so your prompts must make Step 3 produce reliable JSON on their own. """ ) gr.Markdown( """ Enter your name and email exactly as listed in your CV. Both buttons below use the same three prompt boxes. You have **one** chance to run the practice set and get feedback, and **one** chance to run the final set. After you click a button, wait for the results to load before clicking again or refreshing the page. **Good Luck!** """ ) email_input = gr.Textbox(label="Email", placeholder="your.email@example.com") name_input = gr.Textbox(label="First Name, Last Name", placeholder="John Smith") system_prompt_input_1 = gr.Textbox( label="System Prompt for Step 1", placeholder="Enter your Step 1 prompt here...", lines=6, ) system_prompt_input_2 = gr.Textbox( label="System Prompt for Step 2", placeholder="Enter your Step 2 prompt here...", lines=10, ) system_prompt_input_3 = gr.Textbox( label="System Prompt for Step 3", placeholder="Enter your Step 3 prompt here...", lines=6, ) gr.Markdown( """
Please note:
Each run may take a couple of minutes.
After you click a button, wait for the result and do not click it again.
""" ) with gr.Row(): practice_button = gr.Button("Practice Run") final_button = gr.Button("Submit Final") output_text = gr.Textbox(label="Results", lines=18) feedback_md = gr.Markdown("", visible=False) def practice_submit_and_update(email, name, s1, s2, s3): return handle_submission("practice", email, name, s1, s2, s3) def final_submit_and_update(email, name, s1, s2, s3): return handle_submission("final", email, name, s1, s2, s3) practice_button.click( fn=practice_submit_and_update, inputs=[ email_input, name_input, system_prompt_input_1, system_prompt_input_2, system_prompt_input_3, ], outputs=[output_text, practice_button, final_button, feedback_md], ) final_button.click( fn=final_submit_and_update, inputs=[ email_input, name_input, system_prompt_input_1, system_prompt_input_2, system_prompt_input_3, ], outputs=[output_text, practice_button, final_button, feedback_md], ) return demo if __name__ == "__main__": interface = build_interface() interface.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)