File size: 17,112 Bytes
d426cb8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 | # utils.py
import os
import re
import shutil
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset, random_split
from torchvision import transforms
import gradio as gr
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import numpy as np
# 从 config.py 导入常量
from config import DATA_DIR, SCORE_FILE_NAME
# --- 数据集类 ---
class ScoreDataset(Dataset):
def __init__(self, image_paths, scores, transform=None):
self.image_paths = image_paths
self.scores = scores
self.transform = transform
if not self.scores:
self.min_label = 0.0
self.max_label = 100.0
else:
self.min_label = float(min(self.scores))
self.max_label = float(max(self.scores))
if self.max_label == self.min_label:
self.max_label += 1.0
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
score = self.scores[idx]
try:
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
score_normalized = (score - self.min_label) / (self.max_label - self.min_label + 1e-7)
return image, torch.tensor(score_normalized, dtype=torch.float32)
except Exception as e:
print(f"Error loading or processing image {img_path}: {e}")
return torch.zeros(3, 224, 224), torch.tensor(0.5, dtype=torch.float32)
# --- 图片转换函数 (增加更多增强选项,并根据 enable_augmentation 控制) ---
def get_transforms(train=True, image_size=224, enable_augmentation=True,
random_rotation_degrees=15,
random_affine_degrees=15,
random_affine_translate=(0.1, 0.1),
random_affine_scale=(0.9, 1.1),
grayscale_p=0.1,
random_erasing_p=0.1
):
transform_list = []
if train and enable_augmentation:
transform_list.extend([
transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomRotation(random_rotation_degrees),
transforms.RandomAffine(
degrees=random_affine_degrees,
translate=random_affine_translate,
scale=random_affine_scale,
shear=0),
transforms.RandomGrayscale(p=grayscale_p),
transforms.RandomErasing(p=random_erasing_p)
])
else:
transform_list.extend([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
])
transform_list.extend([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
return transforms.Compose(transform_list)
# --- UI交互辅助函数 (原始数据导入Tab) ---
def load_images_from_folder_for_import(folder_path):
all_image_paths = []
all_scores = []
if not os.path.isdir(folder_path):
return [], "错误: 文件夹不存在。", "", None, None, ([], [], -1)
image_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
if not image_files:
return [], "警告: 文件夹中没有找到图片。", "", None, None, ([], [], -1)
for img_name in sorted(image_files):
full_path = os.path.join(folder_path, img_name)
try:
match = re.match(r'(\d+)', img_name)
if match:
score = int(match.group(1))
score = max(0, min(100, score))
else:
score = 50
all_image_paths.append(full_path)
all_scores.append(score)
except Exception as e:
print(f"处理文件 {img_name} 失败: {e}")
continue
if not all_image_paths:
return [], "警告: 没有找到有效图片或解析分数失败。", "", None, None, ([], [], -1)
initial_index = 0
initial_image_name = Path(all_image_paths[initial_index]).name
initial_preview_path = all_image_paths[initial_index]
initial_score = all_scores[initial_index]
current_data_state_tuple = (all_image_paths, all_scores, initial_index)
return all_image_paths, "图片加载完成。", initial_image_name, initial_preview_path, initial_score, current_data_state_tuple
# --- UI交互辅助函数 (训练数据管理Tab) ---
def _update_managed_display_from_index(index, all_image_paths, all_scores):
if not all_image_paths or not (0 <= index < len(all_image_paths)):
return "", None, None, ""
preview_path = all_image_paths[index]
score = all_scores[index]
image_name = Path(preview_path).name
return image_name, preview_path, score, image_name
def _select_managed_image_for_edit(evt: gr.SelectData, current_data_state):
all_image_paths, all_scores, _ = current_data_state
if all_image_paths and 0 <= evt.index < len(all_image_paths):
current_data_state = (all_image_paths, all_scores, evt.index)
selected_name, selected_path, selected_score, delete_filename = _update_managed_display_from_index(evt.index,
all_image_paths,
all_scores)
return selected_name, selected_path, selected_score, delete_filename, current_data_state
return "", None, None, "", current_data_state
def _navigate_managed_image(direction, current_data_state):
all_image_paths, all_scores, current_index = current_data_state
if not all_image_paths:
return "", None, None, "", current_data_state
num_images = len(all_image_paths)
if num_images == 0:
return "", None, None, "", current_data_state
new_index = current_index + direction
if new_index < 0:
new_index = num_images - 1
elif new_index >= num_images:
new_index = 0
current_data_state = (all_image_paths, all_scores, new_index)
selected_name, selected_path, selected_score, delete_filename = _update_managed_display_from_index(new_index,
all_image_paths,
all_scores)
return selected_name, selected_path, selected_score, delete_filename, current_data_state
def _process_managed_single_score_edit(entered_score, current_data_state):
all_image_paths, all_scores, selected_index = current_data_state
if entered_score is not None:
final_score = max(0, min(100, round(entered_score)))
else:
final_score = 50
if selected_index != -1 and 0 <= selected_index < len(all_scores):
all_scores[selected_index] = final_score
current_data_state = (all_image_paths, all_scores, selected_index)
status_msg, new_managed_state = save_data_to_data_dir(current_data_state)
dataframe_data = pd.DataFrame(
[[Path(img_path_str).name, score] for img_path_str, score in
zip(new_managed_state[0], new_managed_state[1])],
columns=["文件名", "分数"])
return status_msg, dataframe_data, new_managed_state[0], new_managed_state[1], new_managed_state
dataframe_data_current = pd.DataFrame()
if all_image_paths and all_scores:
dataframe_data_current = pd.DataFrame([[Path(p).name, s] for p, s in zip(all_image_paths, all_scores)],
columns=["文件名", "分数"])
return "无图片选中或数据无效。", dataframe_data_current, [], [], current_data_state
def load_training_data_for_management():
image_paths = []
scores = []
score_file_path = Path(DATA_DIR) / SCORE_FILE_NAME
if not score_file_path.exists():
return [], "数据文件不存在,请先导入并保存数据。", ([], [], -1), [], []
try:
with open(score_file_path, 'r') as f:
for line in f:
filename, score_str = line.strip().split(',')
full_image_path = Path(DATA_DIR) / filename
if full_image_path.exists():
image_paths.append(str(full_image_path))
scores.append(float(score_str))
else:
print(f"警告: 图像文件 {full_image_path} 不存在,已跳过。")
except Exception as e:
return [], f"错误: 读取分数文件 {score_file_path} 失败: {e}", ([], [], -1), [], []
if not image_paths:
return [], "没有找到有效的训练数据。", ([], [], -1), [], []
df_data = [[Path(p).name, s] for p, s in zip(image_paths, scores)]
status_msg = f"加载完成,共 {len(image_paths)} 张图片。"
return df_data, status_msg, (image_paths, scores, 0), image_paths, scores
def save_data_to_data_dir(current_data_state):
all_image_paths, all_scores, _ = current_data_state
if not all_image_paths:
return "没有要保存的图片和分数。", ([], [], -1)
Path(DATA_DIR).mkdir(parents=True, exist_ok=True)
score_file_path = Path(DATA_DIR) / SCORE_FILE_NAME
updated_image_paths_in_data_dir = []
try:
with open(score_file_path, 'w') as f:
for i, original_img_path_str in enumerate(all_image_paths):
original_filename = Path(original_img_path_str).name
dest_path = Path(DATA_DIR) / original_filename
score = all_scores[i]
if not dest_path.exists() or not Path(original_img_path_str).samefile(dest_path):
try:
shutil.copy2(original_img_path_str, dest_path)
print(f"复制文件: {original_img_path_str} -> {dest_path}")
except shutil.SameFileError:
pass
f.write(f"{original_filename},{score}\n")
updated_image_paths_in_data_dir.append(str(dest_path))
return f"数据已保存到 {DATA_DIR},共 {len(all_image_paths)} 条。", (
updated_image_paths_in_data_dir, all_scores, -1)
except Exception as e:
return f"保存数据失败: {e}", current_data_state
def add_new_image_entry(new_image_file, new_image_name, new_score, current_data_state):
all_image_paths, all_scores, _ = current_data_state
if new_image_file is None and not new_image_name:
return "请提供图片文件或图片名称。", None, None, [], [], current_data_state
source_path_str = None
if new_image_file:
if isinstance(new_image_file, dict) and 'name' in new_image_file:
source_path_str = new_image_file['name']
elif isinstance(new_image_file, str):
source_path_str = new_image_file
if source_path_str:
source_path = Path(source_path_str)
else:
return "无效的图片文件上传。", None, None, [], [], current_data_state
final_image_name = new_image_name.strip() if new_image_name and new_image_name.strip() else None
if source_path_str:
if not final_image_name:
final_image_name = source_path.name
dest_path = Path(DATA_DIR) / final_image_name
try:
shutil.copy2(source_path, dest_path)
new_image_path = str(dest_path)
except Exception as e:
return f"复制图片文件失败: {e}", None, None, [], [], current_data_state
elif final_image_name:
new_image_path = str(Path(DATA_DIR) / final_image_name)
if not Path(new_image_path).exists():
print(f"警告: 图片文件 {final_image_name} 在 data 目录中不存在,但已添加记录。")
else:
return "请提供图片文件或图片名称。", None, None, [], [], current_data_state
existing_filenames = {Path(p).name for p in all_image_paths}
if final_image_name in existing_filenames:
return f"错误: 图片 {final_image_name} 已存在。", None, None, [], [], current_data_state
score_to_add = max(0, min(100, round(new_score))) if new_score is not None else 50
all_image_paths.append(new_image_path)
all_scores.append(score_to_add)
updated_df_data = [[Path(p).name, s] for p, s in zip(all_image_paths, all_scores)]
status_msg, new_managed_state = save_data_to_data_dir((all_image_paths, all_scores, -1))
return status_msg, updated_df_data, None, new_managed_state[0], new_managed_state[1], new_managed_state
def delete_image_entry(selected_filename, current_data_state):
all_image_paths, all_scores, _ = current_data_state
if not selected_filename:
return "请选择要删除的图片。", None, None, [], [], current_data_state
idx_to_delete = -1
for i, p in enumerate(all_image_paths):
if Path(p).name == selected_filename:
idx_to_delete = i
break
if idx_to_delete == -1:
return f"错误: 未找到图片 {selected_filename}。", None, None, [], [], current_data_state
file_to_delete_path = Path(all_image_paths[idx_to_delete])
try:
if file_to_delete_path.exists():
os.remove(file_to_delete_path)
print(f"物理删除文件: {file_to_delete_path}")
except Exception as e:
print(f"删除文件 {selected_filename} 失败: {e}", e)
return f"删除文件 {selected_filename} 失败: {e}", None, None, [], [], current_data_state
del all_image_paths[idx_to_delete]
del all_scores[idx_to_delete]
updated_df_data = [[Path(p).name, s] for p, s in zip(all_image_paths, all_scores)]
status_msg, new_managed_state = save_data_to_data_dir((all_image_paths, all_scores, -1))
return status_msg, updated_df_data, None, new_managed_state[0], new_managed_state[1], new_managed_state
def update_data_from_management_dataframe(dataframe_data, current_data_state):
if dataframe_data.empty:
new_data_state = ([], [], -1)
save_data_to_data_dir(new_data_state)
return new_data_state[0], new_data_state[1], new_data_state
updated_image_paths_in_data_dir = []
updated_all_scores = []
for _, row in dataframe_data.iterrows():
filename = row["文件名"]
score = max(0, min(100, round(row["分数"])))
full_path = str(Path(DATA_DIR) / filename)
if Path(full_path).exists():
updated_image_paths_in_data_dir.append(full_path)
updated_all_scores.append(score)
else:
print(f"警告: 文件 {filename} 在磁盘上不存在,已跳过此条更新。")
new_data_state = (updated_image_paths_in_data_dir, updated_all_scores, -1)
status_msg, final_saved_state = save_data_to_data_dir(new_data_state)
print(status_msg)
return final_saved_state[0], final_saved_state[1], final_saved_state
def get_image_size_by_model_name(model_name):
if model_name == "inception_v3":
return 299
return 224
# --- 评估指标函数 (关键修改:接受 min_label 和 max_label 进行反归一化) ---
def calculate_metrics(y_true_normalized, y_pred_normalized, min_label, max_label):
"""
计算并返回MSE, MAE, R2 Score。
Args:
y_true_normalized (np.ndarray): 真实标签,0-1归一化。
y_pred_normalized (np.ndarray): 预测值,0-1归一化。
min_label (float): 原始标签的最小值。
max_label (float): 原始标签的最大值。
"""
if not isinstance(y_true_normalized, np.ndarray):
y_true_normalized = np.array(y_true_normalized)
if not isinstance(y_pred_normalized, np.ndarray):
y_pred_normalized = np.array(y_pred_normalized)
# 确保预测值在0-1范围内(如果模型输出没有Sigmoid,可能会超出)
y_pred_normalized = np.clip(y_pred_normalized, 0, 1)
# 核心:将预测值和真实值反归一化到原始分数范围
y_true_original = y_true_normalized * (max_label - min_label) + min_label
y_pred_original = y_pred_normalized * (max_label - min_label) + min_label
mse = mean_squared_error(y_true_original, y_pred_original)
mae = mean_absolute_error(y_true_original, y_pred_original)
r2 = r2_score(y_true_original, y_pred_original)
return mse, mae, r2
|