Spaces:
Running
Running
| import json | |
| import threading | |
| from pathlib import Path | |
| import asyncio | |
| # Fix Python 3.14 asyncio loop error in Gradio | |
| try: | |
| asyncio.get_event_loop() | |
| except RuntimeError: | |
| asyncio.set_event_loop(asyncio.new_event_loop()) | |
| import gradio as gr | |
| import math | |
| from multi_agent_lab.config import TrainConfig | |
| from multi_agent_lab.orchestrator import Orchestrator | |
| LOCK = threading.Lock() | |
| def _runs_dir() -> Path: | |
| d = Path("./data/runs") | |
| d.mkdir(parents=True, exist_ok=True) | |
| return d | |
| def _latest_run_dir() -> Path | None: | |
| base = _runs_dir() | |
| if not base.exists(): | |
| return None | |
| runs = sorted([p for p in base.iterdir() if p.is_dir()], key=lambda p: p.name) | |
| return runs[-1] if runs else None | |
| def run_pipeline(scenario: str, seed: int, epochs: int, lr: float, csv_file, target_col: str) -> tuple[str, str]: | |
| with LOCK: | |
| csv_path = str(csv_file) if csv_file else None | |
| target_col_val = target_col.strip() if target_col.strip() else "target" | |
| cfg = TrainConfig(seed=seed, epochs=epochs, lr=lr, csv_path=csv_path, target_col=target_col_val) | |
| orch = Orchestrator(base_runs_dir=_runs_dir()) | |
| try: | |
| store, summary = orch.run(cfg=cfg, scenario=scenario) | |
| report = store.path("report.md").read_text(encoding="utf-8") | |
| return str(store.run_dir), report | |
| except Exception as e: | |
| return "Error", f"**运行失败**:{str(e)}" | |
| def load_latest_report() -> tuple[str, str]: | |
| p = _latest_run_dir() | |
| if p is None: | |
| return "-", "暂无运行记录。点击上方按钮开始一次训练。" | |
| rp = p / "report.md" | |
| if not rp.exists(): | |
| return str(p), "未找到 report.md(可能运行未完成)。" | |
| return str(p), rp.read_text(encoding="utf-8") | |
| def run_inference(features_text: str) -> str: | |
| p = _latest_run_dir() | |
| if p is None: | |
| return "暂无模型,请先完成一次训练。" | |
| try: | |
| report_json = json.loads((p / "report.json").read_text(encoding="utf-8")) | |
| attempts = report_json.get("attempts", 1) | |
| model_json = json.loads((p / f"model.attempt{attempts}.json").read_text(encoding="utf-8")) | |
| except Exception as e: | |
| return f"无法加载模型数据:{str(e)}" | |
| w = model_json.get("w", []) | |
| b = model_json.get("b", 0.0) | |
| try: | |
| features = [float(x.strip()) for x in features_text.split(",") if x.strip()] | |
| except ValueError: | |
| return "输入格式错误,请确保输入的是用逗号分隔的数字。" | |
| if len(features) != len(w): | |
| return f"特征数量不匹配:模型需要 {len(w)} 个特征,但输入了 {len(features)} 个。" | |
| # Logistic regression inference | |
| margin = sum(x * weight for x, weight in zip(features, w)) + b | |
| prob = 1.0 / (1.0 + math.exp(-margin)) | |
| pred = 1 if prob >= 0.5 else 0 | |
| return f"**预测类别**:{pred}\n\n**正类概率**:{prob:.4f}\n**Margin**:{margin:.4f}" | |
| with gr.Blocks(title="多 Agent 训练与调试实验室") as demo: | |
| gr.Markdown("# 多 Agent 训练与调试实验室\n\n一键演示:数据→训练→评估→自动诊断→自动修复→复训→报告。\n\n**新增功能**:支持上传真实 CSV 数据进行训练,并提供推理闭环!") | |
| with gr.Tabs(): | |
| with gr.Tab("1. 模型训练与多 Agent 调试"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| scenario = gr.Dropdown(choices=["stable", "unstable"], value="unstable", label="场景") | |
| seed = gr.Number(value=42, precision=0, label="seed") | |
| epochs = gr.Number(value=20, precision=0, label="epochs") | |
| lr = gr.Number(value=0.2, label="lr(stable 场景有效;unstable 会先覆盖成易发散值)") | |
| with gr.Column(): | |
| csv_file = gr.File(label="上传 CSV 数据集 (可选,不传则使用内置合成数据)", file_types=[".csv"], type="filepath") | |
| target_col = gr.Textbox(value="target", label="目标列名 (例如:target)") | |
| with gr.Row(): | |
| btn_run = gr.Button("开始运行", variant="primary") | |
| btn_load = gr.Button("加载最新报告") | |
| out_run_dir = gr.Textbox(label="run_dir", interactive=False) | |
| out_report = gr.Markdown(label="report.md") | |
| btn_run.click( | |
| fn=run_pipeline, | |
| inputs=[scenario, seed, epochs, lr, csv_file, target_col], | |
| outputs=[out_run_dir, out_report], | |
| ) | |
| btn_load.click(fn=load_latest_report, inputs=[], outputs=[out_run_dir, out_report]) | |
| demo.load(fn=load_latest_report, inputs=[], outputs=[out_run_dir, out_report]) | |
| with gr.Tab("2. 模型推理测试"): | |
| gr.Markdown("### 使用最新训练的模型进行推理") | |
| gr.Markdown("请输入逗号分隔的特征值。例如:`1.5, -0.2, 3.1, 0.8, -1.1`") | |
| feat_input = gr.Textbox(label="特征输入 (逗号分隔)", lines=3) | |
| btn_infer = gr.Button("执行推理", variant="primary") | |
| infer_result = gr.Markdown(label="推理结果") | |
| btn_infer.click(fn=run_inference, inputs=[feat_input], outputs=[infer_result]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |