algorithm-agent / tools.py
fanjingbo111's picture
Deploy algorithm agent app
ae0a268 verified
Raw
History Blame Contribute Delete
19.3 kB
from __future__ import annotations
import ast
import contextlib
import io
import math
import multiprocessing as mp
import re
import traceback
from dataclasses import dataclass
from typing import Any, Dict, List
@dataclass
class AlgorithmPattern:
name: str
when_to_use: str
key_questions: List[str]
answer_requirements: List[str]
@dataclass
class ScriptResult:
executed: bool
ok: bool
output: str
error: str = ""
PATTERNS = [
AlgorithmPattern(
name="动态规划",
when_to_use="问题具有最优子结构、重复子问题,常见关键词包括最大/最小/方案数/背包/序列/区间。",
key_questions=[
"状态 dp 的含义是什么?",
"每一步有哪些选择?",
"状态转移依赖哪些更小状态?",
"边界条件和计算顺序是什么?",
],
answer_requirements=[
"给出状态定义",
"写出转移方程",
"说明初始化和答案位置",
"分析时间与空间复杂度",
],
),
AlgorithmPattern(
name="贪心",
when_to_use="每步局部最优可推出全局最优,常见于区间选择、排序后选择、交换论证问题。",
key_questions=[
"贪心选择是什么?",
"为什么当前选择不会破坏最优解?",
"是否需要排序或优先队列?",
],
answer_requirements=[
"说明贪心规则",
"给出正确性证明或交换论证",
"说明失败条件",
"分析复杂度",
],
),
AlgorithmPattern(
name="图算法",
when_to_use="输入包含点、边、路径、连通性、匹配、网络或拓扑依赖。",
key_questions=[
"图是有向还是无向?是否带权?",
"目标是最短路、连通性、生成树、匹配还是拓扑顺序?",
"边权是否允许负数?",
],
answer_requirements=[
"明确图模型",
"选择 BFS/DFS/Dijkstra/Bellman-Ford/拓扑排序等范式",
"给出关键步骤",
"分析复杂度和适用条件",
],
),
AlgorithmPattern(
name="递归与分治",
when_to_use="问题可以拆成相同形式的子问题,再合并子问题结果;常见于排序、选择、第 k 大、矩阵乘法、最近点对等问题。",
key_questions=[
"递归函数的语义是什么?",
"如何划分子问题?",
"子问题之间是否独立?",
"合并步骤的代价是多少?",
"递归边界是什么?",
],
answer_requirements=[
"给出递归定义",
"写出合并过程",
"列出递推式",
"用主定理或递归树分析复杂度",
],
),
AlgorithmPattern(
name="回溯",
when_to_use="需要在解空间树中搜索所有可行解或最优解,常见于排列组合、N 皇后、图着色、0/1 选择、约束满足问题。",
key_questions=[
"解空间树的每一层表示什么决策?",
"当前部分解如何扩展?",
"约束函数如何剪掉不可行分支?",
"什么时候记录一个完整解?",
],
answer_requirements=[
"定义解空间和递归搜索状态",
"给出扩展规则和约束剪枝条件",
"说明递归终止条件",
"分析最坏时间复杂度和剪枝效果",
],
),
AlgorithmPattern(
name="分支限界",
when_to_use="需要在组合优化问题中寻找最优解,并可用上界/下界剪枝,常见于 0/1 背包、旅行商、任务分配、装载问题。",
key_questions=[
"每个节点对应什么部分解?",
"界函数如何估计最优可能值?",
"采用队列式、优先队列式还是深度优先式扩展?",
"何时剪枝,何时更新当前最优解?",
],
answer_requirements=[
"给出状态节点定义",
"设计上界或下界函数",
"说明节点扩展和剪枝规则",
"分析最坏复杂度和实际剪枝效果",
],
),
AlgorithmPattern(
name="随机化算法",
when_to_use="允许引入随机选择来降低期望复杂度或提高实践性能,常见于随机快速排序、随机选择、哈希、Monte Carlo/Las Vegas 算法。",
key_questions=[
"随机性用于选择样本、划分点还是搜索方向?",
"算法是 Monte Carlo 还是 Las Vegas?",
"正确性是确定保证还是概率保证?",
"期望时间复杂度如何分析?",
],
answer_requirements=[
"说明随机步骤",
"区分正确性保证类型",
"给出期望复杂度或错误概率分析",
"说明重复运行或放大成功概率的方法",
],
),
AlgorithmPattern(
name="遗传算法",
when_to_use="用于大规模、非凸、难以精确建模的优化问题,通过种群进化搜索近似解,常见于调度、路径规划、参数优化。",
key_questions=[
"个体如何编码为染色体?",
"适应度函数如何衡量解的好坏?",
"选择、交叉、变异算子如何设计?",
"终止条件是什么?",
],
answer_requirements=[
"给出编码方式",
"定义适应度函数",
"说明选择/交叉/变异流程",
"说明参数设置、终止条件和近似性",
],
),
AlgorithmPattern(
name="模拟退火",
when_to_use="用于组合优化或连续优化中的近似搜索,通过温度下降机制跳出局部最优,常见于 TSP、排课、布局、调度。",
key_questions=[
"当前解和邻域解如何表示?",
"目标函数或能量函数是什么?",
"温度初值、降温策略和终止条件如何设置?",
"以什么概率接受更差解?",
],
answer_requirements=[
"定义解表示和邻域生成方式",
"给出能量函数和接受准则",
"说明降温策略",
"说明该方法是近似算法并分析参数影响",
],
),
AlgorithmPattern(
name="复杂度与正确性证明",
when_to_use="问题要求证明算法正确性、分析复杂度、比较算法或解释为什么某策略可行。",
key_questions=[
"需要证明什么性质?",
"可用循环不变式、归纳法、交换论证还是反证法?",
"时间复杂度由哪些循环、递归或数据结构操作决定?",
],
answer_requirements=[
"给出证明结构",
"说明不变式或归纳假设",
"分析边界情况",
"给出清晰复杂度结论",
],
),
]
def retrieve_algorithm_patterns(question: str, top_k: int = 3) -> List[AlgorithmPattern]:
text = question.lower()
scores: List[tuple[int, AlgorithmPattern]] = []
keyword_map: Dict[str, List[str]] = {
"动态规划": ["动态规划", "dp", "背包", "最大", "最小", "最优", "序列", "子序列", "编辑距离", "区间"],
"贪心": ["贪心", "活动", "区间调度", "最早结束", "局部最优", "排序选择"],
"图算法": ["图", "节点", "边", "路径", "最短路", "连通", "拓扑", "匹配", "网络"],
"递归与分治": ["分治", "递归", "归并", "快速", "二分", "最近点对", "第k", "第 k", "递推式", "主定理"],
"回溯": ["回溯", "解空间", "排列", "组合", "n皇后", "n 皇后", "图着色", "约束满足", "可行解", "所有解"],
"分支限界": ["分支限界", "分枝限界", "限界", "上界", "下界", "优先队列", "旅行商", "tsp", "装载", "任务分配"],
"随机化算法": ["随机", "随机化", "概率", "期望", "monte carlo", "las vegas", "随机快排", "随机选择", "哈希"],
"遗传算法": ["遗传", "种群", "染色体", "适应度", "交叉", "变异", "选择算子", "进化"],
"模拟退火": ["模拟退火", "退火", "温度", "降温", "邻域", "接受概率", "局部最优", "全局搜索"],
"复杂度与正确性证明": ["证明", "正确性", "复杂度", "为什么", "不变式", "归纳", "交换论证"],
}
for pattern in PATTERNS:
score = 0
for keyword in keyword_map.get(pattern.name, []):
if keyword in text or keyword in question:
score += 3
for token in tokenize(question):
if token in pattern.when_to_use.lower():
score += 1
scores.append((score, pattern))
scores.sort(key=lambda item: item[0], reverse=True)
selected = [pattern for score, pattern in scores if score > 0][:top_k]
if selected:
return selected
return [PATTERNS[0], PATTERNS[-1]]
def format_patterns(patterns: List[AlgorithmPattern]) -> str:
blocks = []
for pattern in patterns:
blocks.append(
"\n".join(
[
f"算法范式: {pattern.name}",
f"适用场景: {pattern.when_to_use}",
"需要澄清/回答的问题:",
*[f"- {item}" for item in pattern.key_questions],
"答案必须包含:",
*[f"- {item}" for item in pattern.answer_requirements],
]
)
)
return "\n\n".join(blocks)
def tokenize(text: str) -> set[str]:
lower = text.lower()
english = re.findall(r"[a-z0-9_]{2,}", lower)
chinese = re.findall(r"[\u4e00-\u9fff]{2,}", lower)
grams: List[str] = []
for phrase in chinese:
grams.extend(phrase[i : i + 2] for i in range(max(0, len(phrase) - 1)))
grams.extend(phrase[i : i + 3] for i in range(max(0, len(phrase) - 2)))
return set(english + grams)
def extract_python_code(text: str) -> str:
matches = re.findall(r"```(?:python|py)\s*([\s\S]*?)```", text, flags=re.IGNORECASE)
if matches:
return matches[-1].strip()
return ""
def safe_run_python(code: str, timeout: int = 3) -> ScriptResult:
code = code.strip()
if not code:
return ScriptResult(executed=False, ok=False, output="", error="未提供 Python 验证脚本。")
validation_error = validate_python_code(code)
if validation_error:
return ScriptResult(executed=False, ok=False, output="", error=validation_error)
queue: mp.Queue[Any] = mp.Queue()
process = mp.Process(target=_run_python_child, args=(code, queue))
process.start()
process.join(timeout)
if process.is_alive():
process.terminate()
process.join(1)
return ScriptResult(executed=True, ok=False, output="", error=f"脚本超过 {timeout}s 超时限制。")
if queue.empty():
return ScriptResult(executed=True, ok=False, output="", error="脚本未返回结果。")
payload = queue.get()
return ScriptResult(
executed=True,
ok=bool(payload.get("ok")),
output=str(payload.get("output", "")),
error=str(payload.get("error", "")),
)
def validate_python_code(code: str) -> str:
try:
tree = ast.parse(code)
except SyntaxError as exc:
return f"脚本语法错误: {exc}"
denied_nodes = (ast.Import, ast.ImportFrom, ast.With, ast.AsyncWith, ast.ClassDef)
denied_calls = {
"open",
"eval",
"exec",
"compile",
"input",
"__import__",
"globals",
"locals",
"vars",
"dir",
"getattr",
"setattr",
"delattr",
"help",
"breakpoint",
}
for node in ast.walk(tree):
if isinstance(node, denied_nodes):
return "验证脚本包含 import、with 或 class 等不允许的结构。"
if isinstance(node, ast.Name) and node.id.startswith("__"):
return "验证脚本包含不允许的双下划线名称。"
if isinstance(node, ast.Attribute) and node.attr.startswith("__"):
return "验证脚本包含不允许的双下划线属性。"
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id in denied_calls:
return f"验证脚本调用了不允许的函数: {node.func.id}"
return ""
def _run_python_child(code: str, queue: mp.Queue[Any]) -> None:
allowed_builtins = {
"abs": abs,
"all": all,
"any": any,
"bool": bool,
"dict": dict,
"enumerate": enumerate,
"float": float,
"int": int,
"len": len,
"list": list,
"max": max,
"min": min,
"pow": pow,
"print": print,
"range": range,
"reversed": reversed,
"round": round,
"set": set,
"sorted": sorted,
"str": str,
"sum": sum,
"tuple": tuple,
"zip": zip,
}
env: Dict[str, Any] = {
"__builtins__": allowed_builtins,
"math": math,
}
stdout = io.StringIO()
try:
with contextlib.redirect_stdout(stdout):
exec(compile(code, "<agent-script>", "exec"), env, env)
queue.put({"ok": True, "output": stdout.getvalue().strip(), "error": ""})
except Exception:
queue.put({"ok": False, "output": stdout.getvalue().strip(), "error": traceback.format_exc(limit=3)})
def offline_algorithm_answer(question: str, patterns: List[AlgorithmPattern], note: str = "") -> str:
pattern_names = "、".join(pattern.name for pattern in patterns)
primary = patterns[0].name if patterns else "通用算法设计"
if primary == "动态规划":
body = "\n".join(
[
"建议优先用动态规划建模。",
"1. 状态定义: 设 dp[...] 表示处理到某个前缀、容量、位置或区间时的最优值。",
"2. 状态转移: 枚举当前决策,把“不选/选择/划分/匹配”等选择映射到更小子问题。",
"3. 边界条件: 空集合、容量为 0、长度为 0 或单元素区间通常作为初始状态。",
"4. 计算顺序: 按依赖关系从小规模到大规模填表。",
"5. 复杂度: 通常由状态数乘以每个状态的转移代价得到。",
]
)
elif primary == "贪心":
body = "\n".join(
[
"建议先检验是否存在可证明的贪心选择。",
"1. 对输入按关键指标排序或维护优先队列。",
"2. 每一步选择当前最不影响未来可行性的对象。",
"3. 用交换论证说明任意最优解都能替换为包含当前选择的最优解。",
"4. 若交换论证无法成立,应退回动态规划或搜索。",
]
)
elif primary == "递归与分治":
body = "\n".join(
[
"建议用递归与分治建模。",
"1. 定义递归函数的输入规模和返回含义。",
"2. 将原问题划分为若干相同形式的子问题。",
"3. 分别求解子问题,再设计合并过程。",
"4. 写出递归边界和递推式。",
"5. 用主定理、递归树或展开法分析复杂度。",
]
)
elif primary == "回溯":
body = "\n".join(
[
"建议用回溯搜索解空间树。",
"1. 定义每一层决策变量和当前部分解。",
"2. 枚举当前层所有候选选择。",
"3. 用约束函数剪掉不可行分支。",
"4. 到达叶子或完整解时记录答案。",
"5. 分析最坏搜索规模,并说明剪枝如何减少实际搜索。",
]
)
elif primary == "分支限界":
body = "\n".join(
[
"建议用分支限界处理组合优化。",
"1. 将部分解表示为搜索树节点。",
"2. 设计上界或下界估计当前节点可能达到的最优值。",
"3. 用队列、优先队列或深度优先策略扩展节点。",
"4. 如果界值不可能优于当前最优解,则剪枝。",
"5. 输出最优解,并分析最坏指数复杂度和剪枝效果。",
]
)
elif primary == "随机化算法":
body = "\n".join(
[
"建议明确随机化算法的随机步骤和概率保证。",
"1. 说明随机选择发生在采样、划分、哈希还是搜索方向。",
"2. 区分算法是 Monte Carlo 还是 Las Vegas。",
"3. 分析正确性概率、失败概率或期望运行时间。",
"4. 如有必要,通过重复运行降低错误概率。",
]
)
elif primary == "遗传算法":
body = "\n".join(
[
"建议用遗传算法描述近似优化流程。",
"1. 设计染色体编码,把候选解表示为个体。",
"2. 定义适应度函数衡量解的好坏。",
"3. 说明选择、交叉、变异和保留策略。",
"4. 设置种群规模、迭代次数和终止条件。",
"5. 强调该方法通常给出近似解,需要实验评价质量。",
]
)
elif primary == "模拟退火":
body = "\n".join(
[
"建议用模拟退火描述近似搜索过程。",
"1. 定义当前解、邻域生成方式和目标函数。",
"2. 设定初始温度、降温策略和终止条件。",
"3. 若新解更优则接受;若更差则按概率接受以跳出局部最优。",
"4. 说明参数对收敛速度和解质量的影响。",
]
)
else:
body = "\n".join(
[
"建议先完成结构化建模,再选择算法范式。",
"1. 明确输入、约束、目标函数和输出。",
"2. 判断是否存在最优子结构、图结构、排序选择或递归拆分。",
"3. 给出伪代码和复杂度分析。",
"4. 对小规模样例写脚本验证结论。",
]
)
sections = [
f"候选算法范式: {pattern_names}",
f"原始问题: {question}",
body,
"当前未获得公开 LLM 的有效响应,因此这里给出的是离线算法设计框架。请检查 API Key、Base URL 和模型名后重新运行。",
]
if note:
sections.insert(0, note)
return "\n\n".join(sections)