Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import time | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import Optional | |
| import gradio as gr | |
| from fastapi import Body, FastAPI, HTTPException, Header, Query, Request | |
| from fastapi.responses import JSONResponse, RedirectResponse | |
| from pydantic import BaseModel | |
| from auditenv.models import AuditAction, AuditFinding, AuditObservation, EnvState, StepResult, TaskId | |
| from auditenv.state import AuditEnvRuntime | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| # Set this environment variable to require API key auth; leave empty to disable. | |
| import os | |
| API_KEY = os.getenv("AUDITENV_API_KEY", "") | |
| # Rate limiting: max requests per IP per window | |
| RATE_LIMIT_MAX = int(os.getenv("AUDITENV_RATE_LIMIT", "120")) | |
| RATE_LIMIT_WINDOW_SECONDS = 60 | |
| # --------------------------------------------------------------------------- | |
| # App setup | |
| # --------------------------------------------------------------------------- | |
| app = FastAPI(title="AuditEnv", version="0.2.0") | |
| # Wire in leaderboard sub-router | |
| from auditenv.leaderboard import router as leaderboard_router | |
| app.include_router(leaderboard_router) | |
| # Per-session runtimes (session isolation) | |
| _sessions: dict[str, AuditEnvRuntime] = {} | |
| # In-memory rate limiter state | |
| _rate_tracker: dict[str, list[float]] = defaultdict(list) | |
| # --------------------------------------------------------------------------- | |
| # Middleware — rate limiting | |
| # --------------------------------------------------------------------------- | |
| async def rate_limit_middleware(request: Request, call_next): | |
| if RATE_LIMIT_MAX <= 0: | |
| return await call_next(request) | |
| client_ip = request.client.host if request.client else "unknown" | |
| now = time.time() | |
| window_start = now - RATE_LIMIT_WINDOW_SECONDS | |
| # Clean old entries | |
| _rate_tracker[client_ip] = [t for t in _rate_tracker[client_ip] if t > window_start] | |
| if len(_rate_tracker[client_ip]) >= RATE_LIMIT_MAX: | |
| return JSONResponse( | |
| status_code=429, | |
| content={ | |
| "error": "rate_limit_exceeded", | |
| "detail": f"Max {RATE_LIMIT_MAX} requests per {RATE_LIMIT_WINDOW_SECONDS}s", | |
| "retry_after_seconds": RATE_LIMIT_WINDOW_SECONDS, | |
| } | |
| ) | |
| _rate_tracker[client_ip].append(now) | |
| return await call_next(request) | |
| # --------------------------------------------------------------------------- | |
| # Auth dependency | |
| # --------------------------------------------------------------------------- | |
| def _check_api_key(x_api_key: Optional[str] = Header(None)) -> None: | |
| """Optional API key check. Only enforced if AUDITENV_API_KEY is set.""" | |
| if API_KEY and x_api_key != API_KEY: | |
| raise HTTPException( | |
| status_code=401, | |
| detail={"error": "unauthorized", "detail": "Invalid or missing X-API-Key header"}, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Error handler — structured 422 errors | |
| # --------------------------------------------------------------------------- | |
| from fastapi.exceptions import RequestValidationError | |
| async def validation_exception_handler(request: Request, exc: RequestValidationError): | |
| return JSONResponse( | |
| status_code=422, | |
| content={ | |
| "error": "validation_error", | |
| "detail": [ | |
| { | |
| "field": ".".join(str(loc) for loc in err.get("loc", [])), | |
| "message": err.get("msg", ""), | |
| "type": err.get("type", ""), | |
| } | |
| for err in exc.errors() | |
| ], | |
| }, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Request / Response models | |
| # --------------------------------------------------------------------------- | |
| class ResetRequest(BaseModel): | |
| task_id: TaskId = "easy" | |
| seed: int = 42 | |
| class ResetResponse(BaseModel): | |
| session_id: str | |
| observation: AuditObservation | |
| class SessionStepRequest(BaseModel): | |
| session_id: str | |
| action: AuditAction | |
| def _resolve_runtime_for_step(action: AuditAction) -> AuditEnvRuntime: | |
| if action.session_id: | |
| runtime = _sessions.get(action.session_id) | |
| if runtime is None or runtime.current is None: | |
| raise HTTPException( | |
| status_code=404, | |
| detail={ | |
| "error": "unknown_session_id", | |
| "detail": f"No session found for {action.session_id}", | |
| }, | |
| ) | |
| if runtime.current.task_id != action.task_id: | |
| raise HTTPException( | |
| status_code=409, | |
| detail={ | |
| "error": "task_session_mismatch", | |
| "detail": "action.task_id does not match session task", | |
| }, | |
| ) | |
| return runtime | |
| active_matches = [ | |
| runtime | |
| for runtime in _sessions.values() | |
| if runtime.current is not None and runtime.current.task_id == action.task_id | |
| ] | |
| if not active_matches: | |
| raise HTTPException( | |
| status_code=400, | |
| detail={"error": "no_active_session", "detail": "No active session. Call /reset first."}, | |
| ) | |
| if len(active_matches) > 1: | |
| raise HTTPException( | |
| status_code=409, | |
| detail={ | |
| "error": "ambiguous_session", | |
| "detail": "Multiple active sessions for task_id. Provide session_id.", | |
| }, | |
| ) | |
| return active_matches[0] | |
| # --------------------------------------------------------------------------- | |
| # Endpoints | |
| # --------------------------------------------------------------------------- | |
| def health() -> dict[str, str]: | |
| return {"status": "ok"} | |
| def root() -> RedirectResponse: | |
| return RedirectResponse(url="/dashboard/") | |
| def reset(req: Optional[ResetRequest] = Body(default=None), x_api_key: Optional[str] = Header(None)) -> AuditObservation: | |
| _check_api_key(x_api_key) | |
| request_payload = req or ResetRequest() | |
| runtime = AuditEnvRuntime(default_seed=request_payload.seed, enable_logging=True) | |
| obs = runtime.reset(task_id=request_payload.task_id, seed=request_payload.seed) | |
| # Store session for isolation | |
| session_id = obs.session_id | |
| _sessions[session_id] = runtime | |
| return obs | |
| def step(action: AuditAction, x_api_key: Optional[str] = Header(None)) -> StepResult: | |
| _check_api_key(x_api_key) | |
| runtime = _resolve_runtime_for_step(action) | |
| try: | |
| result = runtime.step(action) | |
| # Clean up completed sessions | |
| if result.done and runtime.current: | |
| _sessions.pop(runtime.current.session_id, None) | |
| return result | |
| except RuntimeError as exc: | |
| raise HTTPException( | |
| status_code=400, | |
| detail={"error": "runtime_error", "detail": str(exc)}, | |
| ) from exc | |
| except ValueError as exc: | |
| raise HTTPException( | |
| status_code=422, | |
| detail={"error": "validation_error", "detail": str(exc)}, | |
| ) from exc | |
| def state(session_id: Optional[str] = Query(None), x_api_key: Optional[str] = Header(None)) -> EnvState: | |
| _check_api_key(x_api_key) | |
| if not _sessions: | |
| raise HTTPException( | |
| status_code=400, | |
| detail={"error": "no_active_session", "detail": "No active session. Call /reset first."}, | |
| ) | |
| if session_id: | |
| runtime = _sessions.get(session_id) | |
| if runtime is None or runtime.current is None: | |
| raise HTTPException( | |
| status_code=404, | |
| detail={"error": "unknown_session_id", "detail": f"No session found for {session_id}"}, | |
| ) | |
| try: | |
| return runtime.state() | |
| except RuntimeError as exc: | |
| raise HTTPException( | |
| status_code=400, | |
| detail={"error": "runtime_error", "detail": str(exc)}, | |
| ) from exc | |
| active_runtimes = [runtime for runtime in _sessions.values() if runtime.current is not None] | |
| if not active_runtimes: | |
| raise HTTPException( | |
| status_code=400, | |
| detail={"error": "no_active_session", "detail": "No active session. Call /reset first."}, | |
| ) | |
| if len(active_runtimes) > 1: | |
| raise HTTPException( | |
| status_code=409, | |
| detail={"error": "ambiguous_session", "detail": "Multiple active sessions. Provide session_id."}, | |
| ) | |
| runtime = active_runtimes[0] | |
| try: | |
| return runtime.state() | |
| except RuntimeError as exc: | |
| raise HTTPException( | |
| status_code=400, | |
| detail={"error": "runtime_error", "detail": str(exc)}, | |
| ) from exc | |
| # --------------------------------------------------------------------------- | |
| # Visual dashboard (Gradio) | |
| # --------------------------------------------------------------------------- | |
| def _json_out(payload: object) -> str: | |
| if hasattr(payload, "model_dump"): | |
| payload = payload.model_dump() # type: ignore[assignment] | |
| return json.dumps(payload, indent=2, ensure_ascii=False) | |
| def _dashboard_reset(task_id: str, seed: float) -> str: | |
| try: | |
| obs = reset(ResetRequest(task_id=task_id, seed=int(seed))) | |
| return _json_out(obs) | |
| except HTTPException as exc: | |
| return _json_out({"error": "http_error", "status_code": exc.status_code, "detail": exc.detail}) | |
| except Exception as exc: | |
| return _json_out({"error": "runtime_error", "detail": str(exc)}) | |
| def _dashboard_state() -> str: | |
| try: | |
| snapshot = state() | |
| return _json_out(snapshot) | |
| except HTTPException as exc: | |
| return _json_out({"error": "http_error", "status_code": exc.status_code, "detail": exc.detail}) | |
| except Exception as exc: | |
| return _json_out({"error": "runtime_error", "detail": str(exc)}) | |
| def _dashboard_step( | |
| action_type: str, | |
| task_id: str, | |
| document_id: str, | |
| violation_type: str, | |
| confidence: float, | |
| note: str, | |
| ) -> str: | |
| try: | |
| finding = None | |
| if action_type == "submit_finding": | |
| doc_id = document_id.strip() | |
| if not doc_id: | |
| return _json_out( | |
| { | |
| "error": "validation_error", | |
| "detail": "document_id is required for submit_finding", | |
| } | |
| ) | |
| finding = AuditFinding( | |
| document_id=doc_id, | |
| violation_type=violation_type, | |
| evidence=[doc_id], | |
| confidence=max(0.0, min(1.0, float(confidence))), | |
| ) | |
| action = AuditAction( | |
| action_type=action_type, # type: ignore[arg-type] | |
| task_id=task_id, # type: ignore[arg-type] | |
| finding=finding, | |
| note=note or "dashboard", | |
| ) | |
| result = step(action) | |
| return _json_out(result) | |
| except HTTPException as exc: | |
| return _json_out({"error": "http_error", "status_code": exc.status_code, "detail": exc.detail}) | |
| except Exception as exc: | |
| return _json_out({"error": "runtime_error", "detail": str(exc)}) | |
| def _extract_uploaded_text(file_path: str) -> tuple[str, dict[str, object]]: | |
| path = Path(file_path) | |
| if not path.exists() or not path.is_file(): | |
| raise ValueError("Uploaded file was not found on server.") | |
| metadata: dict[str, object] = { | |
| "file_name": path.name, | |
| "suffix": path.suffix.lower(), | |
| "size_bytes": path.stat().st_size, | |
| } | |
| suffix = path.suffix.lower() | |
| if suffix in {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tif", ".tiff"}: | |
| try: | |
| from PIL import Image | |
| import pytesseract | |
| except Exception as exc: | |
| raise ValueError( | |
| "Image OCR dependencies are missing. Upload text/csv/json files, " | |
| "or install Pillow and pytesseract on this runtime." | |
| ) from exc | |
| with Image.open(path) as image: | |
| text = pytesseract.image_to_string(image, lang="eng") | |
| else: | |
| raw = path.read_bytes() | |
| text = "" | |
| for encoding in ("utf-8", "utf-16", "latin-1"): | |
| try: | |
| text = raw.decode(encoding) | |
| metadata["decoded_with"] = encoding | |
| break | |
| except Exception: | |
| continue | |
| if not text: | |
| raise ValueError("Could not decode the uploaded file as text.") | |
| text = text.strip() | |
| if not text: | |
| raise ValueError("Uploaded file produced empty text.") | |
| return text, metadata | |
| def _infer_violation_signals(task_id: str, text: str) -> list[dict[str, object]]: | |
| content = text.lower() | |
| signal_map: dict[str, list[tuple[str, list[str]]]] = { | |
| "easy": [ | |
| ("duplicate_receipt", ["duplicate_flag=true", "matches_receipt", "duplicate receipt"]), | |
| ("alcohol_over_limit", ["alcohol_amount", "alcohol over limit", "policy_limit"]), | |
| ("late_submission", ["late=true", "late submission", "policy_deadline"]), | |
| ], | |
| "medium": [ | |
| ("sod_conflict", ["sod_conflict", "segregation_of_duties", "segregation-of-duties"]), | |
| ("dormant_account_reactivation", ["dormant_account", "dormant=true", "reactivation"]), | |
| ("temporal_anomaly", ["temporal_anomaly", "off_hours", "suspicious_hour"]), | |
| ], | |
| "hard": [ | |
| ("shell_company", ["shell_company", "shell=true", "front company"]), | |
| ("invoice_splitting", ["invoice_splitting", "split_invoice", "split invoice"]), | |
| ("round_tripping", ["round_tripping", "round_trip=true", "round tripping"]), | |
| ], | |
| } | |
| matches: list[dict[str, object]] = [] | |
| for violation_type, keywords in signal_map.get(task_id, []): | |
| matched = [kw for kw in keywords if kw in content] | |
| if matched: | |
| matches.append({ | |
| "violation_type": violation_type, | |
| "matched_keywords": matched, | |
| }) | |
| return matches | |
| def _dashboard_analyze_file(task_id: str, uploaded_file: str | None) -> str: | |
| if not uploaded_file: | |
| return _json_out({ | |
| "error": "validation_error", | |
| "detail": "Please upload a file first.", | |
| }) | |
| try: | |
| text, metadata = _extract_uploaded_text(uploaded_file) | |
| except ValueError as exc: | |
| return _json_out({"error": "validation_error", "detail": str(exc)}) | |
| except Exception as exc: | |
| return _json_out({"error": "runtime_error", "detail": str(exc)}) | |
| signals = _infer_violation_signals(task_id=task_id, text=text) | |
| suggested_action = { | |
| "action_type": "submit_finding" if signals else "noop", | |
| "task_id": task_id, | |
| "violation_type": signals[0]["violation_type"] if signals else None, | |
| "confidence": 0.8 if signals else 0.5, | |
| } | |
| result = { | |
| "status": "ok", | |
| "file": metadata, | |
| "task_id": task_id, | |
| "text_stats": { | |
| "chars": len(text), | |
| "lines": text.count("\n") + 1, | |
| }, | |
| "signals": signals, | |
| "suggested_action": suggested_action, | |
| "text_preview": text[:1200], | |
| } | |
| return _json_out(result) | |
| def _build_dashboard() -> gr.Blocks: | |
| with gr.Blocks(title="AuditEnv Dashboard") as demo: | |
| gr.Markdown("# AuditEnv Dashboard") | |
| gr.Markdown("Use this UI for manual episode interaction. API remains available at `/reset`, `/step`, `/state`, and docs at `/docs`.") | |
| with gr.Row(): | |
| task_dd = gr.Dropdown(choices=["easy", "medium", "hard"], value="easy", label="Task") | |
| seed_num = gr.Number(value=42, precision=0, label="Seed") | |
| with gr.Row(): | |
| reset_btn = gr.Button("Reset") | |
| state_btn = gr.Button("Get State") | |
| gr.Markdown("### Upload and Analyze Your File") | |
| upload_file = gr.File( | |
| label="Upload file (.txt, .csv, .json, or image)", | |
| type="filepath", | |
| file_types=[".txt", ".md", ".csv", ".json", ".yaml", ".yml", ".png", ".jpg", ".jpeg"], | |
| ) | |
| analyze_btn = gr.Button("Analyze Uploaded File") | |
| gr.Markdown("### Step Action") | |
| with gr.Row(): | |
| action_dd = gr.Dropdown( | |
| choices=["submit_finding", "flag_human_review", "noop"], | |
| value="noop", | |
| label="Action Type", | |
| ) | |
| confidence = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.05, label="Confidence") | |
| with gr.Row(): | |
| document_id = gr.Textbox(label="Document ID (required for submit_finding)", value="") | |
| violation_type = gr.Dropdown( | |
| choices=[ | |
| "duplicate_receipt", | |
| "alcohol_over_limit", | |
| "late_submission", | |
| "sod_conflict", | |
| "dormant_account_reactivation", | |
| "temporal_anomaly", | |
| "shell_company", | |
| "invoice_splitting", | |
| "round_tripping", | |
| ], | |
| value="duplicate_receipt", | |
| label="Violation Type", | |
| ) | |
| note = gr.Textbox(label="Note", value="dashboard") | |
| step_btn = gr.Button("Step", variant="primary") | |
| output = gr.Code(label="Response", language="json", lines=20) | |
| reset_btn.click(fn=_dashboard_reset, inputs=[task_dd, seed_num], outputs=output) | |
| state_btn.click(fn=_dashboard_state, inputs=[], outputs=output) | |
| analyze_btn.click( | |
| fn=_dashboard_analyze_file, | |
| inputs=[task_dd, upload_file], | |
| outputs=output, | |
| ) | |
| step_btn.click( | |
| fn=_dashboard_step, | |
| inputs=[action_dd, task_dd, document_id, violation_type, confidence, note], | |
| outputs=output, | |
| ) | |
| return demo | |
| dashboard = _build_dashboard() | |
| app = gr.mount_gradio_app(app, dashboard, path="/dashboard") | |