| import json |
| import logging |
| import os |
| import re |
| import threading |
| import time |
| import traceback |
| from collections import Counter |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from datetime import datetime, timezone |
|
|
| import gradio as gr |
| import gspread |
| from google.oauth2.service_account import Credentials |
| from openai import OpenAI |
|
|
| SCOPES = [ |
| "https://www.googleapis.com/auth/spreadsheets", |
| "https://www.googleapis.com/auth/drive", |
| ] |
|
|
| DEFAULT_MODEL_NAME = "gpt-5-mini" |
| DEFAULT_EVAL_MAX_WORKERS = 4 |
| DEFAULT_LLM2_MAX_WORKERS = 8 |
| DEFAULT_OPENAI_MAX_CONCURRENT_REQUESTS = 16 |
| EVAL_REPEAT_COUNT = 3 |
| SUBMISSION_TYPES = {"practice", "final"} |
| IRRELEVANT_TAG_START = "<user_message>" |
| IRRELEVANT_TAG_END = "</user_message>" |
| SHEET_SECTION_DIVIDER = "=" * 50 |
| STAGE_ONE_OUTPUT_MAX_CHARS = 4000 |
| STAGE_TWO_OUTPUT_MAX_CHARS = 2000 |
| STAGE_THREE_OUTPUT_MAX_CHARS = 4000 |
| PROMPT_CELL_MAX_CHARS = 8000 |
| EXPECTED_OUTPUT_MAX_CHARS = 4000 |
| SHEET_CELL_MAX_CHARS = 49000 |
|
|
| COUNSEL_FIELDS = ("buyer_counsel", "seller_counsel", "third_party_counsel") |
| TARGET_FIRM_FIELD = "target_firm" |
| USER_QUESTION_FIELD = "user_question" |
| FINAL_SCHEMA_KEYS = (*COUNSEL_FIELDS, USER_QUESTION_FIELD) |
| CITATION_PATTERN = re.compile(r"\[\^(\d+)\]") |
| BOOLEAN_ANSWER_PATTERN = re.compile(r"^(true|false|yes|no)\b$", re.IGNORECASE) |
|
|
| LOCAL_ATTEMPTS = {} |
| _OPENAI_CLIENT = None |
| _OPENAI_REQUEST_SEMAPHORE = None |
| _OPENAI_REQUEST_SEMAPHORE_LIMIT = None |
| _OPENAI_REQUEST_SEMAPHORE_LOCK = threading.Lock() |
| _LOG_LEVEL_NAME = os.environ.get("LOG_LEVEL", "INFO").upper() |
| logging.basicConfig(level=getattr(logging, _LOG_LEVEL_NAME, logging.INFO)) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| SHEET_COLUMNS = [ |
| ("timestamp", "Submitted At"), |
| ("submission_type", "Submission Type"), |
| ("name", "Candidate Name"), |
| ("email", "Candidate Email"), |
| ("results", "Results"), |
| ("results_without_llm_1", "Results LLM3"), |
| ("llm_score_breakdown", "LLM Score Breakdown"), |
| ("llm_output_vs_expected", "LLM Output vs Expected"), |
| ("llm_2_output", "LLM 2 Output"), |
| ("prompts", "Prompts"), |
| ] |
| SHEET_HEADERS = [label for _, label in SHEET_COLUMNS] |
|
|
| APP_CSS = """ |
| .submission-note { |
| background: var(--submission-note-bg, #fff7e6); |
| border: 1px solid var(--submission-note-border, #ffe5b4); |
| border-radius: 8px; |
| color: var(--submission-note-text, #3d2b00); |
| margin-bottom: 1em; |
| padding: 16px; |
| } |
| |
| @media (prefers-color-scheme: dark) { |
| .submission-note { |
| --submission-note-bg: #2b2111; |
| --submission-note-border: #6b4b18; |
| --submission-note-text: #f7e8c5; |
| } |
| } |
| |
| .dark .submission-note { |
| --submission-note-bg: #2b2111; |
| --submission-note-border: #6b4b18; |
| --submission-note-text: #f7e8c5; |
| } |
| """ |
|
|
|
|
| def get_model_name(): |
| return os.environ.get("OPENAI_MODEL_NAME", DEFAULT_MODEL_NAME) |
|
|
|
|
| def get_positive_int_env(name, default): |
| raw_value = os.environ.get(name, str(default)) |
| try: |
| value = int(raw_value) |
| except ValueError: |
| logger.warning("Invalid integer for %s=%r; using default %s.", name, raw_value, default) |
| return default |
| return max(1, value) |
|
|
|
|
| def get_eval_max_workers(): |
| if "EVAL_CASE_MAX_WORKERS" in os.environ: |
| return get_positive_int_env("EVAL_CASE_MAX_WORKERS", DEFAULT_EVAL_MAX_WORKERS) |
| return get_positive_int_env("EVAL_MAX_WORKERS", DEFAULT_EVAL_MAX_WORKERS) |
|
|
|
|
| def get_llm2_max_workers(): |
| return get_positive_int_env("LLM2_MAX_WORKERS", DEFAULT_LLM2_MAX_WORKERS) |
|
|
|
|
| def get_openai_max_concurrent_requests(): |
| return get_positive_int_env( |
| "OPENAI_MAX_CONCURRENT_REQUESTS", |
| DEFAULT_OPENAI_MAX_CONCURRENT_REQUESTS, |
| ) |
|
|
|
|
| def get_openai_request_semaphore(): |
| global _OPENAI_REQUEST_SEMAPHORE |
| global _OPENAI_REQUEST_SEMAPHORE_LIMIT |
|
|
| limit = get_openai_max_concurrent_requests() |
| with _OPENAI_REQUEST_SEMAPHORE_LOCK: |
| if _OPENAI_REQUEST_SEMAPHORE is None or _OPENAI_REQUEST_SEMAPHORE_LIMIT != limit: |
| _OPENAI_REQUEST_SEMAPHORE = threading.BoundedSemaphore(limit) |
| _OPENAI_REQUEST_SEMAPHORE_LIMIT = limit |
| return _OPENAI_REQUEST_SEMAPHORE |
|
|
|
|
| def get_openai_client(): |
| global _OPENAI_CLIENT |
|
|
| if _OPENAI_CLIENT is None: |
| api_key = os.environ.get("OPENAI_API_KEY") |
| if not api_key: |
| raise RuntimeError("OPENAI_API_KEY is not set.") |
| _OPENAI_CLIENT = OpenAI(api_key=api_key) |
| return _OPENAI_CLIENT |
|
|
|
|
| def get_google_sheet(ensure_headers=True): |
| creds_json = os.environ.get("GOOGLE_CREDS_JSON") |
| spreadsheet_id = os.environ.get("SPREADSHEET_ID") |
|
|
| if not creds_json or not spreadsheet_id: |
| raise RuntimeError("GOOGLE_CREDS_JSON and SPREADSHEET_ID must be set.") |
|
|
| creds = Credentials.from_service_account_info( |
| json.loads(creds_json), |
| scopes=SCOPES, |
| ) |
| gc = gspread.authorize(creds) |
| spreadsheet = gc.open_by_key(spreadsheet_id) |
| sheet = spreadsheet.worksheet("Submissions") |
| if ensure_headers: |
| ensure_submission_sheet_headers(sheet) |
| return sheet |
|
|
|
|
| def load_eval_set(prefix): |
| questions_str = os.environ.get(f"{prefix}_QUESTIONS_JSON") |
| documents_str = os.environ.get(f"{prefix}_DOCUMENTS_JSON") |
| expected_str = os.environ.get(f"{prefix}_EXPECTED_JSON") |
|
|
| if not questions_str or not documents_str or not expected_str: |
| return {"cases": []} |
|
|
| questions = json.loads(questions_str) |
| documents = json.loads(documents_str) |
| expected = json.loads(expected_str) |
|
|
| if not ( |
| isinstance(questions, list) |
| and isinstance(documents, list) |
| and isinstance(expected, list) |
| ): |
| raise ValueError(f"{prefix} dataset must be a JSON list for all fields.") |
|
|
| if len(questions) != len(documents) or len(questions) != len(expected): |
| raise ValueError( |
| f"{prefix} dataset lengths do not match for questions, documents, and expected answers." |
| ) |
|
|
| cases = [] |
| for case_index, (question, docs_entry, expected_entry) in enumerate( |
| zip(questions, documents, expected), |
| start=1, |
| ): |
| normalized_docs = normalize_documents_entry(docs_entry) |
| cases.append( |
| { |
| "question": str(question), |
| "docs": normalized_docs, |
| "expected": normalize_expected_entry( |
| expected_entry, |
| prefix, |
| case_index, |
| question=str(question), |
| docs=normalized_docs, |
| ), |
| } |
| ) |
|
|
| return {"cases": cases} |
|
|
|
|
| def build_submission_response( |
| ok, |
| message, |
| notice="", |
| disable_practice=False, |
| disable_final=False, |
| ): |
| return { |
| "ok": ok, |
| "message": message, |
| "notice": notice, |
| "disable_practice": disable_practice, |
| "disable_final": disable_final, |
| } |
|
|
|
|
| def build_internal_error_message(submission_type): |
| label = "practice run" if submission_type == "practice" else "submission" |
| return ( |
| f"We hit an internal error while processing your {label}. " |
| "Please try again in a minute." |
| ) |
|
|
|
|
| def build_stage_error_message(submission_type, stage): |
| label = "practice run" if submission_type == "practice" else "submission" |
| return ( |
| f"We hit an internal error while processing your {label}. " |
| f"Stage: {stage}." |
| ) |
|
|
|
|
| def build_save_error_message(submission_type, stage="sheet_write"): |
| label = "practice run" if submission_type == "practice" else "submission" |
| return ( |
| f"We could not save your {label} right now. " |
| f"Stage: {stage}. Please check the Space logs and spreadsheet permissions, then try again." |
| ) |
|
|
|
|
| def build_bypass_save_notice(): |
| return "Debug mode is active: this run was graded but not saved to Google Sheets." |
|
|
|
|
| def should_bypass_sheet_save(): |
| return os.environ.get("BYPASS_SHEET_SAVE", "").strip().lower() == "true" |
|
|
|
|
| def log_stage_exception(message, *args): |
| logger.exception(message, *args) |
| print(traceback.format_exc(), flush=True) |
|
|
|
|
| def safe_evaluate_case(case, prompts, case_index): |
| try: |
| return evaluate_case(case, prompts, case_index=case_index) |
| except Exception as exc: |
| logger.exception("Case %s failed during evaluation.", case_index) |
| raise RuntimeError(f"Evaluation failed for case {case_index}.") from exc |
|
|
|
|
| def normalize_documents_entry(entry): |
| if isinstance(entry, list): |
| documents = [str(item).strip() for item in entry if str(item).strip()] |
| return documents |
|
|
| text = str(entry).strip() |
| if not text: |
| return [] |
|
|
| if "---" in text: |
| parts = [part.strip() for part in text.split("---")] |
| return [part for part in parts if part] |
|
|
| return [text] |
|
|
|
|
| def get_sheet_column_letter(index): |
| letters = "" |
| while index > 0: |
| index, remainder = divmod(index - 1, 26) |
| letters = chr(65 + remainder) + letters |
| return letters |
|
|
|
|
| def ensure_submission_sheet_headers(sheet): |
| expected_headers = SHEET_HEADERS |
| end_column = get_sheet_column_letter(len(expected_headers)) |
| sheet.update(f"A1:{end_column}1", [expected_headers]) |
|
|
|
|
| def get_next_submission_row_index(sheet): |
| rows = sheet.get_all_values() |
| last_data_row = 1 |
| for row_index, row in enumerate(rows, start=1): |
| if row_index == 1: |
| continue |
| if any(str(cell).strip() for cell in row): |
| last_data_row = row_index |
| return last_data_row + 1 |
|
|
|
|
| def serialize_sheet_value(value): |
| if isinstance(value, (dict, list)): |
| return serialize_json(value) |
| return "" if value is None else str(value) |
|
|
|
|
| def normalize_expected_mapping(entry, prefix, case_index): |
| raw_target_firm = entry.get(TARGET_FIRM_FIELD) |
| if raw_target_firm is None: |
| raise ValueError( |
| f"{prefix} case {case_index} is missing a non-empty '{TARGET_FIRM_FIELD}'." |
| ) |
|
|
| target_firm = normalize_whitespace(raw_target_firm) |
| if not target_firm or target_firm.lower() == "none": |
| raise ValueError( |
| f"{prefix} case {case_index} is missing a non-empty '{TARGET_FIRM_FIELD}'." |
| ) |
|
|
| normalized = {TARGET_FIRM_FIELD: target_firm} |
| for field_name in COUNSEL_FIELDS: |
| field_value = entry.get(field_name, "unknown") |
| if field_value is None: |
| normalized[field_name] = "unknown" |
| continue |
| normalized[field_name] = str(field_value).strip() or "unknown" |
|
|
| return normalized |
|
|
|
|
| def extract_target_candidate_from_question(question): |
| stripped = normalize_whitespace(question) |
| patterns = [ |
| r"^Is\s+(.+?)\s+present\s+in\s+the\s+agreement\??$", |
| r"^Is\s+(.+?)\s+mentioned\s+in\s+the\s+agreement\??$", |
| r"^Is\s+(.+?)\s+acting\s+as\s+counsel.*\??$", |
| r"^Is\s+(.+?)\s+anywhere\s+in\s+this\s+Asset\s+Purchase\s+Agreement\??$", |
| r"^Is\s+(.+?)\s+in\s+the\s+agreement\??$", |
| ] |
| for pattern in patterns: |
| match = re.match(pattern, stripped, re.IGNORECASE) |
| if match: |
| return normalize_whitespace(match.group(1)) |
| match = re.match(r"^Is\s+(.+?)\??$", stripped, re.IGNORECASE) |
| if match: |
| return normalize_whitespace(match.group(1)) |
| return "" |
|
|
|
|
| def extract_entity_candidates(text): |
| pattern = re.compile( |
| r"\b([A-Z][A-Za-z0-9'.,-]*(?:\s+(?:&|[A-Z][A-Za-z0-9'.,-]*))*\s+" |
| r"(?:LLP|LLC|LP|Inc\.?|Corporation|Corp\.?|Ltd\.?))\b" |
| ) |
| return [normalize_whitespace(match.group(1)) for match in pattern.finditer(str(text))] |
|
|
|
|
| def build_legacy_expected_mapping(entry, question, docs): |
| normalized = { |
| "buyer_counsel": normalize_whitespace(entry.get("buyer_firm", "unknown")) or "unknown", |
| "seller_counsel": normalize_whitespace(entry.get("seller_firm", "unknown")) or "unknown", |
| "third_party_counsel": normalize_whitespace(entry.get("third_party", "unknown")) or "unknown", |
| } |
|
|
| target_candidate = extract_target_candidate_from_question(question) |
| target_norm = normalize_counsel_value(target_candidate) |
|
|
| for candidate in normalized.values(): |
| candidate_text = normalize_whitespace(candidate) |
| if not candidate_text: |
| continue |
| if target_norm and target_norm in normalize_counsel_value(candidate_text): |
| normalized[TARGET_FIRM_FIELD] = candidate_text |
| break |
| else: |
| doc_candidates = [] |
| for doc in docs: |
| doc_candidates.extend(extract_entity_candidates(doc)) |
|
|
| for candidate_text in doc_candidates: |
| if target_norm and target_norm in normalize_counsel_value(candidate_text): |
| normalized[TARGET_FIRM_FIELD] = candidate_text |
| break |
| else: |
| normalized[TARGET_FIRM_FIELD] = target_candidate or "unknown target" |
|
|
| return normalized |
|
|
|
|
| def normalize_expected_entry(entry, prefix, case_index, question="", docs=None): |
| docs = docs or [] |
| if isinstance(entry, dict): |
| legacy_keys = {"buyer_firm", "seller_firm", "third_party", "contains_target_firm"} |
| if TARGET_FIRM_FIELD not in entry and legacy_keys.intersection(entry.keys()): |
| return build_legacy_expected_mapping(entry, question, docs) |
| return normalize_expected_mapping(entry, prefix, case_index) |
|
|
| if not isinstance(entry, str): |
| raise ValueError( |
| f"{prefix} case {case_index} expected entry must be a string or object." |
| ) |
|
|
| stripped = entry.strip() |
| if stripped.startswith("{") and stripped.endswith("}"): |
| try: |
| parsed = json.loads(stripped) |
| except json.JSONDecodeError as exc: |
| raise ValueError( |
| f"{prefix} case {case_index} contains invalid JSON in expected entry: {exc}" |
| ) from exc |
| if isinstance(parsed, dict): |
| legacy_keys = {"buyer_firm", "seller_firm", "third_party", "contains_target_firm"} |
| if TARGET_FIRM_FIELD not in parsed and legacy_keys.intersection(parsed.keys()): |
| return build_legacy_expected_mapping(parsed, question, docs) |
| return normalize_expected_mapping(parsed, prefix, case_index) |
|
|
| return entry |
|
|
|
|
| def sanitize_input(text, max_length=500): |
| clean_text = re.sub(r"[^a-zA-Z0-9\s.,!?@:\-+'&()/]", "", text) |
| return clean_text.strip()[:max_length] |
|
|
|
|
| def sanitize_prompt(text): |
| return text.strip()[:8000] |
|
|
|
|
| def normalize_email(email): |
| return email.strip().lower() |
|
|
|
|
| def validate_email(email): |
| email_regex = r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$" |
| return re.match(email_regex, email) is not None |
|
|
|
|
| def extract_user_message(text): |
| if not text: |
| return None |
|
|
| pattern = re.compile( |
| rf"{re.escape(IRRELEVANT_TAG_START)}(.*?){re.escape(IRRELEVANT_TAG_END)}", |
| re.IGNORECASE | re.DOTALL, |
| ) |
| match = pattern.search(text) |
| if not match: |
| return None |
|
|
| return match.group(1).strip() |
|
|
|
|
| def normalize_whitespace(text): |
| return re.sub(r"\s+", " ", str(text)).strip() |
|
|
|
|
| def normalize_counsel_value(value): |
| if value is None: |
| return "" |
|
|
| text = normalize_whitespace(value).lower() |
| text = text.replace("&", " and ") |
| text = re.sub(r"[^a-z0-9\s]", " ", text) |
| return " ".join(text.split()) |
|
|
|
|
| def value_is_unknown(value): |
| return normalize_whitespace(value).lower() == "unknown" |
|
|
|
|
| def expected_is_unknown(expected_value): |
| return value_is_unknown(expected_value) or normalize_whitespace(expected_value) == "" |
|
|
|
|
| def truncate_text(text, max_chars=2000): |
| text = str(text or "") |
| if len(text) <= max_chars: |
| return text |
| return f"{text[: max_chars - 3]}..." |
|
|
|
|
| def serialize_json(value): |
| return json.dumps(value, ensure_ascii=True, sort_keys=True) |
|
|
|
|
| def normalize_newlines(text): |
| return str(text or "").replace("\r\n", "\n").replace("\r", "\n") |
|
|
|
|
| def normalize_sheet_block_text(text): |
| normalized = normalize_newlines(text) |
| lines = [line.rstrip() for line in normalized.split("\n")] |
| return "\n".join(lines).strip() |
|
|
|
|
| def format_score(value): |
| formatted = f"{float(value):.4f}".rstrip("0").rstrip(".") |
| if "." not in formatted: |
| formatted += ".0" |
| return formatted |
|
|
|
|
| def join_sheet_sections(blocks): |
| cleaned_blocks = [block for block in blocks if block] |
| return f"\n\n{SHEET_SECTION_DIVIDER}\n\n".join(cleaned_blocks) |
|
|
|
|
| def build_case_run_header(case_index, run_index): |
| return f"CASE {case_index} | RUN {run_index}" |
|
|
|
|
| def format_run_indices(run_indices): |
| labels = [str(run_index) for run_index in run_indices] |
| if not labels: |
| return "" |
| if len(labels) == 1: |
| return labels[0] |
| if len(labels) == 2: |
| return f"{labels[0]} AND {labels[1]}" |
| return f"{', '.join(labels[:-1])} AND {labels[-1]}" |
|
|
|
|
| def build_grouped_case_run_header(case_index, run_indices): |
| return f"CASE {case_index} | RUN {format_run_indices(run_indices)}" |
|
|
|
|
| def truncate_sheet_block(text, max_chars): |
| normalized = normalize_sheet_block_text(text) |
| return truncate_text(normalized, max_chars=max_chars) |
|
|
|
|
| def format_pretty_expected_answer(expected): |
| if isinstance(expected, str): |
| return "Irrelevant case. No final JSON should be produced because LLM 1 should stop the pipeline." |
| return truncate_text( |
| json.dumps(expected, ensure_ascii=True, sort_keys=True, indent=2), |
| max_chars=EXPECTED_OUTPUT_MAX_CHARS, |
| ) |
|
|
|
|
| def format_llm1_expected_text(expected): |
| if isinstance(expected, str): |
| return "Irrelevant case: output must be wrapped in <user_message>...</user_message>." |
| return "Relevant case: output must be exactly TARGET_FIRM: <canonical firm name>." |
|
|
|
|
| def format_llm2_user_output(llm2_outputs): |
| if not llm2_outputs: |
| return "Not run." |
|
|
| blocks = [] |
| for snippet_id, snippet_output in llm2_outputs.items(): |
| blocks.append( |
| "\n".join( |
| [ |
| f"{snippet_id}:", |
| truncate_sheet_block(snippet_output, max_chars=STAGE_TWO_OUTPUT_MAX_CHARS), |
| ] |
| ) |
| ) |
| return "\n\n".join(blocks) |
|
|
|
|
| def format_teacher_forced_note(run_result): |
| if run_result["used_teacher_forced_step1"]: |
| return ( |
| "LLM 1 did not match. LLM 2 and LLM 3 were re-run with the teacher-forced " |
| f"Step 1 context: {run_result['effective_step1_context']}" |
| ) |
| return f"LLM 2 and LLM 3 used the submitted Step 1 context: {run_result['effective_step1_context']}" |
|
|
|
|
| def parse_cited_value(value): |
| raw_text = normalize_whitespace(value) |
| citations = [int(match.group(1)) for match in CITATION_PATTERN.finditer(raw_text)] |
| malformed_citation = False |
| for token in re.findall(r"\[\^[^\]]*\]", raw_text): |
| if not CITATION_PATTERN.fullmatch(token): |
| malformed_citation = True |
| break |
| if "[^" in raw_text and not re.findall(r"\[\^[^\]]*\]", raw_text): |
| malformed_citation = True |
| base_value = normalize_whitespace(CITATION_PATTERN.sub("", raw_text)) |
| return { |
| "raw": raw_text, |
| "base_value": base_value, |
| "citations": citations, |
| "malformed_citation": malformed_citation, |
| } |
|
|
|
|
| def snippet_supports_value(snippet_text, value): |
| normalized_value = normalize_counsel_value(value) |
| if not normalized_value: |
| return False |
| return normalized_value in normalize_counsel_value(snippet_text) |
|
|
|
|
| def get_supporting_snippet_numbers(docs, value): |
| matches = [] |
| for index, doc in enumerate(docs, start=1): |
| if snippet_supports_value(doc, value): |
| matches.append(index) |
| return matches |
|
|
|
|
| def append_citations(value, snippet_numbers): |
| if value_is_unknown(value) or not snippet_numbers: |
| return value |
| citations = " ".join(f"[^{snippet_number}]" for snippet_number in snippet_numbers) |
| return f"{value} {citations}".strip() |
|
|
|
|
| def dedupe_preserving_order(values): |
| seen = set() |
| ordered = [] |
| for value in values: |
| normalized = normalize_counsel_value(value) |
| if not normalized or normalized in seen: |
| continue |
| seen.add(normalized) |
| ordered.append(normalize_whitespace(value)) |
| return ordered |
|
|
|
|
| def find_contradicting_entity(expected, docs, target_firm): |
| target_norm = normalize_counsel_value(target_firm) |
|
|
| for field_name in COUNSEL_FIELDS: |
| candidate = normalize_whitespace(expected.get(field_name, "unknown")) |
| if expected_is_unknown(candidate): |
| continue |
| if normalize_counsel_value(candidate) == target_norm: |
| continue |
| snippet_numbers = get_supporting_snippet_numbers(docs, candidate) |
| if snippet_numbers: |
| return candidate, snippet_numbers |
|
|
| doc_candidates = [] |
| for doc in docs: |
| doc_candidates.extend(extract_entity_candidates(doc)) |
|
|
| for candidate in dedupe_preserving_order(doc_candidates): |
| if normalize_counsel_value(candidate) == target_norm: |
| continue |
| snippet_numbers = get_supporting_snippet_numbers(docs, candidate) |
| if snippet_numbers: |
| return candidate, snippet_numbers |
|
|
| return "", [] |
|
|
|
|
| def build_expected_user_question_spec(expected, docs, target_firm): |
| target_snippet_numbers = get_supporting_snippet_numbers(docs, target_firm) |
| if target_snippet_numbers: |
| return { |
| "verdict": "true", |
| "company": normalize_whitespace(target_firm), |
| "citations": target_snippet_numbers, |
| } |
|
|
| contradicting_company, contradicting_snippet_numbers = find_contradicting_entity( |
| expected, |
| docs, |
| target_firm, |
| ) |
| if contradicting_company and contradicting_snippet_numbers: |
| return { |
| "verdict": "false", |
| "company": contradicting_company, |
| "citations": contradicting_snippet_numbers, |
| } |
|
|
| return { |
| "verdict": "unknown", |
| "company": "", |
| "citations": [], |
| } |
|
|
|
|
| def format_expected_user_question(spec): |
| verdict = spec["verdict"] |
| citations = spec.get("citations", []) |
|
|
| if verdict == "unknown": |
| return "unknown" |
|
|
| return append_citations(verdict, citations) |
|
|
|
|
| def parse_user_question_answer(value): |
| parsed = parse_cited_value(value) |
| base_value = parsed["base_value"] |
|
|
| if value_is_unknown(base_value): |
| return { |
| **parsed, |
| "verdict": "unknown", |
| "evidence": "", |
| } |
|
|
| match = BOOLEAN_ANSWER_PATTERN.match(base_value) |
| if match: |
| token = match.group(1).lower() |
| verdict = "true" if token in {"true", "yes"} else "false" |
| return { |
| **parsed, |
| "verdict": verdict, |
| } |
|
|
| return { |
| **parsed, |
| "verdict": None, |
| } |
|
|
|
|
| def validate_citations_for_value(parsed_value, docs, expected_value): |
| bad_references = [] |
| unsupported_citations = [] |
| for citation_number in parsed_value["citations"]: |
| if citation_number < 1 or citation_number > len(docs): |
| bad_references.append(citation_number) |
| continue |
| if not snippet_supports_value(docs[citation_number - 1], expected_value): |
| unsupported_citations.append(citation_number) |
| return bad_references, unsupported_citations |
|
|
|
|
| def build_expected_step1_output(expected): |
| if isinstance(expected, str): |
| return f"{IRRELEVANT_TAG_START}irrelevant{IRRELEVANT_TAG_END}" |
| return f"TARGET_FIRM: {expected[TARGET_FIRM_FIELD]}" |
|
|
|
|
| def build_expected_llm3_answer(expected, docs, target_firm): |
| if isinstance(expected, str): |
| return expected |
|
|
| formatted = {} |
| for field_name in COUNSEL_FIELDS: |
| expected_value = str(expected.get(field_name, "unknown")) |
| snippet_numbers = get_supporting_snippet_numbers(docs, expected_value) |
| formatted[field_name] = append_citations(expected_value, snippet_numbers) |
| formatted[USER_QUESTION_FIELD] = format_expected_user_question( |
| build_expected_user_question_spec(expected, docs, target_firm) |
| ) |
| return formatted |
|
|
|
|
| def format_prompts_cell(prompts): |
| blocks = [] |
| for label, prompt_key in ( |
| ("LLM 1", "prompt_1"), |
| ("LLM 2", "prompt_2"), |
| ("LLM 3", "prompt_3"), |
| ): |
| blocks.append( |
| "\n".join( |
| [ |
| label, |
| "Prompt:", |
| truncate_sheet_block(prompts[prompt_key], max_chars=PROMPT_CELL_MAX_CHARS), |
| ] |
| ) |
| ) |
| return truncate_text( |
| join_sheet_sections(blocks), |
| max_chars=SHEET_CELL_MAX_CHARS, |
| ) |
|
|
|
|
| def parse_relevant_stage_one_output(text): |
| normalized = normalize_sheet_block_text(text) |
| match = re.fullmatch(r"TARGET_FIRM:\s*(.+)", normalized) |
| if not match: |
| return None |
|
|
| target_firm = match.group(1).strip() |
| if not target_firm: |
| return None |
|
|
| return { |
| "target_firm": target_firm, |
| } |
|
|
|
|
| def is_error_output(text, stage_name): |
| return str(text).startswith(f"Error during {stage_name} call:") |
|
|
|
|
| def run_chat_completion(system_prompt, user_prompt): |
| semaphore = get_openai_request_semaphore() |
| semaphore.acquire() |
| try: |
| response = get_openai_client().chat.completions.create( |
| model=get_model_name(), |
| messages=[ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt}, |
| ], |
| ) |
| return response.choices[0].message.content.strip() |
| finally: |
| semaphore.release() |
|
|
|
|
| def run_step_one(prompts, question): |
| try: |
| llm1_output = run_chat_completion(prompts["prompt_1"], question) |
| except Exception as exc: |
| llm1_output = f"Error during LLM1 call: {exc}" |
|
|
| user_message = extract_user_message(llm1_output) |
| return { |
| "llm1_output": llm1_output, |
| "is_irrelevant": user_message is not None, |
| "user_message": user_message, |
| } |
|
|
|
|
| def build_llm2_user_prompt(step1_context, snippet_id, doc): |
| return ( |
| f"Target firm context:\n{step1_context}\n\n" |
| f"Snippet ID: {snippet_id}\n" |
| f"Snippet Text:\n{doc}" |
| ) |
|
|
|
|
| def run_llm2_snippet(prompts, step1_context, snippet_id, doc): |
| try: |
| return snippet_id, run_chat_completion( |
| prompts["prompt_2"], |
| build_llm2_user_prompt(step1_context, snippet_id, doc), |
| ) |
| except Exception as exc: |
| return snippet_id, f"Error during LLM2 call: {exc}" |
|
|
|
|
| def run_downstream_stages(prompts, docs, step1_context): |
| llm2_outputs = {f"S{index}": "" for index in range(1, len(docs) + 1)} |
| llm3_output = "" |
|
|
| if docs: |
| max_workers = min(get_llm2_max_workers(), len(docs)) |
| if max_workers == 1: |
| for index, doc in enumerate(docs, start=1): |
| snippet_id = f"S{index}" |
| _, snippet_output = run_llm2_snippet(prompts, step1_context, snippet_id, doc) |
| llm2_outputs[snippet_id] = snippet_output |
| else: |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: |
| futures = [ |
| executor.submit( |
| run_llm2_snippet, |
| prompts, |
| step1_context, |
| f"S{index}", |
| doc, |
| ) |
| for index, doc in enumerate(docs, start=1) |
| ] |
| for future in as_completed(futures): |
| snippet_id, snippet_output = future.result() |
| llm2_outputs[snippet_id] = snippet_output |
|
|
| serialized_llm2 = [] |
| for snippet_id, snippet_output in llm2_outputs.items(): |
| serialized_llm2.append(f"{snippet_id}\n{snippet_output}") |
|
|
| try: |
| llm3_output = run_chat_completion( |
| prompts["prompt_3"], |
| ( |
| f"Target firm context:\n{step1_context}\n\n" |
| "Per-snippet analyses:\n" |
| f"{chr(10).join(serialized_llm2)}" |
| ), |
| ) |
| except Exception as exc: |
| llm3_output = f"Error during LLM3 call: {exc}" |
|
|
| return { |
| "snippet_count": len(docs), |
| "llm2_outputs": llm2_outputs, |
| "llm3_output": llm3_output, |
| } |
|
|
|
|
| def parse_final_answer(answer): |
| try: |
| parsed = json.loads(answer) |
| except json.JSONDecodeError as exc: |
| return None, f"Invalid JSON: {exc}" |
|
|
| if not isinstance(parsed, dict): |
| return None, "Final answer must be a JSON object." |
|
|
| return parsed, None |
|
|
|
|
| def build_field_result(correct, issues=None, penalty=None): |
| if penalty is None: |
| penalty = 0.0 if correct else 0.2 |
| return { |
| "correct": correct, |
| "issues": issues or [], |
| "penalty": round(float(penalty), 4), |
| } |
|
|
|
|
| def build_prefixed_failure_tags(field_name, issues): |
| return [f"{issue}_{field_name}" for issue in issues] |
|
|
|
|
| def grade_counsel_field(field_name, parsed_answer, expected_answer, docs): |
| if field_name not in parsed_answer: |
| return build_field_result(False, ["missing_field"]) |
|
|
| parsed_value = parse_cited_value(parsed_answer.get(field_name)) |
| expected_value = str(expected_answer.get(field_name, "unknown")) |
| expected_unknown = expected_is_unknown(expected_value) |
| actual_unknown = value_is_unknown(parsed_value["base_value"]) |
| issues = [] |
|
|
| if parsed_value["malformed_citation"]: |
| issues.append("malformed_citation") |
|
|
| if actual_unknown and expected_unknown: |
| if parsed_value["citations"]: |
| issues.append("unexpected_citation") |
| return build_field_result(not issues, issues) |
|
|
| if actual_unknown and not expected_unknown: |
| issues.append("used_unknown") |
| return build_field_result(False, issues) |
|
|
| if not actual_unknown and expected_unknown: |
| issues.append("hallucinated_value") |
| return build_field_result(False, issues) |
|
|
| if normalize_counsel_value(parsed_value["base_value"]) != normalize_counsel_value(expected_value): |
| issues.append("wrong_value") |
|
|
| if not parsed_value["citations"]: |
| issues.append("missing_citation") |
| else: |
| bad_references = [] |
| unsupported_citations = [] |
| for citation_number in parsed_value["citations"]: |
| if citation_number < 1 or citation_number > len(docs): |
| bad_references.append(citation_number) |
| continue |
| if not snippet_supports_value(docs[citation_number - 1], parsed_value["base_value"]): |
| unsupported_citations.append(citation_number) |
| if bad_references: |
| issues.append("bad_citation_reference") |
| if unsupported_citations: |
| issues.append("unsupported_citation") |
|
|
| return build_field_result(not issues, issues) |
|
|
|
|
| def grade_user_question_field(parsed_answer, expected_user_question, docs, target_firm): |
| if USER_QUESTION_FIELD not in parsed_answer: |
| return build_field_result(False, ["missing_field"], penalty=0.2) |
|
|
| parsed_value = parse_user_question_answer(parsed_answer.get(USER_QUESTION_FIELD)) |
| issues = [] |
| penalty = 0.0 |
| if parsed_value["malformed_citation"]: |
| issues.append("malformed_citation") |
|
|
| expected_verdict = expected_user_question["verdict"] |
| expected_company = normalize_whitespace(expected_user_question.get("company", "")) |
|
|
| if expected_verdict == "unknown": |
| if parsed_value["verdict"] != "unknown": |
| issues.append("wrong_value") |
| penalty += 0.1 |
| if parsed_value["citations"]: |
| issues.append("unexpected_citation") |
| penalty += 0.1 |
| return build_field_result(not issues, issues, penalty=min(0.2, penalty)) |
|
|
| if parsed_value["verdict"] == "unknown": |
| issues.append("used_unknown") |
| return build_field_result(False, issues, penalty=0.2) |
|
|
| if parsed_value["verdict"] != expected_verdict: |
| issues.append("wrong_value") |
| penalty += 0.1 |
|
|
| if not parsed_value["citations"]: |
| issues.append("missing_citation") |
| penalty += 0.1 |
| else: |
| citation_target = expected_company or target_firm |
| bad_references, unsupported_citations = validate_citations_for_value( |
| parsed_value, |
| docs, |
| citation_target, |
| ) |
| if bad_references: |
| issues.append("bad_citation_reference") |
| if unsupported_citations: |
| issues.append("unsupported_citation") |
| if parsed_value["malformed_citation"] or bad_references or unsupported_citations: |
| penalty += 0.1 |
|
|
| return build_field_result(not issues, issues, penalty=min(0.2, penalty)) |
|
|
|
|
| def grade_case(case_run, expected): |
| case_result = {"case_score": 0.0, "failure_tags": [], "field_results": {}} |
| if isinstance(expected, str): |
| if case_run["is_irrelevant"]: |
| case_result["case_score"] = 1.0 |
| else: |
| case_result["failure_tags"] = ["relevance_false_positive"] |
| case_result["passed"] = case_result["case_score"] == 1.0 |
| return case_result |
|
|
| parsed_answer, parse_error = parse_final_answer(case_run["llm3_output"]) |
| if parse_error: |
| case_result["failure_tags"] = ["invalid_json"] |
| case_result["field_results"] = { |
| field_name: build_field_result(False, ["invalid_json"]) |
| for field_name in FINAL_SCHEMA_KEYS |
| } |
| case_result["passed"] = False |
| return case_result |
|
|
| total_penalty = 0.0 |
| failure_tags = [] |
| field_results = {} |
|
|
| for field_name in COUNSEL_FIELDS: |
| field_result = grade_counsel_field(field_name, parsed_answer, expected, case_run["docs"]) |
| field_results[field_name] = field_result |
| total_penalty += field_result["penalty"] |
| if field_result["issues"]: |
| failure_tags.extend(build_prefixed_failure_tags(field_name, field_result["issues"])) |
|
|
| user_question_result = grade_user_question_field( |
| parsed_answer, |
| build_expected_user_question_spec(expected, case_run["docs"], case_run["effective_target_firm"]), |
| case_run["docs"], |
| case_run["effective_target_firm"], |
| ) |
| field_results[USER_QUESTION_FIELD] = user_question_result |
| total_penalty += user_question_result["penalty"] |
| if user_question_result["issues"]: |
| failure_tags.extend( |
| build_prefixed_failure_tags(USER_QUESTION_FIELD, user_question_result["issues"]) |
| ) |
|
|
| case_result["case_score"] = round(max(0.0, 1.0 - total_penalty), 4) |
| case_result["failure_tags"] = sorted(set(failure_tags)) |
| case_result["field_results"] = field_results |
| case_result["passed"] = total_penalty == 0.0 |
| return case_result |
|
|
|
|
| def grade_llm1_stage(llm1_output, expected): |
| is_irrelevant = extract_user_message(llm1_output) is not None |
| parsed_output = parse_relevant_stage_one_output(llm1_output) |
|
|
| if isinstance(expected, str): |
| if is_irrelevant: |
| return { |
| "score": 1.0, |
| "reasoning": "matched", |
| "failure_tags": [], |
| "parsed_output": None, |
| } |
| reasoning = "call_failed" if is_error_output(llm1_output, "LLM1") else "wrong_route" |
| return { |
| "score": 0.0, |
| "reasoning": reasoning, |
| "failure_tags": [f"llm1_{reasoning}"], |
| "parsed_output": None, |
| } |
|
|
| if is_irrelevant: |
| return { |
| "score": 0.0, |
| "reasoning": "wrong_route", |
| "failure_tags": ["llm1_wrong_route"], |
| "parsed_output": None, |
| } |
|
|
| if parsed_output is None: |
| if is_error_output(llm1_output, "LLM1"): |
| reasoning = "call_failed" |
| score = 0.0 |
| else: |
| expected_target = normalize_counsel_value(expected[TARGET_FIRM_FIELD]) |
| output_norm = normalize_counsel_value(llm1_output) |
| if expected_target and expected_target in output_norm: |
| reasoning = "right_answer_wrong_format" |
| score = 0.5 |
| else: |
| reasoning = "wrong_answer_wrong_format" |
| score = 0.0 |
| return { |
| "score": score, |
| "reasoning": reasoning, |
| "failure_tags": [f"llm1_{reasoning}"], |
| "parsed_output": None, |
| } |
|
|
| if normalize_counsel_value(parsed_output["target_firm"]) != normalize_counsel_value( |
| expected[TARGET_FIRM_FIELD] |
| ): |
| return { |
| "score": 0.5, |
| "reasoning": "wrong_answer_right_format", |
| "failure_tags": ["llm1_wrong_answer_right_format"], |
| "parsed_output": parsed_output, |
| } |
|
|
| return { |
| "score": 1.0, |
| "reasoning": "matched", |
| "failure_tags": [], |
| "parsed_output": parsed_output, |
| } |
|
|
|
|
| def summarize_llm2_stage(case_run, expected): |
| if isinstance(expected, str): |
| return "Not run. Irrelevant case." |
|
|
| total_snippets = case_run["snippet_count"] |
| if total_snippets == 0: |
| return "No snippets were provided." |
|
|
| error_ids = [] |
| successful_snippets = 0 |
| for index in range(1, total_snippets + 1): |
| snippet_id = f"S{index}" |
| snippet_output = case_run["llm2_outputs"].get(snippet_id, "") |
| if not snippet_output or is_error_output(snippet_output, "LLM2"): |
| error_ids.append(snippet_id) |
| else: |
| successful_snippets += 1 |
|
|
| summary = f"{successful_snippets}/{total_snippets} snippet calls completed successfully." |
| if error_ids: |
| summary += f" Errored snippets: {', '.join(error_ids)}." |
| return summary |
|
|
|
|
| def grade_llm3_stage(case_run, expected, semantic_result): |
| if semantic_result["failure_tags"]: |
| reasoning = ", ".join(semantic_result["failure_tags"]) |
| elif isinstance(expected, str) and case_run["is_irrelevant"]: |
| reasoning = "skipped_irrelevant" |
| else: |
| reasoning = "clean" |
|
|
| return { |
| "score": semantic_result["case_score"], |
| "reasoning": reasoning, |
| } |
|
|
|
|
| def summarize_field_accuracy(case_results): |
| totals = { |
| field_name: {"correct": 0, "total": 0} |
| for field_name in FINAL_SCHEMA_KEYS |
| } |
|
|
| for case_result in case_results: |
| for field_name, field_summary in case_result["field_results"].items(): |
| totals[field_name]["total"] += 1 |
| if field_summary["correct"]: |
| totals[field_name]["correct"] += 1 |
|
|
| summary = {} |
| for field_name, counts in totals.items(): |
| total = counts["total"] |
| accuracy = round((counts["correct"] / total) * 100, 2) if total else 0.0 |
| summary[field_name] = { |
| "correct": counts["correct"], |
| "total": total, |
| "accuracy": accuracy, |
| } |
| return summary |
|
|
|
|
| def average_run_score(runs, key): |
| if not runs: |
| return 0.0 |
| return round(sum(run[key] for run in runs) / len(runs), 4) |
|
|
|
|
| def total_run_score(runs, key): |
| if not runs: |
| return 0.0 |
| return round(sum(run[key] for run in runs), 4) |
|
|
|
|
| def summarize_reasoning_counts(runs, key): |
| counts = Counter(run[key] for run in runs) |
| ordered_labels = [] |
| for run in runs: |
| label = run[key] |
| if label not in ordered_labels: |
| ordered_labels.append(label) |
| return "; ".join(f"{label} x{counts[label]}" for label in ordered_labels) |
|
|
|
|
| def explain_reasoning_label(label): |
| explanations = { |
| "matched": "The model followed the required instruction and produced the expected type of output.", |
| "wrong_format": "The model answered the right task but did not use the exact output format the evaluator requires.", |
| "right_answer_wrong_format": "The model identified the right answer but did not use the exact required Step 1 format.", |
| "wrong_answer_right_format": "The model used the required Step 1 format but identified the wrong target firm.", |
| "wrong_answer_wrong_format": "The model gave neither the right answer nor the required Step 1 format.", |
| "wrong_route": "The model chose the wrong kind of response for this question, such as treating an irrelevant query as relevant or the reverse.", |
| "call_failed": "The model call failed, so no valid answer was produced for this step.", |
| "invalid_json": "The Step 3 answer was not valid JSON in the required schema.", |
| "relevance_false_positive": "The system treated a question as relevant even though it should have been rejected as irrelevant.", |
| "skipped_irrelevant": "This step was skipped because the question was correctly identified as irrelevant.", |
| "clean": "The answer passed the evaluator checks for this step.", |
| } |
|
|
| if ", " in label: |
| parts = [part.strip() for part in label.split(",") if part.strip()] |
| explained_parts = [explanations.get(part, part.replace("_", " ")) for part in parts] |
| return "; ".join(explained_parts) |
|
|
| return explanations.get(label, label.replace("_", " ")) |
|
|
|
|
| def summarize_run_groups(runs, key): |
| grouped = [] |
| for run_index, run in enumerate(runs, start=1): |
| value = run[key] |
| if grouped and grouped[-1]["value"] == value: |
| grouped[-1]["run_indices"].append(run_index) |
| continue |
| grouped.append({"value": value, "run_indices": [run_index]}) |
| return [ |
| f"RUN {format_run_indices(group['run_indices'])}: {group['value']}" |
| for group in grouped |
| ] |
|
|
|
|
| def build_explained_run_groups(runs, key): |
| grouped = [] |
| for run_index, run in enumerate(runs, start=1): |
| value = run[key] |
| if grouped and grouped[-1]["value"] == value: |
| grouped[-1]["run_indices"].append(run_index) |
| continue |
| grouped.append({"value": value, "run_indices": [run_index]}) |
| return [ |
| "\n".join( |
| [ |
| f"RUN {format_run_indices(group['run_indices'])}", |
| f"Label: {group['value']}", |
| f"Meaning: {explain_reasoning_label(group['value'])}", |
| ] |
| ) |
| for group in grouped |
| ] |
|
|
|
|
| def build_score_breakdown_cell(case_results): |
| blocks = [] |
| for case_index, case_result in enumerate(case_results, start=1): |
| all_runs = list(range(1, len(case_result["runs"]) + 1)) |
| llm1_reasoning_lines = build_explained_run_groups(case_result["runs"], "llm1_reasoning") |
| llm3_reasoning_lines = build_explained_run_groups(case_result["runs"], "llm3_reasoning") |
| block_lines = [ |
| f"CASE {case_index}", |
| f"RUNS: {format_run_indices(all_runs)}", |
| "", |
| "LLM 1", |
| f"Score: {format_score(total_run_score(case_result['runs'], 'llm1_score'))} / {len(all_runs)}", |
| "Reasoning:", |
| *llm1_reasoning_lines, |
| "", |
| "LLM 3", |
| f"Score: {format_score(total_run_score(case_result['runs'], 'llm3_score'))} / {len(all_runs)}", |
| "Reasoning:", |
| *llm3_reasoning_lines, |
| ] |
| blocks.append("\n".join(block_lines)) |
| return truncate_text( |
| join_sheet_sections(blocks), |
| max_chars=SHEET_CELL_MAX_CHARS, |
| ) |
|
|
|
|
| def build_llm2_output_cell(case_results): |
| blocks = [] |
| for case_index, case_result in enumerate(case_results, start=1): |
| for run_index, run_result in enumerate(case_result["runs"], start=1): |
| blocks.append( |
| "\n".join( |
| [ |
| build_case_run_header(case_index, run_index), |
| format_teacher_forced_note(run_result) |
| if not isinstance(case_result["expected"], str) |
| else "Downstream stages were skipped because this was graded as irrelevant.", |
| "", |
| f"LLM 2 Summary: {run_result['llm2_reasoning']}", |
| "", |
| "LLM 2 Output:", |
| format_llm2_user_output(run_result["llm2_outputs"]), |
| ] |
| ) |
| ) |
| return truncate_text( |
| join_sheet_sections(blocks), |
| max_chars=SHEET_CELL_MAX_CHARS, |
| ) |
|
|
|
|
| def build_output_vs_expected_cell(case_results): |
| blocks = [] |
| for case_index, case_result in enumerate(case_results, start=1): |
| llm1_expected = format_llm1_expected_text(case_result["expected"]) |
| grouped_blocks = [] |
| for run_index, run_result in enumerate(case_result["runs"], start=1): |
| llm3_output = run_result["llm3_output"] or "Not run." |
| block_body = "\n".join( |
| [ |
| "LLM 1", |
| f"Status: {run_result['llm1_reasoning']}", |
| "User Output:", |
| truncate_sheet_block(run_result["llm1_output"], max_chars=STAGE_ONE_OUTPUT_MAX_CHARS), |
| "", |
| "Expected:", |
| llm1_expected, |
| "", |
| ( |
| format_teacher_forced_note(run_result) |
| if not isinstance(case_result["expected"], str) |
| else "Downstream stages were skipped because this case was irrelevant." |
| ), |
| "", |
| "LLM 3", |
| f"Reasoning: {run_result['llm3_reasoning']}", |
| "User Output:", |
| truncate_sheet_block(llm3_output, max_chars=STAGE_THREE_OUTPUT_MAX_CHARS), |
| "", |
| "Expected:", |
| format_pretty_expected_answer(run_result["llm3_expected"]), |
| ] |
| ) |
| normalized_body = normalize_sheet_block_text(block_body) |
| for grouped_block in grouped_blocks: |
| if grouped_block["body"] == normalized_body: |
| grouped_block["run_indices"].append(run_index) |
| break |
| else: |
| grouped_blocks.append({"body": normalized_body, "run_indices": [run_index]}) |
| for grouped_block in grouped_blocks: |
| blocks.append( |
| "\n".join( |
| [ |
| build_grouped_case_run_header(case_index, grouped_block["run_indices"]), |
| grouped_block["body"], |
| ] |
| ) |
| ) |
| return truncate_text( |
| join_sheet_sections(blocks), |
| max_chars=SHEET_CELL_MAX_CHARS, |
| ) |
|
|
|
|
| def estimate_case_llm_call_count(case): |
| if isinstance(case["expected"], str): |
| return EVAL_REPEAT_COUNT |
| return EVAL_REPEAT_COUNT * (len(case["docs"]) + 2) |
|
|
|
|
| def estimate_submission_llm_call_count(cases): |
| return sum(estimate_case_llm_call_count(case) for case in cases) |
|
|
|
|
| def evaluate_case(case, prompts, case_index=None): |
| started_at = time.perf_counter() |
| runs = [] |
| for _ in range(EVAL_REPEAT_COUNT): |
| step_one_result = run_step_one(prompts, case["question"]) |
| llm1_result = grade_llm1_stage(step_one_result["llm1_output"], case["expected"]) |
| expected_relevant = not isinstance(case["expected"], str) |
| used_teacher_forced_step1 = False |
| effective_step1_context = step_one_result["llm1_output"] |
|
|
| if expected_relevant: |
| if llm1_result["score"] < 1.0: |
| effective_step1_context = build_expected_step1_output(case["expected"]) |
| used_teacher_forced_step1 = True |
| downstream_result = run_downstream_stages( |
| prompts, |
| case["docs"], |
| effective_step1_context, |
| ) |
| else: |
| downstream_result = { |
| "snippet_count": len(case["docs"]), |
| "llm2_outputs": {}, |
| "llm3_output": "", |
| } |
|
|
| effective_target_firm = "" |
| if expected_relevant: |
| effective_target_firm = parse_relevant_stage_one_output(effective_step1_context)["target_firm"] |
|
|
| case_run = { |
| "question": case["question"], |
| "docs": case["docs"], |
| "snippet_count": downstream_result["snippet_count"], |
| "llm1_output": step_one_result["llm1_output"], |
| "llm2_outputs": downstream_result["llm2_outputs"], |
| "llm3_output": downstream_result["llm3_output"], |
| "is_irrelevant": step_one_result["is_irrelevant"], |
| "user_message": step_one_result["user_message"], |
| "effective_step1_context": effective_step1_context, |
| "effective_target_firm": effective_target_firm, |
| "used_teacher_forced_step1": used_teacher_forced_step1, |
| } |
|
|
| semantic_result = grade_case(case_run, case["expected"]) |
| llm3_result = grade_llm3_stage(case_run, case["expected"], semantic_result) |
| llm2_reasoning = summarize_llm2_stage(case_run, case["expected"]) |
| llm3_expected = build_expected_llm3_answer( |
| case["expected"], |
| case["docs"], |
| effective_target_firm, |
| ) |
| failure_tags = sorted(set(llm1_result["failure_tags"] + semantic_result["failure_tags"])) |
|
|
| runs.append( |
| { |
| "llm1_output": truncate_sheet_block( |
| step_one_result["llm1_output"], |
| max_chars=STAGE_ONE_OUTPUT_MAX_CHARS, |
| ), |
| "llm2_outputs": { |
| snippet_id: truncate_sheet_block( |
| snippet_output, |
| max_chars=STAGE_TWO_OUTPUT_MAX_CHARS, |
| ) |
| for snippet_id, snippet_output in case_run["llm2_outputs"].items() |
| }, |
| "llm3_output": truncate_sheet_block( |
| case_run["llm3_output"], |
| max_chars=STAGE_THREE_OUTPUT_MAX_CHARS, |
| ), |
| "is_irrelevant": step_one_result["is_irrelevant"], |
| "failure_tags": failure_tags, |
| "field_results": semantic_result["field_results"], |
| "passed": llm1_result["score"] == 1.0 and llm3_result["score"] == 1.0, |
| "llm1_score": llm1_result["score"], |
| "llm1_reasoning": llm1_result["reasoning"], |
| "llm2_reasoning": llm2_reasoning, |
| "llm3_score": llm3_result["score"], |
| "llm3_reasoning": llm3_result["reasoning"], |
| "used_teacher_forced_step1": used_teacher_forced_step1, |
| "effective_step1_context": truncate_sheet_block( |
| effective_step1_context, |
| max_chars=STAGE_ONE_OUTPUT_MAX_CHARS, |
| ), |
| "effective_target_firm": effective_target_firm, |
| "llm3_expected": llm3_expected, |
| "total_points": round(llm1_result["score"] + llm3_result["score"], 4), |
| "total_points_without_llm1": round(llm3_result["score"], 4), |
| } |
| ) |
|
|
| average_llm1_score = average_run_score(runs, "llm1_score") |
| average_llm3_score = average_run_score(runs, "llm3_score") |
| average_total_points = round(average_llm1_score + average_llm3_score, 4) |
|
|
| result = { |
| "expected": case["expected"], |
| "runs": runs, |
| "average_llm1_score": average_llm1_score, |
| "average_llm3_score": average_llm3_score, |
| "average_total_points": average_total_points, |
| "case_score": average_total_points, |
| "passed": abs(average_total_points - 2.0) < 1e-9, |
| } |
|
|
| elapsed_seconds = time.perf_counter() - started_at |
| logger.info( |
| "Completed case %s: relevant=%s snippets=%s repeat_count=%s " |
| "expected_llm_calls=%s elapsed_seconds=%.2f", |
| case_index if case_index is not None else "unknown", |
| not isinstance(case["expected"], str), |
| len(case["docs"]), |
| EVAL_REPEAT_COUNT, |
| estimate_case_llm_call_count(case), |
| elapsed_seconds, |
| ) |
| return result |
|
|
|
|
| def grade_submission(eval_set, prompts, mode): |
| case_results = [] |
| cases = eval_set["cases"] |
|
|
| if not cases: |
| return { |
| "submission_type": mode, |
| "score_percent": 0.0, |
| "passed_cases": 0, |
| "total_cases": 0, |
| "failure_summary": {}, |
| "field_accuracy": summarize_field_accuracy([]), |
| "case_results": [], |
| "results": 0.0, |
| "results_without_llm_1": 0.0, |
| } |
|
|
| max_workers = min(get_eval_max_workers(), len(cases)) |
| started_at = time.perf_counter() |
| expected_llm_calls = estimate_submission_llm_call_count(cases) |
| snippet_counts = [len(case["docs"]) for case in cases] |
| logger.info( |
| "Starting %s grading: cases=%s repeat_count=%s snippet_counts=%s " |
| "expected_llm_calls=%s case_workers=%s llm2_workers=%s openai_global_limit=%s", |
| mode, |
| len(cases), |
| EVAL_REPEAT_COUNT, |
| snippet_counts, |
| expected_llm_calls, |
| max_workers, |
| get_llm2_max_workers(), |
| get_openai_max_concurrent_requests(), |
| ) |
|
|
| ordered_results = [None] * len(cases) |
|
|
| if max_workers == 1: |
| for index, case in enumerate(cases): |
| ordered_results[index] = safe_evaluate_case(case, prompts, index + 1) |
| else: |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: |
| future_to_index = { |
| executor.submit(safe_evaluate_case, case, prompts, index + 1): index |
| for index, case in enumerate(cases) |
| } |
| for future in as_completed(future_to_index): |
| index = future_to_index[future] |
| ordered_results[index] = future.result() |
|
|
| case_results = ordered_results |
| total_cases = len(case_results) |
| total_results = round( |
| sum(case_result["average_total_points"] for case_result in case_results), |
| 4, |
| ) |
| total_results_without_llm_1 = round( |
| sum(case_result["average_llm3_score"] for case_result in case_results), |
| 4, |
| ) |
| run_results = [ |
| run_result |
| for case_result in case_results |
| for run_result in case_result["runs"] |
| ] |
|
|
| passed_cases = sum(1 for case_result in case_results if case_result["passed"]) |
| max_points = total_cases * 2.0 |
| score_percent = round((total_results / max_points) * 100, 2) if max_points else 0.0 |
| failure_summary = dict( |
| sorted( |
| Counter( |
| tag |
| for run_result in run_results |
| for tag in run_result["failure_tags"] |
| ).items() |
| ) |
| ) |
| field_accuracy = summarize_field_accuracy(run_results) |
| elapsed_seconds = time.perf_counter() - started_at |
| logger.info( |
| "Completed %s grading: cases=%s repeat_count=%s expected_llm_calls=%s " |
| "elapsed_seconds=%.2f", |
| mode, |
| total_cases, |
| EVAL_REPEAT_COUNT, |
| expected_llm_calls, |
| elapsed_seconds, |
| ) |
|
|
| return { |
| "submission_type": mode, |
| "score_percent": score_percent, |
| "passed_cases": passed_cases, |
| "total_cases": total_cases, |
| "failure_summary": failure_summary, |
| "field_accuracy": field_accuracy, |
| "case_results": case_results, |
| "results": total_results, |
| "results_without_llm_1": total_results_without_llm_1, |
| } |
|
|
|
|
| def get_attempt_status(email): |
| normalized_email = normalize_email(email) |
| status = {"practice_used": False, "final_used": False} |
|
|
| local_status = LOCAL_ATTEMPTS.get(normalized_email, {}) |
| if local_status.get("practice_used"): |
| status["practice_used"] = True |
| if local_status.get("final_used"): |
| status["final_used"] = True |
|
|
| sheet = get_google_sheet(ensure_headers=False) |
| rows = sheet.get("B:D") |
|
|
| for row in rows[1:]: |
| row_email = None |
| row_type = None |
|
|
| if len(row) >= 3 and row[0] in SUBMISSION_TYPES: |
| row_type = row[0] |
| row_email = normalize_email(row[2]) |
| elif row: |
| row_type = "final" |
| row_email = normalize_email(row[0]) |
|
|
| if row_email != normalized_email: |
| continue |
|
|
| if row_type == "practice": |
| status["practice_used"] = True |
| if row_type == "final": |
| status["final_used"] = True |
|
|
| return status |
|
|
|
|
| def save_attempt(record): |
| sheet = get_google_sheet() |
| row = [serialize_sheet_value(record.get(key, "")) for key, _ in SHEET_COLUMNS] |
| sheet.append_row( |
| row, |
| value_input_option="RAW", |
| insert_data_option="INSERT_ROWS", |
| table_range="A1", |
| ) |
|
|
|
|
| def record_local_attempt(email, submission_type): |
| LOCAL_ATTEMPTS[email] = { |
| "practice_used": LOCAL_ATTEMPTS.get(email, {}).get("practice_used", False) |
| or submission_type == "practice", |
| "final_used": LOCAL_ATTEMPTS.get(email, {}).get("final_used", False) |
| or submission_type == "final", |
| } |
|
|
|
|
| def build_attempt_record(name, email, prompts, submission_result): |
| return { |
| "timestamp": datetime.now(timezone.utc).isoformat(), |
| "submission_type": submission_result["submission_type"], |
| "name": name, |
| "email": email, |
| "results": submission_result["results"], |
| "results_without_llm_1": submission_result["results_without_llm_1"], |
| "llm_score_breakdown": build_score_breakdown_cell(submission_result["case_results"]), |
| "llm_output_vs_expected": build_output_vs_expected_cell( |
| submission_result["case_results"] |
| ), |
| "llm_2_output": build_llm2_output_cell(submission_result["case_results"]), |
| "prompts": format_prompts_cell(prompts), |
| } |
|
|
|
|
| def format_field_accuracy(field_accuracy): |
| labels = { |
| "buyer_counsel": "Buyer counsel", |
| "seller_counsel": "Seller counsel", |
| "third_party_counsel": "Third-party counsel", |
| "user_question": "User question", |
| } |
| lines = [] |
| for field_name in FINAL_SCHEMA_KEYS: |
| summary = field_accuracy[field_name] |
| lines.append( |
| ( |
| f"- {labels[field_name]}: {summary['accuracy']}% " |
| f"({summary['correct']}/{summary['total']})" |
| ) |
| ) |
| return "\n".join(lines) |
|
|
|
|
| def build_user_facing_reasoning_message(label, llm_name): |
| if llm_name == "LLM 1": |
| messages = { |
| "matched": "No issue in these runs.", |
| "right_answer_wrong_format": "The model identified the right answer but did not use the required Step 1 format.", |
| "wrong_answer_right_format": "The model used the required Step 1 format but identified the wrong target firm.", |
| "wrong_answer_wrong_format": "The model gave neither the right output nor the right format.", |
| "wrong_route": "The model chose the wrong route for this question and should have handled relevance differently.", |
| "call_failed": "The Step 1 model call failed, so no valid answer was produced.", |
| } |
| return messages.get(label, explain_reasoning_label(label)) |
|
|
| if label == "matched": |
| return "No issue in these runs." |
| if label == "clean": |
| return "The final JSON was valid and passed the evaluator checks." |
| if label == "skipped_irrelevant": |
| return "This case was correctly rejected as irrelevant before Step 3." |
| if label == "invalid_json": |
| return "The final answer was not valid JSON in the required schema." |
| if label == "relevance_false_positive": |
| return ( |
| "This question should have been rejected as irrelevant earlier, but the pipeline continued as if it were relevant." |
| ) |
|
|
| parts = [part.strip() for part in str(label).split(",") if part.strip()] |
| explained = [] |
| for part in parts: |
| detail = explain_output_issue_tag(part) |
| if detail: |
| explained.append(detail) |
| else: |
| explained.append(explain_reasoning_label(part)) |
| return "; ".join(explained) if explained else "The final output had evaluator-detected issues." |
|
|
|
|
| def split_field_issue_tag(tag): |
| for field_name in sorted(FINAL_SCHEMA_KEYS, key=len, reverse=True): |
| suffix = f"_{field_name}" |
| if tag.endswith(suffix): |
| return tag[: -len(suffix)], field_name |
| return None, None |
|
|
|
|
| def explain_output_issue_tag(tag): |
| field_labels = { |
| "buyer_counsel": "Buyer counsel", |
| "seller_counsel": "Seller counsel", |
| "third_party_counsel": "Third-party counsel", |
| "user_question": "User question", |
| } |
| issue_name, field_name = split_field_issue_tag(tag) |
| if not issue_name or not field_name: |
| return None |
|
|
| field_label = field_labels[field_name] |
| issue_messages = { |
| "missing_field": "the field was missing from the final JSON.", |
| "malformed_citation": "the citation format was malformed.", |
| "unexpected_citation": "a citation was included where it should not have been.", |
| "used_unknown": 'the answer said "unknown" even though the evidence supported a conclusion.', |
| "hallucinated_value": "a value was given even though the expected answer was unknown.", |
| "wrong_value": "the answer value was wrong.", |
| "missing_citation": "the answer needed a citation but did not include one.", |
| "bad_citation_reference": "the citation pointed to a snippet number that does not exist.", |
| "unsupported_citation": "the citation pointed to a snippet that does not support the answer.", |
| } |
| message = issue_messages.get(issue_name) |
| if not message: |
| return None |
| return f"{field_label}: {message}" |
|
|
|
|
| def build_practice_run_issue_lines(runs, key, llm_name): |
| grouped = [] |
| for run_index, run in enumerate(runs, start=1): |
| value = run[key] |
| if value in {"matched", "clean", "skipped_irrelevant"}: |
| continue |
| if grouped and grouped[-1]["value"] == value: |
| grouped[-1]["run_indices"].append(run_index) |
| continue |
| grouped.append({"value": value, "run_indices": [run_index]}) |
|
|
| if not grouped: |
| return ["- No issues detected."] |
|
|
| return [ |
| f"- RUN {format_run_indices(group['run_indices'])}: " |
| f"{build_user_facing_reasoning_message(group['value'], llm_name)}" |
| for group in grouped |
| ] |
|
|
|
|
| def build_practice_case_feedback(case_results): |
| blocks = [] |
| for case_index, case_result in enumerate(case_results, start=1): |
| run_count = len(case_result["runs"]) |
| llm1_score = total_run_score(case_result["runs"], "llm1_score") |
| llm3_score = total_run_score(case_result["runs"], "llm3_score") |
| blocks.append( |
| "\n".join( |
| [ |
| f"CASE {case_index}", |
| "", |
| "LLM 1", |
| f"Score: {format_score(llm1_score)} / {run_count}", |
| "Issues:", |
| *build_practice_run_issue_lines(case_result["runs"], "llm1_reasoning", "LLM 1"), |
| "", |
| "LLM 3", |
| f"Score: {format_score(llm3_score)} / {run_count}", |
| "Issues:", |
| *build_practice_run_issue_lines(case_result["runs"], "llm3_reasoning", "LLM 3"), |
| ] |
| ) |
| ) |
| return join_sheet_sections(blocks) |
|
|
|
|
| def build_practice_scoring_note(submission_result): |
| total_cases = submission_result["total_cases"] |
| return ( |
| f"This practice score is based on {total_cases} hidden calibration cases. " |
| "Each case is run 3 times to check prompt consistency." |
| ) |
|
|
|
|
| def format_practice_feedback(name, submission_result): |
| case_results = submission_result["case_results"] |
| return ( |
| f"Practice run complete for {name}.\n\n" |
| f"Score: {submission_result['score_percent']}%\n" |
| f"{build_practice_scoring_note(submission_result)}\n\n" |
| f"{build_practice_case_feedback(case_results)}\n\n" |
| "Your final submission is still available." |
| ) |
|
|
|
|
| def load_eval_set_for_mode(submission_type): |
| prefix = "PRACTICE" if submission_type == "practice" else "FINAL" |
| return load_eval_set(prefix) |
|
|
|
|
| def validate_submission_inputs(email, name): |
| if not validate_email(email): |
| return "Invalid email address. Please enter a valid email." |
|
|
| if not name.strip(): |
| return "Please enter your full name." |
|
|
| return None |
|
|
|
|
| def submit_attempt(submission_type, email, name, prompt_1, prompt_2, prompt_3): |
| if submission_type not in SUBMISSION_TYPES: |
| raise ValueError(f"Unsupported submission type: {submission_type}") |
|
|
| validation_error = validate_submission_inputs(email, name) |
| if validation_error: |
| return build_submission_response(False, validation_error) |
|
|
| normalized_email = normalize_email(email) |
| clean_name = sanitize_input(name) |
| prompts = { |
| "prompt_1": sanitize_prompt(prompt_1), |
| "prompt_2": sanitize_prompt(prompt_2), |
| "prompt_3": sanitize_prompt(prompt_3), |
| } |
| bypass_sheet_save = should_bypass_sheet_save() |
|
|
| try: |
| attempt_status = get_attempt_status(normalized_email) |
| except Exception: |
| log_stage_exception( |
| "Unexpected error while reading sheet state for %s submission for %s.", |
| submission_type, |
| normalized_email, |
| ) |
| return build_submission_response( |
| False, |
| build_stage_error_message(submission_type, "sheet_read"), |
| ) |
|
|
| if submission_type == "practice": |
| if attempt_status["final_used"]: |
| return build_submission_response( |
| False, |
| ( |
| f"A final submission has already been received for {normalized_email}. " |
| "Practice is no longer available." |
| ), |
| disable_practice=True, |
| disable_final=True, |
| ) |
| if attempt_status["practice_used"]: |
| return build_submission_response( |
| False, |
| ( |
| f"Practice has already been used for {normalized_email}. " |
| "Your final submission is still available." |
| ), |
| disable_practice=True, |
| disable_final=False, |
| ) |
|
|
| if submission_type == "final" and attempt_status["final_used"]: |
| return build_submission_response( |
| False, |
| ( |
| f"Final submission already received for {normalized_email}. " |
| "You can only submit the final once." |
| ), |
| disable_practice=True, |
| disable_final=True, |
| ) |
|
|
| try: |
| eval_set = load_eval_set_for_mode(submission_type) |
| except Exception: |
| log_stage_exception( |
| "Unexpected error while loading %s dataset for %s.", |
| submission_type, |
| normalized_email, |
| ) |
| return build_submission_response( |
| False, |
| build_stage_error_message(submission_type, "dataset_load"), |
| ) |
|
|
| if not eval_set["cases"]: |
| return build_submission_response( |
| False, |
| f"No hidden cases configured for the {submission_type} dataset.", |
| ) |
|
|
| try: |
| submission_result = grade_submission(eval_set, prompts, submission_type) |
| except Exception: |
| log_stage_exception( |
| "Unexpected error while grading %s submission for %s.", |
| submission_type, |
| normalized_email, |
| ) |
| return build_submission_response( |
| False, |
| build_stage_error_message(submission_type, "grading"), |
| ) |
|
|
| try: |
| record = build_attempt_record(clean_name, normalized_email, prompts, submission_result) |
| except Exception: |
| log_stage_exception( |
| "Unexpected error while building %s record for %s.", |
| submission_type, |
| normalized_email, |
| ) |
| return build_submission_response( |
| False, |
| build_stage_error_message(submission_type, "record_build"), |
| ) |
|
|
| bypass_notice = build_bypass_save_notice() if bypass_sheet_save else "" |
|
|
| if not bypass_sheet_save: |
| try: |
| save_attempt(record) |
| record_local_attempt(normalized_email, submission_type) |
| except Exception: |
| log_stage_exception( |
| "Unexpected error while saving %s submission for %s.", |
| submission_type, |
| normalized_email, |
| ) |
| return build_submission_response( |
| False, |
| build_save_error_message(submission_type, "sheet_write"), |
| ) |
| else: |
| print( |
| f"BYPASS_SHEET_SAVE enabled for {submission_type} submission {normalized_email}.", |
| flush=True, |
| ) |
|
|
| if submission_type == "practice": |
| notice = "Practice run saved. Your final submission is still available." |
| if bypass_notice: |
| notice = bypass_notice |
| return build_submission_response( |
| True, |
| format_practice_feedback(clean_name, submission_result), |
| notice=notice, |
| disable_practice=not bypass_sheet_save, |
| disable_final=False, |
| ) |
|
|
| notice = "Final submission received. You can close the page." |
| if bypass_notice: |
| notice = bypass_notice |
| return build_submission_response( |
| True, |
| f"Thank you for your submission, {clean_name}!", |
| notice=notice, |
| disable_practice=not bypass_sheet_save, |
| disable_final=not bypass_sheet_save, |
| ) |
|
|
|
|
| def build_submission_callback_result(result): |
| return ( |
| result["message"], |
| gr.update(interactive=not result["disable_practice"]), |
| gr.update(interactive=not result["disable_final"]), |
| gr.update(value=result["notice"], visible=bool(result["notice"])), |
| ) |
|
|
|
|
| def handle_submission(submission_type, email, name, s1, s2, s3): |
| try: |
| result = submit_attempt(submission_type, email, name, s1, s2, s3) |
| except Exception: |
| log_stage_exception("Unhandled callback failure for %s submission.", submission_type) |
| result = build_submission_response( |
| False, |
| build_stage_error_message(submission_type, "callback"), |
| ) |
| return build_submission_callback_result(result) |
|
|
|
|
| def build_interface(): |
| with gr.Blocks(css=APP_CSS) as demo: |
| gr.Markdown( |
| """ |
| # Applicant Task: Target Company & Law Firm Identification |
| Draft prompts for a strict 3-step legal review pipeline over snippets from Asset Purchase Agreements [SEC Agreement Example](https://www.sec.gov/Archives/edgar/data/28452/000119312505012401/dex101.htm) |
| > This evaluation system uses (default: `gpt-5-mini`) for all LLM steps. |
| |
| ## What you need to do |
| |
| - Decide whether the query is relevant and normalize the target firm name. |
| - Inspect each snippet independently and record only what that snippet actually supports. |
| - Combine those snippet-level findings into one final JSON answer with citations, including a cited answer to the original user question. |
| |
| ## The 3-step pipeline |
| |
| ### Step 1: Relevance and target-firm normalization |
| Decide whether the question belongs in this deal at all, and if it does, standardize the firm name you will track through the rest of the pipeline. |
| |
| - If the query is irrelevant, return `<user_message>...</user_message>`. |
| - If the query is relevant, return only: |
| |
| ```text |
| TARGET_FIRM: <canonical firm name> |
| ``` |
| Please ensure your final output uses the exact key `TARGET_FIRM:` as shown above, alongside the firm name. |
| |
| ### Step 2: Snippet-by-snippet analysis |
| Step 2 is not the final answer; it is a working note for one snippet at a time. |
| |
| - Your Step 2 prompt runs independently on each evidence unit, so the model only sees one snippet per call. |
| - The app passes: |
| - the Step 1 output |
| - `Snippet ID: S1..Sn` |
| - the snippet text |
| - A good Step 2 prompt says what this snippet supports, what it does not support, and where the evidence is still uncertain. |
| |
| ### Step 3: Reconciliation and final answer |
| Step 3 receives the Step 1 output plus all Step 2 notes and must turn them into one answer that could survive a hostile redline, including a direct answer to the original user question. |
| |
| - Step 3 receives the Step 1 output plus all Step 2 outputs. |
| - It must return valid JSON with this exact schema: |
| |
| ```json |
| { |
| "buyer_counsel": "string with citations like \"Firm Name [^2]\" or \"unknown\"", |
| "seller_counsel": "string with citations like \"Firm Name [^4]\" or \"unknown\"", |
| "third_party_counsel": "string with citations like \"Firm Name [^1]\" or \"unknown\"", |
| "user_question": "string with citations like \"true [^2]\", \"false [^4]\", or \"unknown\"" |
| } |
| ``` |
| |
| Use `"unknown"` when a counsel field or the user-question answer cannot be supported by the evidence. When you provide a firm name, include snippet citations such as `[^2]` that point to the relevant Step 2 snippet IDs. The `user_question` field should answer the original query using only `true`, `false`, or `unknown`: if the answer is `true` or `false`, include supporting snippet citations and do not add any extra text. |
| """ |
| ) |
|
|
| gr.Markdown( |
| """ |
| ## Workflow At A Glance |
| |
| This visual shows how the query moves from `LLM1` to `LLM2` to `LLM3`, and how the final JSON is assembled from the APA snippets. |
| """ |
| ) |
| gr.Image( |
| value="mermaid_diagram.png", |
| label="Pipeline overview", |
| show_label=True, |
| interactive=False, |
| ) |
|
|
| with gr.Accordion("Example Workflow", open=False): |
| gr.Markdown( |
| """ |
| **User query** |
| ```text |
| Is Kirkland & Ellis LLP acting as counsel anywhere in this Asset Purchase Agreement? |
| ``` |
| |
| **Step 1 relevant output** |
| ```text |
| TARGET_FIRM: Kirkland & Ellis LLP |
| ``` |
| |
| **Example evidence units** |
| - `S1`: the opening paragraph names the parties and mentions Kirkland & Ellis LLP as transaction counsel to the buyer. |
| - `S2`: the notices section identifies buyer counsel as Kirkland & Ellis LLP and seller counsel as Wachtell, Lipton, Rosen & Katz. |
| - `S3`: a boilerplate clause contains no counsel information and should not drive the final answer. |
| - `S4`: a representative provision states that Gibson, Dunn & Crutcher LLP advises the securityholders' representative. |
| - `S5`: a later notice block again confirms seller counsel as Wachtell, Lipton, Rosen & Katz. |
| |
| **Final JSON shape** |
| ```json |
| { |
| "buyer_counsel": "Kirkland & Ellis LLP [^1] [^2]", |
| "seller_counsel": "Wachtell, Lipton, Rosen & Katz [^2] [^5]", |
| "third_party_counsel": "Gibson, Dunn & Crutcher LLP [^4]", |
| "user_question": "true [^1] [^2]" |
| } |
| ``` |
| """ |
| ) |
|
|
| with gr.Accordion("Practice and Final Submission", open=False): |
| gr.Markdown( |
| """ |
| - You may use **one optional practice run** per email to test your prompts against a hidden calibration set. |
| - The practice run uses 3 hidden calibration cases. |
| - Each case is run 3 times to check prompt consistency. |
| - For each run, LLM 1 can earn up to 1 point for correct routing and target-firm normalization, and LLM 3 can earn up to 1 point for a correct final JSON answer with supported citations. |
| - Step 2 is not scored directly, but it strongly affects the LLM 3 score because Step 3 relies on the snippet-level analysis. |
| - Practice returns aggregate feedback only: score percentage, an LLM 1 summary, and an LLM 3 summary. |
| - You may then revise your prompts or keep them as they are. |
| - You may submit **one final submission** per email against a separate hidden holdout set. |
| - After the final submission, practice is no longer available. |
| - No structured decoding is used for you, so your prompts must make Step 3 produce reliable JSON on their own. |
| """ |
| ) |
|
|
| gr.Markdown( |
| """ |
| Enter your name and email exactly as listed in your CV. Both buttons below use the same three prompt boxes. |
| |
| You have **one** chance to run the practice set and get feedback, and **one** chance to run the final set. After you click a button, wait for the results to load before clicking again or refreshing the page. |
| |
| **Good Luck!** |
| """ |
| ) |
|
|
| email_input = gr.Textbox(label="Email", placeholder="your.email@example.com") |
| name_input = gr.Textbox(label="First Name, Last Name", placeholder="John Smith") |
| system_prompt_input_1 = gr.Textbox( |
| label="System Prompt for Step 1", |
| placeholder="Enter your Step 1 prompt here...", |
| lines=6, |
| ) |
| system_prompt_input_2 = gr.Textbox( |
| label="System Prompt for Step 2", |
| placeholder="Enter your Step 2 prompt here...", |
| lines=10, |
| ) |
| system_prompt_input_3 = gr.Textbox( |
| label="System Prompt for Step 3", |
| placeholder="Enter your Step 3 prompt here...", |
| lines=6, |
| ) |
|
|
| gr.Markdown( |
| """ |
| <div class="submission-note"> |
| <b>Please note:</b><br> |
| Each run may take a couple of minutes.<br> |
| After you click a button, wait for the result and do not click it again. |
| </div> |
| """ |
| ) |
|
|
| with gr.Row(): |
| practice_button = gr.Button("Practice Run") |
| final_button = gr.Button("Submit Final") |
|
|
| output_text = gr.Textbox(label="Results", lines=18) |
| feedback_md = gr.Markdown("", visible=False) |
|
|
| def practice_submit_and_update(email, name, s1, s2, s3): |
| return handle_submission("practice", email, name, s1, s2, s3) |
|
|
| def final_submit_and_update(email, name, s1, s2, s3): |
| return handle_submission("final", email, name, s1, s2, s3) |
|
|
| practice_button.click( |
| fn=practice_submit_and_update, |
| inputs=[ |
| email_input, |
| name_input, |
| system_prompt_input_1, |
| system_prompt_input_2, |
| system_prompt_input_3, |
| ], |
| outputs=[output_text, practice_button, final_button, feedback_md], |
| ) |
|
|
| final_button.click( |
| fn=final_submit_and_update, |
| inputs=[ |
| email_input, |
| name_input, |
| system_prompt_input_1, |
| system_prompt_input_2, |
| system_prompt_input_3, |
| ], |
| outputs=[output_text, practice_button, final_button, feedback_md], |
| ) |
|
|
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| interface = build_interface() |
| interface.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False) |
|
|