zhangjianfei09 commited on
Commit
6c144d3
·
1 Parent(s): 3e3144c

debug server

Browse files
Files changed (3) hide show
  1. fixed_server.py +318 -0
  2. kimina_api.py +144 -0
  3. start.sh +12 -6
fixed_server.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from collections import Counter
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from contextlib import asynccontextmanager
5
+ import json
6
+ from typing import Annotated
7
+
8
+ from fastapi import Depends, FastAPI, Header, HTTPException, Request
9
+ from loguru import logger
10
+ from pydantic import BaseModel, Field
11
+ from tqdm import tqdm
12
+
13
+ from utils.proof_utils import split_proof_header
14
+ from utils.repl_cache import LRUReplCache
15
+
16
+ from .config import settings
17
+ from .healthcheck import router
18
+ from .kimina_api import router as kimina_router
19
+ from .leanrepl import LeanCrashError, LeanREPL
20
+
21
+ repls = {}
22
+ semaphore = asyncio.Semaphore(settings.MAX_CONCURRENT_REQUESTS)
23
+ repl_cache = LRUReplCache(max_size=settings.MAX_REPLS)
24
+
25
+
26
+ async def _repl_creater():
27
+ while True:
28
+ if len(repl_cache.create_queue) > 0:
29
+ repl_to_create = Counter(repl_cache.create_queue)
30
+ repl_cache.create_queue = []
31
+
32
+ for header, amount in tqdm(repl_to_create.items(), desc="Creating REPLs"):
33
+ tasks = []
34
+ creating_repls = []
35
+ try:
36
+ for _ in range(amount):
37
+ repl = LeanREPL()
38
+ creating_repls.append(repl)
39
+ tasks.append(asyncio.to_thread(repl.create_env, header, 600))
40
+
41
+ responses = await asyncio.gather(*tasks)
42
+ logger.info(
43
+ f"Created {len(responses)} {str([header])[:50]} repls with response {str(responses)[:50]}"
44
+ )
45
+ except LeanCrashError:
46
+ for repl in creating_repls:
47
+ repl_cache.close_queue.put(repl)
48
+
49
+ # put the repls in the cache
50
+ for repl in creating_repls:
51
+ await repl_cache.put(header, repl)
52
+
53
+ repl_cache.evict_if_needed()
54
+ await asyncio.sleep(10)
55
+
56
+
57
+ async def _repl_cleaner():
58
+ while True:
59
+ await asyncio.sleep(1)
60
+ while not repl_cache.close_queue.empty():
61
+ id, repl = repl_cache.close_queue.get()
62
+ await asyncio.to_thread(repl.close)
63
+ logger.info(f"Closed {id} repl")
64
+
65
+
66
+ async def _stat_printer():
67
+ update_interval = 15
68
+ while True:
69
+ await asyncio.sleep(update_interval)
70
+ await repl_cache.print_status(update_interval)
71
+
72
+
73
+ @asynccontextmanager
74
+ async def lifespan(app: FastAPI):
75
+ """App lifespan context manager"""
76
+ app.state.executor = ThreadPoolExecutor(max_workers=5000)
77
+ asyncio.get_running_loop().set_default_executor(app.state.executor)
78
+
79
+ # Repl cache manager tasks
80
+ relp_cache_tasks = [
81
+ asyncio.create_task(_repl_cleaner()),
82
+ asyncio.create_task(_repl_creater()),
83
+ asyncio.create_task(_stat_printer()),
84
+ ]
85
+
86
+ # Prefill repl_cache, The pre-filled amount should not be greater than settings.MAX_REPLS.
87
+ # repl_cache.create_queue.extend(["import Mathlib"] * int(settings.MAX_REPLS / 2))
88
+ # TODO: Make it an initialization parameter.
89
+ repl_cache.create_queue.extend(
90
+ ["import Mathlib\nimport Aesop"] * int(settings.MAX_REPLS)
91
+ )
92
+
93
+ try:
94
+ yield
95
+ finally:
96
+ # Cancel cache manager task
97
+ for task in relp_cache_tasks:
98
+ task.cancel()
99
+ try:
100
+ await task
101
+ except asyncio.CancelledError:
102
+ pass
103
+
104
+ # Close thread pools
105
+ app.state.executor.shutdown(wait=True)
106
+
107
+
108
+ app = FastAPI(lifespan=lifespan)
109
+
110
+
111
+ # ------ Dependencies ------
112
+ def validate_api_access(request: Request, authorization: str = Header(None)) -> None:
113
+ api_key = settings.API_KEY
114
+ if api_key is None:
115
+ return
116
+
117
+ if authorization is None or not authorization.startswith("Bearer "):
118
+ raise HTTPException(
119
+ status_code=401, detail="Missing or invalid Authorization header"
120
+ )
121
+
122
+ token = authorization.split("Bearer ")[-1]
123
+ if token != api_key:
124
+ raise HTTPException(status_code=403, detail="Invalid API Key")
125
+
126
+
127
+ require_access_dep = Annotated[None, Depends(validate_api_access)]
128
+
129
+
130
+ # ------ Schemas ------
131
+ class Code(BaseModel):
132
+ custom_id: str | int
133
+ proof: str = Field(None)
134
+ code: str = Field(None) # To be backward compatibility with autoformalizer client
135
+
136
+ def get_proof_content(self) -> str:
137
+ return self.proof if self.proof is not None else self.code
138
+
139
+
140
+ class VerifyRequestBody(BaseModel):
141
+ codes: list[Code]
142
+ timeout: int = 300
143
+ infotree_type: str | None = None
144
+ disable_cache: bool = False
145
+
146
+
147
+ # ------ Endpoint ------
148
+ @app.get("/")
149
+ async def root(require_access_dep: require_access_dep):
150
+ return {"status": "ok"}
151
+
152
+
153
+ @app.post("/verify")
154
+ async def verify(
155
+ body: VerifyRequestBody,
156
+ access: require_access_dep,
157
+ ):
158
+ """verify the proof code."""
159
+ codes = body.codes
160
+ timeout = body.timeout
161
+ infotree_type = body.infotree_type
162
+ disable_cache = body.disable_cache
163
+
164
+ tasks = [
165
+ process_one_code_with_repl_fast(
166
+ code, timeout, infotree_type, disable_cache=disable_cache
167
+ )
168
+ for code in codes
169
+ ]
170
+
171
+ # Await the results of all the tasks concurrently
172
+ results_data = await asyncio.gather(*tasks)
173
+
174
+ results = []
175
+ for result in results_data:
176
+ custom_id, error, response = result
177
+ results.append(
178
+ {
179
+ "custom_id": custom_id,
180
+ "error": error,
181
+ "response": response,
182
+ }
183
+ )
184
+
185
+ return {"results": results}
186
+
187
+
188
+ async def process_one_code_with_repl_fast(
189
+ code: Code,
190
+ timeout: int,
191
+ infotree_type: str | None,
192
+ disable_cache: bool = False,
193
+ ):
194
+ # Throttle the incoming request
195
+ async with semaphore:
196
+ error_msg = None
197
+ response = None
198
+
199
+ custom_id = code.custom_id
200
+ proof = code.get_proof_content()
201
+
202
+ if proof is None:
203
+ logger.warning(f"[{custom_id}] No code provided")
204
+ return custom_id, "No code provided", response
205
+
206
+ proof_header, proof_body = split_proof_header(proof)
207
+
208
+ log_message = {
209
+ 'custom_id': custom_id,
210
+ 'proof_header': proof_header,
211
+ 'proof_body': proof_body,
212
+ 'timeout': timeout,
213
+ }
214
+ logger.debug(
215
+ f"[{custom_id}] Processing code: {json.dumps(log_message)}"
216
+ )
217
+
218
+ # if we can not found the proof header, create a new repl
219
+ if len(proof_header.strip()) == 0 or disable_cache:
220
+ lean_repl = LeanREPL()
221
+ try:
222
+ response = await asyncio.to_thread(
223
+ lean_repl.one_pass_verify, proof, timeout, infotree_type
224
+ )
225
+ except LeanCrashError as e:
226
+ error_msg = str(e)
227
+ log_message["error"] = error_msg
228
+ logger.error(
229
+ f"[{custom_id}] Error raised in one_pass_verify with 1-shot repl: {json.dumps(log_message)}"
230
+ )
231
+ finally:
232
+ del lean_repl
233
+ return custom_id, error_msg, response
234
+
235
+ # Get lean repl instance from the lrucache
236
+ grep_id, repl = await repl_cache.get(proof_header)
237
+
238
+ # If we can not get the repl from the lrucache, we will create a new repl
239
+ if grep_id is None:
240
+ repl = LeanREPL()
241
+
242
+ # And import the proof header
243
+ try:
244
+ response = await asyncio.to_thread(
245
+ repl.create_env, proof_header, timeout
246
+ )
247
+ except LeanCrashError as e:
248
+ error_msg = str(e)
249
+ log_message["error"] = error_msg
250
+ logger.error(
251
+ f"[{custom_id}] Error raised in one_pass_verify with 1-shot repl: {json.dumps(log_message)}"
252
+ )
253
+ del repl
254
+ return custom_id, error_msg, response
255
+
256
+ try:
257
+ response = await asyncio.to_thread(
258
+ repl.extend_env,
259
+ 0,
260
+ proof_body,
261
+ timeout,
262
+ infotree_type,
263
+ )
264
+ except LeanCrashError as e:
265
+ error_msg = str(e)
266
+ log_message["error"] = error_msg
267
+ logger.error(
268
+ f"[{custom_id}] Error raised while extending repl env with proof: {json.dumps(log_message)}"
269
+ )
270
+ if grep_id is not None:
271
+ logger.error(f"[{custom_id}] Removing repl from cache: {grep_id}")
272
+ await repl_cache.destroy(proof_header, grep_id, repl)
273
+ else:
274
+ del repl
275
+ return custom_id, error_msg, response
276
+
277
+ exceeds_limit = False
278
+ if (
279
+ settings.REPL_MEMORY_CHECK_INTERVAL is not None
280
+ and settings.REPL_MEMORY_LIMIT_GB is not None
281
+ and repl.run_command_total % settings.REPL_MEMORY_CHECK_INTERVAL == 0
282
+ ):
283
+ # Check if the REPL exceeds memory limit after execution
284
+ exceeds_limit = await asyncio.to_thread(
285
+ repl.exceeds_memory_limit, settings.REPL_MEMORY_LIMIT_GB
286
+ )
287
+
288
+ if exceeds_limit:
289
+ logger.warning(
290
+ f"REPL exceeds memory limit after execution, destroying it. last verified proof: {json.dumps(log_message)}"
291
+ )
292
+
293
+ if grep_id is None:
294
+ del repl
295
+ else:
296
+ logger.warning(f"Removing repl from cache: {grep_id}")
297
+ await repl_cache.destroy(proof_header, grep_id, repl)
298
+ else:
299
+ # release back to the cache if memory is within limits
300
+ if grep_id is None:
301
+ await repl_cache.put(proof_header, repl)
302
+ else:
303
+ await repl_cache.release(proof_header, grep_id, repl)
304
+
305
+ return custom_id, error_msg, response
306
+
307
+
308
+ @app.post("/one_pass_verify_batch")
309
+ async def one_pass_verify_batch(
310
+ body: VerifyRequestBody,
311
+ access: require_access_dep,
312
+ ):
313
+ """Backward compatible endpoint: accepts both 'proof' / 'code' fields."""
314
+ return await verify(body, access)
315
+
316
+
317
+ app.include_router(router)
318
+ app.include_router(kimina_router)
kimina_api.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, status, Depends
2
+ from loguru import logger
3
+ from typing import List, Optional, Union, Dict, Any
4
+ from pydantic import BaseModel, Field
5
+
6
+ from .server import require_access_dep, process_one_code_with_repl_fast, Code, VerifyRequestBody
7
+ from .config import settings
8
+
9
+ router = APIRouter()
10
+
11
+
12
+ # ------ Kimina Client Compatible Schemas ------
13
+ class Snippet(BaseModel):
14
+ id: str
15
+ code: str = Field(alias="snip")
16
+
17
+ class Config:
18
+ populate_by_name = True
19
+
20
+
21
+ class CheckRequest(BaseModel):
22
+ snippets: List[Snippet]
23
+ timeout: int = 300
24
+ debug: bool = False
25
+ reuse: bool = True
26
+ infotree: Optional[Dict[str, Any]] = None
27
+
28
+
29
+ class Message(BaseModel):
30
+ severity: str
31
+ pos: Optional[Dict[str, int]] = None
32
+ endPos: Optional[Dict[str, int]] = None
33
+ data: str
34
+
35
+
36
+ class Sorry(BaseModel):
37
+ pos: Optional[Dict[str, int]] = None
38
+ endPos: Optional[Dict[str, int]] = None
39
+ goal: str
40
+
41
+
42
+ class ReplResponse(BaseModel):
43
+ id: str
44
+ error: Optional[str] = None
45
+ response: Optional[Dict[str, Any]] = None
46
+
47
+ def analyze(self):
48
+ class Status:
49
+ def __init__(self, status_value):
50
+ self.value = status_value
51
+ self.status = status_value
52
+
53
+ @property
54
+ def status(self):
55
+ return self._status
56
+
57
+ @status.setter
58
+ def status(self, value):
59
+ self._status = value
60
+
61
+ if self.error:
62
+ return Status("error")
63
+ if self.response is None:
64
+ return Status("unknown")
65
+
66
+ # Check if there are error messages
67
+ messages = self.response.get("messages", [])
68
+ for msg in messages:
69
+ if msg.get("severity") == "error":
70
+ return Status("invalid")
71
+
72
+ return Status("valid")
73
+
74
+
75
+ class CheckResponse(BaseModel):
76
+ results: List[ReplResponse]
77
+
78
+
79
+ # ------ Kimina Client Compatible Endpoint ------
80
+ @router.post("/api/check")
81
+ async def kimina_check(
82
+ request: CheckRequest,
83
+ access: None = Depends(require_access_dep)
84
+ ):
85
+ """
86
+ Kimina client compatible endpoint for code verification.
87
+ Converts Kimina client format to server format.
88
+ """
89
+ try:
90
+ # Convert Kimina format to server format
91
+ codes = []
92
+ for snippet in request.snippets:
93
+ code_obj = Code(
94
+ custom_id=snippet.id,
95
+ code=snippet.code
96
+ )
97
+ codes.append(code_obj)
98
+
99
+ # Create server request body
100
+ server_request = VerifyRequestBody(
101
+ codes=codes,
102
+ timeout=request.timeout,
103
+ infotree_type=request.infotree.get("type") if request.infotree else None,
104
+ disable_cache=not request.reuse
105
+ )
106
+
107
+ # Process using existing verify function infrastructure
108
+ tasks = [
109
+ process_one_code_with_repl_fast(
110
+ code,
111
+ server_request.timeout,
112
+ server_request.infotree_type,
113
+ disable_cache=server_request.disable_cache
114
+ )
115
+ for code in server_request.codes
116
+ ]
117
+
118
+ # Await the results of all the tasks concurrently
119
+ results_data = await asyncio.gather(*tasks)
120
+
121
+ # Convert results to Kimina format
122
+ results = []
123
+ for result in results_data:
124
+ custom_id, error, response = result
125
+ repl_response = ReplResponse(
126
+ id=str(custom_id),
127
+ error=error,
128
+ response=response
129
+ )
130
+ results.append(repl_response)
131
+
132
+ return CheckResponse(results=results)
133
+
134
+ except Exception as e:
135
+ logger.error(f"Error in kimina_check endpoint: {e}")
136
+ raise HTTPException(
137
+ status_code=500,
138
+ detail=f"Internal server error: {str(e)}"
139
+ )
140
+
141
+
142
+ # Import asyncio for the gather function
143
+ import asyncio
144
+
start.sh CHANGED
@@ -1,12 +1,7 @@
1
  #!/bin/bash
