AutoSTAT / workflow /modeling /model_inference.py
ElvisWang111's picture
Upload folder using huggingface_hub
342e4c4 verified
import base64
import gzip
import io
import json
import traceback
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
import streamlit as st
from workflow.dataloading.dataloading_core import process_complex_data
from utils.sanitize_code import sanitize_code, to_json_serializable
def infer_load_data(agent) -> None:
uploaded_files = st.file_uploader(
"选择推理数据集",
accept_multiple_files=True,
help="拖拽或点击上传多个文件",
)
if uploaded_files:
try:
with st.spinner("正在处理数据..."):
big_df, dfs = process_complex_data(uploaded_files, agent)
if big_df is not None:
agent.save_inference_data(big_df)
st.success("导入并处理完成!")
except Exception as err:
st.error(f"导入失败:{err}")
def infer_execution(agent):
inference_df = agent.load_inference_processed_df()
edited_code = agent.load_inference_code()
try:
model_obj = agent.load_best_model()
exec_ns = {
"inference_df": inference_df,
'model_obj': model_obj,
"np": np,
"pd": pd,
"StandardScaler": StandardScaler
}
with st.spinner("正在进行推断分析..."):
exec(edited_code, exec_ns)
result_dict = exec_ns.get("result_dict")
if result_dict is None:
st.error("脚本未写入 `result_dict`。请确保编辑后的脚本在末尾赋值 result_dict。")
else:
art = result_dict.get('artifacts', {})
b64 = art.pop('predictions_df_b64', None)
if not art:
result_dict.pop('artifacts', None)
serializable = to_json_serializable(result_dict)
try:
result_json = json.dumps(serializable, ensure_ascii=False)
except Exception:
result_json = json.dumps(serializable, default=str, ensure_ascii=False)
with st.expander("推理结果", True):
if b64:
try:
gz_bytes = base64.b64decode(b64)
csv_bytes = gzip.decompress(gz_bytes)
df_pred = pd.read_csv(io.BytesIO(csv_bytes))
st.success("已加载带预测结果的 DataFrame")
st.dataframe(df_pred)
st.download_button(
label="下载带预测结果(predictions.csv)",
data=csv_bytes,
file_name="predictions.csv",
mime="text/csv"
)
except Exception as e:
st.error(f"解码 predictions_df 失败: {e}")
# 兜底:尝试从 records 字段恢复
records = result_dict.get('predictions_df_records')
if records:
try:
df_pred = pd.DataFrame(records)
st.dataframe(df_pred)
except Exception as e2:
st.error(f"从 records 恢复表格失败: {e2}")
except Exception as e:
st.error(f"推断失败:{e}")
st.text(traceback.format_exc())
agent.save_inference_error(traceback.format_exc())
raw = agent.code_generation_for_inference(agent.load_code(), inference_data.head(), auto=True)
code = sanitize_code(raw)
agent.save_inference_code(code)