new-human
feat: 支持真实CSV数据训练与模型推理闭环
0bb564a
import json
import threading
from pathlib import Path
import asyncio
# Fix Python 3.14 asyncio loop error in Gradio
try:
asyncio.get_event_loop()
except RuntimeError:
asyncio.set_event_loop(asyncio.new_event_loop())
import gradio as gr
import math
from multi_agent_lab.config import TrainConfig
from multi_agent_lab.orchestrator import Orchestrator
LOCK = threading.Lock()
def _runs_dir() -> Path:
d = Path("./data/runs")
d.mkdir(parents=True, exist_ok=True)
return d
def _latest_run_dir() -> Path | None:
base = _runs_dir()
if not base.exists():
return None
runs = sorted([p for p in base.iterdir() if p.is_dir()], key=lambda p: p.name)
return runs[-1] if runs else None
def run_pipeline(scenario: str, seed: int, epochs: int, lr: float, csv_file, target_col: str) -> tuple[str, str]:
with LOCK:
csv_path = str(csv_file) if csv_file else None
target_col_val = target_col.strip() if target_col.strip() else "target"
cfg = TrainConfig(seed=seed, epochs=epochs, lr=lr, csv_path=csv_path, target_col=target_col_val)
orch = Orchestrator(base_runs_dir=_runs_dir())
try:
store, summary = orch.run(cfg=cfg, scenario=scenario)
report = store.path("report.md").read_text(encoding="utf-8")
return str(store.run_dir), report
except Exception as e:
return "Error", f"**运行失败**:{str(e)}"
def load_latest_report() -> tuple[str, str]:
p = _latest_run_dir()
if p is None:
return "-", "暂无运行记录。点击上方按钮开始一次训练。"
rp = p / "report.md"
if not rp.exists():
return str(p), "未找到 report.md(可能运行未完成)。"
return str(p), rp.read_text(encoding="utf-8")
def run_inference(features_text: str) -> str:
p = _latest_run_dir()
if p is None:
return "暂无模型,请先完成一次训练。"
try:
report_json = json.loads((p / "report.json").read_text(encoding="utf-8"))
attempts = report_json.get("attempts", 1)
model_json = json.loads((p / f"model.attempt{attempts}.json").read_text(encoding="utf-8"))
except Exception as e:
return f"无法加载模型数据:{str(e)}"
w = model_json.get("w", [])
b = model_json.get("b", 0.0)
try:
features = [float(x.strip()) for x in features_text.split(",") if x.strip()]
except ValueError:
return "输入格式错误,请确保输入的是用逗号分隔的数字。"
if len(features) != len(w):
return f"特征数量不匹配:模型需要 {len(w)} 个特征,但输入了 {len(features)} 个。"
# Logistic regression inference
margin = sum(x * weight for x, weight in zip(features, w)) + b
prob = 1.0 / (1.0 + math.exp(-margin))
pred = 1 if prob >= 0.5 else 0
return f"**预测类别**:{pred}\n\n**正类概率**:{prob:.4f}\n**Margin**:{margin:.4f}"
with gr.Blocks(title="多 Agent 训练与调试实验室") as demo:
gr.Markdown("# 多 Agent 训练与调试实验室\n\n一键演示:数据→训练→评估→自动诊断→自动修复→复训→报告。\n\n**新增功能**:支持上传真实 CSV 数据进行训练,并提供推理闭环!")
with gr.Tabs():
with gr.Tab("1. 模型训练与多 Agent 调试"):
with gr.Row():
with gr.Column():
scenario = gr.Dropdown(choices=["stable", "unstable"], value="unstable", label="场景")
seed = gr.Number(value=42, precision=0, label="seed")
epochs = gr.Number(value=20, precision=0, label="epochs")
lr = gr.Number(value=0.2, label="lr(stable 场景有效;unstable 会先覆盖成易发散值)")
with gr.Column():
csv_file = gr.File(label="上传 CSV 数据集 (可选,不传则使用内置合成数据)", file_types=[".csv"], type="filepath")
target_col = gr.Textbox(value="target", label="目标列名 (例如:target)")
with gr.Row():
btn_run = gr.Button("开始运行", variant="primary")
btn_load = gr.Button("加载最新报告")
out_run_dir = gr.Textbox(label="run_dir", interactive=False)
out_report = gr.Markdown(label="report.md")
btn_run.click(
fn=run_pipeline,
inputs=[scenario, seed, epochs, lr, csv_file, target_col],
outputs=[out_run_dir, out_report],
)
btn_load.click(fn=load_latest_report, inputs=[], outputs=[out_run_dir, out_report])
demo.load(fn=load_latest_report, inputs=[], outputs=[out_run_dir, out_report])
with gr.Tab("2. 模型推理测试"):
gr.Markdown("### 使用最新训练的模型进行推理")
gr.Markdown("请输入逗号分隔的特征值。例如:`1.5, -0.2, 3.1, 0.8, -1.1`")
feat_input = gr.Textbox(label="特征输入 (逗号分隔)", lines=3)
btn_infer = gr.Button("执行推理", variant="primary")
infer_result = gr.Markdown(label="推理结果")
btn_infer.click(fn=run_inference, inputs=[feat_input], outputs=[infer_result])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)