2
  set -e
3
 
4
- echo "LEAN_SERVER_REPL_PATH: $LEAN_SERVER_REPL_PATH"
5
- echo "LEAN_SERVER_PROJECT_DIR: $LEAN_SERVER_PROJECT_DIR"
6
-
7
- ls -la $LEAN_SERVER_REPL_PATH
8
- ls -la $LEAN_SERVER_PROJECT_DIR
9
-
10
  which lake
11
  which elan
12
  lake --version
@@ -14,6 +9,12 @@ elan --version
14
  lean --version
15
  elan toolchain list
16
 
 
 
 
 
 
 
17
  cd /home/user/app/kimina-lean-server
18
  python3 -m server &
19
 
@@ -22,6 +23,9 @@ ps aux | grep python
22
  netstat -tlnp | grep 8888
23
  ps aux | grep server
24
 
 
 
 
25
  echo "curl -v http://localhost:8888/"
26
  curl -v http://localhost:8888/
27
  echo "curl -v http://localhost:8888/health"
@@ -30,6 +34,8 @@ echo "curl -v http://localhost:8888/api/check"
30
  curl -v http://localhost:8888/api/check
31
  echo "curl -v http://localhost:8888/docs"
32
  curl -v http://localhost:8888/docs
 
 
33
 
34
  cd /home/user/app
