File size: 22,195 Bytes
d9d3238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e0cf1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9d3238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e0cf1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9d3238
 
 
 
 
 
 
 
 
8e0cf1c
d9d3238
 
 
 
 
8e0cf1c
 
 
 
 
 
 
 
 
 
d9d3238
 
 
 
 
 
 
 
 
 
 
 
 
 
8e0cf1c
d9d3238
8e0cf1c
d9d3238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e0cf1c
 
 
 
 
 
 
 
 
 
 
d9d3238
 
 
 
 
 
 
8e0cf1c
d9d3238
 
 
 
 
 
 
 
 
 
 
 
 
 
8e0cf1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9d3238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
import argparse
import ast
import json
import os
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple


@dataclass
class ToolCall:
    name: str
    args: Dict[str, Any]


@dataclass
class ToolResult:
    ok: bool
    output: str
    artifacts: List[str]


def _tokenize(s: str) -> List[str]:
    s = s.lower()
    s = re.sub(r"[^0-9a-z\u4e00-\u9fff]+", " ", s)
    parts = [p.strip() for p in s.split() if p.strip()]
    return parts


def tool_search_kb(query: str, kb_dir: str, top_k: int = 4) -> ToolResult:
    base = Path(kb_dir)
    if not base.exists():
        return ToolResult(ok=True, output="[]", artifacts=[])

    q_tokens = _tokenize(query)
    if not q_tokens:
        return ToolResult(ok=True, output="[]", artifacts=[])

    hits: List[Tuple[int, str, str]] = []
    for p in sorted(base.glob("**/*.md")):
        try:
            text = p.read_text(encoding="utf-8")
        except Exception:
            continue
        chunks = [c.strip() for c in re.split(r"\n\s*\n+", text) if c.strip()]
        for c in chunks:
            c_tokens = _tokenize(c)
            score = sum(1 for t in q_tokens if t in c_tokens)
            if score <= 0:
                continue
            hits.append((score, str(p), c))

    hits.sort(key=lambda x: (-x[0], x[1]))
    picked = hits[: max(0, int(top_k))]
    out = [{"score": s, "source": src, "text": txt} for (s, src, txt) in picked]
    return ToolResult(ok=True, output=json.dumps(out, ensure_ascii=False, indent=2), artifacts=[])


class _SafeEval(ast.NodeVisitor):
    def visit(self, node):
        return super().visit(node)

    def generic_visit(self, node):
        raise ValueError(f"unsupported: {node.__class__.__name__}")

    def visit_Expression(self, node: ast.Expression):
        return self.visit(node.body)

    def visit_BinOp(self, node: ast.BinOp):
        left = self.visit(node.left)
        right = self.visit(node.right)
        if isinstance(node.op, ast.Add):
            return left + right
        if isinstance(node.op, ast.Sub):
            return left - right
        if isinstance(node.op, ast.Mult):
            return left * right
        if isinstance(node.op, ast.Div):
            return left / right
        if isinstance(node.op, ast.FloorDiv):
            return left // right
        if isinstance(node.op, ast.Mod):
            return left % right
        if isinstance(node.op, ast.Pow):
            return left ** right
        raise ValueError("unsupported operator")

    def visit_UnaryOp(self, node: ast.UnaryOp):
        v = self.visit(node.operand)
        if isinstance(node.op, ast.UAdd):
            return +v
        if isinstance(node.op, ast.USub):
            return -v
        raise ValueError("unsupported unary operator")

    def visit_Constant(self, node: ast.Constant):
        if isinstance(node.value, (int, float)):
            return node.value
        raise ValueError("only numbers allowed")


def tool_calc(expression: str) -> ToolResult:
    expr = expression.strip()
    if not re.fullmatch(r"[0-9\.\s\+\-\*\/\(\)\%\^]+", expr):
        return ToolResult(ok=False, output="表达式包含不允许的字符", artifacts=[])
    expr = expr.replace("^", "**")
    try:
        tree = ast.parse(expr, mode="eval")
        value = _SafeEval().visit(tree)
    except Exception as e:
        return ToolResult(ok=False, output=f"计算失败: {e}", artifacts=[])
    return ToolResult(ok=True, output=str(value), artifacts=[])


def tool_write_text(path: str, text: str) -> ToolResult:
    p = Path(path)
    p.parent.mkdir(parents=True, exist_ok=True)
    p.write_text(text, encoding="utf-8")
    return ToolResult(ok=True, output=f"wrote: {p}", artifacts=[str(p)])


