Spaces:
Running
Running
| import importlib | |
| import json | |
| import traceback | |
| import base64 | |
| import gzip | |
| import pickle | |
| import time | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import streamlit_antd_components as sac | |
| from streamlit_ace import st_ace | |
| import torch | |
| import torchvision | |
| import xgboost | |
| import lightgbm | |
| from sklearn.ensemble import GradientBoostingRegressor, RandomForestClassifier, RandomForestRegressor | |
| from sklearn.linear_model import LinearRegression | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.preprocessing import StandardScaler | |
| from utils.sanitize_code import sanitize_code, to_json_serializable | |
| def train_execution(agent): | |
| code = agent.load_code() | |
| df = agent.load_df() | |
| torch = importlib.import_module("torch") | |
| torchvision = importlib.import_module("torchvision") | |
| exec_ns = { | |
| "df": df, | |
| "np": np, | |
| "pd": pd, | |
| "torch": torch, | |
| "torchvision": torchvision, | |
| "train_test_split": train_test_split, | |
| "StandardScaler": StandardScaler, | |
| "LinearRegression": LinearRegression, | |
| "RandomForestRegressor": RandomForestRegressor, | |
| "GradientBoostingRegressor": GradientBoostingRegressor, | |
| "RandomForestClassifier": RandomForestClassifier, | |
| "xgboost": xgboost, | |
| "lightgbm": lightgbm, | |
| } | |
| try: | |
| with st.spinner("正在运行程序..."): | |
| exec(code, exec_ns) | |
| except Exception as exc: | |
| st.error(f"已保存报错,请重新调用llm生成代码debug") | |
| # st.error(f"脚本执行失败:{exc}") | |
| st.text(traceback.format_exc()) | |
| agent.save_error(traceback.format_exc()) | |
| modeling_code_gen(agent, debug=True) | |
| else: | |
| result_dict = exec_ns.get("result_dict") | |
| if result_dict is None: | |
| st.error( | |
| "脚本未写入 `result_dict`。请确保编辑后的脚本在末尾赋值 result_dict。" | |
| ) | |
| else: | |
| art = result_dict.get('artifacts', {}) | |
| b64 = art.pop('best_model_b64', None) | |
| artifact_warning = result_dict.pop('artifact_warning', None) | |
| if not art: | |
| result_dict.pop('artifacts', None) | |
| serializable = to_json_serializable(result_dict) | |
| try: | |
| result_json = json.dumps(serializable, ensure_ascii=False) | |
| except Exception: | |
| result_json = json.dumps(serializable, default=str, ensure_ascii=False) | |
| with st.spinner("请求 LLM 格式化结果为 Markdown..."): | |
| formatted = agent.result_format_prompt(result_json) | |
| agent.save_modeling_result(formatted) | |
| if b64: | |
| gz_bytes = base64.b64decode(b64) | |
| try: | |
| agent.save_best_model_gz_bytes(gz_bytes) | |
| model_obj = pickle.loads(gzip.decompress(gz_bytes)) | |
| st.success("最佳模型已加载到内存,可用于即时推理(示例)。") | |
| agent.save_best_model(model_obj) | |
| except Exception as e: | |
| st.error(f"加载模型失败:{e}") | |
| def modeling_code_gen(agent, debug = False, auto = False, ) -> None: | |
| df = agent.load_df() | |
| suggest = agent.load_suggestion() | |
| print(suggest) | |
| chat_history = agent.load_memory() | |
| already_generated = any( | |
| entry["role"] == "assistant" and "训练脚本已更新!请重新运行代码!" in str(entry["content"]) | |
| for entry in chat_history | |
| ) | |
| if suggest is not None: | |
| if debug == True or (auto and not already_generated): | |
| with st.spinner("建模 Agent 正在生成训练脚本..."): | |
| raw = agent.code_generation( | |
| df.head().to_string(), | |
| suggest, | |
| ) | |
| code = sanitize_code(raw) | |
| agent.save_code(code) | |
| st.chat_message("assistant").write("训练脚本已更新!请重新运行代码!") | |
| agent.add_memory({"role": "assistant", "content": "训练脚本已更新!请重新运行代码!"}) | |
| st.rerun() | |
| analyze_btn = st.button("🔧 生成建模代码", key='modeling_code') | |
| if analyze_btn: | |
| with st.spinner("建模 Agent 正在生成训练脚本..."): | |
| raw = agent.code_generation( | |
| df.head().to_string(), | |
| suggest, | |
| ) | |
| code = sanitize_code(raw) | |
| agent.save_code(code) | |
| st.chat_message("assistant").write("训练脚本已更新!请重新运行代码!") | |
| agent.add_memory({"role": "assistant", "content": "训练脚本已更新!请重新运行代码!"}) | |
| st.rerun() | |
| def train_download_model(agent): | |
| model = agent.load_best_model_gz_bytes() | |
| if model is not None: | |
| st.download_button( | |
| label="⬇️ 下载最佳模型", | |
| data=model, | |
| file_name="best_model.pkl.gz", | |
| mime="application/gzip" | |
| ) |