RegTech / server.py
bardd's picture
Upload 48 files
916ecde verified
#!/usr/bin/env python3
from __future__ import annotations
import base64
import binascii
import json
import os
import re
import subprocess
import uuid
from contextlib import contextmanager, nullcontext
from concurrent.futures import ThreadPoolExecutor, as_completed
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from time import monotonic
from typing import Any, Iterator
try:
import langsmith as ls
from langsmith import Client as LangSmithClient
except ImportError: # pragma: no cover - runtime optional in local dev until installed.
ls = None
LangSmithClient = None
APP_DIR = Path(__file__).resolve().parent
STATIC_DIR = APP_DIR / "static"
UPLOADS_DIR = APP_DIR / "uploads"
PROMPTS_DIR = APP_DIR / "prompts"
HOST = os.environ.get("HOST", "127.0.0.1")
PORT = int(os.environ.get("PORT", "8080"))
GEMINI_TIMEOUT_SEC = int(os.environ.get("GEMINI_TIMEOUT_SEC", "90"))
GEMINI_CLI_BINARY = os.environ.get("GEMINI_CLI_BINARY", "gemini")
LOCKED_GEMINI_MODEL = "gemini-3-flash-preview"
MAX_IMAGE_BYTES = int(os.environ.get("MAX_IMAGE_BYTES", str(8 * 1024 * 1024)))
MAX_BATCH_IMAGES = int(os.environ.get("MAX_BATCH_IMAGES", "20"))
MAX_PARALLEL_WORKERS = max(1, int(os.environ.get("MAX_PARALLEL_WORKERS", "4")))
PIPELINE_STAGE_WORKERS = max(1, int(os.environ.get("PIPELINE_STAGE_WORKERS", "4")))
VALIDATION_RETRY_PASSES = max(0, int(os.environ.get("VALIDATION_RETRY_PASSES", "1")))
LANGSMITH_PROJECT = os.environ.get("LANGSMITH_PROJECT", "regtechdemo-hf-v2")
LANGSMITH_TRACE_USER_AD_COPY = (
os.environ.get("LANGSMITH_TRACE_USER_AD_COPY", "true").strip().lower() == "true"
)
LANGSMITH_TRACE_RAW_REQUEST = (
os.environ.get("LANGSMITH_TRACE_RAW_REQUEST", "true").strip().lower() == "true"
)
LANGSMITH_ENABLED = (
ls is not None
and bool(os.environ.get("LANGSMITH_API_KEY", "").strip())
and os.environ.get("LANGSMITH_TRACING", "true").strip().lower() == "true"
)
LANGSMITH_CLIENT = LangSmithClient() if LANGSMITH_ENABLED and LangSmithClient is not None else None
ALLOWED_IMAGE_MIME_TO_EXT = {
"image/png": "png",
"image/jpeg": "jpg",
"image/jpg": "jpg",
"image/webp": "webp",
"image/gif": "gif",
}
DATA_URL_RE = re.compile(r"^data:(?P<mime>[-\w.+/]+);base64,(?P<data>[A-Za-z0-9+/=\s]+)$")
DEFAULT_SYSTEM_PROMPT = (
"You are a UK fintech ad compliance screening assistant. "
"Return only valid JSON and nothing else."
)
JSON_SCHEMA_HINT = {
"risk_level": "low | medium | high",
"summary": "short sentence",
"violations": [
{
"issue": "what is risky",
"rule_refs": ["FCA handbook or principle area"],
"why": "why this is a risk",
"fix": "specific rewrite guidance",
}
],
"safe_rewrite": "optional ad rewrite",
}
PROMPT_FILE_MAP = {
"legal_basis": "legal_basis.md",
"fca": "fca.md",
"cma": "cma.md",
"pra": "pra.md",
"validation": "validation.md",
}
PIPELINE_STAGE_ORDER = ["legal_basis", "fca", "cma", "pra"]
REGULATOR_STAGE_ORDER = ["fca", "cma", "pra"]
ALL_REVIEW_STAGES = set(PIPELINE_STAGE_ORDER)
PROMPT_CACHE: dict[str, str] = {}
if os.environ.get("LANGSMITH_API_KEY") and ls is None:
print("LANGSMITH_API_KEY is set but the langsmith package is not installed.", flush=True)
def sanitize_for_langsmith(value: Any, ad_text: str = "") -> Any:
if isinstance(value, dict):
return {str(k): sanitize_for_langsmith(v, ad_text=ad_text) for k, v in value.items()}
if isinstance(value, list):
return [sanitize_for_langsmith(item, ad_text=ad_text) for item in value]
if isinstance(value, tuple):
return [sanitize_for_langsmith(item, ad_text=ad_text) for item in value]
if isinstance(value, str):
if ad_text and not LANGSMITH_TRACE_USER_AD_COPY:
return value.replace(ad_text, "[REDACTED_USER_AD_COPY]")
return value
return value
@contextmanager
def traced_stage(
name: str,
run_type: str,
*,
inputs: Any | None = None,
metadata: dict[str, Any] | None = None,
tags: list[str] | None = None,
) -> Iterator[tuple[Any | None, dict[str, Any]]]:
outputs: dict[str, Any] = {}
if not LANGSMITH_ENABLED or ls is None:
yield None, outputs
return
kwargs: dict[str, Any] = {"name": name, "run_type": run_type}
if inputs is not None:
kwargs["inputs"] = inputs
if metadata:
kwargs["metadata"] = metadata
if tags:
kwargs["tags"] = tags
with ls.trace(**kwargs) as run:
try:
yield run, outputs
except Exception as err:
outputs.setdefault("error", str(err))
run.end(outputs=outputs)
raise
else:
run.end(outputs=outputs)
def flush_langsmith() -> None:
if LANGSMITH_CLIENT is None or not hasattr(LANGSMITH_CLIENT, "flush"):
return
try:
LANGSMITH_CLIENT.flush()
except Exception as err: # pragma: no cover - best-effort cleanup only.
print(f"LangSmith flush failed: {err}", flush=True)
def load_prompt_template(stage_name: str) -> str:
if stage_name in PROMPT_CACHE:
return PROMPT_CACHE[stage_name]
filename = PROMPT_FILE_MAP.get(stage_name)
if not filename:
raise RuntimeError(f"Unknown prompt stage '{stage_name}'.")
prompt_path = PROMPTS_DIR / filename
if not prompt_path.exists():
raise RuntimeError(f"Prompt file missing for stage '{stage_name}': {prompt_path}")
content = prompt_path.read_text(encoding="utf-8").strip()
PROMPT_CACHE[stage_name] = content
return content
def infer_input_mode(ad_text: str, image_at_path: str | None) -> str:
has_text = bool(ad_text.strip())
has_image = bool(image_at_path)
if has_text and has_image:
return "text+image"
if has_image:
return "image"
return "text"
def get_operator_override(system_prompt: str) -> str:
prompt = system_prompt.strip()
if not prompt or prompt == DEFAULT_SYSTEM_PROMPT:
return ""
return prompt
def build_submission_block(
*,
ad_text: str,
extra_context: str,
image_at_path: str | None,
) -> str:
input_mode = infer_input_mode(ad_text, image_at_path)
parts = [
"Submission",
f"Input mode: {input_mode}",
"",
]
if image_at_path:
parts += [
"Creative image reference:",
f"@{image_at_path}",
"Analyze the full image and all visible text in context.",
"",
]
parts += [
"Ad copy:",
ad_text.strip() if ad_text.strip() else "[Not provided]",
]
if extra_context.strip():
parts += ["", "Extra context:", extra_context.strip()]
return "\n".join(parts)
def build_parallel_stage_prompt(
stage_name: str,
*,
ad_text: str,
extra_context: str,
image_at_path: str | None,
system_prompt: str,
pass_number: int,
prior_passes: list[dict[str, Any]] | None = None,
retry_context: dict[str, Any] | None = None,
request_id: str | None = None,
) -> str:
with traced_stage(
f"build_{stage_name}_prompt",
"tool",
inputs=sanitize_for_langsmith(
{
"stage": stage_name,
"ad_text": ad_text,
"extra_context": extra_context,
"image_at_path": image_at_path,
"system_prompt": system_prompt,
"pass_number": pass_number,
"prior_passes": prior_passes or [],
"retry_context": retry_context or {},
},
ad_text=ad_text,
),
metadata={"request_id": request_id, "stage": stage_name, "pass_number": pass_number},
tags=["prompt-build", stage_name],
) as (_run, outputs):
operator_override = get_operator_override(system_prompt)
prompt = [
load_prompt_template(stage_name),
"",
f"Pipeline pass: {pass_number}",
"This runtime uses Gemini CLI. When the prompt requires `google_web_search`, you must use it before finalizing if the tool is available.",
"",
build_submission_block(
ad_text=ad_text,
extra_context=extra_context,
image_at_path=image_at_path,
),
]
if prior_passes:
prompt += [
"",
"Prior pipeline pass history JSON:",
json.dumps(prior_passes, ensure_ascii=True, indent=2),
]
if retry_context:
prompt += [
"",
"Validator retry context JSON:",
json.dumps(retry_context, ensure_ascii=True, indent=2),
]
if operator_override:
prompt += ["", "Additional operator instructions:", operator_override]
full_prompt = "\n".join(prompt).strip()
outputs["prompt"] = sanitize_for_langsmith(full_prompt, ad_text=ad_text)
return full_prompt
def build_validation_prompt(
*,
ad_text: str,
extra_context: str,
image_at_path: str | None,
system_prompt: str,
pass_number: int,
legal_basis_output: dict[str, Any],
module_outputs: dict[str, dict[str, Any]],
prior_passes: list[dict[str, Any]] | None = None,
retry_context: dict[str, Any] | None = None,
request_id: str | None = None,
) -> str:
with traced_stage(
"build_validation_prompt",
"tool",
inputs=sanitize_for_langsmith(
{
"ad_text": ad_text,
"extra_context": extra_context,
"image_at_path": image_at_path,
"system_prompt": system_prompt,
"pass_number": pass_number,
"legal_basis_output": legal_basis_output,
"module_outputs": module_outputs,
"prior_passes": prior_passes or [],
"retry_context": retry_context or {},
},
ad_text=ad_text,
),
metadata={"request_id": request_id, "pass_number": pass_number},
tags=["prompt-build", "validation"],
) as (_run, outputs):
operator_override = get_operator_override(system_prompt)
prompt = [
load_prompt_template("validation"),
"",
f"Pipeline pass: {pass_number}",
"This runtime uses Gemini CLI. When the prompt requires `google_web_search`, you must use it before finalizing if the tool is available.",
"",
"Legal basis output JSON:",
json.dumps(legal_basis_output, ensure_ascii=True, indent=2),
"",
"Module outputs JSON:",
json.dumps(module_outputs, ensure_ascii=True, indent=2),
"",
build_submission_block(
ad_text=ad_text,
extra_context=extra_context,
image_at_path=image_at_path,
),
]
if prior_passes:
prompt += [
"",
"Prior pipeline pass history JSON:",
json.dumps(prior_passes, ensure_ascii=True, indent=2),
]
if retry_context:
prompt += [
"",
"Validator retry context JSON:",
json.dumps(retry_context, ensure_ascii=True, indent=2),
]
if operator_override:
prompt += ["", "Additional operator instructions:", operator_override]
full_prompt = "\n".join(prompt).strip()
outputs["prompt"] = sanitize_for_langsmith(full_prompt, ad_text=ad_text)
return full_prompt
def gemini_cmd_candidates(prompt: str) -> list[list[str]]:
# Model is intentionally locked and never exposed to users.
return [
[GEMINI_CLI_BINARY, "--model", LOCKED_GEMINI_MODEL, "-p", prompt],
[GEMINI_CLI_BINARY, "-m", LOCKED_GEMINI_MODEL, "-p", prompt],
[GEMINI_CLI_BINARY, "--model", LOCKED_GEMINI_MODEL, "--prompt", prompt],
[GEMINI_CLI_BINARY, "-m", LOCKED_GEMINI_MODEL, "--prompt", prompt],
]
def is_flag_parse_error(stderr: str, stdout: str) -> bool:
combined = f"{stderr}\n{stdout}".lower()
return any(
token in combined
for token in (
"unknown option",
"unknown argument",
"invalid option",
"unrecognized option",
"unrecognized argument",
"unexpected argument",
"did you mean",
)
)
def run_gemini(
prompt: str,
*,
ad_text: str = "",
request_id: str | None = None,
trace_name: str = "gemini_cli_subprocess",
trace_metadata: dict[str, Any] | None = None,
) -> str:
attempts = gemini_cmd_candidates(prompt)
child_env = os.environ.copy()
metadata = {
"request_id": request_id,
"model": LOCKED_GEMINI_MODEL,
"cli_binary": GEMINI_CLI_BINARY,
"timeout_sec": GEMINI_TIMEOUT_SEC,
}
if trace_metadata:
metadata.update(trace_metadata)
with traced_stage(
trace_name,
"llm",
inputs=sanitize_for_langsmith(
{
"prompt": prompt,
"attempt_count": len(attempts),
},
ad_text=ad_text,
),
metadata=metadata,
tags=["gemini-cli", "llm"],
) as (_run, outputs):
last_error = "Gemini CLI invocation failed."
# Keep only GEMINI_API_KEY to avoid CLI warnings when both vars are set.
if not child_env.get("GEMINI_API_KEY") and child_env.get("GOOGLE_API_KEY"):
child_env["GEMINI_API_KEY"] = child_env["GOOGLE_API_KEY"]
child_env.pop("GOOGLE_API_KEY", None)
for idx, cmd in enumerate(attempts):
proc = subprocess.run(
cmd,
capture_output=True,
text=True,
cwd=str(APP_DIR),
env=child_env,
timeout=GEMINI_TIMEOUT_SEC,
check=False,
)
outputs["last_attempt"] = {
"index": idx + 1,
"cmd": sanitize_for_langsmith(cmd, ad_text=ad_text),
"returncode": proc.returncode,
"stdout": sanitize_for_langsmith(proc.stdout or "", ad_text=ad_text),
"stderr": sanitize_for_langsmith(proc.stderr or "", ad_text=ad_text),
}
if proc.returncode == 0:
final_output = (proc.stdout or "").strip()
outputs["raw_output"] = sanitize_for_langsmith(final_output, ad_text=ad_text)
return final_output
stderr = (proc.stderr or "").strip()
stdout = (proc.stdout or "").strip()
details = stderr if stderr else stdout
last_error = details or f"Gemini CLI exited with code {proc.returncode}."
# Only retry different flag shapes if this appears to be flag parsing trouble.
if idx < len(attempts) - 1 and is_flag_parse_error(stderr, stdout):
continue
break
outputs["final_error"] = sanitize_for_langsmith(last_error, ad_text=ad_text)
raise RuntimeError(last_error)
def try_parse_json(text: str, *, ad_text: str = "", request_id: str | None = None) -> Any | None:
with traced_stage(
"try_parse_json",
"parser",
inputs=sanitize_for_langsmith({"raw_text": text}, ad_text=ad_text),
metadata={"request_id": request_id},
tags=["parser"],
) as (_run, outputs):
trimmed = text.strip()
if not trimmed:
outputs["parsed"] = None
return None
# Handle markdown fences if the model returns them.
if trimmed.startswith("```"):
lines = trimmed.splitlines()
if len(lines) >= 3 and lines[-1].strip().startswith("```"):
trimmed = "\n".join(lines[1:-1]).strip()
if trimmed.lower().startswith("json"):
trimmed = trimmed[4:].strip()
try:
parsed = json.loads(trimmed)
outputs["parsed"] = sanitize_for_langsmith(parsed, ad_text=ad_text)
return parsed
except json.JSONDecodeError as err:
outputs["parse_error"] = str(err)
return None
def safe_filename_stem(raw_name: str) -> str:
stem = Path(raw_name).stem if raw_name else "ad-image"
cleaned = re.sub(r"[^A-Za-z0-9_-]+", "-", stem).strip("-")
if not cleaned:
return "ad-image"
return cleaned[:40]
def save_image_from_data_url(
image_data_url: str,
image_filename: str,
*,
request_id: str | None = None,
) -> str:
with traced_stage(
"save_image_from_data_url",
"tool",
inputs={
"image_filename": image_filename,
"data_url_preview": image_data_url[:240],
"data_url_length": len(image_data_url),
},
metadata={"request_id": request_id},
tags=["image-save"],
) as (_run, outputs):
match = DATA_URL_RE.match(image_data_url.strip())
if not match:
raise ValueError("Image must be a valid base64 data URL (data:image/...;base64,...).")
mime_type = match.group("mime").lower()
extension = ALLOWED_IMAGE_MIME_TO_EXT.get(mime_type)
if not extension:
allowed = ", ".join(sorted(ALLOWED_IMAGE_MIME_TO_EXT))
raise ValueError(f"Unsupported image type '{mime_type}'. Allowed: {allowed}.")
base64_payload = re.sub(r"\s+", "", match.group("data"))
try:
image_bytes = base64.b64decode(base64_payload, validate=True)
except (ValueError, binascii.Error):
raise ValueError("Image base64 payload is invalid.") from None
if not image_bytes:
raise ValueError("Image payload is empty.")
if len(image_bytes) > MAX_IMAGE_BYTES:
raise ValueError(f"Image is too large. Max size is {MAX_IMAGE_BYTES} bytes.")
UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
final_name = f"{safe_filename_stem(image_filename)}-{uuid.uuid4().hex[:10]}.{extension}"
image_path = UPLOADS_DIR / final_name
image_path.write_bytes(image_bytes)
image_ref = f"uploads/{final_name}"
outputs["image_ref"] = image_ref
outputs["mime_type"] = mime_type
outputs["bytes_written"] = len(image_bytes)
return image_ref
def normalize_image_inputs(
payload: dict[str, Any],
*,
ad_text: str = "",
request_id: str | None = None,
) -> list[dict[str, str]]:
with traced_stage(
"normalize_image_inputs",
"tool",
inputs=sanitize_for_langsmith(payload, ad_text=ad_text),
metadata={"request_id": request_id},
tags=["image-normalize"],
) as (_run, outputs):
images_field = payload.get("images")
single_data_url = str(payload.get("image_data_url", "")).strip()
single_filename = str(payload.get("image_filename", "")).strip()
normalized: list[dict[str, str]] = []
if isinstance(images_field, list) and images_field:
if len(images_field) > MAX_BATCH_IMAGES:
raise ValueError(f"Too many images. Max is {MAX_BATCH_IMAGES}.")
for idx, item in enumerate(images_field):
if not isinstance(item, dict):
raise ValueError(f"images[{idx}] must be an object.")
data_url = str(item.get("data_url", "")).strip()
filename = str(item.get("filename", "")).strip() or f"image-{idx + 1}.png"
if not data_url:
raise ValueError(f"images[{idx}].data_url is required.")
normalized.append({"data_url": data_url, "filename": filename})
elif single_data_url:
normalized.append(
{
"data_url": single_data_url,
"filename": single_filename or "image.png",
}
)
outputs["normalized"] = sanitize_for_langsmith(normalized, ad_text=ad_text)
return normalized
def run_single_check(
prompt: str,
*,
ad_text: str = "",
request_id: str | None = None,
trace_name: str = "run_single_check",
trace_metadata: dict[str, Any] | None = None,
) -> tuple[bool, int, dict[str, Any]]:
with traced_stage(
trace_name,
"chain",
inputs=sanitize_for_langsmith({"prompt": prompt}, ad_text=ad_text),
metadata={"request_id": request_id, **(trace_metadata or {})},
tags=["single-check"],
) as (_run, outputs):
try:
raw_output = run_gemini(
prompt,
ad_text=ad_text,
request_id=request_id,
trace_name="gemini_cli_subprocess",
trace_metadata=trace_metadata,
)
parsed_output = try_parse_json(raw_output, ad_text=ad_text, request_id=request_id)
outputs["parsed_output"] = sanitize_for_langsmith(parsed_output, ad_text=ad_text)
outputs["raw_output"] = sanitize_for_langsmith(raw_output, ad_text=ad_text)
return True, 200, {"parsed_output": parsed_output, "raw_output": raw_output}
except FileNotFoundError:
error = f"Gemini CLI not found. Install it and ensure '{GEMINI_CLI_BINARY}' is on PATH."
outputs["error"] = error
return False, 500, {"error": error}
except subprocess.TimeoutExpired:
error = f"Gemini CLI timed out after {GEMINI_TIMEOUT_SEC}s."
outputs["error"] = error
return False, 504, {"error": error}
except RuntimeError as err:
outputs["error"] = str(err)
return False, 500, {"error": str(err)}
def run_single_image_check(
index: int,
total: int,
image_ref: str,
ad_text: str,
extra_context: str,
system_prompt: str,
request_id: str,
) -> dict[str, Any]:
with traced_stage(
"run_single_image_check",
"chain",
inputs=sanitize_for_langsmith(
{
"index": index,
"total": total,
"image_ref": image_ref,
"ad_text": ad_text,
"extra_context": extra_context,
},
ad_text=ad_text,
),
metadata={"request_id": request_id, "image_index": index, "image_ref": image_ref},
tags=["bulk-image-check"],
) as (_run, outputs):
print(f"[batch {index}/{total}] starting check for {image_ref}", flush=True)
started = monotonic()
result = run_review_pipeline(
ad_text=ad_text,
extra_context=extra_context,
system_prompt=system_prompt,
image_at_path=image_ref,
request_id=request_id,
trace_name="run_single_image_pipeline",
trace_metadata={"image_index": index, "image_ref": image_ref},
)
elapsed = monotonic() - started
status_text = "ok" if result.get("ok") else "failed"
print(f"[batch {index}/{total}] {status_text} in {elapsed:.1f}s", flush=True)
outputs["elapsed_sec"] = round(elapsed, 3)
outputs["result"] = sanitize_for_langsmith(result, ad_text=ad_text)
return {
"index": index,
"ok": bool(result.get("ok")),
"image_reference": image_ref,
"parsed_output": result.get("parsed_output"),
"raw_output": result.get("raw_output"),
"error": result.get("error"),
"pipeline_output": result.get("pipeline_output"),
}
def severity_rank(severity: str) -> int:
value = str(severity or "").upper()
if value == "CRITICAL":
return 3
if value == "HIGH":
return 2
if value == "ADVISORY":
return 1
return 0
def dedupe_preserve_order(values: list[str]) -> list[str]:
seen: set[str] = set()
output: list[str] = []
for value in values:
key = value.strip()
if not key or key in seen:
continue
seen.add(key)
output.append(key)
return output
def stage_result(
stage_name: str,
ok: bool,
status: int,
result: dict[str, Any],
) -> dict[str, Any]:
parsed_output = result.get("parsed_output")
return {
"stage": stage_name,
"ok": ok,
"status": status,
"parsed_output": parsed_output if isinstance(parsed_output, dict) else None,
"raw_output": result.get("raw_output"),
"error": result.get("error"),
}
def run_named_stage(
stage_name: str,
prompt: str,
*,
ad_text: str,
request_id: str,
trace_metadata: dict[str, Any] | None = None,
) -> dict[str, Any]:
ok, status, result = run_single_check(
prompt,
ad_text=ad_text,
request_id=request_id,
trace_name=f"stage_{stage_name}",
trace_metadata={"stage": stage_name, **(trace_metadata or {})},
)
return stage_result(stage_name, ok, status, result)
def normalize_stage_name(stage_name: str) -> str:
value = str(stage_name or "").strip().lower()
return value if value in ALL_REVIEW_STAGES else ""
def normalize_module_name(module_name: str) -> str:
value = str(module_name or "").strip().lower()
return value if value in REGULATOR_STAGE_ORDER else ""
def normalize_applicability(value: Any) -> str:
normalized = str(value or "").strip().lower()
if normalized in {"apply", "not_apply", "uncertain"}:
return normalized
return "uncertain"
def normalize_confidence(value: Any) -> float:
try:
numeric = float(value)
except (TypeError, ValueError):
return 0.0
if numeric < 0:
return 0.0
if numeric > 100:
return 100.0
return round(numeric, 2)
def normalize_string_list(value: Any) -> list[str]:
if not isinstance(value, list):
return []
items = [str(item).strip() for item in value if str(item).strip()]
return dedupe_preserve_order(items)
def normalize_source_verification(value: Any) -> dict[str, Any]:
if not isinstance(value, dict):
return {
"verification_timestamp": "",
"official_urls": [],
"google_web_search_used": False,
"manual_review_required": True,
}
official_urls = normalize_string_list(
value.get("official_urls")
or value.get("source_urls")
or value.get("urls")
or []
)
if not official_urls:
official_urls = dedupe_preserve_order(
normalize_string_list(value.get("handbook_urls"))
+ normalize_string_list(value.get("policy_urls"))
+ normalize_string_list(value.get("policy_statement_urls"))
+ normalize_string_list(value.get("legislation_urls"))
)
return {
"verification_timestamp": str(value.get("verification_timestamp") or ""),
"official_urls": official_urls,
"google_web_search_used": bool(value.get("google_web_search_used", False)),
"manual_review_required": bool(value.get("manual_review_required", False)),
}
def normalize_finding(
finding: dict[str, Any],
*,
default_module: str,
default_authority_type: str = "unknown",
) -> dict[str, Any]:
return {
"module": default_module,
"issue": str(finding.get("issue") or "Unspecified issue"),
"rule_ref": str(finding.get("rule_ref") or "Unknown"),
"source_url": str(finding.get("source_url") or ""),
"authority_type": str(finding.get("authority_type") or default_authority_type),
"severity": str(finding.get("severity") or "ADVISORY").upper(),
"confidence": normalize_confidence(finding.get("confidence")),
"why": str(finding.get("why") or "No explanation provided."),
"fix": str(finding.get("fix") or "No fix provided."),
}
def default_legal_basis_output(ad_text: str, image_at_path: str | None) -> dict[str, Any]:
return {
"module": "legal_basis",
"summary": "Legal basis could not be determined reliably.",
"input_mode": infer_input_mode(ad_text, image_at_path),
"product_type": "unknown",
"channel": "unknown",
"audience": "unknown",
"promotion_scope": "uncertain",
"claimed_exemptions": [],
"applicability": {
"fca": "uncertain",
"cma": "uncertain",
"pra": "uncertain",
},
"legal_basis_findings": [
{
"module": "legal_basis",
"issue": "Legal basis could not be verified",
"rule_ref": "Perimeter / exemption verification required",
"source_url": "",
"authority_type": "verification",
"severity": "ADVISORY",
"confidence": 0.0,
"why": "The legal-basis stage failed or returned invalid JSON, so regulator applicability is uncertain.",
"fix": "Re-run with verified official sources or escalate to manual review.",
}
],
"source_verification": {
"verification_timestamp": "",
"official_urls": [],
"google_web_search_used": False,
"manual_review_required": True,
},
"manual_review_required": True,
}
def coerce_legal_basis_output(
stage: dict[str, Any],
*,
ad_text: str,
image_at_path: str | None,
) -> dict[str, Any]:
parsed = stage.get("parsed_output")
fallback = default_legal_basis_output(ad_text, image_at_path)
if not isinstance(parsed, dict):
return fallback
claimed_exemptions: list[dict[str, Any]] = []
for item in parsed.get("claimed_exemptions", []):
if not isinstance(item, dict):
continue
status = str(item.get("status") or "uncertain").strip().lower()
if status not in {"claimed", "not_claimed", "uncertain"}:
status = "uncertain"
claimed_exemptions.append(
{
"name": str(item.get("name") or "Unknown"),
"status": status,
"evidence": str(item.get("evidence") or ""),
}
)
legal_basis_findings: list[dict[str, Any]] = []
for finding in parsed.get("legal_basis_findings", []):
if isinstance(finding, dict):
legal_basis_findings.append(
normalize_finding(
finding,
default_module="legal_basis",
default_authority_type="verification",
)
)
source_verification = normalize_source_verification(parsed.get("source_verification"))
manual_review_required = bool(
parsed.get("manual_review_required", False)
or source_verification.get("manual_review_required", False)
or not stage.get("ok")
)
return {
"module": "legal_basis",
"summary": str(parsed.get("summary") or fallback["summary"]),
"input_mode": str(parsed.get("input_mode") or infer_input_mode(ad_text, image_at_path)),
"product_type": str(parsed.get("product_type") or "unknown"),
"channel": str(parsed.get("channel") or "unknown"),
"audience": str(parsed.get("audience") or "unknown"),
"promotion_scope": str(parsed.get("promotion_scope") or "uncertain"),
"claimed_exemptions": claimed_exemptions,
"applicability": {
"fca": normalize_applicability(parsed.get("applicability", {}).get("fca") if isinstance(parsed.get("applicability"), dict) else None),
"cma": normalize_applicability(parsed.get("applicability", {}).get("cma") if isinstance(parsed.get("applicability"), dict) else None),
"pra": normalize_applicability(parsed.get("applicability", {}).get("pra") if isinstance(parsed.get("applicability"), dict) else None),
},
"legal_basis_findings": legal_basis_findings or fallback["legal_basis_findings"],
"source_verification": source_verification,
"manual_review_required": manual_review_required,
}
def coerce_module_output(module_name: str, stage: dict[str, Any]) -> dict[str, Any]:
parsed = stage.get("parsed_output")
fallback = {
"module": module_name,
"applicability": "uncertain",
"why_applicable": f"{module_name.upper()} applicability could not be verified.",
"summary": f"{module_name.upper()} module did not return valid JSON.",
"findings": [],
"safe_rewrite": "",
"source_verification": {
"verification_timestamp": "",
"official_urls": [],
"google_web_search_used": False,
"manual_review_required": True,
},
"manual_review_required": True,
}
if not isinstance(parsed, dict):
return fallback
findings: list[dict[str, Any]] = []
for finding in parsed.get("findings", []):
if isinstance(finding, dict):
findings.append(
normalize_finding(
finding,
default_module=module_name,
)
)
source_verification = normalize_source_verification(parsed.get("source_verification"))
return {
"module": normalize_module_name(str(parsed.get("module") or module_name)) or module_name,
"applicability": normalize_applicability(parsed.get("applicability")),
"why_applicable": str(parsed.get("why_applicable") or ""),
"summary": str(parsed.get("summary") or f"{module_name.upper()} module completed."),
"findings": findings,
"safe_rewrite": str(parsed.get("safe_rewrite") or ""),
"source_verification": source_verification,
"manual_review_required": bool(
parsed.get("manual_review_required", False)
or source_verification.get("manual_review_required", False)
or not stage.get("ok")
),
}
def synthesize_validation_output(
legal_basis_output: dict[str, Any],
module_outputs: dict[str, dict[str, Any]],
*,
pass_number: int,
) -> dict[str, Any]:
validated_findings: list[dict[str, Any]] = []
conflicts: list[str] = []
safe_rewrite = ""
source_urls = list(legal_basis_output.get("source_verification", {}).get("official_urls", []))
google_web_search_used = bool(
legal_basis_output.get("source_verification", {}).get("google_web_search_used", False)
)
applicability_summary = {
module_name: normalize_applicability(
legal_basis_output.get("applicability", {}).get(module_name)
)
for module_name in REGULATOR_STAGE_ORDER
}
manual_review_required = bool(legal_basis_output.get("manual_review_required", False))
for finding in legal_basis_output.get("legal_basis_findings", []):
if isinstance(finding, dict):
validated_findings.append(
{
"module": "legal_basis",
"issue": str(finding.get("issue") or "Unspecified issue"),
"rule_ref": str(finding.get("rule_ref") or "Unknown"),
"source_url": str(finding.get("source_url") or ""),
"severity": str(finding.get("severity") or "ADVISORY").upper(),
"confidence": normalize_confidence(finding.get("confidence")),
"why": str(finding.get("why") or "No explanation provided."),
"fix": str(finding.get("fix") or "No fix provided."),
}
)
for module_name in REGULATOR_STAGE_ORDER:
module_output = module_outputs.get(module_name)
if not module_output:
continue
module_applicability = normalize_applicability(module_output.get("applicability"))
source_verification = module_output.get("source_verification", {})
source_urls.extend(source_verification.get("official_urls", []))
google_web_search_used = google_web_search_used or bool(source_verification.get("google_web_search_used", False))
legal_basis_applicability = applicability_summary.get(module_name, "uncertain")
effective_applicability = legal_basis_applicability
if effective_applicability == "uncertain" and module_applicability != "uncertain":
effective_applicability = module_applicability
applicability_summary[module_name] = module_applicability
if (
legal_basis_applicability != "uncertain"
and module_applicability != "uncertain"
and legal_basis_applicability != module_applicability
):
conflicts.append(
f"{module_name.upper()} applicability conflict: legal_basis={legal_basis_applicability}, module={module_applicability}."
)
manual_review_required = True
if effective_applicability != "apply":
if module_output.get("findings"):
conflicts.append(
f"{module_name.upper()} returned findings while applicability is {effective_applicability}."
)
manual_review_required = True
manual_review_required = manual_review_required or bool(module_output.get("manual_review_required", False))
continue
if not safe_rewrite and module_output.get("safe_rewrite"):
safe_rewrite = str(module_output.get("safe_rewrite"))
for finding in module_output.get("findings", []):
if not isinstance(finding, dict):
continue
validated_findings.append(
{
"module": module_name,
"issue": str(finding.get("issue") or "Unspecified issue"),
"rule_ref": str(finding.get("rule_ref") or "Unknown"),
"source_url": str(finding.get("source_url") or ""),
"severity": str(finding.get("severity") or "ADVISORY").upper(),
"confidence": normalize_confidence(finding.get("confidence")),
"why": str(finding.get("why") or "No explanation provided."),
"fix": str(finding.get("fix") or "No fix provided."),
}
)
manual_review_required = manual_review_required or bool(module_output.get("manual_review_required", False))
deduped_findings: list[dict[str, Any]] = []
seen_finding_keys: set[tuple[str, str, str]] = set()
for finding in validated_findings:
key = (
str(finding.get("module") or ""),
str(finding.get("issue") or ""),
str(finding.get("rule_ref") or ""),
)
if key in seen_finding_keys:
continue
seen_finding_keys.add(key)
deduped_findings.append(finding)
validated_findings = deduped_findings
source_urls = dedupe_preserve_order([url for url in source_urls if url])
applicability_uncertain = any(
applicability_summary.get(module_name) == "uncertain" for module_name in REGULATOR_STAGE_ORDER
)
if applicability_uncertain:
manual_review_required = True
has_high = any(severity_rank(item.get("severity", "")) >= 2 for item in validated_findings)
if validated_findings:
risk_level = "high" if has_high else "medium"
overall_verdict = "FAIL"
summary = "Validated issues remain after legal-basis and regulator arbitration."
elif manual_review_required:
risk_level = "medium"
overall_verdict = "MANUAL_REVIEW"
summary = "No definitive breach set can be returned safely; manual review is required."
else:
risk_level = "low"
overall_verdict = "PASS"
summary = "No material issues identified after legal-basis and regulator arbitration."
retry_required = pass_number <= VALIDATION_RETRY_PASSES and bool(
conflicts or applicability_uncertain or not google_web_search_used or not source_urls
)
retry_guidance: list[str] = []
if conflicts:
retry_guidance.append("Resolve applicability conflicts between legal basis and regulator modules.")
if applicability_uncertain:
retry_guidance.append("Verify whether any claimed exemption or perimeter route is actually available.")
if not google_web_search_used:
retry_guidance.append("Use google_web_search and cite official sources before finalizing.")
if not source_urls:
retry_guidance.append("Return official source URLs for legal basis and cited rules.")
return {
"overall_verdict": overall_verdict,
"risk_level": risk_level,
"summary": summary,
"applicability_summary": applicability_summary,
"validated_findings": validated_findings,
"safe_rewrite": safe_rewrite,
"conflicts": dedupe_preserve_order(conflicts),
"retry_required": retry_required,
"retry_targets": list(PIPELINE_STAGE_ORDER) if retry_required else [],
"retry_reason": "; ".join(dedupe_preserve_order(retry_guidance)),
"retry_guidance": dedupe_preserve_order(retry_guidance),
"source_verification": {
"verification_timestamp": "",
"official_urls": source_urls,
"google_web_search_used": google_web_search_used,
"manual_review_required": manual_review_required,
},
"manual_review_required": manual_review_required,
}
def coerce_validation_output(
stage: dict[str, Any],
*,
legal_basis_output: dict[str, Any],
module_outputs: dict[str, dict[str, Any]],
pass_number: int,
) -> dict[str, Any]:
parsed = stage.get("parsed_output")
fallback = synthesize_validation_output(legal_basis_output, module_outputs, pass_number=pass_number)
if not isinstance(parsed, dict):
return fallback
applicability_summary_raw = parsed.get("applicability_summary")
applicability_summary = dict(fallback["applicability_summary"])
if isinstance(applicability_summary_raw, dict):
for module_name in REGULATOR_STAGE_ORDER:
applicability_summary[module_name] = normalize_applicability(applicability_summary_raw.get(module_name))
validated_findings: list[dict[str, Any]] = []
for finding in parsed.get("validated_findings", []):
if isinstance(finding, dict):
normalized_module = str(finding.get("module") or "").strip().lower()
if normalized_module not in {"legal_basis", *REGULATOR_STAGE_ORDER}:
normalized_module = "legal_basis"
validated_findings.append(
{
"module": normalized_module,
"issue": str(finding.get("issue") or "Unspecified issue"),
"rule_ref": str(finding.get("rule_ref") or "Unknown"),
"source_url": str(finding.get("source_url") or ""),
"severity": str(finding.get("severity") or "ADVISORY").upper(),
"confidence": normalize_confidence(finding.get("confidence")),
"why": str(finding.get("why") or "No explanation provided."),
"fix": str(finding.get("fix") or "No fix provided."),
}
)
if not validated_findings:
validated_findings = fallback["validated_findings"]
risk_level = str(parsed.get("risk_level") or fallback["risk_level"]).lower()
if risk_level not in {"low", "medium", "high"}:
risk_level = fallback["risk_level"]
source_verification = normalize_source_verification(parsed.get("source_verification"))
manual_review_required = bool(
parsed.get("manual_review_required", False)
or fallback["manual_review_required"]
or source_verification.get("manual_review_required", False)
)
retry_required = bool(parsed.get("retry_required", False) or fallback["retry_required"])
if pass_number > VALIDATION_RETRY_PASSES:
retry_required = False
retry_targets = [
normalize_stage_name(item)
for item in parsed.get("retry_targets", [])
if normalize_stage_name(item)
]
if retry_required and not retry_targets:
retry_targets = list(PIPELINE_STAGE_ORDER)
conflicts = parsed.get("conflicts")
retry_guidance = parsed.get("retry_guidance")
return {
"overall_verdict": str(parsed.get("overall_verdict") or fallback["overall_verdict"]).upper(),
"risk_level": risk_level,
"summary": str(parsed.get("summary") or fallback["summary"]),
"applicability_summary": applicability_summary,
"validated_findings": validated_findings,
"safe_rewrite": str(parsed.get("safe_rewrite") or fallback["safe_rewrite"]),
"conflicts": conflicts if isinstance(conflicts, list) else fallback["conflicts"],
"retry_required": retry_required,
"retry_targets": retry_targets,
"retry_reason": str(parsed.get("retry_reason") or fallback["retry_reason"]),
"retry_guidance": retry_guidance if isinstance(retry_guidance, list) else fallback["retry_guidance"],
"source_verification": {
"verification_timestamp": str(
source_verification.get("verification_timestamp")
or fallback["source_verification"]["verification_timestamp"]
),
"official_urls": source_verification.get("official_urls")
or fallback["source_verification"]["official_urls"],
"google_web_search_used": bool(
source_verification.get("google_web_search_used")
or fallback["source_verification"]["google_web_search_used"]
),
"manual_review_required": manual_review_required,
},
"manual_review_required": manual_review_required,
}
def build_legacy_output(validation_output: dict[str, Any]) -> dict[str, Any]:
violations: list[dict[str, Any]] = []
for finding in validation_output.get("validated_findings", []):
if not isinstance(finding, dict):
continue
rule_ref = str(finding.get("rule_ref") or "Unknown")
violations.append(
{
"issue": str(finding.get("issue") or "Unspecified issue"),
"rule_refs": [rule_ref] if rule_ref else [],
"why": str(finding.get("why") or "No explanation provided."),
"fix": str(finding.get("fix") or "No fix provided."),
"module": str(finding.get("module") or "unknown"),
"severity": str(finding.get("severity") or "ADVISORY"),
"confidence": normalize_confidence(finding.get("confidence")),
"source_url": str(finding.get("source_url") or ""),
}
)
return {
"risk_level": validation_output.get("risk_level", "medium"),
"summary": validation_output.get("summary", "No summary available."),
"violations": violations,
"safe_rewrite": validation_output.get("safe_rewrite", ""),
"overall_verdict": validation_output.get("overall_verdict", "MANUAL_REVIEW"),
"manual_review_required": bool(validation_output.get("manual_review_required", False)),
"conflicts": validation_output.get("conflicts", []),
"applicability_summary": validation_output.get("applicability_summary", {}),
"source_verification": validation_output.get("source_verification", {}),
}
def execute_parallel_stage_group(
stage_prompts: dict[str, str],
*,
ad_text: str,
request_id: str,
trace_metadata: dict[str, Any] | None = None,
) -> dict[str, dict[str, Any]]:
stage_results: dict[str, dict[str, Any]] = {}
if not stage_prompts:
return stage_results
worker_count = min(PIPELINE_STAGE_WORKERS, len(stage_prompts))
with ThreadPoolExecutor(max_workers=worker_count) as executor:
future_map = {
executor.submit(
run_named_stage,
stage_name,
prompt,
ad_text=ad_text,
request_id=request_id,
trace_metadata={"parallel_group": True, **(trace_metadata or {})},
): stage_name
for stage_name, prompt in stage_prompts.items()
}
for future in as_completed(future_map):
stage_name = future_map[future]
try:
stage_results[stage_name] = future.result()
except Exception as err:
stage_results[stage_name] = {
"stage": stage_name,
"ok": False,
"status": 500,
"parsed_output": None,
"raw_output": None,
"error": f"Unexpected stage error: {err}",
}
return stage_results
def run_review_pipeline(
*,
ad_text: str,
extra_context: str,
system_prompt: str,
image_at_path: str | None,
request_id: str,
trace_name: str,
trace_metadata: dict[str, Any] | None = None,
) -> dict[str, Any]:
with traced_stage(
trace_name,
"chain",
inputs=sanitize_for_langsmith(
{
"ad_text": ad_text,
"extra_context": extra_context,
"system_prompt": system_prompt,
"image_at_path": image_at_path,
},
ad_text=ad_text,
),
metadata={"request_id": request_id, **(trace_metadata or {})},
tags=["review-pipeline"],
) as (_run, outputs):
passes: list[dict[str, Any]] = []
retry_context: dict[str, Any] | None = None
final_validation_output: dict[str, Any] | None = None
for pass_number in range(1, VALIDATION_RETRY_PASSES + 2):
stage_prompts = {
stage_name: build_parallel_stage_prompt(
stage_name,
ad_text=ad_text,
extra_context=extra_context,
image_at_path=image_at_path,
system_prompt=system_prompt,
pass_number=pass_number,
prior_passes=passes,
retry_context=retry_context,
request_id=request_id,
)
for stage_name in PIPELINE_STAGE_ORDER
}
stage_results = execute_parallel_stage_group(
stage_prompts,
ad_text=ad_text,
request_id=request_id,
trace_metadata={"pass_number": pass_number, **(trace_metadata or {})},
)
legal_basis_stage = stage_results.get("legal_basis") or {
"stage": "legal_basis",
"ok": False,
"status": 500,
"parsed_output": None,
"raw_output": None,
"error": "Legal basis stage missing.",
}
legal_basis_output = coerce_legal_basis_output(
legal_basis_stage,
ad_text=ad_text,
image_at_path=image_at_path,
)
module_stage_results: dict[str, dict[str, Any]] = {}
module_outputs: dict[str, dict[str, Any]] = {}
for module_name in REGULATOR_STAGE_ORDER:
module_stage = stage_results.get(module_name) or {
"stage": module_name,
"ok": False,
"status": 500,
"parsed_output": None,
"raw_output": None,
"error": f"{module_name.upper()} stage missing.",
}
module_stage_results[module_name] = module_stage
module_outputs[module_name] = coerce_module_output(module_name, module_stage)
validation_prompt = build_validation_prompt(
ad_text=ad_text,
extra_context=extra_context,
image_at_path=image_at_path,
system_prompt=system_prompt,
pass_number=pass_number,
legal_basis_output=legal_basis_output,
module_outputs=module_outputs,
prior_passes=passes,
retry_context=retry_context,
request_id=request_id,
)
validation_stage = run_named_stage(
"validation",
validation_prompt,
ad_text=ad_text,
request_id=request_id,
trace_metadata={"pass_number": pass_number, **(trace_metadata or {})},
)
validation_output = coerce_validation_output(
validation_stage,
legal_basis_output=legal_basis_output,
module_outputs=module_outputs,
pass_number=pass_number,
)
pass_record = {
"pass_number": pass_number,
"parallel_stage_order": list(PIPELINE_STAGE_ORDER),
"parallel_stages": {
"legal_basis": {
"stage": legal_basis_stage,
"output": legal_basis_output,
},
**{
module_name: {
"stage": module_stage_results[module_name],
"output": module_outputs[module_name],
}
for module_name in REGULATOR_STAGE_ORDER
},
},
"validation": {
"stage": validation_stage,
"output": validation_output,
},
}
passes.append(pass_record)
if validation_output.get("retry_required") and pass_number <= VALIDATION_RETRY_PASSES:
retry_context = {
"retry_reason": validation_output.get("retry_reason", ""),
"retry_targets": validation_output.get("retry_targets", list(PIPELINE_STAGE_ORDER)),
"retry_guidance": validation_output.get("retry_guidance", []),
"prior_validation_output": validation_output,
}
continue
final_validation_output = validation_output
break
if final_validation_output is None:
final_validation_output = passes[-1]["validation"]["output"]
legacy_output = build_legacy_output(final_validation_output)
pipeline_output = {
"request_id": request_id,
"input_mode": infer_input_mode(ad_text, image_at_path),
"parallel_stage_order": list(PIPELINE_STAGE_ORDER),
"retry_performed": len(passes) > 1,
"total_passes": len(passes),
"passes": passes,
"final_validation": final_validation_output,
"legacy_output": legacy_output,
}
outputs["pipeline_output"] = sanitize_for_langsmith(pipeline_output, ad_text=ad_text)
return {
"ok": True,
"parsed_output": legacy_output,
"raw_output": json.dumps(pipeline_output, ensure_ascii=True, indent=2),
"pipeline_output": pipeline_output,
"error": None,
}
class AppHandler(SimpleHTTPRequestHandler):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, directory=str(STATIC_DIR), **kwargs)
def _send_json(self, status: int, payload: dict[str, Any]) -> None:
data = json.dumps(payload, ensure_ascii=True).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json; charset=utf-8")
self.send_header("Content-Length", str(len(data)))
self.end_headers()
self.wfile.write(data)
def do_POST(self) -> None:
request_id = uuid.uuid4().hex
trace_context = (
ls.tracing_context(
enabled=True,
client=LANGSMITH_CLIENT,
project_name=LANGSMITH_PROJECT,
tags=["hf-space-v2", "api-check"],
metadata={"request_id": request_id, "path": self.path},
)
if LANGSMITH_ENABLED and ls is not None
else nullcontext()
)
with trace_context:
with traced_stage(
"http_post_api_check",
"chain",
inputs={
"path": self.path,
"headers": {
"content_type": self.headers.get("Content-Type", ""),
"content_length": self.headers.get("Content-Length", "0"),
},
},
metadata={"request_id": request_id},
tags=["http-request"],
) as (_request_run, request_outputs):
def send_response(status: int, payload: dict[str, Any], *, ad_text: str = "") -> None:
payload.setdefault("request_id", request_id)
request_outputs["http_status"] = status
request_outputs["response"] = sanitize_for_langsmith(payload, ad_text=ad_text)
self._send_json(status, payload)
if self.path != "/api/check":
send_response(404, {"ok": False, "error": "Not found"})
return
content_length = int(self.headers.get("Content-Length", "0"))
if content_length <= 0:
send_response(400, {"ok": False, "error": "Request body is required."})
return
content_type = self.headers.get("Content-Type", "")
if "application/json" not in content_type.lower():
send_response(400, {"ok": False, "error": "Content-Type must be application/json."})
return
raw_body = self.rfile.read(content_length)
try:
body_str = raw_body.decode("utf-8")
except UnicodeDecodeError:
send_response(400, {"ok": False, "error": "Body contains invalid UTF-8 data."})
return
try:
payload = json.loads(body_str)
except json.JSONDecodeError:
send_response(400, {"ok": False, "error": "Body must be valid JSON."})
return
ad_text = str(payload.get("ad_text", "")).strip()
extra_context = str(payload.get("extra_context", "")).strip()
system_prompt = str(payload.get("system_prompt", DEFAULT_SYSTEM_PROMPT)).strip()
if LANGSMITH_TRACE_RAW_REQUEST:
request_outputs["raw_body"] = sanitize_for_langsmith(body_str, ad_text=ad_text)
request_outputs["payload"] = sanitize_for_langsmith(payload, ad_text=ad_text)
try:
image_inputs = normalize_image_inputs(payload, ad_text=ad_text, request_id=request_id)
except ValueError as err:
send_response(400, {"ok": False, "error": str(err)}, ad_text=ad_text)
return
if not ad_text and not image_inputs:
send_response(400, {"ok": False, "error": "Provide 'ad_text' or an image."})
return
if not system_prompt:
system_prompt = DEFAULT_SYSTEM_PROMPT
if not image_inputs:
try:
result = run_review_pipeline(
ad_text=ad_text,
extra_context=extra_context,
system_prompt=system_prompt,
image_at_path=None,
request_id=request_id,
trace_name="run_single_text_pipeline",
trace_metadata={"mode": "single"},
)
except Exception as err:
send_response(500, {"ok": False, "error": f"Pipeline error: {err}"}, ad_text=ad_text)
return
if not result.get("ok"):
send_response(500, {"ok": False, "error": result["error"]}, ad_text=ad_text)
return
send_response(
200,
{
"ok": True,
"mode": "single",
"parallel_workers": 1,
"all_success": True,
"total": 1,
"success_count": 1,
"failure_count": 0,
"results": [
{
"index": 1,
"ok": True,
"image_reference": None,
"parsed_output": result["parsed_output"],
"raw_output": result["raw_output"],
"error": None,
"pipeline_output": result.get("pipeline_output"),
}
],
"parsed_output": result["parsed_output"],
"raw_output": result["raw_output"],
"image_reference": None,
"pipeline_output": result.get("pipeline_output"),
},
ad_text=ad_text,
)
return
image_refs: list[str] = []
for image in image_inputs:
try:
image_ref = save_image_from_data_url(
image_data_url=image["data_url"],
image_filename=image["filename"],
request_id=request_id,
)
except ValueError as err:
send_response(400, {"ok": False, "error": str(err)}, ad_text=ad_text)
return
image_refs.append(image_ref)
total = len(image_refs)
worker_count = max(1, min(MAX_PARALLEL_WORKERS, total))
request_outputs["bulk_meta"] = {
"total_images": total,
"parallel_workers": worker_count,
}
print(
f"Starting bulk Gemini checks: total_images={total}, parallel_workers={worker_count}",
flush=True,
)
results: list[dict[str, Any] | None] = [None] * total
completed = 0
with ThreadPoolExecutor(max_workers=worker_count) as executor:
future_to_slot = {
executor.submit(
run_single_image_check,
idx,
total,
image_ref,
ad_text,
extra_context,
system_prompt,
request_id,
): (idx - 1, image_ref)
for idx, image_ref in enumerate(image_refs, start=1)
}
for future in as_completed(future_to_slot):
slot, image_ref = future_to_slot[future]
try:
results[slot] = future.result()
except Exception as err:
# Defensive fallback: this should be rare because worker handles model errors.
results[slot] = {
"index": slot + 1,
"ok": False,
"image_reference": image_ref,
"parsed_output": None,
"raw_output": None,
"error": f"Unexpected worker error: {err}",
"pipeline_output": None,
}
completed += 1
print(f"Bulk progress: {completed}/{total} completed", flush=True)
finalized_results = [item for item in results if item is not None]
finalized_results.sort(key=lambda item: int(item["index"]))
success_count = sum(1 for item in finalized_results if item["ok"])
failure_count = len(finalized_results) - success_count
first = finalized_results[0]
send_response(
200,
{
"ok": True,
"mode": "bulk" if len(finalized_results) > 1 else "single",
"parallel_workers": worker_count,
"all_success": failure_count == 0,
"total": len(finalized_results),
"success_count": success_count,
"failure_count": failure_count,
"results": finalized_results,
# Keep compatibility with single-result UI consumers.
"parsed_output": first.get("parsed_output"),
"raw_output": first.get("raw_output"),
"image_reference": first.get("image_reference"),
"pipeline_output": first.get("pipeline_output"),
},
ad_text=ad_text,
)
def main() -> None:
STATIC_DIR.mkdir(parents=True, exist_ok=True)
UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
server = ThreadingHTTPServer((HOST, PORT), AppHandler)
print(f"Server running at http://{HOST}:{PORT}")
try:
server.serve_forever()
finally:
flush_langsmith()
if __name__ == "__main__":
main()