def tool_read_text(path: str) -> ToolResult:
    p = Path(path)
    if not p.exists():
        return ToolResult(ok=False, output="not found", artifacts=[])
    return ToolResult(ok=True, output=p.read_text(encoding="utf-8"), artifacts=[])


def _extract_math_expr(goal: str) -> Optional[str]:
    m = re.search(r"(\d[\d\s\+\-\*\/\(\)\%\^\.]*\d)", goal)
    if not m:
        return None
    expr = m.group(1)
    expr = re.sub(r"\s+", "", expr)
    if len(expr) > 120:
        return None
    if not re.search(r"[\+\-\*\/\%\^]", expr):
        return None
    return expr


def _extract_faq_count(goal: str) -> Optional[int]:
    m = re.search(r"(\d+)\s*条\s*FAQ", goal, flags=re.IGNORECASE)
    if m:
        return max(1, min(20, int(m.group(1))))
    m = re.search(r"(\d+)\s*条", goal)
    if m and ("faq" in goal.lower() or "常见问题" in goal):
        return max(1, min(20, int(m.group(1))))
    if "faq" in goal.lower() or "常见问题" in goal:
        return 5
    return None


def _norm_space(s: str) -> str:
    return re.sub(r"\s+", " ", (s or "").strip())


def _lead_to_goal(lead: Dict[str, Any]) -> str:
    company = _norm_space(str(lead.get("company") or ""))
    contact = _norm_space(str(lead.get("contact") or ""))
    role = _norm_space(str(lead.get("role") or ""))
    channel = _norm_space(str(lead.get("channel") or ""))
    product = _norm_space(str(lead.get("product") or ""))
    stage = _norm_space(str(lead.get("stage") or ""))
    pain = _norm_space(str(lead.get("pain_points") or ""))
    budget = _norm_space(str(lead.get("budget") or ""))
    timeline = _norm_space(str(lead.get("timeline") or ""))
    notes = _norm_space(str(lead.get("notes") or ""))

    parts: List[str] = []
    parts.append("生成一份销售线索跟进方案(含补问清单、下一步任务清单、跟进话术:微信/电话/邮件、风险提示),并写进报告。")
    if company:
        parts.append(f"公司/组织:{company}")
    if contact or role:
        who = contact if contact else "(未提供姓名)"
        if role:
            who = f"{who}{role})"
        parts.append(f"联系人:{who}")
    if channel:
        parts.append(f"来源渠道:{channel}")
    if product:
        parts.append(f"关注产品:{product}")
    if stage:
        parts.append(f"当前阶段:{stage}")
    if pain:
        parts.append(f"主要诉求/痛点:{pain}")
    if budget:
        parts.append(f"预算:{budget}")
    if timeline:
        parts.append(f"期望时间:{timeline}")
    if notes:
        parts.append(f"备注:{notes}")
    return "\n".join(parts)


def make_plan(goal: str, kb_dir: str, out_report: str) -> List[ToolCall]:
    plan: List[ToolCall] = []
    plan.append(ToolCall(name="search_kb", args={"query": goal, "kb_dir": kb_dir, "top_k": 6}))
    expr = _extract_math_expr(goal)
    if expr:
        plan.append(ToolCall(name="calc", args={"expression": expr}))
    plan.append(ToolCall(name="write_report", args={"goal": goal, "out_report": out_report}))
    return plan


def _format_faq_items(snippets: List[Dict[str, Any]], n: int) -> List[str]:
    pool: List[str] = []
    for s in snippets:
        t = str(s.get("text", "")).strip()
        t = re.sub(r"\s+", " ", t)
        if len(t) >= 12:
            pool.append(t)
    if not pool:
        pool = [
            "我们提供从获客到转化的自动化流程编排能力,减少重复操作。",
            "支持接入常见广告与社媒渠道,并统一沉淀线索与素材。",
            "关键数据可追踪,可用规则触发通知与任务分配。",
            "提供基础的权限与审计,方便多人协作与风控。",
            "支持将常见问题与业务规范整理进知识库,供智能体检索。",
        ]
    items: List[str] = []
    i = 0
    while len(items) < n:
        items.append(pool[i % len(pool)])
        i += 1
    return items


