Spaces:
Running
Running
| import asyncio | |
| from collections import Counter | |
| from concurrent.futures import ThreadPoolExecutor | |
| from contextlib import asynccontextmanager | |
| import json | |
| from typing import Annotated | |
| from fastapi import Depends, FastAPI, Header, HTTPException, Request | |
| from loguru import logger | |
| from pydantic import BaseModel, Field | |
| from tqdm import tqdm | |
| from utils.proof_utils import split_proof_header | |
| from utils.repl_cache import LRUReplCache | |
| from .config import settings | |
| from .healthcheck import router | |
| from .kimina_api import router as kimina_router | |
| from .leanrepl import LeanCrashError, LeanREPL | |
| repls = {} | |
| semaphore = asyncio.Semaphore(settings.MAX_CONCURRENT_REQUESTS) | |
| repl_cache = LRUReplCache(max_size=settings.MAX_REPLS) | |
| async def _repl_creater(): | |
| while True: | |
| if len(repl_cache.create_queue) > 0: | |
| repl_to_create = Counter(repl_cache.create_queue) | |
| repl_cache.create_queue = [] | |
| for header, amount in tqdm(repl_to_create.items(), desc="Creating REPLs"): | |
| tasks = [] | |
| creating_repls = [] | |
| try: | |
| for _ in range(amount): | |
| repl = LeanREPL() | |
| creating_repls.append(repl) | |
| tasks.append(asyncio.to_thread(repl.create_env, header, 600)) | |
| responses = await asyncio.gather(*tasks) | |
| logger.info( | |
| f"Created {len(responses)} {str([header])[:50]} repls with response {str(responses)[:50]}" | |
| ) | |
| except LeanCrashError: | |
| for repl in creating_repls: | |
| repl_cache.close_queue.put(repl) | |
| # put the repls in the cache | |
| for repl in creating_repls: | |
| await repl_cache.put(header, repl) | |
| repl_cache.evict_if_needed() | |
| await asyncio.sleep(10) | |
| async def _repl_cleaner(): | |
| while True: | |
| await asyncio.sleep(1) | |
| while not repl_cache.close_queue.empty(): | |
| id, repl = repl_cache.close_queue.get() | |
| await asyncio.to_thread(repl.close) | |
| logger.info(f"Closed {id} repl") | |
| async def _stat_printer(): | |
| update_interval = 15 | |
| while True: | |
| await asyncio.sleep(update_interval) | |
| await repl_cache.print_status(update_interval) | |
| async def lifespan(app: FastAPI): | |
| """App lifespan context manager""" | |
| app.state.executor = ThreadPoolExecutor(max_workers=5000) | |
| asyncio.get_running_loop().set_default_executor(app.state.executor) | |
| # Repl cache manager tasks | |
| relp_cache_tasks = [ | |
| asyncio.create_task(_repl_cleaner()), | |
| asyncio.create_task(_repl_creater()), | |
| asyncio.create_task(_stat_printer()), | |
| ] | |
| # Prefill repl_cache, The pre-filled amount should not be greater than settings.MAX_REPLS. | |
| # repl_cache.create_queue.extend(["import Mathlib"] * int(settings.MAX_REPLS / 2)) | |
| # TODO: Make it an initialization parameter. | |
| repl_cache.create_queue.extend( | |
| ["import Mathlib\nimport Aesop"] * int(settings.MAX_REPLS) | |
| ) | |
| try: | |
| yield | |
| finally: | |
| # Cancel cache manager task | |
| for task in relp_cache_tasks: | |
| task.cancel() | |
| try: | |
| await task | |
| except asyncio.CancelledError: | |
| pass | |
| # Close thread pools | |
| app.state.executor.shutdown(wait=True) | |
| app = FastAPI(lifespan=lifespan) | |
| # ------ Dependencies ------ | |
| def validate_api_access(request: Request, authorization: str = Header(None)) -> None: | |
| api_key = settings.API_KEY | |
| if api_key is None: | |
| return | |
| if authorization is None or not authorization.startswith("Bearer "): | |
| raise HTTPException( | |
| status_code=401, detail="Missing or invalid Authorization header" | |
| ) | |
| token = authorization.split("Bearer ")[-1] | |
| if token != api_key: | |
| raise HTTPException(status_code=403, detail="Invalid API Key") | |
| require_access_dep = Annotated[None, Depends(validate_api_access)] | |
| # ------ Schemas ------ | |
| class Code(BaseModel): | |
| custom_id: str | int | |
| proof: str = Field(None) | |
| code: str = Field(None) # To be backward compatibility with autoformalizer client | |
| def get_proof_content(self) -> str: | |
| return self.proof if self.proof is not None else self.code | |
| class VerifyRequestBody(BaseModel): | |
| codes: list[Code] | |
| timeout: int = 300 | |
| infotree_type: str | None = None | |
| disable_cache: bool = False | |
| # ------ Endpoint ------ | |
| async def root(require_access_dep: require_access_dep): | |
| return {"status": "ok"} | |
| async def verify( | |
| body: VerifyRequestBody, | |
| access: require_access_dep, | |
| ): | |
| """verify the proof code.""" | |
| codes = body.codes | |
| timeout = body.timeout | |
| infotree_type = body.infotree_type | |
| disable_cache = body.disable_cache | |
| tasks = [ | |
| process_one_code_with_repl_fast( | |
| code, timeout, infotree_type, disable_cache=disable_cache | |
| ) | |
| for code in codes | |
| ] | |
| # Await the results of all the tasks concurrently | |
| results_data = await asyncio.gather(*tasks) | |
| results = [] | |
| for result in results_data: | |
| custom_id, error, response = result | |
| results.append( | |
| { | |
| "custom_id": custom_id, | |
| "error": error, | |
| "response": response, | |
| } | |
| ) | |
| return {"results": results} | |
| async def process_one_code_with_repl_fast( | |
| code: Code, | |
| timeout: int, | |
| infotree_type: str | None, | |
| disable_cache: bool = False, | |
| ): | |
| # Throttle the incoming request | |
| async with semaphore: | |
| error_msg = None | |
| response = None | |
| custom_id = code.custom_id | |
| proof = code.get_proof_content() | |
| if proof is None: | |
| logger.warning(f"[{custom_id}] No code provided") | |
| return custom_id, "No code provided", response | |
| proof_header, proof_body = split_proof_header(proof) | |
| log_message = { | |
| 'custom_id': custom_id, | |
| 'proof_header': proof_header, | |
| 'proof_body': proof_body, | |
| 'timeout': timeout, | |
| } | |
| logger.debug( | |
| f"[{custom_id}] Processing code: {json.dumps(log_message)}" | |
| ) | |
| # if we can not found the proof header, create a new repl | |
| if len(proof_header.strip()) == 0 or disable_cache: | |
| lean_repl = LeanREPL() | |
| try: | |
| response = await asyncio.to_thread( | |
| lean_repl.one_pass_verify, proof, timeout, infotree_type | |
| ) | |
| except LeanCrashError as e: | |
| error_msg = str(e) | |
| log_message["error"] = error_msg | |
| logger.error( | |
| f"[{custom_id}] Error raised in one_pass_verify with 1-shot repl: {json.dumps(log_message)}" | |
| ) | |
| finally: | |
| del lean_repl | |
| return custom_id, error_msg, response | |
| # Get lean repl instance from the lrucache | |
| grep_id, repl = await repl_cache.get(proof_header) | |
| # If we can not get the repl from the lrucache, we will create a new repl | |
| if grep_id is None: | |
| repl = LeanREPL() | |
| # And import the proof header | |
| try: | |
| response = await asyncio.to_thread( | |
| repl.create_env, proof_header, timeout | |
| ) | |
| except LeanCrashError as e: | |
| error_msg = str(e) | |
| log_message["error"] = error_msg | |
| logger.error( | |
| f"[{custom_id}] Error raised in one_pass_verify with 1-shot repl: {json.dumps(log_message)}" | |
| ) | |
| del repl | |
| return custom_id, error_msg, response | |
| try: | |
| response = await asyncio.to_thread( | |
| repl.extend_env, | |
| 0, | |
| proof_body, | |
| timeout, | |
| infotree_type, | |
| ) | |
| except LeanCrashError as e: | |
| error_msg = str(e) | |
| log_message["error"] = error_msg | |
| logger.error( | |
| f"[{custom_id}] Error raised while extending repl env with proof: {json.dumps(log_message)}" | |
| ) | |
| if grep_id is not None: | |
| logger.error(f"[{custom_id}] Removing repl from cache: {grep_id}") | |
| await repl_cache.destroy(proof_header, grep_id, repl) | |
| else: | |
| del repl | |
| return custom_id, error_msg, response | |
| exceeds_limit = False | |
| if ( | |
| settings.REPL_MEMORY_CHECK_INTERVAL is not None | |
| and settings.REPL_MEMORY_LIMIT_GB is not None | |
| and repl.run_command_total % settings.REPL_MEMORY_CHECK_INTERVAL == 0 | |
| ): | |
| # Check if the REPL exceeds memory limit after execution | |
| exceeds_limit = await asyncio.to_thread( | |
| repl.exceeds_memory_limit, settings.REPL_MEMORY_LIMIT_GB | |
| ) | |
| if exceeds_limit: | |
| logger.warning( | |
| f"REPL exceeds memory limit after execution, destroying it. last verified proof: {json.dumps(log_message)}" | |
| ) | |
| if grep_id is None: | |
| del repl | |
| else: | |
| logger.warning(f"Removing repl from cache: {grep_id}") | |
| await repl_cache.destroy(proof_header, grep_id, repl) | |
| else: | |
| # release back to the cache if memory is within limits | |
| if grep_id is None: | |
| await repl_cache.put(proof_header, repl) | |
| else: | |
| await repl_cache.release(proof_header, grep_id, repl) | |
| return custom_id, error_msg, response | |
| async def one_pass_verify_batch( | |
| body: VerifyRequestBody, | |
| access: require_access_dep, | |
| ): | |
| """Backward compatible endpoint: accepts both 'proof' / 'code' fields.""" | |
| return await verify(body, access) | |
| app.include_router(router) | |
| app.include_router(kimina_router) | |