import base64 import gzip import io import json import traceback import numpy as np import pandas as pd from sklearn.preprocessing import StandardScaler import streamlit as st from workflow.dataloading.dataloading_core import process_complex_data from utils.sanitize_code import sanitize_code, to_json_serializable def infer_load_data(agent) -> None: uploaded_files = st.file_uploader( "选择推理数据集", accept_multiple_files=True, help="拖拽或点击上传多个文件", ) if uploaded_files: try: with st.spinner("正在处理数据..."): big_df, dfs = process_complex_data(uploaded_files, agent) if big_df is not None: agent.save_inference_data(big_df) st.success("导入并处理完成!") except Exception as err: st.error(f"导入失败:{err}") def infer_execution(agent): inference_df = agent.load_inference_processed_df() edited_code = agent.load_inference_code() try: model_obj = agent.load_best_model() exec_ns = { "inference_df": inference_df, 'model_obj': model_obj, "np": np, "pd": pd, "StandardScaler": StandardScaler } with st.spinner("正在进行推断分析..."): exec(edited_code, exec_ns) result_dict = exec_ns.get("result_dict") if result_dict is None: st.error("脚本未写入 `result_dict`。请确保编辑后的脚本在末尾赋值 result_dict。") else: art = result_dict.get('artifacts', {}) b64 = art.pop('predictions_df_b64', None) if not art: result_dict.pop('artifacts', None) serializable = to_json_serializable(result_dict) try: result_json = json.dumps(serializable, ensure_ascii=False) except Exception: result_json = json.dumps(serializable, default=str, ensure_ascii=False) with st.expander("推理结果", True): if b64: try: gz_bytes = base64.b64decode(b64) csv_bytes = gzip.decompress(gz_bytes) df_pred = pd.read_csv(io.BytesIO(csv_bytes)) st.success("已加载带预测结果的 DataFrame") st.dataframe(df_pred) st.download_button( label="下载带预测结果(predictions.csv)", data=csv_bytes, file_name="predictions.csv", mime="text/csv" ) except Exception as e: st.error(f"解码 predictions_df 失败: {e}") # 兜底:尝试从 records 字段恢复 records = result_dict.get('predictions_df_records') if records: try: df_pred = pd.DataFrame(records) st.dataframe(df_pred) except Exception as e2: st.error(f"从 records 恢复表格失败: {e2}") except Exception as e: st.error(f"推断失败:{e}") st.text(traceback.format_exc()) agent.save_inference_error(traceback.format_exc()) raw = agent.code_generation_for_inference(agent.load_code(), inference_data.head(), auto=True) code = sanitize_code(raw) agent.save_inference_code(code)