def _render_report(goal: str, snippets: List[Dict[str, Any]], faq_count: Optional[int], math_expr: Optional[str], math_result: Optional[str]) -> str:
    lines: List[str] = []
    lines.append("# 智能体生成报告")
    lines.append("")
    lines.append("## 目标")
    lines.append(goal.strip())
    lines.append("")
    if snippets:
        lines.append("## 知识库检索摘要")
        for i, s in enumerate(snippets[:4], start=1):
            src = os.path.basename(str(s.get("source", "")))
            txt = str(s.get("text", "")).strip().replace("\n", " ")
            txt = re.sub(r"\s+", " ", txt)
            lines.append(f"- 片段{i}{src}):{txt}")
        lines.append("")
    if faq_count:
        lines.append(f"## FAQ({faq_count} 条)")
        items = _format_faq_items(snippets, faq_count)
        for it in items:
            lines.append(f"- {it}")
        lines.append("")
    if math_expr:
        lines.append("## 计算结果")
        if math_result is None:
            lines.append(f"- 表达式:{math_expr}")
            lines.append("- 结果:计算失败")
        else:
            lines.append(f"- 表达式:{math_expr}")
            lines.append(f"- 结果:{math_result}")
        lines.append("")
    lines.append("## 结论")
    lines.append("已基于本地知识库完成检索,并生成可复核的报告工件。")
    lines.append("")
    return "\n".join(lines)


def _lead_sections(goal: str, snippets: List[Dict[str, Any]]) -> Dict[str, Any]:
    lines = goal.splitlines()
    facts: Dict[str, str] = {}
    for ln in lines:
        if ":" in ln:
            k, v = ln.split(":", 1)
            k = _norm_space(k)
            v = _norm_space(v)
            if k and v:
                facts[k] = v

    company = facts.get("公司/组织", "")
    product = facts.get("关注产品", "")
    stage = facts.get("当前阶段", "")
    pain = facts.get("主要诉求/痛点", "")
    budget = facts.get("预算", "")
    timeline = facts.get("期望时间", "")

    summary = ";".join([p for p in [company, product, stage, pain] if p]) or "已记录线索信息。"

    questions: List[str] = [
        "当前目标是什么(增长/获客/转化/复购/效率/风控)?优先级如何排序?",
        "现有流程里最耗时或最容易丢信息的环节是哪一步?",
        "线索来源结构与规模(渠道、日均量、转化率、客单价)大概是多少?",
        "团队协作方式(谁负责、如何交接、是否需要权限与审计)?",
        "期望在多长周期内看到可量化改进?评估指标是什么?",
    ]
    if budget:
        questions.append("预算口径(一次性/按月/按量)与采购流程(招采/合同/付款节点)?")
    if timeline:
        questions.append("上线时间点是否有外部约束(活动/发布/投放节奏)?")

    tasks: List[Dict[str, str]] = [
        {"task": "补齐关键字段:目标、现状流程、线索规模、关键指标", "owner": "你", "due": "今天"},
        {"task": "确认决策链路:需求方/使用方/审批方/预算方", "owner": "你", "due": "今天"},
        {"task": "整理 3 个典型线索样本与跟进记录(脱敏)", "owner": "客户", "due": "本周"},
        {"task": "对齐一版 SOP:触达→跟进→转化→复盘的关键动作与责任人", "owner": "你", "due": "本周"},
        {"task": "输出落地方案:自动化触发点、权限与审计、指标口径", "owner": "你", "due": "本周"},
        {"task": "安排演示/试用:用样本数据跑一遍闭环", "owner": "你", "due": "下周"},
    ]

    wechat = "我先基于你们当前的线索流程整理一版“可落地的跟进方案+自动化触发点”,你看方便补充下:目标/线索规模/当前痛点三项吗?补齐后我今天给你一版可执行清单。"
    phone = "你好,我是这边负责增长流程自动化方案的。想快速了解下:你们目前线索从哪个渠道进来、谁负责跟进、最容易卡在哪一步?我这边把方案做成你们团队可直接执行的清单。"
    email = "\n".join(
        [
            "主题:线索跟进闭环方案梳理(补问清单 + 可执行任务)",
            "",
            "你好,",
            "",
            "我已根据当前线索信息先拟了一版跟进方案(含补问清单、下一步任务、话术与风险点)。",
            "为保证方案能直接落地,麻烦补充 3 点:目标/线索规模/当前流程卡点(可简单一句话)。",
            "",
            "我收到后会在当天给出:",
            "1) 可执行的任务清单(负责人/截止时间)",
            "2) 自动化触发点建议",
            "3) 指标口径与复盘方式",
            "",
            "谢谢",
        ]
    )

    risks = [
        "对外触达需避免夸大承诺,话术以“可量化目标/可验证步骤”为准。",
        "涉及个人信息与线索数据时,需明确数据来源与留存周期,按权限最小化原则访问。",
        "对接渠道 API/表单时,先做字段映射与去重规则,避免重复触达与口径混乱。",
    ]

    cited: List[str] = []
    for s in snippets[:3]:
        src = os.path.basename(str(s.get("source", "")))
        txt = _norm_space(str(s.get("text", "")).replace("\n", " "))
        if txt:
            cited.append(f"{src}{txt}")

    return {
        "summary": summary,
        "questions": questions[:7],
        "tasks": tasks,
        "scripts": {"wechat": wechat, "phone": phone, "email": email},
        "risks": risks,
        "cited": cited,
    }


