Spaces:
Runtime error
Runtime error
File size: 6,338 Bytes
f7240ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | # image_processor.py
"""图像处理辅助工具 - 专为Gradio Web应用优化(根目录版)"""
from PIL import Image, ImageEnhance
class ImageProcessor:
"""轻量级图像处理工具,专为YouTube缩略图生成优化"""
@staticmethod
def prepare_for_training(image, target_size=512):
"""
为LoRA训练准备单张图像
用于Gradio界面中用户上传的训练图片
"""
if image is None:
return None
# 确保是PIL Image对象
if not isinstance(image, Image.Image):
return None
# 转换为RGB
if image.mode != 'RGB':
image = image.convert('RGB')
# 智能裁剪到正方形
width, height = image.size
if width != height:
# 裁剪到正方形,保持中心
size = min(width, height)
left = (width - size) // 2
top = (height - size) // 2
image = image.crop((left, top, left + size, top + size))
# 调整到目标尺寸
image = image.resize((target_size, target_size), Image.Resampling.LANCZOS)
return image
@staticmethod
def enhance_thumbnail(image, enhance_level="medium"):
"""
增强生成的缩略图效果
enhance_level: "light", "medium", "strong"
"""
if image is None or not isinstance(image, Image.Image):
return image
# 根据增强级别设置参数
if enhance_level == "light":
brightness, contrast, sharpness = 1.05, 1.05, 1.1
elif enhance_level == "medium":
brightness, contrast, sharpness = 1.1, 1.15, 1.2
elif enhance_level == "strong":
brightness, contrast, sharpness = 1.15, 1.25, 1.3
else:
return image
# 应用增强
enhanced = ImageEnhance.Brightness(image).enhance(brightness)
enhanced = ImageEnhance.Contrast(enhanced).enhance(contrast)
enhanced = ImageEnhance.Sharpness(enhanced).enhance(sharpness)
return enhanced
@staticmethod
def create_comparison(image1, image2, labels=None):
"""
创建两张图片的对比视图(A/B测试用)
"""
if not image1 or not image2:
return image1 or image2
# 确保尺寸一致
width = max(image1.width, image2.width)
height = max(image1.height, image2.height)
image1 = image1.resize((width, height), Image.Resampling.LANCZOS)
image2 = image2.resize((width, height), Image.Resampling.LANCZOS)
# 创建并排对比
comparison = Image.new('RGB', (width * 2, height), color='white')
comparison.paste(image1, (0, 0))
comparison.paste(image2, (width, 0))
return comparison
@staticmethod
def resize_for_web(image, max_size=1024):
"""
优化图片用于网页显示(减小文件大小)
"""
if not image:
return image
width, height = image.size
# 如果图片太大,按比例缩小
if width > max_size or height > max_size:
ratio = min(max_size / width, max_size / height)
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
return image
@staticmethod
def validate_training_images(images):
"""
验证训练图片是否符合要求
返回: (valid_count, issues)
"""
if not images:
return 0, ["❌ 未上传任何图片"]
valid_count = 0
issues = []
# 检查图片数量
if len(images) < 5:
issues.append(f"⚠️ 图片数量较少:{len(images)}张,建议至少5-10张")
for i, img in enumerate(images):
if img is None:
issues.append(f"❌ 第{i+1}张图片无法读取")
continue
width, height = img.size
# 检查分辨率
if width < 256 or height < 256:
issues.append(f"⚠️ 第{i+1}张图片分辨率过低: {width}x{height}")
continue
# 检查纵横比
aspect_ratio = width / height
if aspect_ratio < 0.3 or aspect_ratio > 3.0:
issues.append(f"⚠️ 第{i+1}张图片比例异常: {aspect_ratio:.2f}")
continue
valid_count += 1
# 生成总体建议
if valid_count >= 10:
quality = "✅ 优秀"
elif valid_count >= 5:
quality = "✅ 良好"
elif valid_count >= 3:
quality = "⚠️ 基本可用"
else:
quality = "❌ 不足"
if not issues:
issues.append(f"{quality} - 数据质量评分")
return valid_count, issues
# 实用工具函数,可以直接在Gradio界面中调用
def quick_enhance(image, level="medium"):
"""
快速增强缩略图 - 用于生成后的可选后处理
"""
return ImageProcessor.enhance_thumbnail(image, level)
def prepare_uploaded_images(images):
"""
批量处理用户上传的训练图片
返回: (processed_images, validation_report)
"""
if not images:
return [], "❌ 未上传图片"
processed = []
for img in images:
if img is not None:
processed_img = ImageProcessor.prepare_for_training(img)
if processed_img:
processed.append(processed_img)
valid_count, issues = ImageProcessor.validate_training_images(processed)
report = f"✅ 成功处理 {len(processed)} 张图片\n"
report += f"📊 有效图片: {valid_count} 张\n"
if issues:
report += "⚠️ 检测到的问题:\n" + "\n".join(issues)
return processed, report
def create_ab_test_comparison(image1, image2):
"""
创建A/B测试对比图 - 用于比较不同prompt效果
"""
return ImageProcessor.create_comparison(image1, image2)
|