# utils.py import os import re import shutil from pathlib import Path from PIL import Image import torch from torch.utils.data import Dataset, random_split from torchvision import transforms import gradio as gr from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score import numpy as np # 从 config.py 导入常量 from config import DATA_DIR, SCORE_FILE_NAME # --- 数据集类 --- class ScoreDataset(Dataset): def __init__(self, image_paths, scores, transform=None): self.image_paths = image_paths self.scores = scores self.transform = transform if not self.scores: self.min_label = 0.0 self.max_label = 100.0 else: self.min_label = float(min(self.scores)) self.max_label = float(max(self.scores)) if self.max_label == self.min_label: self.max_label += 1.0 def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] score = self.scores[idx] try: image = Image.open(img_path).convert('RGB') if self.transform: image = self.transform(image) score_normalized = (score - self.min_label) / (self.max_label - self.min_label + 1e-7) return image, torch.tensor(score_normalized, dtype=torch.float32) except Exception as e: print(f"Error loading or processing image {img_path}: {e}") return torch.zeros(3, 224, 224), torch.tensor(0.5, dtype=torch.float32) # --- 图片转换函数 (增加更多增强选项,并根据 enable_augmentation 控制) --- def get_transforms(train=True, image_size=224, enable_augmentation=True, random_rotation_degrees=15, random_affine_degrees=15, random_affine_translate=(0.1, 0.1), random_affine_scale=(0.9, 1.1), grayscale_p=0.1, random_erasing_p=0.1 ): transform_list = [] if train and enable_augmentation: transform_list.extend([ transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.RandomRotation(random_rotation_degrees), transforms.RandomAffine( degrees=random_affine_degrees, translate=random_affine_translate, scale=random_affine_scale, shear=0), transforms.RandomGrayscale(p=grayscale_p), transforms.RandomErasing(p=random_erasing_p) ]) else: transform_list.extend([ transforms.Resize(image_size), transforms.CenterCrop(image_size), ]) transform_list.extend([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) return transforms.Compose(transform_list) # --- UI交互辅助函数 (原始数据导入Tab) --- def load_images_from_folder_for_import(folder_path): all_image_paths = [] all_scores = [] if not os.path.isdir(folder_path): return [], "错误: 文件夹不存在。", "", None, None, ([], [], -1) image_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] if not image_files: return [], "警告: 文件夹中没有找到图片。", "", None, None, ([], [], -1) for img_name in sorted(image_files): full_path = os.path.join(folder_path, img_name) try: match = re.match(r'(\d+)', img_name) if match: score = int(match.group(1)) score = max(0, min(100, score)) else: score = 50 all_image_paths.append(full_path) all_scores.append(score) except Exception as e: print(f"处理文件 {img_name} 失败: {e}") continue if not all_image_paths: return [], "警告: 没有找到有效图片或解析分数失败。", "", None, None, ([], [], -1) initial_index = 0 initial_image_name = Path(all_image_paths[initial_index]).name initial_preview_path = all_image_paths[initial_index] initial_score = all_scores[initial_index] current_data_state_tuple = (all_image_paths, all_scores, initial_index) return all_image_paths, "图片加载完成。", initial_image_name, initial_preview_path, initial_score, current_data_state_tuple # --- UI交互辅助函数 (训练数据管理Tab) --- def _update_managed_display_from_index(index, all_image_paths, all_scores): if not all_image_paths or not (0 <= index < len(all_image_paths)): return "", None, None, "" preview_path = all_image_paths[index] score = all_scores[index] image_name = Path(preview_path).name return image_name, preview_path, score, image_name def _select_managed_image_for_edit(evt: gr.SelectData, current_data_state): all_image_paths, all_scores, _ = current_data_state if all_image_paths and 0 <= evt.index < len(all_image_paths): current_data_state = (all_image_paths, all_scores, evt.index) selected_name, selected_path, selected_score, delete_filename = _update_managed_display_from_index(evt.index, all_image_paths, all_scores) return selected_name, selected_path, selected_score, delete_filename, current_data_state return "", None, None, "", current_data_state def _navigate_managed_image(direction, current_data_state): all_image_paths, all_scores, current_index = current_data_state if not all_image_paths: return "", None, None, "", current_data_state num_images = len(all_image_paths) if num_images == 0: return "", None, None, "", current_data_state new_index = current_index + direction if new_index < 0: new_index = num_images - 1 elif new_index >= num_images: new_index = 0 current_data_state = (all_image_paths, all_scores, new_index) selected_name, selected_path, selected_score, delete_filename = _update_managed_display_from_index(new_index, all_image_paths, all_scores) return selected_name, selected_path, selected_score, delete_filename, current_data_state def _process_managed_single_score_edit(entered_score, current_data_state): all_image_paths, all_scores, selected_index = current_data_state if entered_score is not None: final_score = max(0, min(100, round(entered_score))) else: final_score = 50 if selected_index != -1 and 0 <= selected_index < len(all_scores): all_scores[selected_index] = final_score current_data_state = (all_image_paths, all_scores, selected_index) status_msg, new_managed_state = save_data_to_data_dir(current_data_state) dataframe_data = pd.DataFrame( [[Path(img_path_str).name, score] for img_path_str, score in zip(new_managed_state[0], new_managed_state[1])], columns=["文件名", "分数"]) return status_msg, dataframe_data, new_managed_state[0], new_managed_state[1], new_managed_state dataframe_data_current = pd.DataFrame() if all_image_paths and all_scores: dataframe_data_current = pd.DataFrame([[Path(p).name, s] for p, s in zip(all_image_paths, all_scores)], columns=["文件名", "分数"]) return "无图片选中或数据无效。", dataframe_data_current, [], [], current_data_state def load_training_data_for_management(): image_paths = [] scores = [] score_file_path = Path(DATA_DIR) / SCORE_FILE_NAME if not score_file_path.exists(): return [], "数据文件不存在,请先导入并保存数据。", ([], [], -1), [], [] try: with open(score_file_path, 'r') as f: for line in f: filename, score_str = line.strip().split(',') full_image_path = Path(DATA_DIR) / filename if full_image_path.exists(): image_paths.append(str(full_image_path)) scores.append(float(score_str)) else: print(f"警告: 图像文件 {full_image_path} 不存在,已跳过。") except Exception as e: return [], f"错误: 读取分数文件 {score_file_path} 失败: {e}", ([], [], -1), [], [] if not image_paths: return [], "没有找到有效的训练数据。", ([], [], -1), [], [] df_data = [[Path(p).name, s] for p, s in zip(image_paths, scores)] status_msg = f"加载完成,共 {len(image_paths)} 张图片。" return df_data, status_msg, (image_paths, scores, 0), image_paths, scores def save_data_to_data_dir(current_data_state): all_image_paths, all_scores, _ = current_data_state if not all_image_paths: return "没有要保存的图片和分数。", ([], [], -1) Path(DATA_DIR).mkdir(parents=True, exist_ok=True) score_file_path = Path(DATA_DIR) / SCORE_FILE_NAME updated_image_paths_in_data_dir = [] try: with open(score_file_path, 'w') as f: for i, original_img_path_str in enumerate(all_image_paths): original_filename = Path(original_img_path_str).name dest_path = Path(DATA_DIR) / original_filename score = all_scores[i] if not dest_path.exists() or not Path(original_img_path_str).samefile(dest_path): try: shutil.copy2(original_img_path_str, dest_path) print(f"复制文件: {original_img_path_str} -> {dest_path}") except shutil.SameFileError: pass f.write(f"{original_filename},{score}\n") updated_image_paths_in_data_dir.append(str(dest_path)) return f"数据已保存到 {DATA_DIR},共 {len(all_image_paths)} 条。", ( updated_image_paths_in_data_dir, all_scores, -1) except Exception as e: return f"保存数据失败: {e}", current_data_state def add_new_image_entry(new_image_file, new_image_name, new_score, current_data_state): all_image_paths, all_scores, _ = current_data_state if new_image_file is None and not new_image_name: return "请提供图片文件或图片名称。", None, None, [], [], current_data_state source_path_str = None if new_image_file: if isinstance(new_image_file, dict) and 'name' in new_image_file: source_path_str = new_image_file['name'] elif isinstance(new_image_file, str): source_path_str = new_image_file if source_path_str: source_path = Path(source_path_str) else: return "无效的图片文件上传。", None, None, [], [], current_data_state final_image_name = new_image_name.strip() if new_image_name and new_image_name.strip() else None if source_path_str: if not final_image_name: final_image_name = source_path.name dest_path = Path(DATA_DIR) / final_image_name try: shutil.copy2(source_path, dest_path) new_image_path = str(dest_path) except Exception as e: return f"复制图片文件失败: {e}", None, None, [], [], current_data_state elif final_image_name: new_image_path = str(Path(DATA_DIR) / final_image_name) if not Path(new_image_path).exists(): print(f"警告: 图片文件 {final_image_name} 在 data 目录中不存在,但已添加记录。") else: return "请提供图片文件或图片名称。", None, None, [], [], current_data_state existing_filenames = {Path(p).name for p in all_image_paths} if final_image_name in existing_filenames: return f"错误: 图片 {final_image_name} 已存在。", None, None, [], [], current_data_state score_to_add = max(0, min(100, round(new_score))) if new_score is not None else 50 all_image_paths.append(new_image_path) all_scores.append(score_to_add) updated_df_data = [[Path(p).name, s] for p, s in zip(all_image_paths, all_scores)] status_msg, new_managed_state = save_data_to_data_dir((all_image_paths, all_scores, -1)) return status_msg, updated_df_data, None, new_managed_state[0], new_managed_state[1], new_managed_state def delete_image_entry(selected_filename, current_data_state): all_image_paths, all_scores, _ = current_data_state if not selected_filename: return "请选择要删除的图片。", None, None, [], [], current_data_state idx_to_delete = -1 for i, p in enumerate(all_image_paths): if Path(p).name == selected_filename: idx_to_delete = i break if idx_to_delete == -1: return f"错误: 未找到图片 {selected_filename}。", None, None, [], [], current_data_state file_to_delete_path = Path(all_image_paths[idx_to_delete]) try: if file_to_delete_path.exists(): os.remove(file_to_delete_path) print(f"物理删除文件: {file_to_delete_path}") except Exception as e: print(f"删除文件 {selected_filename} 失败: {e}", e) return f"删除文件 {selected_filename} 失败: {e}", None, None, [], [], current_data_state del all_image_paths[idx_to_delete] del all_scores[idx_to_delete] updated_df_data = [[Path(p).name, s] for p, s in zip(all_image_paths, all_scores)] status_msg, new_managed_state = save_data_to_data_dir((all_image_paths, all_scores, -1)) return status_msg, updated_df_data, None, new_managed_state[0], new_managed_state[1], new_managed_state def update_data_from_management_dataframe(dataframe_data, current_data_state): if dataframe_data.empty: new_data_state = ([], [], -1) save_data_to_data_dir(new_data_state) return new_data_state[0], new_data_state[1], new_data_state updated_image_paths_in_data_dir = [] updated_all_scores = [] for _, row in dataframe_data.iterrows(): filename = row["文件名"] score = max(0, min(100, round(row["分数"]))) full_path = str(Path(DATA_DIR) / filename) if Path(full_path).exists(): updated_image_paths_in_data_dir.append(full_path) updated_all_scores.append(score) else: print(f"警告: 文件 {filename} 在磁盘上不存在,已跳过此条更新。") new_data_state = (updated_image_paths_in_data_dir, updated_all_scores, -1) status_msg, final_saved_state = save_data_to_data_dir(new_data_state) print(status_msg) return final_saved_state[0], final_saved_state[1], final_saved_state def get_image_size_by_model_name(model_name): if model_name == "inception_v3": return 299 return 224 # --- 评估指标函数 (关键修改:接受 min_label 和 max_label 进行反归一化) --- def calculate_metrics(y_true_normalized, y_pred_normalized, min_label, max_label): """ 计算并返回MSE, MAE, R2 Score。 Args: y_true_normalized (np.ndarray): 真实标签,0-1归一化。 y_pred_normalized (np.ndarray): 预测值,0-1归一化。 min_label (float): 原始标签的最小值。 max_label (float): 原始标签的最大值。 """ if not isinstance(y_true_normalized, np.ndarray): y_true_normalized = np.array(y_true_normalized) if not isinstance(y_pred_normalized, np.ndarray): y_pred_normalized = np.array(y_pred_normalized) # 确保预测值在0-1范围内(如果模型输出没有Sigmoid,可能会超出) y_pred_normalized = np.clip(y_pred_normalized, 0, 1) # 核心:将预测值和真实值反归一化到原始分数范围 y_true_original = y_true_normalized * (max_label - min_label) + min_label y_pred_original = y_pred_normalized * (max_label - min_label) + min_label mse = mean_squared_error(y_true_original, y_pred_original) mae = mean_absolute_error(y_true_original, y_pred_original) r2 = r2_score(y_true_original, y_pred_original) return mse, mae, r2