def _render_lead_report(goal: str, snippets: List[Dict[str, Any]]) -> Tuple[str, Dict[str, Any]]:
    data = _lead_sections(goal=goal, snippets=snippets)
    lines: List[str] = []
    lines.append("# 线索跟进方案")
    lines.append("")
    lines.append("## 线索摘要")
    lines.append(data["summary"])
    lines.append("")
    lines.append("## 需要补问(建议按顺序)")
    for q in data["questions"]:
        lines.append(f"- {q}")
    lines.append("")
    lines.append("## 下一步任务清单")
    for t in data["tasks"]:
        lines.append(f"- {t['task']}(负责人:{t['owner']};截止:{t['due']})")
    lines.append("")
    lines.append("## 跟进话术")
    lines.append("### 微信/IM")
    lines.append(data["scripts"]["wechat"])
    lines.append("")
    lines.append("### 电话开场")
    lines.append(data["scripts"]["phone"])
    lines.append("")
    lines.append("### 邮件")
    lines.append(data["scripts"]["email"])
    lines.append("")
    lines.append("## 风险提示")
    for r in data["risks"]:
        lines.append(f"- {r}")
    lines.append("")
    if data["cited"]:
        lines.append("## 参考知识片段")
        for c in data["cited"]:
            lines.append(f"- {c}")
        lines.append("")
    lines.append("## 原始目标")
    lines.append(goal.strip())
    lines.append("")
    return "\n".join(lines), data



def _needs_report(goal: str) -> bool:
    g = goal.lower()
    if "报告" in goal or "写进报告" in goal or "写入报告" in goal:
        return True
    if "report" in g:
        return True
    return True


def check_closure(goal: str, out_report: str, mode: str) -> Tuple[bool, str]:
    p = Path(out_report)
    if _needs_report(goal):
        if not p.exists():
            return False, "报告文件不存在"
        txt = p.read_text(encoding="utf-8")
        if mode == "lead_followup":
            if "线索跟进方案" not in txt:
                return False, "报告缺少标题"
            if "需要补问" not in txt or "下一步任务清单" not in txt or "跟进话术" not in txt:
                return False, "报告结构不完整"
            if len(re.findall(r"^\-\s+", txt, flags=re.MULTILINE)) < 8:
                return False, "清单项不足"
        else:
            if "智能体生成报告" not in txt:
                return False, "报告缺少标题"
        faq_count = _extract_faq_count(goal)
        if faq_count:
            if len(re.findall(r"^\-\s+", txt, flags=re.MULTILINE)) < faq_count:
                return False, "FAQ 条数不足"
        expr = _extract_math_expr(goal)
        if expr:
            if expr not in txt:
                return False, "报告未包含表达式"
            calc_res = tool_calc(expr)
            if calc_res.ok and str(calc_res.output) not in txt:
                return False, "报告未包含计算结果"
    return True, "ok"


def run_agent(goal: str, kb_dir: str, out_report: str, max_rounds: int = 3, mode: str = "general") -> Dict[str, Any]:
    round_logs: List[Dict[str, Any]] = []
    state: Dict[str, Any] = {"snippets": [], "math_expr": None, "math_result": None, "artifacts": [], "structured": None}

    for r in range(1, max_rounds + 1):
        plan = make_plan(goal=goal, kb_dir=kb_dir, out_report=out_report)
        step_logs: List[Dict[str, Any]] = []

        for step in plan:
            if step.name == "search_kb":
                res = tool_search_kb(**step.args)
                if res.ok:
                    try:
                        state["snippets"] = json.loads(res.output)
                    except Exception:
                        state["snippets"] = []
                step_logs.append({"tool": step.name, "ok": res.ok, "output": res.output[:800]})
                continue

            if step.name == "calc":
                state["math_expr"] = step.args.get("expression")
                res = tool_calc(**step.args)
                state["math_result"] = res.output if res.ok else None
                step_logs.append({"tool": step.name, "ok": res.ok, "output": res.output})
                continue

            if step.name == "write_report":
                faq_count = _extract_faq_count(goal)
                if mode == "lead_followup":
                    report, structured = _render_lead_report(goal=goal, snippets=state.get("snippets") or [])
                    state["structured"] = structured
                else:
                    report = _render_report(
                        goal=goal,
                        snippets=state.get("snippets") or [],
                        faq_count=faq_count,
                        math_expr=state.get("math_expr"),
                        math_result=state.get("math_result"),
                    )
                res = tool_write_text(path=out_report, text=report)
                state["artifacts"].extend(res.artifacts)
                step_logs.append({"tool": step.name, "ok": res.ok, "output": res.output})
                continue

            step_logs.append({"tool": step.name, "ok": False, "output": "unknown tool"})

        ok, reason = check_closure(goal=goal, out_report=out_report, mode=mode)
        round_logs.append({"round": r, "plan": [c.__dict__ for c in plan], "steps": step_logs, "check": {"ok": ok, "reason": reason}})
        if ok:
            break

    final = []
    final.append("已完成闭环执行:规划→工具执行→自检→产出。")
    final.append(f"报告工件:{out_report}")
    expr = state.get("math_expr")
    if expr:
        final.append(f"计算:{expr} = {state.get('math_result')}")
    faq_count = _extract_faq_count(goal)
    if faq_count:
        final.append(f"FAQ:{faq_count} 条(见报告)")

    return {
        "goal": goal,
        "out_report": out_report,
        "round_logs": round_logs,
        "final_answer": "\n".join(final),
        "artifacts": state.get("artifacts"),
        "structured": state.get("structured"),
        "mode": mode,
    }


def run_agent_mode(mode: str, payload: Dict[str, Any], kb_dir: str, out_report: str, max_rounds: int = 3) -> Dict[str, Any]:
    mode = (mode or "general").strip()
    if mode == "lead_followup":
        lead = payload.get("lead") or {}
        goal = _lead_to_goal(lead if isinstance(lead, dict) else {})
        result = run_agent(goal=goal, kb_dir=kb_dir, out_report=out_report, max_rounds=max_rounds, mode=mode)
        return result
    goal = str(payload.get("goal") or "").strip()
    return run_agent(goal=goal, kb_dir=kb_dir, out_report=out_report, max_rounds=max_rounds, mode="general")


def _print_round_logs(round_logs: List[Dict[str, Any]]) -> None:
    for r in round_logs:
        print(f"\n=== Round {r['round']} ===")
        print("Plan:")
        for s in r["plan"]:
            print(f"- {s['name']} {json.dumps(s['args'], ensure_ascii=False)}")
        print("Steps:")
        for s in r["steps"]:
            out = s["output"]
            out = out.replace("\n", " ")
            if len(out) > 180:
                out = out[:180] + "..."
            print(f"- {s['tool']}: ok={s['ok']} output={out}")
        print(f"Check: ok={r['check']['ok']} reason={r['check']['reason']}")


def cmd_demo() -> int:
    here = Path(__file__).resolve().parent
    goal = "基于知识库,生成一份 5 条 FAQ,并计算 17*23 的结果写进报告"
    out_report = str(here / "out" / "demo_report.md")
    kb_dir = str(here / "kb")
    result = run_agent(goal=goal, kb_dir=kb_dir, out_report=out_report, max_rounds=3)
    _print_round_logs(result["round_logs"])
    print("\n=== Final ===")
    print(result["final_answer"])
    try:
        txt = (here / "out" / "demo_report.md").read_text(encoding="utf-8")
        head = "\n".join(txt.splitlines()[:22])
        print("\n=== Report Preview (Top) ===")
        print(head)
    except Exception:
        pass
    return 0


def cmd_run(goal: str) -> int:
    here = Path(__file__).resolve().parent
    out_report = str(here / "out" / "demo_report.md")
    kb_dir = str(here / "kb")
    result = run_agent(goal=goal, kb_dir=kb_dir, out_report=out_report, max_rounds=3)
    _print_round_logs(result["round_logs"])
    print("\n=== Final ===")
    print(result["final_answer"])
    return 0


def main(argv: List[str]) -> int:
    p = argparse.ArgumentParser(prog="agent.py")
    sub = p.add_subparsers(dest="cmd", required=True)

    sub.add_parser("demo")

    p_run = sub.add_parser("run")
    p_run.add_argument("--goal", required=True)

    args = p.parse_args(argv)
    if args.cmd == "demo":
        return cmd_demo()
    if args.cmd == "run":
        return cmd_run(goal=args.goal)
    return 2


if __name__ == "__main__":
    raise SystemExit(main(sys.argv[1:]))