resNet0 / app.py
suzakudry's picture
Update app.py
ed3da8d verified
import gradio as gr
import pandas as pd
from pathlib import Path
import os
import matplotlib.pyplot as plt
# 导入所有模块中的核心函数和配置
from utils import load_images_from_folder_for_import, save_data_to_data_dir, \
load_training_data_for_management, \
add_new_image_entry, delete_image_entry, update_data_from_management_dataframe, \
_update_managed_display_from_index, _select_managed_image_for_edit, _navigate_managed_image, \
_process_managed_single_score_edit
from train_predict import TrainingAndPredictionEngine
from config import DEFAULT_EPOCHS, DEFAULT_LR, MODEL_SAVE_BASE_PATH, \
DEFAULT_DROPOUT_RATE, DEFAULT_WEIGHT_DECAY, DEFAULT_PCA_VARIANCE_RATIO, \
DEFAULT_OPTIMIZER, DEFAULT_LR_SCHEDULER, DEFAULT_SCHEDULER_PATIENCE, \
DEFAULT_SCHEDULER_FACTOR, DEFAULT_SCHEDULER_T_MAX, DEFAULT_LOSS_FUNCTION, \
DEFAULT_EARLY_STOPPING_PATIENCE, DEFAULT_BATCH_SIZE, DEFAULT_DATA_AUGMENTATION, DATA_DIR
# --- UI辅助函数 (原始数据导入 Tab 专用) ---
def _update_main_display_from_index(index, all_image_paths, all_scores):
if not all_image_paths or not (0 <= index < len(all_image_paths)):
return "", None, None
preview_path = all_image_paths[index]
score = all_scores[index]
image_name = Path(preview_path).name
return image_name, preview_path, score
def _select_image_for_edit(evt: gr.SelectData, current_data_state):
all_image_paths, all_scores, _ = current_data_state
if all_image_paths and 0 <= evt.index < len(all_image_paths):
current_data_state = (all_image_paths, all_scores, evt.index)
selected_name, selected_path, selected_score = _update_main_display_from_index(evt.index, all_image_paths,
all_scores)
return selected_name, selected_path, selected_score, current_data_state
return "", None, None, current_data_state
def _navigate_image(direction, current_data_state):
all_image_paths, all_scores, current_index = current_data_state
if not all_image_paths:
return "", None, None, current_data_state
num_images = len(all_image_paths)
if num_images == 0:
return "", None, None, current_data_state
new_index = current_index + direction
if new_index < 0:
new_index = num_images - 1
elif new_index >= num_images:
new_index = 0
current_data_state = (all_image_paths, all_scores, new_index)
selected_name, selected_path, selected_score = _update_main_display_from_index(new_index, all_image_paths,
all_scores)
return selected_name, selected_path, selected_score, current_data_state
def _process_single_score_edit(entered_score, current_data_state):
all_image_paths, all_scores, selected_index = current_data_state
if entered_score is not None:
final_score = max(0, min(100, round(entered_score)))
else:
final_score = 50
if selected_index != -1 and 0 <= selected_index < len(all_scores):
all_scores[selected_index] = final_score
current_data_state = (all_image_paths, all_scores, selected_index)
dataframe_data = pd.DataFrame(
[[Path(img_path_str).name, score] for img_path_str, score in zip(all_image_paths, all_scores)],
columns=["文件名", "分数"])
return dataframe_data, current_data_state
dataframe_data_current = pd.DataFrame()
if all_image_paths and all_scores:
dataframe_data_current = pd.DataFrame([[Path(p).name, s] for p, s in zip(all_image_paths, all_scores)],
columns=["文件名", "分数"])
return dataframe_data_current, current_data_state
def _create_static_error_plot(message):
fig, ax = plt.subplots(figsize=(8, 4))
ax.text(0.5, 0.5, message,
horizontalalignment='center', verticalalignment='center',
transform=ax.transAxes, fontsize=12, color='red')
ax.set_title("训练失败")
ax.axis('off')
plt.tight_layout()
# 返回两个图表,因为 train_model 现在返回两个
# 创建一个空的MSE图
fig_mse, ax_mse = plt.subplots(figsize=(8, 4))
ax_mse.text(0.5, 0.5, "MSE图表:训练失败", horizontalalignment='center', verticalalignment='center',
transform=ax_mse.transAxes, fontsize=12, color='red')
ax_mse.axis('off')
plt.tight_layout(fig_mse)
return fig, fig_mse
# --- 全局状态和初始化 ---
training_engine = TrainingAndPredictionEngine()
# 可选的基础CNN模型列表
BASE_CNN_MODELS = ["resnet50", "resnet18", "vgg16", "densenet121", "efficientnet_b0", "inception_v3"]
OPTIMIZERS = ["Adam", "SGD", "AdamW"]
LR_SCHEDULERS = ["None", "ReduceLROnPlateau", "CosineAnnealingLR"]
LOSS_FUNCTIONS = ["MSELoss", "L1Loss", "SmoothL1Loss"]
# --- Gradio UI界面定义 ---
with gr.Blocks(title="图片分数编辑与模型训练") as demo:
current_import_data_state = gr.State(([], [], -1))
current_managed_data_state = gr.State(([], [], -1))
gr.Markdown("## 图片分数编辑与模型训练系统")
gr.Markdown("### 使用流程:")
gr.Markdown(
"1. **原始数据导入**: 从外部文件夹加载图片,解析文件名中的分数进行初始标注,并可预览、编辑。点击“保存数据”将图片复制到 `data` 目录并生成 `scores.txt`。")
gr.Markdown(
"2. **训练数据管理**: 浏览、编辑、添加、删除已保存到 `data` 目录的训练图片和分数。此处的更改会立即同步到文件。")
gr.Markdown("3. **训练模型**: 选择模型类型和基础CNN模型,调整参数,开始训练。")
gr.Markdown("4. **预测图片**: 上传图片,选择训练好的模型进行预测。")
with gr.Tab("原始数据导入"):
gr.Markdown("### 从外部文件夹导入图片并进行初始分数编辑")
with gr.Row():
folder_input = gr.Textbox(label="输入图片文件夹路径", placeholder="请输入绝对路径,例如 D:\\images",
value="")
load_btn = gr.Button("加载图片", variant="primary", scale=0)
import_status_text = gr.Textbox(label="操作状态", interactive=False, max_lines=2)
with gr.Column():
gallery = gr.Gallery(
label="图片列表", show_label=True, elem_id="gallery", height=600,
preview=True, columns=5, rows=2, object_fit="contain"
)
with gr.Group():
gr.Markdown("#### 当前图片信息与分数编辑")
with gr.Row():
current_image_name_display = gr.Textbox(
label="当前图片名称", interactive=False, show_label=True,
placeholder="图片名称将在此处显示...", scale=3
)
current_score_input = gr.Number(
label="当前分数 (0-100)",
scale=1
)
confirm_score_btn = gr.Button("确定", variant="secondary", scale=0)
with gr.Row():
prev_btn = gr.Button("⬅️ 上一张", scale=1)
next_btn = gr.Button("下一张 ➡️", scale=1)
current_image_preview = gr.Image(
label="当前选中图片(预览)", height=300, show_download_button=False,
container=False, interactive=False
)
import_score_dataframe = gr.Dataframe(
headers=["文件名", "分数"], datatype=["str", "number"], col_count=(2, "fixed"),
interactive=True, label="所有图片分数表格 (可直接编辑)", value=[]
)
save_import_data_btn = gr.Button("保存数据到训练目录", variant="secondary")
with gr.Tab("训练数据管理"):
gr.Markdown("### 管理已保存的训练图片和分数")
gr.Markdown(f"当前训练数据目录: `{DATA_DIR}`")
load_managed_data_btn = gr.Button("加载训练数据", variant="primary")
managed_data_status_text = gr.Textbox(label="管理状态", interactive=False, max_lines=2)
with gr.Column():
managed_gallery = gr.Gallery(
label="已保存图片列表", show_label=True, elem_id="managed_gallery", height=400,
preview=True, columns=5, rows=2, object_fit="contain", value=[]
)
with gr.Group():
gr.Markdown("#### 当前图片信息与分数编辑 (训练数据)")
with gr.Row():
managed_image_name_display = gr.Textbox(
label="当前图片名称", interactive=False, show_label=True,
placeholder="图片名称将在此处显示...", scale=3
)
managed_score_input = gr.Number(
label="当前分数 (0-100)",
scale=1
)
confirm_managed_score_btn = gr.Button("确定", variant="secondary", scale=0)
with gr.Row():
prev_managed_btn = gr.Button("⬅️ 上一张", scale=1)
next_managed_btn = gr.Button("下一张 ➡️", scale=1)
managed_image_preview = gr.Image(
label="当前选中图片(预览)", height=200, show_download_button=False,
container=False, interactive=False
)
managed_score_dataframe = gr.Dataframe(
headers=["文件名", "分数"], datatype=["str", "number"], col_count=(2, "fixed"),
interactive=True, label="所有图片分数表格 (可直接编辑)", value=[]
)
with gr.Accordion("添加/删除图片条目", open=False):
with gr.Row():
new_image_file_input = gr.File(label="上传新图片文件", type="filepath")
new_image_name_input = gr.Textbox(label="或输入文件名 (例如:85.jpg)", placeholder="若上传文件可留空",
scale=1)
new_score_input = gr.Number(label="新分数 (0-100)", value=50)
add_entry_btn = gr.Button("添加条目", variant="secondary", scale=0)
with gr.Row():
delete_filename_input = gr.Textbox(label="要删除的文件名", placeholder="请输入完整文件名,例如 85.jpg")
delete_entry_btn = gr.Button("删除条目", variant="stop", scale=0)
with gr.Tab("训练模型"):
gr.Markdown("## 模型训练")
gr.Markdown(
"1. **重要**: 确保在“原始数据导入”或“训练数据管理”标签页已加载并**保存**了数据。模型将使用 `data` 文件夹中的图片和分数。")
gr.Markdown("2. 选择模型类型和基础CNN模型,调整参数,开始训练。")
gr.Markdown("3. 点击“开始训练”按钮。训练过程中的损失曲线会实时显示。")
with gr.Row():
model_type_selector = gr.Dropdown(
["深度学习", "端到端深度学习", "随机森林", "支持向量回归", "梯度提升回归", "堆叠回归", "K近邻",
"线性回归"],
label="选择模型类型", value="深度学习", interactive=True
)
epochs_input = gr.Slider(1, 100, DEFAULT_EPOCHS, label="训练轮次", step=1)
lr_input = gr.Number(DEFAULT_LR, label="学习率", precision=5)
base_cnn_selector_train = gr.Dropdown(
BASE_CNN_MODELS,
label="选择基础CNN模型 (深度学习模式)",
value="resnet50",
interactive=True,
visible=True
)
with gr.Accordion("高级训练参数", open=False):
with gr.Row():
batch_size_input = gr.Slider(
1, 128, DEFAULT_BATCH_SIZE, label="批量大小 (Batch Size)", step=1, interactive=True,
visible=True
)
dropout_rate_input = gr.Slider(
0.0, 0.9, DEFAULT_DROPOUT_RATE, label="Dropout 比率", step=0.05, interactive=True,
visible=True
)
with gr.Row():
weight_decay_input = gr.Number(
DEFAULT_WEIGHT_DECAY, label="权重衰减 (L2正则化)", precision=7, interactive=True,
visible=True
)
pca_variance_ratio_input = gr.Slider(
0.7, 1.0, DEFAULT_PCA_VARIANCE_RATIO, label="PCA保留方差比例 (仅Sklearn)", step=0.01,
interactive=True,
visible=False
)
with gr.Row():
optimizer_selector = gr.Dropdown(
OPTIMIZERS, label="优化器", value=DEFAULT_OPTIMIZER, interactive=True,
visible=True
)
loss_function_selector = gr.Dropdown(
LOSS_FUNCTIONS, label="损失函数", value=DEFAULT_LOSS_FUNCTION, interactive=True,
visible=True
)
with gr.Row():
lr_scheduler_selector = gr.Dropdown(
LR_SCHEDULERS, label="学习率调度器", value=DEFAULT_LR_SCHEDULER, interactive=True,
visible=True
)
scheduler_patience_input = gr.Slider(
1, 20, DEFAULT_SCHEDULER_PATIENCE, label="调度器耐心 (ReduceLROnPlateau)", step=1, interactive=True,
visible=True
)
with gr.Row():
scheduler_factor_input = gr.Slider(
0.01, 0.5, DEFAULT_SCHEDULER_FACTOR, label="调度器因子 (ReduceLROnPlateau)", step=0.01,
interactive=True,
visible=True
)
scheduler_t_max_input = gr.Slider(
1, 200, DEFAULT_SCHEDULER_T_MAX, label="调度器T_max (CosineAnnealingLR)", step=1, interactive=True,
visible=True
)
with gr.Row():
early_stopping_patience_input = gr.Slider(
1, 20, DEFAULT_EARLY_STOPPING_PATIENCE, label="早停耐心值", step=1, interactive=True,
visible=True
)
with gr.Row():
enable_augmentation_checkbox = gr.Checkbox(
DEFAULT_DATA_AUGMENTATION, label="开启数据增强", interactive=True,
visible=True
)
train_start_btn = gr.Button("开始训练", variant="primary")
train_status_text = gr.Textbox(label="训练状态", interactive=False)
with gr.Row():
loss_plot_output = gr.Plot(label="训练与验证损失", scale=1) # 标签修改
metrics_plot_output = gr.Plot(label="验证MSE与MAE", scale=1) # 标签修改
with gr.Tab("预测图片"):
gr.Markdown("## 图片预测")
gr.Markdown("1. 选择之前训练过的模型类型。")
gr.Markdown("2. 上传您要预测的图片。")
gr.Markdown("3. 点击“预测”按钮获取分数。")
with gr.Row():
predict_model_type_selector = gr.Dropdown(
["深度学习", "端到端深度学习", "随机森林", "支持向量回归", "梯度提升回归", "堆叠回归", "K近邻",
"线性回归"],
label="选择预测模型类型", value="深度学习", interactive=True
)
image_for_predict = gr.Image(type="filepath", label="上传图片进行预测")
predict_btn = gr.Button("预测", variant="primary")
predicted_score_output = gr.Label(label="预测结果")
base_cnn_selector_predict = gr.Dropdown(
BASE_CNN_MODELS,
label="选择基础CNN模型 (深度学习模式)",
value="resnet50",
interactive=True,
visible=True
)
# --- 事件绑定 ---
# 1. 原始数据导入 Tab 的事件
load_btn.click(
fn=load_images_from_folder_for_import,
inputs=[folder_input],
outputs=[
gallery,
import_status_text,
current_image_name_display,
current_image_preview,
current_score_input,
current_import_data_state
]
).success(
fn=lambda state_tuple_from_load: pd.DataFrame(
[[Path(p).name, s] for p, s in zip(state_tuple_from_load[0], state_tuple_from_load[1])],
columns=["文件名", "分数"]
),
inputs=[current_import_data_state],
outputs=import_score_dataframe
)
gallery.select(
fn=_select_image_for_edit,
inputs=[current_import_data_state],
outputs=[current_image_name_display, current_image_preview, current_score_input, current_import_data_state]
)
prev_btn.click(
fn=_navigate_image,
inputs=[gr.State(-1), current_import_data_state],
outputs=[current_image_name_display, current_image_preview, current_score_input, current_import_data_state]
)
next_btn.click(
fn=_navigate_image,
inputs=[gr.State(1), current_import_data_state],
outputs=[current_image_name_display, current_image_preview, current_score_input, current_import_data_state]
)
confirm_score_btn.click(
fn=_process_single_score_edit,
inputs=[current_score_input, current_import_data_state],
outputs=[import_score_dataframe, current_import_data_state]
)
current_score_input.submit(
fn=_process_single_score_edit,
inputs=[current_score_input, current_import_data_state],
outputs=[import_score_dataframe, current_import_data_state]
)
import_score_dataframe.change(
fn=lambda dataframe_data, state: (state[0], [row["分数"] for _, row in dataframe_data.iterrows()], state[2]),
inputs=[import_score_dataframe, current_import_data_state],
outputs=[current_import_data_state]
)
save_import_data_btn.click(
fn=save_data_to_data_dir,
inputs=[current_import_data_state],
outputs=[import_status_text, current_managed_data_state]
).then(
fn=load_training_data_for_management,
inputs=[],
outputs=[managed_score_dataframe, managed_data_status_text, current_managed_data_state, managed_gallery,
gr.State(current_managed_data_state.value[1])]
).then(
fn=lambda all_image_paths, all_scores, index: _update_managed_display_from_index(index, all_image_paths,
all_scores),
inputs=[gr.State(current_managed_data_state.value[0]), gr.State(current_managed_data_state.value[1]),
gr.State(current_managed_data_state.value[2])],
outputs=[managed_image_name_display, managed_image_preview, managed_score_input, delete_filename_input]
)
# 2. 训练数据管理 Tab 的事件
load_managed_data_btn.click(
fn=load_training_data_for_management,
inputs=[],
outputs=[managed_score_dataframe, managed_data_status_text, current_managed_data_state, managed_gallery,
gr.State(current_managed_data_state.value[1])]
).then(
fn=lambda all_image_paths, all_scores, index: _update_managed_display_from_index(index, all_image_paths,
all_scores),
inputs=[gr.State(current_managed_data_state.value[0]), gr.State(current_managed_data_state.value[1]),
gr.State(current_managed_data_state.value[2])],
outputs=[managed_image_name_display, managed_image_preview, managed_score_input, delete_filename_input]
)
managed_gallery.select(
fn=_select_managed_image_for_edit,
inputs=[current_managed_data_state],
outputs=[managed_image_name_display, managed_image_preview, managed_score_input, delete_filename_input,
current_managed_data_state]
)
prev_managed_btn.click(
fn=_navigate_managed_image,
inputs=[gr.State(-1), current_managed_data_state],
outputs=[managed_image_name_display, managed_image_preview, managed_score_input, delete_filename_input,
current_managed_data_state]
)
next_managed_btn.click(
fn=_navigate_managed_image,
inputs=[gr.State(1), current_managed_data_state],
outputs=[managed_image_name_display, managed_image_preview, managed_score_input, delete_filename_input,
current_managed_data_state]
)
confirm_managed_score_btn.click(
fn=_process_managed_single_score_edit,
inputs=[managed_score_input, current_managed_data_state],
outputs=[managed_data_status_text, managed_score_dataframe, managed_gallery,
gr.State(current_managed_data_state.value[1]), current_managed_data_state]
)
managed_score_input.submit(
fn=_process_managed_single_score_edit,
inputs=[managed_score_input, current_managed_data_state],
outputs=[managed_data_status_text, managed_score_dataframe, managed_gallery,
gr.State(current_managed_data_state.value[1]), current_managed_data_state]
)
managed_score_dataframe.change(
fn=update_data_from_management_dataframe,
inputs=[managed_score_dataframe, current_managed_data_state],
outputs=[managed_gallery, gr.State(current_managed_data_state.value[1]), current_managed_data_state]
).success(
fn=lambda: "表格更新并保存成功!",
inputs=[],
outputs=[managed_data_status_text]
)
add_entry_btn.click(
fn=add_new_image_entry,
inputs=[new_image_file_input, new_image_name_input, new_score_input, current_managed_data_state],
outputs=[managed_data_status_text, managed_score_dataframe, new_image_file_input, managed_gallery,
gr.State(current_managed_data_state.value[1]), current_managed_data_state]
).then(
fn=lambda all_image_paths, all_scores, index: _update_managed_display_from_index(index, all_image_paths,
all_scores),
inputs=[gr.State(current_managed_data_state.value[0]), gr.State(current_managed_data_state.value[1]),
gr.State(current_managed_data_state.value[2])],
outputs=[managed_image_name_display, managed_image_preview, managed_score_input, delete_filename_input]
)
delete_entry_btn.click(
fn=delete_image_entry,
inputs=[delete_filename_input, current_managed_data_state],
outputs=[managed_data_status_text, managed_score_dataframe, delete_filename_input, managed_gallery,
gr.State(current_managed_data_state.value[1]), current_managed_data_state]
).then(
fn=lambda all_image_paths, all_scores, index: _update_managed_display_from_index(index, all_image_paths,
all_scores),
inputs=[gr.State(current_managed_data_state.value[0]), gr.State(current_managed_data_state.value[1]),
gr.State(current_managed_data_state.value[2])],
outputs=[managed_image_name_display, managed_image_preview, managed_score_input, delete_filename_input]
)
# 3. 训练模型 Tab 的事件
def update_advanced_params_visibility(model_type):
is_pytorch_model = model_type in ["深度学习", "端到端深度学习"]
is_sklearn_model = model_type in ["随机森林", "支持向量回归", "梯度提升回归", "堆叠回归", "K近邻", "线性回归"]
return {
base_cnn_selector_train: gr.update(visible=is_pytorch_model or is_sklearn_model),
lr_input: gr.update(visible=is_pytorch_model),
epochs_input: gr.update(visible=is_pytorch_model),
batch_size_input: gr.update(visible=is_pytorch_model),
dropout_rate_input: gr.update(visible=is_pytorch_model),
weight_decay_input: gr.update(visible=is_pytorch_model),
optimizer_selector: gr.update(visible=is_pytorch_model),
loss_function_selector: gr.update(visible=is_pytorch_model),
lr_scheduler_selector: gr.update(visible=is_pytorch_model),
scheduler_patience_input: gr.update(
visible=is_pytorch_model and lr_scheduler_selector.value == "ReduceLROnPlateau"),
scheduler_factor_input: gr.update(
visible=is_pytorch_model and lr_scheduler_selector.value == "ReduceLROnPlateau"),
scheduler_t_max_input: gr.update(
visible=is_pytorch_model and lr_scheduler_selector.value == "CosineAnnealingLR"),
early_stopping_patience_input: gr.update(visible=is_pytorch_model),
enable_augmentation_checkbox: gr.update(visible=is_pytorch_model or is_sklearn_model),
pca_variance_ratio_input: gr.update(visible=is_sklearn_model)
}
model_type_selector.change(
fn=update_advanced_params_visibility,
inputs=[model_type_selector],
outputs=[
base_cnn_selector_train, lr_input, epochs_input,
batch_size_input, dropout_rate_input, weight_decay_input,
optimizer_selector, loss_function_selector, lr_scheduler_selector,
scheduler_patience_input, scheduler_factor_input, scheduler_t_max_input,
early_stopping_patience_input, enable_augmentation_checkbox, pca_variance_ratio_input
]
)
lr_scheduler_selector.change(
fn=lambda scheduler_name: {
scheduler_patience_input: gr.update(visible=scheduler_name == "ReduceLROnPlateau"),
scheduler_factor_input: gr.update(visible=scheduler_name == "ReduceLROnPlateau"),
scheduler_t_max_input: gr.update(visible=scheduler_name == "CosineAnnealingLR")
},
inputs=[lr_scheduler_selector],
outputs=[
scheduler_patience_input, scheduler_factor_input, scheduler_t_max_input
]
)
data_prep_success_flag_state = gr.State(False)
train_start_btn.click(
fn=training_engine.switch_model_type,
inputs=[
model_type_selector,
base_cnn_selector_train,
dropout_rate_input,
weight_decay_input,
pca_variance_ratio_input,
optimizer_selector,
lr_scheduler_selector,
scheduler_patience_input,
scheduler_factor_input,
scheduler_t_max_input,
loss_function_selector,
early_stopping_patience_input,
batch_size_input,
enable_augmentation_checkbox
],
outputs=[train_status_text]
).then(
fn=training_engine.prepare_data_for_training,
inputs=[],
outputs=[data_prep_success_flag_state, train_status_text]
).then(
fn=lambda success_flag, epochs, lr: (
training_engine.train_model(epochs, lr)
if success_flag
else _create_static_error_plot("数据准备失败或无数据,无法训练")
),
inputs=[data_prep_success_flag_state, epochs_input, lr_input],
outputs=[loss_plot_output, metrics_plot_output]
)
# 4. 预测图片 Tab 的事件
def update_predict_cnn_visibility(model_type):
is_cnn_based = model_type in ["深度学习", "端到端深度学习", "随机森林", "支持向量回归", "梯度提升回归",
"堆叠回归", "K近邻", "线性回归"]
return gr.update(visible=is_cnn_based)
predict_model_type_selector.change(
fn=update_predict_cnn_visibility,
inputs=[predict_model_type_selector],
outputs=[base_cnn_selector_predict]
)
predict_btn.click(
fn=training_engine.predict_score,
inputs=[image_for_predict, predict_model_type_selector, base_cnn_selector_predict],
outputs=predicted_score_output
)
# 启动 Gradio 应用
if __name__ == "__main__":
Path(DATA_DIR).mkdir(exist_ok=True, parents=True)
Path(MODEL_SAVE_BASE_PATH).parent.mkdir(exist_ok=True, parents=True)
demo.launch(
share=True,
server_port=7860,
show_error=True,
debug=True
)