File size: 5,431 Bytes
d536096
 
 
4f9d2b1
 
 
 
 
 
 
d536096
 
0bb564a
d536096
 
 
 
 
 
 
4f9d2b1
d536096
 
 
 
 
 
 
 
 
 
0bb564a
d536096
0bb564a
 
 
 
d536096
0bb564a
 
 
 
 
 
d536096
 
 
 
 
 
 
 
 
 
0bb564a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d536096
 
0bb564a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d536096
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import json
import threading
from pathlib import Path
import asyncio

# Fix Python 3.14 asyncio loop error in Gradio
try:
    asyncio.get_event_loop()
except RuntimeError:
    asyncio.set_event_loop(asyncio.new_event_loop())

import gradio as gr
import math

from multi_agent_lab.config import TrainConfig
from multi_agent_lab.orchestrator import Orchestrator

LOCK = threading.Lock()

def _runs_dir() -> Path:
    d = Path("./data/runs")
    d.mkdir(parents=True, exist_ok=True)
    return d

def _latest_run_dir() -> Path | None:
    base = _runs_dir()
    if not base.exists():
        return None
    runs = sorted([p for p in base.iterdir() if p.is_dir()], key=lambda p: p.name)
    return runs[-1] if runs else None

def run_pipeline(scenario: str, seed: int, epochs: int, lr: float, csv_file, target_col: str) -> tuple[str, str]:
    with LOCK:
        csv_path = str(csv_file) if csv_file else None
        target_col_val = target_col.strip() if target_col.strip() else "target"
        
        cfg = TrainConfig(seed=seed, epochs=epochs, lr=lr, csv_path=csv_path, target_col=target_col_val)
        orch = Orchestrator(base_runs_dir=_runs_dir())
        try:
            store, summary = orch.run(cfg=cfg, scenario=scenario)
            report = store.path("report.md").read_text(encoding="utf-8")
            return str(store.run_dir), report
        except Exception as e:
            return "Error", f"**运行失败**:{str(e)}"

def load_latest_report() -> tuple[str, str]:
    p = _latest_run_dir()
    if p is None:
        return "-", "暂无运行记录。点击上方按钮开始一次训练。"
    rp = p / "report.md"
    if not rp.exists():
        return str(p), "未找到 report.md(可能运行未完成)。"
    return str(p), rp.read_text(encoding="utf-8")

def run_inference(features_text: str) -> str:
    p = _latest_run_dir()
    if p is None:
        return "暂无模型,请先完成一次训练。"
        
    try:
        report_json = json.loads((p / "report.json").read_text(encoding="utf-8"))
        attempts = report_json.get("attempts", 1)
        model_json = json.loads((p / f"model.attempt{attempts}.json").read_text(encoding="utf-8"))
    except Exception as e:
        return f"无法加载模型数据:{str(e)}"
        
    w = model_json.get("w", [])
    b = model_json.get("b", 0.0)
    
    try:
        features = [float(x.strip()) for x in features_text.split(",") if x.strip()]
    except ValueError:
        return "输入格式错误,请确保输入的是用逗号分隔的数字。"
        
    if len(features) != len(w):
        return f"特征数量不匹配:模型需要 {len(w)} 个特征,但输入了 {len(features)} 个。"
        
    # Logistic regression inference
    margin = sum(x * weight for x, weight in zip(features, w)) + b
    prob = 1.0 / (1.0 + math.exp(-margin))
    pred = 1 if prob >= 0.5 else 0
    
    return f"**预测类别**:{pred}\n\n**正类概率**:{prob:.4f}\n**Margin**:{margin:.4f}"

with gr.Blocks(title="多 Agent 训练与调试实验室") as demo:
    gr.Markdown("# 多 Agent 训练与调试实验室\n\n一键演示:数据→训练→评估→自动诊断→自动修复→复训→报告。\n\n**新增功能**:支持上传真实 CSV 数据进行训练,并提供推理闭环!")

    with gr.Tabs():
        with gr.Tab("1. 模型训练与多 Agent 调试"):
            with gr.Row():
                with gr.Column():
                    scenario = gr.Dropdown(choices=["stable", "unstable"], value="unstable", label="场景")
                    seed = gr.Number(value=42, precision=0, label="seed")
                    epochs = gr.Number(value=20, precision=0, label="epochs")
                    lr = gr.Number(value=0.2, label="lr(stable 场景有效;unstable 会先覆盖成易发散值)")
                with gr.Column():
                    csv_file = gr.File(label="上传 CSV 数据集 (可选,不传则使用内置合成数据)", file_types=[".csv"], type="filepath")
                    target_col = gr.Textbox(value="target", label="目标列名 (例如:target)")

            with gr.Row():
                btn_run = gr.Button("开始运行", variant="primary")
                btn_load = gr.Button("加载最新报告")

            out_run_dir = gr.Textbox(label="run_dir", interactive=False)
            out_report = gr.Markdown(label="report.md")

            btn_run.click(
                fn=run_pipeline,
                inputs=[scenario, seed, epochs, lr, csv_file, target_col],
                outputs=[out_run_dir, out_report],
            )
            btn_load.click(fn=load_latest_report, inputs=[], outputs=[out_run_dir, out_report])
            demo.load(fn=load_latest_report, inputs=[], outputs=[out_run_dir, out_report])
            
        with gr.Tab("2. 模型推理测试"):
            gr.Markdown("### 使用最新训练的模型进行推理")
            gr.Markdown("请输入逗号分隔的特征值。例如:`1.5, -0.2, 3.1, 0.8, -1.1`")
            feat_input = gr.Textbox(label="特征输入 (逗号分隔)", lines=3)
            btn_infer = gr.Button("执行推理", variant="primary")
            infer_result = gr.Markdown(label="推理结果")
            
            btn_infer.click(fn=run_inference, inputs=[feat_input], outputs=[infer_result])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)