Spaces:
Running
Running
File size: 5,431 Bytes
d536096 4f9d2b1 d536096 0bb564a d536096 4f9d2b1 d536096 0bb564a d536096 0bb564a d536096 0bb564a d536096 0bb564a d536096 0bb564a d536096 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | import json
import threading
from pathlib import Path
import asyncio
# Fix Python 3.14 asyncio loop error in Gradio
try:
asyncio.get_event_loop()
except RuntimeError:
asyncio.set_event_loop(asyncio.new_event_loop())
import gradio as gr
import math
from multi_agent_lab.config import TrainConfig
from multi_agent_lab.orchestrator import Orchestrator
LOCK = threading.Lock()
def _runs_dir() -> Path:
d = Path("./data/runs")
d.mkdir(parents=True, exist_ok=True)
return d
def _latest_run_dir() -> Path | None:
base = _runs_dir()
if not base.exists():
return None
runs = sorted([p for p in base.iterdir() if p.is_dir()], key=lambda p: p.name)
return runs[-1] if runs else None
def run_pipeline(scenario: str, seed: int, epochs: int, lr: float, csv_file, target_col: str) -> tuple[str, str]:
with LOCK:
csv_path = str(csv_file) if csv_file else None
target_col_val = target_col.strip() if target_col.strip() else "target"
cfg = TrainConfig(seed=seed, epochs=epochs, lr=lr, csv_path=csv_path, target_col=target_col_val)
orch = Orchestrator(base_runs_dir=_runs_dir())
try:
store, summary = orch.run(cfg=cfg, scenario=scenario)
report = store.path("report.md").read_text(encoding="utf-8")
return str(store.run_dir), report
except Exception as e:
return "Error", f"**运行失败**:{str(e)}"
def load_latest_report() -> tuple[str, str]:
p = _latest_run_dir()
if p is None:
return "-", "暂无运行记录。点击上方按钮开始一次训练。"
rp = p / "report.md"
if not rp.exists():
return str(p), "未找到 report.md(可能运行未完成)。"
return str(p), rp.read_text(encoding="utf-8")
def run_inference(features_text: str) -> str:
p = _latest_run_dir()
if p is None:
return "暂无模型,请先完成一次训练。"
try:
report_json = json.loads((p / "report.json").read_text(encoding="utf-8"))
attempts = report_json.get("attempts", 1)
model_json = json.loads((p / f"model.attempt{attempts}.json").read_text(encoding="utf-8"))
except Exception as e:
return f"无法加载模型数据:{str(e)}"
w = model_json.get("w", [])
b = model_json.get("b", 0.0)
try:
features = [float(x.strip()) for x in features_text.split(",") if x.strip()]
except ValueError:
return "输入格式错误,请确保输入的是用逗号分隔的数字。"
if len(features) != len(w):
return f"特征数量不匹配:模型需要 {len(w)} 个特征,但输入了 {len(features)} 个。"
# Logistic regression inference
margin = sum(x * weight for x, weight in zip(features, w)) + b
prob = 1.0 / (1.0 + math.exp(-margin))
pred = 1 if prob >= 0.5 else 0
return f"**预测类别**:{pred}\n\n**正类概率**:{prob:.4f}\n**Margin**:{margin:.4f}"
with gr.Blocks(title="多 Agent 训练与调试实验室") as demo:
gr.Markdown("# 多 Agent 训练与调试实验室\n\n一键演示:数据→训练→评估→自动诊断→自动修复→复训→报告。\n\n**新增功能**:支持上传真实 CSV 数据进行训练,并提供推理闭环!")
with gr.Tabs():
with gr.Tab("1. 模型训练与多 Agent 调试"):
with gr.Row():
with gr.Column():
scenario = gr.Dropdown(choices=["stable", "unstable"], value="unstable", label="场景")
seed = gr.Number(value=42, precision=0, label="seed")
epochs = gr.Number(value=20, precision=0, label="epochs")
lr = gr.Number(value=0.2, label="lr(stable 场景有效;unstable 会先覆盖成易发散值)")
with gr.Column():
csv_file = gr.File(label="上传 CSV 数据集 (可选,不传则使用内置合成数据)", file_types=[".csv"], type="filepath")
target_col = gr.Textbox(value="target", label="目标列名 (例如:target)")
with gr.Row():
btn_run = gr.Button("开始运行", variant="primary")
btn_load = gr.Button("加载最新报告")
out_run_dir = gr.Textbox(label="run_dir", interactive=False)
out_report = gr.Markdown(label="report.md")
btn_run.click(
fn=run_pipeline,
inputs=[scenario, seed, epochs, lr, csv_file, target_col],
outputs=[out_run_dir, out_report],
)
btn_load.click(fn=load_latest_report, inputs=[], outputs=[out_run_dir, out_report])
demo.load(fn=load_latest_report, inputs=[], outputs=[out_run_dir, out_report])
with gr.Tab("2. 模型推理测试"):
gr.Markdown("### 使用最新训练的模型进行推理")
gr.Markdown("请输入逗号分隔的特征值。例如:`1.5, -0.2, 3.1, 0.8, -1.1`")
feat_input = gr.Textbox(label="特征输入 (逗号分隔)", lines=3)
btn_infer = gr.Button("执行推理", variant="primary")
infer_result = gr.Markdown(label="推理结果")
btn_infer.click(fn=run_inference, inputs=[feat_input], outputs=[infer_result])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|