techreg / server.py
bardd's picture
Upload 39 files
7153d5c verified
#!/usr/bin/env python3
from __future__ import annotations
import base64
import binascii
import json
import os
import re
import subprocess
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from time import monotonic
from typing import Any
APP_DIR = Path(__file__).resolve().parent
STATIC_DIR = APP_DIR / "static"
UPLOADS_DIR = APP_DIR / "uploads"
HOST = os.environ.get("HOST", "127.0.0.1")
PORT = int(os.environ.get("PORT", "8080"))
GEMINI_TIMEOUT_SEC = int(os.environ.get("GEMINI_TIMEOUT_SEC", "90"))
GEMINI_CLI_BINARY = os.environ.get("GEMINI_CLI_BINARY", "gemini")
LOCKED_GEMINI_MODEL = "gemini-3-flash-preview"
MAX_IMAGE_BYTES = int(os.environ.get("MAX_IMAGE_BYTES", str(8 * 1024 * 1024)))
MAX_BATCH_IMAGES = int(os.environ.get("MAX_BATCH_IMAGES", "20"))
MAX_PARALLEL_WORKERS = max(1, int(os.environ.get("MAX_PARALLEL_WORKERS", "4")))
ALLOWED_IMAGE_MIME_TO_EXT = {
"image/png": "png",
"image/jpeg": "jpg",
"image/jpg": "jpg",
"image/webp": "webp",
"image/gif": "gif",
}
DATA_URL_RE = re.compile(r"^data:(?P<mime>[-\w.+/]+);base64,(?P<data>[A-Za-z0-9+/=\s]+)$")
DEFAULT_SYSTEM_PROMPT = (
"You are a UK fintech ad compliance screening assistant. "
"Return only valid JSON and nothing else."
)
JSON_SCHEMA_HINT = {
"risk_level": "low | medium | high",
"summary": "short sentence",
"violations": [
{
"issue": "what is risky",
"rule_refs": ["FCA handbook or principle area"],
"why": "why this is a risk",
"fix": "specific rewrite guidance",
}
],
"safe_rewrite": "optional ad rewrite",
}
def build_prompt(ad_text: str, extra_context: str, system_prompt: str, image_at_path: str | None) -> str:
ad_text_clean = ad_text.strip()
parts = [
system_prompt.strip(),
"",
"Task: Screen this ad copy for UK fintech/FCA-style risk.",
"Output format: Return only JSON in this shape:",
json.dumps(JSON_SCHEMA_HINT, ensure_ascii=True, indent=2),
"",
]
if image_at_path:
parts += [
"Creative image reference:",
f"@{image_at_path}",
"Use this image as part of your compliance risk review.",
"",
]
parts += [
"Ad copy:",
ad_text_clean if ad_text_clean else "[Not provided]",
]
if extra_context.strip():
parts += ["", "Extra context:", extra_context.strip()]
return "\n".join(parts)
def gemini_cmd_candidates(prompt: str) -> list[list[str]]:
# Model is intentionally locked and never exposed to users.
return [
[GEMINI_CLI_BINARY, "--model", LOCKED_GEMINI_MODEL, "-p", prompt],
[GEMINI_CLI_BINARY, "-m", LOCKED_GEMINI_MODEL, "-p", prompt],
[GEMINI_CLI_BINARY, "--model", LOCKED_GEMINI_MODEL, "--prompt", prompt],
[GEMINI_CLI_BINARY, "-m", LOCKED_GEMINI_MODEL, "--prompt", prompt],
]
def is_flag_parse_error(stderr: str, stdout: str) -> bool:
combined = f"{stderr}\n{stdout}".lower()
return any(
token in combined
for token in (
"unknown option",
"unknown argument",
"invalid option",
"unrecognized option",
"unrecognized argument",
"unexpected argument",
"did you mean",
)
)
def run_gemini(prompt: str) -> str:
attempts = gemini_cmd_candidates(prompt)
last_error = "Gemini CLI invocation failed."
child_env = os.environ.copy()
# Keep only GEMINI_API_KEY to avoid CLI warnings when both vars are set.
if not child_env.get("GEMINI_API_KEY") and child_env.get("GOOGLE_API_KEY"):
child_env["GEMINI_API_KEY"] = child_env["GOOGLE_API_KEY"]
child_env.pop("GOOGLE_API_KEY", None)
for idx, cmd in enumerate(attempts):
proc = subprocess.run(
cmd,
capture_output=True,
text=True,
cwd=str(APP_DIR),
env=child_env,
timeout=GEMINI_TIMEOUT_SEC,
check=False,
)
if proc.returncode == 0:
return (proc.stdout or "").strip()
stderr = (proc.stderr or "").strip()
stdout = (proc.stdout or "").strip()
details = stderr if stderr else stdout
last_error = details or f"Gemini CLI exited with code {proc.returncode}."
# Only retry different flag shapes if this appears to be flag parsing trouble.
if idx < len(attempts) - 1 and is_flag_parse_error(stderr, stdout):
continue
break
raise RuntimeError(last_error)
def try_parse_json(text: str) -> Any | None:
trimmed = text.strip()
if not trimmed:
return None
# Handle markdown fences if the model returns them.
if trimmed.startswith("```"):
lines = trimmed.splitlines()
if len(lines) >= 3 and lines[-1].strip().startswith("```"):
trimmed = "\n".join(lines[1:-1]).strip()
if trimmed.lower().startswith("json"):
trimmed = trimmed[4:].strip()
try:
return json.loads(trimmed)
except json.JSONDecodeError:
return None
def safe_filename_stem(raw_name: str) -> str:
stem = Path(raw_name).stem if raw_name else "ad-image"
cleaned = re.sub(r"[^A-Za-z0-9_-]+", "-", stem).strip("-")
if not cleaned:
return "ad-image"
return cleaned[:40]
def save_image_from_data_url(image_data_url: str, image_filename: str) -> str:
match = DATA_URL_RE.match(image_data_url.strip())
if not match:
raise ValueError("Image must be a valid base64 data URL (data:image/...;base64,...).")
mime_type = match.group("mime").lower()
extension = ALLOWED_IMAGE_MIME_TO_EXT.get(mime_type)
if not extension:
allowed = ", ".join(sorted(ALLOWED_IMAGE_MIME_TO_EXT))
raise ValueError(f"Unsupported image type '{mime_type}'. Allowed: {allowed}.")
base64_payload = re.sub(r"\s+", "", match.group("data"))
try:
image_bytes = base64.b64decode(base64_payload, validate=True)
except (ValueError, binascii.Error):
raise ValueError("Image base64 payload is invalid.") from None
if not image_bytes:
raise ValueError("Image payload is empty.")
if len(image_bytes) > MAX_IMAGE_BYTES:
raise ValueError(f"Image is too large. Max size is {MAX_IMAGE_BYTES} bytes.")
UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
final_name = f"{safe_filename_stem(image_filename)}-{uuid.uuid4().hex[:10]}.{extension}"
image_path = UPLOADS_DIR / final_name
image_path.write_bytes(image_bytes)
return f"uploads/{final_name}"
def normalize_image_inputs(payload: dict[str, Any]) -> list[dict[str, str]]:
images_field = payload.get("images")
single_data_url = str(payload.get("image_data_url", "")).strip()
single_filename = str(payload.get("image_filename", "")).strip()
normalized: list[dict[str, str]] = []
if isinstance(images_field, list) and images_field:
if len(images_field) > MAX_BATCH_IMAGES:
raise ValueError(f"Too many images. Max is {MAX_BATCH_IMAGES}.")
for idx, item in enumerate(images_field):
if not isinstance(item, dict):
raise ValueError(f"images[{idx}] must be an object.")
data_url = str(item.get("data_url", "")).strip()
filename = str(item.get("filename", "")).strip() or f"image-{idx + 1}.png"
if not data_url:
raise ValueError(f"images[{idx}].data_url is required.")
normalized.append({"data_url": data_url, "filename": filename})
elif single_data_url:
normalized.append(
{
"data_url": single_data_url,
"filename": single_filename or "image.png",
}
)
return normalized
def run_single_check(prompt: str) -> tuple[bool, int, dict[str, Any]]:
try:
raw_output = run_gemini(prompt)
return True, 200, {"parsed_output": try_parse_json(raw_output), "raw_output": raw_output}
except FileNotFoundError:
return (
False,
500,
{"error": f"Gemini CLI not found. Install it and ensure '{GEMINI_CLI_BINARY}' is on PATH."},
)
except subprocess.TimeoutExpired:
return False, 504, {"error": f"Gemini CLI timed out after {GEMINI_TIMEOUT_SEC}s."}
except RuntimeError as err:
return False, 500, {"error": str(err)}
def run_single_image_check(
index: int,
total: int,
image_ref: str,
ad_text: str,
extra_context: str,
system_prompt: str,
) -> dict[str, Any]:
print(f"[batch {index}/{total}] starting check for {image_ref}", flush=True)
started = monotonic()
prompt = build_prompt(
ad_text=ad_text,
extra_context=extra_context,
system_prompt=system_prompt,
image_at_path=image_ref,
)
ok, _status, result = run_single_check(prompt)
elapsed = monotonic() - started
status_text = "ok" if ok else "failed"
print(f"[batch {index}/{total}] {status_text} in {elapsed:.1f}s", flush=True)
return {
"index": index,
"ok": ok,
"image_reference": image_ref,
"parsed_output": result.get("parsed_output"),
"raw_output": result.get("raw_output"),
"error": result.get("error"),
}
class AppHandler(SimpleHTTPRequestHandler):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, directory=str(STATIC_DIR), **kwargs)
def _send_json(self, status: int, payload: dict[str, Any]) -> None:
data = json.dumps(payload, ensure_ascii=True).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json; charset=utf-8")
self.send_header("Content-Length", str(len(data)))
self.end_headers()
self.wfile.write(data)
def do_POST(self) -> None:
if self.path != "/api/check":
self._send_json(404, {"ok": False, "error": "Not found"})
return
content_length = int(self.headers.get("Content-Length", "0"))
if content_length <= 0:
self._send_json(400, {"ok": False, "error": "Request body is required."})
return
content_type = self.headers.get("Content-Type", "")
if "application/json" not in content_type.lower():
self._send_json(400, {"ok": False, "error": "Content-Type must be application/json."})
return
raw_body = self.rfile.read(content_length)
try:
body_str = raw_body.decode("utf-8")
except UnicodeDecodeError:
self._send_json(400, {"ok": False, "error": "Body contains invalid UTF-8 data."})
return
try:
payload = json.loads(body_str)
except json.JSONDecodeError:
self._send_json(400, {"ok": False, "error": "Body must be valid JSON."})
return
ad_text = str(payload.get("ad_text", "")).strip()
extra_context = str(payload.get("extra_context", "")).strip()
system_prompt = str(payload.get("system_prompt", DEFAULT_SYSTEM_PROMPT)).strip()
try:
image_inputs = normalize_image_inputs(payload)
except ValueError as err:
self._send_json(400, {"ok": False, "error": str(err)})
return
if not ad_text and not image_inputs:
self._send_json(400, {"ok": False, "error": "Provide 'ad_text' or an image."})
return
if not system_prompt:
system_prompt = DEFAULT_SYSTEM_PROMPT
if not image_inputs:
prompt = build_prompt(
ad_text=ad_text,
extra_context=extra_context,
system_prompt=system_prompt,
image_at_path=None,
)
ok, status, result = run_single_check(prompt)
if not ok:
self._send_json(status, {"ok": False, "error": result["error"]})
return
self._send_json(
200,
{
"ok": True,
"mode": "single",
"parallel_workers": 1,
"all_success": True,
"total": 1,
"success_count": 1,
"failure_count": 0,
"results": [
{
"index": 1,
"ok": True,
"image_reference": None,
"parsed_output": result["parsed_output"],
"raw_output": result["raw_output"],
"error": None,
}
],
"parsed_output": result["parsed_output"],
"raw_output": result["raw_output"],
"image_reference": None,
},
)
return
image_refs: list[str] = []
for image in image_inputs:
try:
image_ref = save_image_from_data_url(
image_data_url=image["data_url"], image_filename=image["filename"]
)
except ValueError as err:
self._send_json(400, {"ok": False, "error": str(err)})
return
image_refs.append(image_ref)
total = len(image_refs)
worker_count = max(1, min(MAX_PARALLEL_WORKERS, total))
print(
f"Starting bulk Gemini checks: total_images={total}, parallel_workers={worker_count}",
flush=True,
)
results: list[dict[str, Any] | None] = [None] * total
completed = 0
with ThreadPoolExecutor(max_workers=worker_count) as executor:
future_to_slot = {
executor.submit(
run_single_image_check,
idx,
total,
image_ref,
ad_text,
extra_context,
system_prompt,
): (idx - 1, image_ref)
for idx, image_ref in enumerate(image_refs, start=1)
}
for future in as_completed(future_to_slot):
slot, image_ref = future_to_slot[future]
try:
results[slot] = future.result()
except Exception as err:
# Defensive fallback: this should be rare because worker handles model errors.
results[slot] = {
"index": slot + 1,
"ok": False,
"image_reference": image_ref,
"parsed_output": None,
"raw_output": None,
"error": f"Unexpected worker error: {err}",
}
completed += 1
print(f"Bulk progress: {completed}/{total} completed", flush=True)
finalized_results = [item for item in results if item is not None]
finalized_results.sort(key=lambda item: int(item["index"]))
success_count = sum(1 for item in finalized_results if item["ok"])
failure_count = len(finalized_results) - success_count
first = finalized_results[0]
self._send_json(
200,
{
"ok": True,
"mode": "bulk" if len(finalized_results) > 1 else "single",
"parallel_workers": worker_count,
"all_success": failure_count == 0,
"total": len(finalized_results),
"success_count": success_count,
"failure_count": failure_count,
"results": finalized_results,
# Keep compatibility with single-result UI consumers.
"parsed_output": first.get("parsed_output"),
"raw_output": first.get("raw_output"),
"image_reference": first.get("image_reference"),
},
)
def main() -> None:
STATIC_DIR.mkdir(parents=True, exist_ok=True)
UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
server = ThreadingHTTPServer((HOST, PORT), AppHandler)
print(f"Server running at http://{HOST}:{PORT}")
server.serve_forever()
if __name__ == "__main__":
main()