Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| from pyDOE2 import lhs | |
| from io import BytesIO | |
| from PIL import Image | |
| from sklearn.decomposition import PCA | |
| from scipy.stats.qmc import Sobol, Halton | |
| from time import time, sleep | |
| import tempfile | |
| def plot_pairplot(df, title="Pairplot"): | |
| import seaborn as sns | |
| img_buf = BytesIO() | |
| g = sns.pairplot(df) | |
| g.fig.suptitle(title, y=1.02) | |
| plt.tight_layout() | |
| plt.savefig(img_buf, format='png') | |
| plt.close() | |
| img_buf.seek(0) | |
| return Image.open(img_buf) | |
| def plot_parallel_coords(df, title="Parallel Coordinates"): | |
| from pandas.plotting import parallel_coordinates | |
| img_buf = BytesIO() | |
| df_plot = df.copy() | |
| df_plot['index'] = range(len(df_plot)) | |
| plt.figure(figsize=(8, 5)) | |
| parallel_coordinates(df_plot, class_column='index', color=['#348ABD']) | |
| plt.title(title) | |
| plt.legend([], [], frameon=False) | |
| plt.tight_layout() | |
| plt.savefig(img_buf, format='png') | |
| plt.close() | |
| img_buf.seek(0) | |
| return Image.open(img_buf) | |
| def plot_pca(df, title="PCA"): | |
| img_buf = BytesIO() | |
| X = df.values | |
| pca = PCA(n_components=2) | |
| X_pca = pca.fit_transform(X) | |
| plt.figure(figsize=(6, 6)) | |
| plt.scatter(X_pca[:, 0], X_pca[:, 1], s=60) | |
| plt.xlabel("PCA1") | |
| plt.ylabel("PCA2") | |
| plt.title(title) | |
| plt.tight_layout() | |
| plt.savefig(img_buf, format='png') | |
| plt.close() | |
| img_buf.seek(0) | |
| return Image.open(img_buf) | |
| def plot_scatter_2d(df, title="2D Scatter"): | |
| plt.figure(figsize=(6, 6)) | |
| plt.scatter(df.iloc[:, 0], df.iloc[:, 1], s=60) | |
| plt.xlabel(df.columns[0]) | |
| plt.ylabel(df.columns[1]) | |
| plt.title(title) | |
| plt.tight_layout() | |
| img_buf = BytesIO() | |
| plt.savefig(img_buf, format='png') | |
| plt.close() | |
| img_buf.seek(0) | |
| return Image.open(img_buf) | |
| def plot_scatter_3d(df, title="3D Scatter"): | |
| from mpl_toolkits.mplot3d import Axes3D | |
| plt.figure(figsize=(6, 6)) | |
| ax = plt.subplot(111, projection='3d') | |
| ax.scatter(df.iloc[:, 0], df.iloc[:, 1], df.iloc[:, 2], s=60) | |
| ax.set_xlabel(df.columns[0]) | |
| ax.set_ylabel(df.columns[1]) | |
| ax.set_zlabel(df.columns[2]) | |
| plt.title(title) | |
| plt.tight_layout() | |
| img_buf = BytesIO() | |
| plt.savefig(img_buf, format='png') | |
| plt.close() | |
| img_buf.seek(0) | |
| return Image.open(img_buf) | |
| def is_valid_row(row): | |
| if not isinstance(row, (list, tuple)) or len(row) < 4: | |
| return False | |
| try: | |
| if str(row[0]).strip() == "": | |
| return False | |
| float(row[1]) | |
| float(row[2]) | |
| float(row[3]) | |
| return True | |
| except Exception: | |
| return False | |
| def gen_design(design_type, n_params, n_samples, param_lows, param_highs, param_steps, seed): | |
| if seed is not None and str(seed).strip() != "" and int(seed) != 0: | |
| my_seed = int(seed) | |
| else: | |
| my_seed = None | |
| if design_type == "LHS": | |
| if my_seed is not None: | |
| np.random.seed(my_seed) | |
| design = lhs(n_params, samples=n_samples, criterion='maximin') | |
| elif design_type == "Sobol": | |
| sampler = Sobol(d=n_params, scramble=True, seed=my_seed) | |
| design = sampler.random(n_samples) | |
| elif design_type == "Halton": | |
| sampler = Halton(d=n_params, scramble=True, seed=my_seed) | |
| design = sampler.random(n_samples) | |
| elif design_type == "Uniform": | |
| if my_seed is not None: | |
| np.random.seed(my_seed) | |
| design = np.random.rand(n_samples, n_params) | |
| else: | |
| raise ValueError("Unknown SFD type!") | |
| real_samples = np.zeros_like(design) | |
| for idx, (low, high, step) in enumerate(zip(param_lows, param_highs, param_steps)): | |
| real_samples[:, idx] = design[:, idx] * (high - low) + low | |
| if step > 0: | |
| real_samples[:, idx] = np.round((real_samples[:, idx] - low) / step) * step + low | |
| else: | |
| decimals = str(step)[::-1].find('.') | |
| real_samples[:, idx] = np.round(real_samples[:, idx], decimals) | |
| real_samples[:, idx] = np.clip(real_samples[:, idx], low, high) | |
| return pd.DataFrame(real_samples) | |
| def auto_plot(df, design_type): | |
| if df is None or df.shape[0] == 0: | |
| return None | |
| n_params = df.shape[1] | |
| title = f"{design_type} 設計" | |
| if n_params == 2: | |
| return plot_scatter_2d(df, title) | |
| elif n_params == 3: | |
| return plot_scatter_3d(df, title) | |
| elif 4 <= n_params <= 8: | |
| return plot_pairplot(df, title) | |
| elif 9 <= n_params <= 15: | |
| return plot_parallel_coords(df, title) | |
| else: | |
| return plot_pca(df, title) | |
| def compare_all_designs(param_table, n_samples, seed, prog=gr.Progress()): | |
| all_types = ["LHS", "Sobol", "Halton", "Uniform"] | |
| if isinstance(param_table, pd.DataFrame): | |
| param_table = param_table.values.tolist() | |
| param_names, param_lows, param_highs, param_steps = [], [], [], [] | |
| for row in param_table: | |
| if not is_valid_row(row): | |
| continue | |
| try: | |
| param_names.append(str(row[0]).strip()) | |
| param_lows.append(float(row[1])) | |
| param_highs.append(float(row[2])) | |
| param_steps.append(float(row[3])) | |
| except Exception: | |
| continue | |
| n_params = len(param_names) | |
| if n_params == 0: | |
| msg = pd.DataFrame({"提醒": ["請正確輸入至少一列參數(含名稱/最小/最大/間隔)"]}) | |
| # 產生四個臨時空白csv路徑 | |
| tmpcsvs = [] | |
| for des in all_types: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=f'_{des}_empty.csv', mode='w', encoding='utf-8') as tmpfile: | |
| pd.DataFrame().to_csv(tmpfile) | |
| tmpcsvs.append(tmpfile.name) | |
| empty_time = "" | |
| empty_bar = 0.0 | |
| return msg, msg, msg, msg, \ | |
| None, None, None, None, \ | |
| tmpcsvs[0], tmpcsvs[1], tmpcsvs[2], tmpcsvs[3], \ | |
| empty_time, empty_time, empty_time, empty_time, \ | |
| empty_bar, empty_bar, empty_bar, empty_bar | |
| # ----------- 實際產生設計 ---------- | |
| dfs = [] | |
| imgs = [] | |
| csvs = [] | |
| times = [] | |
| bars = [] | |
| for i, des in enumerate(all_types): | |
| t0 = time() | |
| prog(0.0, desc=f"正在計算 {des} ...") | |
| sleep(0.15) | |
| prog(0.5, desc=f"正在計算 {des} ...") | |
| df = gen_design(des, n_params, int(n_samples), param_lows, param_highs, param_steps, seed) | |
| df.columns = param_names | |
| prog(0.85, desc=f"正在繪圖 {des} ...") | |
| sleep(0.1) | |
| img = auto_plot(df, des) | |
| exec_time = time() - t0 | |
| # ====== CSV 儲存為實體檔案 ====== | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=f'_{des}_design.csv', mode='w', encoding='utf-8') as tmpfile: | |
| df.to_csv(tmpfile, index=False) | |
| csvs.append(tmpfile.name) | |
| dfs.append(df) | |
| imgs.append(img) | |
| times.append(f"計算時間: {exec_time:.2f}s") | |
| bars.append(1.0) | |
| return dfs[0], dfs[1], dfs[2], dfs[3], \ | |
| imgs[0], imgs[1], imgs[2], imgs[3], \ | |
| csvs[0], csvs[1], csvs[2], csvs[3], \ | |
| times[0], times[1], times[2], times[3], \ | |
| bars[0], bars[1], bars[2], bars[3] | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # 自動實驗設計法配置工具 | |
| 這個工具會自動產生四種常見的實驗設計法,你可以自行決定該使用哪一種: | |
| | 設計法 | 簡單說明 | 什麼時候用? | | |
| |-------------|----------------------------------------|----------------------------------| | |
| | **LHS** | 把每個參數平均分散,讓實驗覆蓋更完整 | 最推薦:AI訓練、參數探索 | | |
| | **Sobol** | 類似 LHS,但分布更均勻,適合做大量實驗 | 要仔細找出各參數影響時 | | |
| | **Halton** | 跟 Sobol 很像,數量少時也很平均 | 想要分布均勻,但實驗組數不多 | | |
| | **Uniform** | 完全亂數,可能有空隙、不太平均 | 只是初步測試、不在意分布均不均 | | |
| ## **建議**: | |
| > - 若不確定設計法,LHS 通常為機器學習/AI建模較穩定選擇 | |
| > - Sobol/ Halton 適合追求高覆蓋性、對全域敏感度要求高的設計 | |
| > - Uniform 僅建議少量初步探索,正式建模請選用上面三種之一 | |
| ## **懶人選擇**: | |
| > - 只要不知道怎麼選,「**LHS**」最安全! | |
| --- | |
| """) | |
| param_table = gr.Dataframe( | |
| headers=["名稱", "最小值", "最大值", "間隔(step)"], | |
| datatype=["str", "number", "number", "number"], | |
| row_count=(3, "dynamic"), | |
| col_count=(4, "fixed"), | |
| value=[ | |
| ["A", 50, 100, 25], | |
| ["B", 10, 30, 2], | |
| ["C", 100, 1000, 250], | |
| ], | |
| label="參數設定(可新增行,間隔=1為整數)" | |
| ) | |
| with gr.Row(): | |
| n_samples = gr.Number(label="組數", value=8, precision=0) | |
| seed = gr.Number(label="亂數種子(留空或0為隨機)", value=42, precision=0) | |
| btn = gr.Button("自動產生與比較所有設計圖") | |
| with gr.Tab("LHS"): | |
| df1 = gr.Dataframe(label="LHS 設計點") | |
| img1 = gr.Image(type="pil", label="LHS 分布圖") | |
| csv1 = gr.File(label="CSV下載") | |
| time1 = gr.Markdown() | |
| bar1 = gr.Slider(label="進度", minimum=0, maximum=1, step=0.01, value=0) | |
| with gr.Tab("Sobol"): | |
| df2 = gr.Dataframe(label="Sobol 設計點") | |
| img2 = gr.Image(type="pil", label="Sobol 分布圖") | |
| csv2 = gr.File(label="CSV下載") | |
| time2 = gr.Markdown() | |
| bar2 = gr.Slider(label="進度", minimum=0, maximum=1, step=0.01, value=0) | |
| with gr.Tab("Halton"): | |
| df3 = gr.Dataframe(label="Halton 設計點") | |
| img3 = gr.Image(type="pil", label="Halton 分布圖") | |
| csv3 = gr.File(label="CSV下載") | |
| time3 = gr.Markdown() | |
| bar3 = gr.Slider(label="進度", minimum=0, maximum=1, step=0.01, value=0) | |
| with gr.Tab("Uniform"): | |
| df4 = gr.Dataframe(label="Uniform 設計點") | |
| img4 = gr.Image(type="pil", label="Uniform 分布圖") | |
| csv4 = gr.File(label="CSV下載") | |
| time4 = gr.Markdown() | |
| bar4 = gr.Slider(label="進度", minimum=0, maximum=1, step=0.01, value=0) | |
| btn.click( | |
| compare_all_designs, | |
| inputs=[param_table, n_samples, seed], | |
| outputs=[ | |
| df1, df2, df3, df4, | |
| img1, img2, img3, img4, | |
| csv1, csv2, csv3, csv4, | |
| time1, time2, time3, time4, | |
| bar1, bar2, bar3, bar4, | |
| ] | |
| ) | |
| demo.launch() | |