File size: 7,808 Bytes
3f9fa87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 |
#!/usr/bin/env python3
"""
图像质量优化系统 - 测试版本
用于快速测试评估和生成功能
"""
import os
import json
import gc
import time
from pathlib import Path
from typing import List, Dict, Any
from PIL import Image
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from diffusers import StableDiffusionXLPipeline
import logging
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_quality_assessment(image_path: str, model_name: str = "Qwen/Qwen2.5-VL-7B-Instruct"):
"""测试单张图像的质量评估"""
logger.info(f"测试图像质量评估: {image_path}")
# 加载模型
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_name)
prompt = """请仔细分析这张动漫风格图像的质量,特别关注以下方面:
1. 角色脸部质量(重点评估):
- 脸部细节是否清晰
- 眼睛是否对称且细节丰富
- 鼻子和嘴巴的比例是否正确
- 脸部轮廓是否自然
2. 整体图像质量:
- 线条清晰度
- 色彩饱和度和对比度
- 构图和比例
请给出总体质量评分(1-10分),并说明是否需要重新生成。
请按以下格式回答:
评分:X/10
脸部质量:[好/一般/差]
主要问题:[具体描述问题]
是否需要重新生成:[是/否]"""
try:
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_path},
{"type": "text", "text": prompt}
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
response = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
logger.info(f"评估结果:\n{response}")
return response
except Exception as e:
logger.error(f"评估失败: {e}")
return None
finally:
# 清理内存
del model, processor
torch.cuda.empty_cache()
gc.collect()
def test_image_generation(prompt_data: Dict[str, Any], model_path: str = "models/waiNSFWIllustrious_v140.safetensors"):
"""测试图像生成"""
logger.info("测试图像生成功能")
# 加载Illustrious模型
pipeline = StableDiffusionXLPipeline.from_single_file(
model_path,
torch_dtype=torch.float16,
use_safetensors=True
)
pipeline = pipeline.to("cuda")
try:
positive_prompt = prompt_data.get('positive_prompt', '')
negative_prompt = prompt_data.get('negative_prompt', '')
logger.info(f"正面提示词: {positive_prompt[:100]}...")
image = pipeline(
prompt=positive_prompt,
negative_prompt=negative_prompt,
width=1024,
height=512,
num_inference_steps=35,
guidance_scale=7.5,
num_images_per_prompt=1,
).images[0]
# 保存测试图像
output_path = "/home/ubuntu/lyl/QwenIllustrious/data_tool/improve_data_quality/test_generated.png"
image.save(output_path)
logger.info(f"测试图像已保存到: {output_path}")
return output_path
except Exception as e:
logger.error(f"生成失败: {e}")
return None
finally:
# 清理内存
del pipeline
torch.cuda.empty_cache()
gc.collect()
def quick_assessment_test():
"""快速评估测试 - 评估几张样本图像"""
illustrious_dir = Path("/home/ubuntu/lyl/QwenIllustrious/illustrious_generated")
metadata_dir = illustrious_dir / "metadata"
# 获取前5张图像进行测试
image_files = list(illustrious_dir.glob("*.png"))[:5]
logger.info(f"快速测试 - 评估前 {len(image_files)} 张图像")
results = []
for image_file in image_files:
logger.info(f"评估: {image_file.name}")
# 检查metadata
metadata_file = metadata_dir / f"{image_file.stem}.json"
if metadata_file.exists():
with open(metadata_file, 'r', encoding='utf-8') as f:
metadata = json.load(f)
else:
logger.warning(f"缺少metadata: {metadata_file}")
continue
# 评估图像
assessment = test_quality_assessment(str(image_file))
if assessment:
results.append({
'image': image_file.name,
'assessment': assessment,
'metadata': metadata
})
# 为了节省时间,每次评估后稍微休息
time.sleep(2)
# 保存结果
output_file = "/home/ubuntu/lyl/QwenIllustrious/data_tool/improve_data_quality/quick_test_results.json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
logger.info(f"快速测试完成,结果保存到: {output_file}")
return results
def generation_test():
"""生成测试 - 使用样本metadata生成图像"""
illustrious_dir = Path("/home/ubuntu/lyl/QwenIllustrious/illustrious_generated")
metadata_dir = illustrious_dir / "metadata"
# 获取第一个metadata文件
metadata_files = list(metadata_dir.glob("*.json"))
if not metadata_files:
logger.error("未找到metadata文件")
return None
metadata_file = metadata_files[0]
logger.info(f"使用metadata: {metadata_file.name}")
with open(metadata_file, 'r', encoding='utf-8') as f:
metadata = json.load(f)
prompt_data = metadata['original_prompt_data']
logger.info(f"提示词数据: {prompt_data}")
# 生成测试图像
generated_image = test_image_generation(prompt_data)
return generated_image
def main():
"""主测试函数"""
logger.info("开始图像质量优化系统测试")
# 测试选项
test_options = {
'1': ('快速质量评估测试', quick_assessment_test),
'2': ('图像生成测试', generation_test),
'3': ('完整流程测试', lambda: logger.info("完整流程测试尚未实现"))
}
print("请选择测试类型:")
for key, (name, _) in test_options.items():
print(f"{key}. {name}")
choice = input("输入选择 (1-3): ").strip()
if choice in test_options:
name, func = test_options[choice]
logger.info(f"执行: {name}")
try:
result = func()
logger.info(f"测试完成: {name}")
return result
except Exception as e:
logger.error(f"测试失败: {e}")
raise
else:
logger.error("无效选择")
if __name__ == "__main__":
main()
|