#!/usr/bin/env python3 from __future__ import annotations import base64 import binascii import json import os import re import subprocess import uuid from concurrent.futures import ThreadPoolExecutor, as_completed from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer from pathlib import Path from time import monotonic from typing import Any APP_DIR = Path(__file__).resolve().parent STATIC_DIR = APP_DIR / "static" UPLOADS_DIR = APP_DIR / "uploads" HOST = os.environ.get("HOST", "127.0.0.1") PORT = int(os.environ.get("PORT", "8080")) GEMINI_TIMEOUT_SEC = int(os.environ.get("GEMINI_TIMEOUT_SEC", "90")) GEMINI_CLI_BINARY = os.environ.get("GEMINI_CLI_BINARY", "gemini") LOCKED_GEMINI_MODEL = "gemini-3-flash-preview" MAX_IMAGE_BYTES = int(os.environ.get("MAX_IMAGE_BYTES", str(8 * 1024 * 1024))) MAX_BATCH_IMAGES = int(os.environ.get("MAX_BATCH_IMAGES", "20")) MAX_PARALLEL_WORKERS = max(1, int(os.environ.get("MAX_PARALLEL_WORKERS", "4"))) ALLOWED_IMAGE_MIME_TO_EXT = { "image/png": "png", "image/jpeg": "jpg", "image/jpg": "jpg", "image/webp": "webp", "image/gif": "gif", } DATA_URL_RE = re.compile(r"^data:(?P[-\w.+/]+);base64,(?P[A-Za-z0-9+/=\s]+)$") DEFAULT_SYSTEM_PROMPT = ( "You are a UK fintech ad compliance screening assistant. " "Return only valid JSON and nothing else." ) JSON_SCHEMA_HINT = { "risk_level": "low | medium | high", "summary": "short sentence", "violations": [ { "issue": "what is risky", "rule_refs": ["FCA handbook or principle area"], "why": "why this is a risk", "fix": "specific rewrite guidance", } ], "safe_rewrite": "optional ad rewrite", } def build_prompt(ad_text: str, extra_context: str, system_prompt: str, image_at_path: str | None) -> str: ad_text_clean = ad_text.strip() parts = [ system_prompt.strip(), "", "Task: Screen this ad copy for UK fintech/FCA-style risk.", "Output format: Return only JSON in this shape:", json.dumps(JSON_SCHEMA_HINT, ensure_ascii=True, indent=2), "", ] if image_at_path: parts += [ "Creative image reference:", f"@{image_at_path}", "Use this image as part of your compliance risk review.", "", ] parts += [ "Ad copy:", ad_text_clean if ad_text_clean else "[Not provided]", ] if extra_context.strip(): parts += ["", "Extra context:", extra_context.strip()] return "\n".join(parts) def gemini_cmd_candidates(prompt: str) -> list[list[str]]: # Model is intentionally locked and never exposed to users. return [ [GEMINI_CLI_BINARY, "--model", LOCKED_GEMINI_MODEL, "-p", prompt], [GEMINI_CLI_BINARY, "-m", LOCKED_GEMINI_MODEL, "-p", prompt], [GEMINI_CLI_BINARY, "--model", LOCKED_GEMINI_MODEL, "--prompt", prompt], [GEMINI_CLI_BINARY, "-m", LOCKED_GEMINI_MODEL, "--prompt", prompt], ] def is_flag_parse_error(stderr: str, stdout: str) -> bool: combined = f"{stderr}\n{stdout}".lower() return any( token in combined for token in ( "unknown option", "unknown argument", "invalid option", "unrecognized option", "unrecognized argument", "unexpected argument", "did you mean", ) ) def run_gemini(prompt: str) -> str: attempts = gemini_cmd_candidates(prompt) last_error = "Gemini CLI invocation failed." child_env = os.environ.copy() # Keep only GEMINI_API_KEY to avoid CLI warnings when both vars are set. if not child_env.get("GEMINI_API_KEY") and child_env.get("GOOGLE_API_KEY"): child_env["GEMINI_API_KEY"] = child_env["GOOGLE_API_KEY"] child_env.pop("GOOGLE_API_KEY", None) for idx, cmd in enumerate(attempts): proc = subprocess.run( cmd, capture_output=True, text=True, cwd=str(APP_DIR), env=child_env, timeout=GEMINI_TIMEOUT_SEC, check=False, ) if proc.returncode == 0: return (proc.stdout or "").strip() stderr = (proc.stderr or "").strip() stdout = (proc.stdout or "").strip() details = stderr if stderr else stdout last_error = details or f"Gemini CLI exited with code {proc.returncode}." # Only retry different flag shapes if this appears to be flag parsing trouble. if idx < len(attempts) - 1 and is_flag_parse_error(stderr, stdout): continue break raise RuntimeError(last_error) def try_parse_json(text: str) -> Any | None: trimmed = text.strip() if not trimmed: return None # Handle markdown fences if the model returns them. if trimmed.startswith("```"): lines = trimmed.splitlines() if len(lines) >= 3 and lines[-1].strip().startswith("```"): trimmed = "\n".join(lines[1:-1]).strip() if trimmed.lower().startswith("json"): trimmed = trimmed[4:].strip() try: return json.loads(trimmed) except json.JSONDecodeError: return None def safe_filename_stem(raw_name: str) -> str: stem = Path(raw_name).stem if raw_name else "ad-image" cleaned = re.sub(r"[^A-Za-z0-9_-]+", "-", stem).strip("-") if not cleaned: return "ad-image" return cleaned[:40] def save_image_from_data_url(image_data_url: str, image_filename: str) -> str: match = DATA_URL_RE.match(image_data_url.strip()) if not match: raise ValueError("Image must be a valid base64 data URL (data:image/...;base64,...).") mime_type = match.group("mime").lower() extension = ALLOWED_IMAGE_MIME_TO_EXT.get(mime_type) if not extension: allowed = ", ".join(sorted(ALLOWED_IMAGE_MIME_TO_EXT)) raise ValueError(f"Unsupported image type '{mime_type}'. Allowed: {allowed}.") base64_payload = re.sub(r"\s+", "", match.group("data")) try: image_bytes = base64.b64decode(base64_payload, validate=True) except (ValueError, binascii.Error): raise ValueError("Image base64 payload is invalid.") from None if not image_bytes: raise ValueError("Image payload is empty.") if len(image_bytes) > MAX_IMAGE_BYTES: raise ValueError(f"Image is too large. Max size is {MAX_IMAGE_BYTES} bytes.") UPLOADS_DIR.mkdir(parents=True, exist_ok=True) final_name = f"{safe_filename_stem(image_filename)}-{uuid.uuid4().hex[:10]}.{extension}" image_path = UPLOADS_DIR / final_name image_path.write_bytes(image_bytes) return f"uploads/{final_name}" def normalize_image_inputs(payload: dict[str, Any]) -> list[dict[str, str]]: images_field = payload.get("images") single_data_url = str(payload.get("image_data_url", "")).strip() single_filename = str(payload.get("image_filename", "")).strip() normalized: list[dict[str, str]] = [] if isinstance(images_field, list) and images_field: if len(images_field) > MAX_BATCH_IMAGES: raise ValueError(f"Too many images. Max is {MAX_BATCH_IMAGES}.") for idx, item in enumerate(images_field): if not isinstance(item, dict): raise ValueError(f"images[{idx}] must be an object.") data_url = str(item.get("data_url", "")).strip() filename = str(item.get("filename", "")).strip() or f"image-{idx + 1}.png" if not data_url: raise ValueError(f"images[{idx}].data_url is required.") normalized.append({"data_url": data_url, "filename": filename}) elif single_data_url: normalized.append( { "data_url": single_data_url, "filename": single_filename or "image.png", } ) return normalized def run_single_check(prompt: str) -> tuple[bool, int, dict[str, Any]]: try: raw_output = run_gemini(prompt) return True, 200, {"parsed_output": try_parse_json(raw_output), "raw_output": raw_output} except FileNotFoundError: return ( False, 500, {"error": f"Gemini CLI not found. Install it and ensure '{GEMINI_CLI_BINARY}' is on PATH."}, ) except subprocess.TimeoutExpired: return False, 504, {"error": f"Gemini CLI timed out after {GEMINI_TIMEOUT_SEC}s."} except RuntimeError as err: return False, 500, {"error": str(err)} def run_single_image_check( index: int, total: int, image_ref: str, ad_text: str, extra_context: str, system_prompt: str, ) -> dict[str, Any]: print(f"[batch {index}/{total}] starting check for {image_ref}", flush=True) started = monotonic() prompt = build_prompt( ad_text=ad_text, extra_context=extra_context, system_prompt=system_prompt, image_at_path=image_ref, ) ok, _status, result = run_single_check(prompt) elapsed = monotonic() - started status_text = "ok" if ok else "failed" print(f"[batch {index}/{total}] {status_text} in {elapsed:.1f}s", flush=True) return { "index": index, "ok": ok, "image_reference": image_ref, "parsed_output": result.get("parsed_output"), "raw_output": result.get("raw_output"), "error": result.get("error"), } class AppHandler(SimpleHTTPRequestHandler): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, directory=str(STATIC_DIR), **kwargs) def _send_json(self, status: int, payload: dict[str, Any]) -> None: data = json.dumps(payload, ensure_ascii=True).encode("utf-8") self.send_response(status) self.send_header("Content-Type", "application/json; charset=utf-8") self.send_header("Content-Length", str(len(data))) self.end_headers() self.wfile.write(data) def do_POST(self) -> None: if self.path != "/api/check": self._send_json(404, {"ok": False, "error": "Not found"}) return content_length = int(self.headers.get("Content-Length", "0")) if content_length <= 0: self._send_json(400, {"ok": False, "error": "Request body is required."}) return content_type = self.headers.get("Content-Type", "") if "application/json" not in content_type.lower(): self._send_json(400, {"ok": False, "error": "Content-Type must be application/json."}) return raw_body = self.rfile.read(content_length) try: body_str = raw_body.decode("utf-8") except UnicodeDecodeError: self._send_json(400, {"ok": False, "error": "Body contains invalid UTF-8 data."}) return try: payload = json.loads(body_str) except json.JSONDecodeError: self._send_json(400, {"ok": False, "error": "Body must be valid JSON."}) return ad_text = str(payload.get("ad_text", "")).strip() extra_context = str(payload.get("extra_context", "")).strip() system_prompt = str(payload.get("system_prompt", DEFAULT_SYSTEM_PROMPT)).strip() try: image_inputs = normalize_image_inputs(payload) except ValueError as err: self._send_json(400, {"ok": False, "error": str(err)}) return if not ad_text and not image_inputs: self._send_json(400, {"ok": False, "error": "Provide 'ad_text' or an image."}) return if not system_prompt: system_prompt = DEFAULT_SYSTEM_PROMPT if not image_inputs: prompt = build_prompt( ad_text=ad_text, extra_context=extra_context, system_prompt=system_prompt, image_at_path=None, ) ok, status, result = run_single_check(prompt) if not ok: self._send_json(status, {"ok": False, "error": result["error"]}) return self._send_json( 200, { "ok": True, "mode": "single", "parallel_workers": 1, "all_success": True, "total": 1, "success_count": 1, "failure_count": 0, "results": [ { "index": 1, "ok": True, "image_reference": None, "parsed_output": result["parsed_output"], "raw_output": result["raw_output"], "error": None, } ], "parsed_output": result["parsed_output"], "raw_output": result["raw_output"], "image_reference": None, }, ) return image_refs: list[str] = [] for image in image_inputs: try: image_ref = save_image_from_data_url( image_data_url=image["data_url"], image_filename=image["filename"] ) except ValueError as err: self._send_json(400, {"ok": False, "error": str(err)}) return image_refs.append(image_ref) total = len(image_refs) worker_count = max(1, min(MAX_PARALLEL_WORKERS, total)) print( f"Starting bulk Gemini checks: total_images={total}, parallel_workers={worker_count}", flush=True, ) results: list[dict[str, Any] | None] = [None] * total completed = 0 with ThreadPoolExecutor(max_workers=worker_count) as executor: future_to_slot = { executor.submit( run_single_image_check, idx, total, image_ref, ad_text, extra_context, system_prompt, ): (idx - 1, image_ref) for idx, image_ref in enumerate(image_refs, start=1) } for future in as_completed(future_to_slot): slot, image_ref = future_to_slot[future] try: results[slot] = future.result() except Exception as err: # Defensive fallback: this should be rare because worker handles model errors. results[slot] = { "index": slot + 1, "ok": False, "image_reference": image_ref, "parsed_output": None, "raw_output": None, "error": f"Unexpected worker error: {err}", } completed += 1 print(f"Bulk progress: {completed}/{total} completed", flush=True) finalized_results = [item for item in results if item is not None] finalized_results.sort(key=lambda item: int(item["index"])) success_count = sum(1 for item in finalized_results if item["ok"]) failure_count = len(finalized_results) - success_count first = finalized_results[0] self._send_json( 200, { "ok": True, "mode": "bulk" if len(finalized_results) > 1 else "single", "parallel_workers": worker_count, "all_success": failure_count == 0, "total": len(finalized_results), "success_count": success_count, "failure_count": failure_count, "results": finalized_results, # Keep compatibility with single-result UI consumers. "parsed_output": first.get("parsed_output"), "raw_output": first.get("raw_output"), "image_reference": first.get("image_reference"), }, ) def main() -> None: STATIC_DIR.mkdir(parents=True, exist_ok=True) UPLOADS_DIR.mkdir(parents=True, exist_ok=True) server = ThreadingHTTPServer((HOST, PORT), AppHandler) print(f"Server running at http://{HOST}:{PORT}") server.serve_forever() if __name__ == "__main__": main()