Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import json | |
| import subprocess | |
| from huggingface_hub import InferenceClient | |
| _client_cache = {} | |
| def get_client(model_id: str, endpoint_url: str = None, user_token: str = None): | |
| cache_key = f"{model_id}_{endpoint_url}_{user_token}" | |
| if cache_key not in _client_cache: | |
| token = user_token if user_token else os.environ.get("HF_TOKEN") | |
| if endpoint_url: | |
| _client_cache[cache_key] = InferenceClient( | |
| base_url=endpoint_url, | |
| token=token | |
| ) | |
| else: | |
| _client_cache[cache_key] = InferenceClient( | |
| model=model_id, | |
| token=token | |
| ) | |
| return _client_cache[cache_key] | |
| def validate_with_repl(lean_code): | |
| # Try local paths first, then fall back to /app paths | |
| local_repl = "/tmp/repl/.lake/build/bin/repl" | |
| local_lake = os.path.expanduser("~/.elan/bin/lake") | |
| local_cwd = "/tmp/repl" | |
| docker_repl = "/app/repl/.lake/build/bin/repl" | |
| docker_lake = "/root/.elan/bin/lake" | |
| docker_cwd = "/app/eval_project" | |
| # Use whichever exists | |
| if os.path.exists(local_repl): | |
| repl_bin = local_repl | |
| lake_bin = local_lake | |
| cwd = local_cwd | |
| elif os.path.exists(docker_repl): | |
| repl_bin = docker_repl | |
| lake_bin = docker_lake | |
| cwd = docker_cwd | |
| else: | |
| return {"valid": None, "error": "REPL not available (run elan installer)"} | |
| if "import Mathlib" not in lean_code: | |
| lean_code = "import Mathlib\nimport Aesop\n\n" + lean_code | |
| input_json = json.dumps({"cmd": lean_code}) | |
| try: | |
| result = subprocess.run( | |
| [lake_bin, "env", repl_bin], | |
| input=input_json, | |
| text=True, | |
| capture_output=True, | |
| cwd=cwd, | |
| timeout=60, | |
| env={**os.environ, "PATH": os.path.dirname(lake_bin) + ":" + os.environ.get("PATH", "")} | |
| ) | |
| output = json.loads(result.stdout) | |
| has_error = any(m.get("severity") == "error" for m in output.get("messages", [])) | |
| has_sorry = len(output.get("sorries", [])) > 0 | |
| return { | |
| "valid": not has_error and not has_sorry, | |
| "error": "Contains sorry" if has_sorry else ("Has errors" if has_error else None) | |
| } | |
| except subprocess.TimeoutExpired: | |
| return {"valid": False, "error": "Validation timed out after 60s"} | |
| except Exception as e: | |
| return {"valid": False, "error": str(e)} | |
| def process(theorem, model_id, endpoint_url, user_token, temperature, top_p, max_tokens, request: gr.Request): | |
| username = request.username if request else None | |
| OWNER_USERNAME = "lzumot" | |
| prover_models = ["DeepSeek-Prover", "Goedel-Prover", "Kimina-Prover"] | |
| is_prover = any(m in model_id for m in prover_models) | |
| final_endpoint = None | |
| final_token = None | |
| if is_prover: | |
| if endpoint_url: | |
| final_endpoint = endpoint_url | |
| final_token = user_token if user_token else os.environ.get("HF_TOKEN") | |
| else: | |
| if username != OWNER_USERNAME: | |
| yield "Access denied: Prover models require login or provide endpoint", "", [{"error": "Auth required"}], "" | |
| return | |
| final_endpoint = os.environ.get("PROVER_ENDPOINT_URL") | |
| if not final_endpoint: | |
| yield "Error: PROVER_ENDPOINT_URL not set", "", [{"error": "Not configured"}], "" | |
| return | |
| # Check for basic test theorem | |
| if not theorem or len(theorem.strip()) < 5: | |
| yield "Error: Enter a theorem", "", [{"error": "Empty input"}], "" | |
| return | |
| prompt = f"""Complete the following Lean 4 code: | |
| ```lean4 | |
| import Mathlib | |
| import Aesop | |
| {theorem} | |
| ``` | |
| Proof:""" | |
| print(f"DEBUG: Generating with {model_id}...") | |
| yield "Initializing...", "", [], "" | |
| try: | |
| client = get_client(model_id, final_endpoint, final_token) | |
| except Exception as e: | |
| yield f"Client Error: {str(e)}", "", [{"error": str(e)}], "" | |
| return | |
| live_proof = "" | |
| try: | |
| yield "Waiting for model...", "⏳ Generating proof...", [], "" | |
| # Use non-streaming first to test if API works, then switch to streaming | |
| live_proof = client.text_generation( | |
| prompt, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p | |
| ) | |
| yield "Generation complete...", live_proof, [], "" # Show result after | |
| except Exception as e: | |
| print(f"DEBUG: Generation error: {e}") | |
| yield f"Generation Error: {str(e)}", live_proof, [{"error": str(e)}], "" | |
| return | |
| proof = live_proof.replace("```lean4", "").replace("```", "").strip() | |
| print(f"DEBUG: Generated {len(proof)} chars") | |
| yield "Validating...", proof, [], "" | |
| validation = validate_with_repl(proof) | |
| # Handle None (local dev) vs True/False | |
| if validation["valid"] is None: | |
| status_icon = "⚠️" | |
| status_text = "NO VALIDATION (LOCAL)" | |
| else: | |
| status_icon = "✅" if validation["valid"] else "❌" | |
| status_text = "PASS" if validation["valid"] else "FAIL" | |
| result_json = [{ | |
| "status": f"{status_icon} {status_text}", | |
| "valid": validation["valid"], | |
| "error": validation.get("error", "None") | |
| }] | |
| final_status = f"{status_icon} {status_text}" | |
| if validation.get("error"): | |
| final_status += f" - {validation['error']}" | |
| yield final_status, proof, result_json, proof | |
| with gr.Blocks(title="Lean 4 Prover") as demo: | |
| gr.Markdown("# Lean 4 Proof Validator") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_id = gr.Dropdown( | |
| choices=[ | |
| "Qwen/Qwen2.5-0.5B-Instruct", | |
| "deepseek-ai/DeepSeek-Prover-V2-7B", | |
| "Goedel-LM/Goedel-Prover-V2-8B", | |
| "AI-MO/Kimina-Prover-Distill-0.6B", | |
| "AI-MO/Kimina-Prover-Distill-8B" | |
| ], | |
| value="AI-MO/Kimina-Prover-Distill-0.6B", | |
| label="Model" | |
| ) | |
| with gr.Accordion("Generation Parameters", open=True): | |
| temperature = gr.Slider(0.0, 2.0, value=0.8, step=0.1, label="Temperature") | |
| top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.05, label="Top-p") | |
| max_tokens = gr.Slider(256, 4096, value=1024, step=256, label="Max Tokens") | |
| with gr.Accordion("Endpoint Settings (Advanced)", open=True): | |
| endpoint_url = gr.Textbox(label="Your Endpoint URL", placeholder="https://xxx.aws.endpoints.huggingface.cloud") | |
| user_token = gr.Textbox(label="Your HF Token", type="password", placeholder="hf_...") | |
| theorem = gr.Textbox(label="Theorem", placeholder="theorem test : 1 + 1 = 2 := by", lines=3) | |
| btn = gr.Button("Generate & Validate", variant="primary") | |
| with gr.Column(): | |
| status_text = gr.Textbox(label="Status", value="Ready") | |
| live_output = gr.Textbox(label="Live Generation", lines=15, interactive=False) | |
| results_json = gr.JSON(label="Result") | |
| final_proof = gr.Textbox(label="Final Proof (Copy)", lines=5) | |
| btn.click( | |
| process, | |
| [theorem, model_id, endpoint_url, user_token, temperature, top_p, max_tokens], | |
| [status_text, live_output, results_json, final_proof] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0",server_port=7860,) |