import json import threading from pathlib import Path import asyncio # Fix Python 3.14 asyncio loop error in Gradio try: asyncio.get_event_loop() except RuntimeError: asyncio.set_event_loop(asyncio.new_event_loop()) import gradio as gr import math from multi_agent_lab.config import TrainConfig from multi_agent_lab.orchestrator import Orchestrator LOCK = threading.Lock() def _runs_dir() -> Path: d = Path("./data/runs") d.mkdir(parents=True, exist_ok=True) return d def _latest_run_dir() -> Path | None: base = _runs_dir() if not base.exists(): return None runs = sorted([p for p in base.iterdir() if p.is_dir()], key=lambda p: p.name) return runs[-1] if runs else None def run_pipeline(scenario: str, seed: int, epochs: int, lr: float, csv_file, target_col: str) -> tuple[str, str]: with LOCK: csv_path = str(csv_file) if csv_file else None target_col_val = target_col.strip() if target_col.strip() else "target" cfg = TrainConfig(seed=seed, epochs=epochs, lr=lr, csv_path=csv_path, target_col=target_col_val) orch = Orchestrator(base_runs_dir=_runs_dir()) try: store, summary = orch.run(cfg=cfg, scenario=scenario) report = store.path("report.md").read_text(encoding="utf-8") return str(store.run_dir), report except Exception as e: return "Error", f"**运行失败**:{str(e)}" def load_latest_report() -> tuple[str, str]: p = _latest_run_dir() if p is None: return "-", "暂无运行记录。点击上方按钮开始一次训练。" rp = p / "report.md" if not rp.exists(): return str(p), "未找到 report.md(可能运行未完成)。" return str(p), rp.read_text(encoding="utf-8") def run_inference(features_text: str) -> str: p = _latest_run_dir() if p is None: return "暂无模型,请先完成一次训练。" try: report_json = json.loads((p / "report.json").read_text(encoding="utf-8")) attempts = report_json.get("attempts", 1) model_json = json.loads((p / f"model.attempt{attempts}.json").read_text(encoding="utf-8")) except Exception as e: return f"无法加载模型数据:{str(e)}" w = model_json.get("w", []) b = model_json.get("b", 0.0) try: features = [float(x.strip()) for x in features_text.split(",") if x.strip()] except ValueError: return "输入格式错误,请确保输入的是用逗号分隔的数字。" if len(features) != len(w): return f"特征数量不匹配:模型需要 {len(w)} 个特征,但输入了 {len(features)} 个。" # Logistic regression inference margin = sum(x * weight for x, weight in zip(features, w)) + b prob = 1.0 / (1.0 + math.exp(-margin)) pred = 1 if prob >= 0.5 else 0 return f"**预测类别**:{pred}\n\n**正类概率**:{prob:.4f}\n**Margin**:{margin:.4f}" with gr.Blocks(title="多 Agent 训练与调试实验室") as demo: gr.Markdown("# 多 Agent 训练与调试实验室\n\n一键演示:数据→训练→评估→自动诊断→自动修复→复训→报告。\n\n**新增功能**:支持上传真实 CSV 数据进行训练,并提供推理闭环!") with gr.Tabs(): with gr.Tab("1. 模型训练与多 Agent 调试"): with gr.Row(): with gr.Column(): scenario = gr.Dropdown(choices=["stable", "unstable"], value="unstable", label="场景") seed = gr.Number(value=42, precision=0, label="seed") epochs = gr.Number(value=20, precision=0, label="epochs") lr = gr.Number(value=0.2, label="lr(stable 场景有效;unstable 会先覆盖成易发散值)") with gr.Column(): csv_file = gr.File(label="上传 CSV 数据集 (可选,不传则使用内置合成数据)", file_types=[".csv"], type="filepath") target_col = gr.Textbox(value="target", label="目标列名 (例如:target)") with gr.Row(): btn_run = gr.Button("开始运行", variant="primary") btn_load = gr.Button("加载最新报告") out_run_dir = gr.Textbox(label="run_dir", interactive=False) out_report = gr.Markdown(label="report.md") btn_run.click( fn=run_pipeline, inputs=[scenario, seed, epochs, lr, csv_file, target_col], outputs=[out_run_dir, out_report], ) btn_load.click(fn=load_latest_report, inputs=[], outputs=[out_run_dir, out_report]) demo.load(fn=load_latest_report, inputs=[], outputs=[out_run_dir, out_report]) with gr.Tab("2. 模型推理测试"): gr.Markdown("### 使用最新训练的模型进行推理") gr.Markdown("请输入逗号分隔的特征值。例如:`1.5, -0.2, 3.1, 0.8, -1.1`") feat_input = gr.Textbox(label="特征输入 (逗号分隔)", lines=3) btn_infer = gr.Button("执行推理", variant="primary") infer_result = gr.Markdown(label="推理结果") btn_infer.click(fn=run_inference, inputs=[feat_input], outputs=[infer_result]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)