xcz0's picture
增强 LLM 客户端功能,支持多种推理提供方和错误处理,优化生成 Lean 代码的调用逻辑
965cabb
import streamlit as st
import subprocess
import json
import re
import os
from huggingface_hub import InferenceClient
def _make_inference_client(model_id: str) -> InferenceClient:
"""Create an InferenceClient with broad compatibility across huggingface_hub versions.
Supports:
- HF Serverless Inference via token
- Custom Inference Endpoint via base_url
- Optional provider selection via env
Env vars:
- HF_TOKEN (required in this app)
- HF_BASE_URL (optional): custom endpoint base URL
- HF_PROVIDER (optional): e.g. "hf-inference"
"""
token = os.environ.get("HF_TOKEN")
base_url = os.environ.get("HF_BASE_URL")
provider = os.environ.get("HF_PROVIDER")
base_kwargs: dict[str, object] = {}
if base_url:
base_kwargs["base_url"] = base_url
if provider:
base_kwargs["provider"] = provider
# Try a few constructor signatures to tolerate different hub versions.
candidates = [
{"model": model_id, "token": token, **base_kwargs},
{"model": model_id, "api_key": token, **base_kwargs},
{"token": token, **base_kwargs},
{"api_key": token, **base_kwargs},
{**base_kwargs},
]
last_err: Exception | None = None
for kwargs in candidates:
filtered = {k: v for k, v in kwargs.items() if v is not None}
try:
return InferenceClient(**filtered)
except TypeError as e:
last_err = e
continue
# Should be unreachable, but keep a safe fallback.
if last_err is not None:
raise last_err
return InferenceClient()
def _call_llm(client: InferenceClient, prompt: str, model_id: str) -> str:
"""Call the model with robust fallbacks.
Primary: text_generation
Fallback: chat completion (OpenAI-style)
"""
# 1) Try text-generation first.
try:
return client.text_generation(
prompt,
model=model_id,
max_new_tokens=4096,
temperature=0.2,
top_p=0.95,
)
except TypeError:
# Older hub versions may not accept `model=` here if model is set in client.
return client.text_generation(
prompt,
max_new_tokens=4096,
temperature=0.2,
top_p=0.95,
)
except StopIteration as e:
# huggingface_hub currently may raise StopIteration when it cannot find
# a provider for the requested (task, model).
first_err: Exception = e
except Exception as e:
first_err = e
# 2) Fallback to chat completion if available.
try:
chat = client.chat.completions.create(
model=model_id,
messages=[{"role": "user", "content": prompt}],
max_tokens=4096,
temperature=0.2,
top_p=0.95,
)
content = chat.choices[0].message.content
if isinstance(content, str) and content.strip():
return content
raise RuntimeError("chat-completion 返回空内容")
except Exception:
if isinstance(first_err, StopIteration):
raise RuntimeError(
"Hugging Face Inference 未能为该模型找到可用的推理提供方(provider)。"
) from first_err
raise
# Note: other exceptions are handled by the caller for better UI reporting.
def _extract_lean_code_blocks(text: str) -> list[str]:
matches = re.findall(
r"```(?:lean4?|lean)\s*\r?\n(.*?)\r?\n```",
text,
flags=re.DOTALL | re.IGNORECASE,
)
return [m.strip() for m in matches if m.strip()]
def _extract_plan(text: str) -> str | None:
m = re.search(
r"(?is)(?:^|\n)\s*(?:plan|proof\s*plan)\s*:\s*(.*?)(?=\n```(?:lean4?|lean)|\Z)",
text,
)
if not m:
return None
plan = m.group(1).strip()
return plan or None
def _ensure_mathlib_prelude(lean_code: str) -> str:
if re.search(r"(?im)^\s*import\s+Mathlib\b", lean_code):
return lean_code
return "import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\n" + lean_code
# 设置页面
st.title("Lean 4 证明能力在线测评")
st.markdown("输入自然语言数学题,由 LLM 生成证明并由 Lean 4 编译器实时验证。")
# 配置 LLM 客户端 (建议在 HF Space 设置中添加 HF_TOKEN 密钥)
# 默认模型(可通过环境变量覆盖)
MODEL_ID = os.environ.get("HF_MODEL_ID", "deepseek-ai/DeepSeek-Prover-V2-7B")
client = _make_inference_client(MODEL_ID)
# 用户输入
nl_problem = st.text_area(
"输入数学问题 (自然语言):", "证明:对于任何实数 x,若 x > 0,则 x + 1/x >= 2。"
)
if st.button("生成并验证证明"):
if not os.environ.get("HF_TOKEN"):
st.error(
"未检测到环境变量 HF_TOKEN,无法调用 Hugging Face Inference。请在 Space Secrets 中设置 HF_TOKEN。"
)
st.stop()
with st.spinner("LLM 正在生成 Lean 4 代码..."):
# DeepSeek-Prover-V2-7B 在 HF Inference 上通常走 text-generation;用“补全 Lean 文件 + 先给 plan 再给代码”的指令更稳。
# 说明:因为用户输入是自然语言,这里让模型生成“可编译的完整 Lean 片段(含 theorem)”。
formal_skeleton = f"""
/-!
Problem (natural language):
{nl_problem}
Please formalize the statement in Lean 4 and prove it.
-/
import Mathlib
import Aesop
set_option maxHeartbeats 0
open BigOperators Real Nat Topology Rat
-- You may choose a reasonable theorem name.
theorem generated_problem : True := by
trivial
""".strip()
prompt = f"""
Complete the following Lean 4 code.
```lean4
{formal_skeleton}
```
Requirements:
1) First, provide a detailed proof plan. Start it with exactly `PLAN:`.
2) Then, output ONE Lean 4 code block (```lean4 ... ```), containing the full corrected Lean file.
3) The final Lean code must compile with Mathlib. Do NOT use `sorry`.
""".strip()
try:
gen_text = _call_llm(client, prompt, MODEL_ID)
except Exception as e:
# Provide actionable guidance instead of crashing.
st.error("调用 Hugging Face Inference 失败。")
st.write("错误信息:", str(e))
st.info(
"可能原因:该模型不支持 Hugging Face Serverless Inference 的 `text-generation`/`chat-completion`。\n"
"可选解决方案:\n"
"1) 在 Space Variables 里设置 `HF_MODEL_ID` 为一个已在 Inference 上可用的模型;\n"
"2) 使用 Hugging Face Inference Endpoint:设置 `HF_BASE_URL` 指向你的 Endpoint,并确保 `HF_TOKEN` 有权限;\n"
"3) 如需显式指定 provider,可设置 `HF_PROVIDER`(例如 `hf-inference`)。"
)
st.stop()
plan = _extract_plan(gen_text)
code_blocks = _extract_lean_code_blocks(gen_text)
if plan:
with st.expander("Proof plan (模型生成)", expanded=False):
st.markdown(plan)
if code_blocks:
lean_code = _ensure_mathlib_prelude(code_blocks[-1])
st.code(lean_code, language="lean")
# 运行验证
with st.spinner("Lean 4 编译器验证中..."):
full_input = json.dumps({"cmd": lean_code})
repl_bin = os.environ.get(
"LEAN_REPL_BIN", "/app/repl/.lake/build/bin/repl"
)
proc = subprocess.run(
["lake", "env", repl_bin],
input=full_input,
text=True,
capture_output=True,
cwd="./eval_project",
)
try:
res_json = json.loads(proc.stdout)
has_sorry = len(res_json.get("sorries", [])) > 0
has_error = any(
m.get("severity") == "error"
for m in res_json.get("messages", [])
)
if not has_sorry and not has_error:
st.success("✅ 证明验证通过!逻辑严密。")
else:
st.error("❌ 验证失败。")
if has_error:
st.write("错误详情:", res_json.get("messages", []))
if has_sorry:
st.write(
"包含未完成证明(sorry):", res_json.get("sorries", [])
)
if proc.stderr:
st.write("repl stderr:", proc.stderr)
except Exception as e:
st.error(f"解析编译器输出失败: {e}")
if proc.stderr:
st.write("repl stderr:", proc.stderr)
else:
st.warning("LLM 未能生成标准的 Lean 代码块(```lean4 ...```)。")
st.text_area("模型原始输出(用于排查)", gen_text, height=240)