35
  ls -l
 
1
  #!/bin/bash
2
  set -e
3
 
4
+ export PATH="~/.elan/bin:${PATH}"
 
 
 
 
 
5
  which lake
6
  which elan
7
  lake --version
 
9
  lean --version
10
  elan toolchain list
11
 
12
+ echo "LEAN_SERVER_REPL_PATH: $LEAN_SERVER_REPL_PATH"
13
+ echo "LEAN_SERVER_PROJECT_DIR: $LEAN_SERVER_PROJECT_DIR"
14
+
15
+ ls -la $LEAN_SERVER_REPL_PATH
16
+ ls -la $LEAN_SERVER_PROJECT_DIR
17
+
18
  cd /home/user/app/kimina-lean-server
19
  python3 -m server &
20
 
 
23
  netstat -tlnp | grep 8888
24
  ps aux | grep server
25
 
26
+ free -h
27
+ top -p 15
28
+
29
  echo "curl -v http://localhost:8888/"
30
  curl -v http://localhost:8888/
31
  echo "curl -v http://localhost:8888/health"
 
34
  curl -v http://localhost:8888/api/check
35
  echo "curl -v http://localhost:8888/docs"
36
  curl -v http://localhost:8888/docs
37
+ echo "curl -v http://localhost:8888/verify"
38
+ curl -v http://localhost:8888/verify
39
 
40
  cd /home/user/app
41
  ls -l