|
|
|
|
|
""" |
|
|
Image Quality Assessment and Optimization System |
|
|
Uses Qwen-VL 2.5 for image quality assessment and Illustrious model to regenerate low-quality images |
|
|
Focuses on anime character facial quality |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import gc |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Any, Tuple |
|
|
from PIL import Image |
|
|
import torch |
|
|
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor |
|
|
from qwen_vl_utils import process_vision_info |
|
|
from diffusers import StableDiffusionXLPipeline |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
|
handlers=[ |
|
|
logging.FileHandler('data_quality_optimization.log'), |
|
|
logging.StreamHandler() |
|
|
] |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class QualityAssessment: |
|
|
"""Use Qwen-VL 2.5 for image quality assessment""" |
|
|
|
|
|
def __init__(self, model_name: str = "Qwen/Qwen2.5-VL-7B-Instruct"): |
|
|
self.model_name = model_name |
|
|
self.model = None |
|
|
self.processor = None |
|
|
self.load_model() |
|
|
|
|
|
def load_model(self): |
|
|
"""Load Qwen-VL model""" |
|
|
logger.info("Loading Qwen-VL 2.5 model...") |
|
|
try: |
|
|
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
|
self.model_name, |
|
|
torch_dtype=torch.bfloat16, |
|
|
attn_implementation="flash_attention_2", |
|
|
device_map="auto", |
|
|
) |
|
|
self.processor = AutoProcessor.from_pretrained(self.model_name) |
|
|
logger.info("Qwen-VL 2.5 model loaded successfully") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load Qwen-VL model: {e}") |
|
|
raise |
|
|
|
|
|
def assess_image_quality(self, image_path: str) -> Dict[str, Any]: |
|
|
""" |
|
|
Assess image quality, especially focusing on anime character faces |
|
|
|
|
|
Args: |
|
|
image_path: Image file path |
|
|
|
|
|
Returns: |
|
|
Dictionary containing quality score and detailed analysis |
|
|
""" |
|
|
prompt = """You are a professional illustrator. Now an AI model generates some images for you. To satisfy your high-end customers, please carefully analyze the quality of this anime-style image, paying special attention to the following aspects: |
|
|
|
|
|
1. Character Face Quality (Primary Focus): |
|
|
- Are facial details clear and well-defined? |
|
|
- Are the eyes symmetrical and detailed? |
|
|
- Are the proportions of nose and mouth correct? |
|
|
- Is the facial contour natural? |
|
|
- Are there any blurry, distorted, or unnatural areas? |
|
|
|
|
|
2. Overall Image Quality: |
|
|
- Line clarity and sharpness |
|
|
- Color saturation and contrast |
|
|
- Composition and proportions |
|
|
- Level of detail richness |
|
|
|
|
|
3. Technical Issues: |
|
|
- Are there any artifacts or noise? |
|
|
- Are there obvious generation errors? |
|
|
- Is the resolution sufficient? |
|
|
|
|
|
Please provide an overall quality score (1-10, with 10 being the highest) and explain specific issues. If there are facial quality problems or the total score is below 7, recommend regeneration. |
|
|
|
|
|
Please respond in the following format: |
|
|
Score: X/10 |
|
|
Face Quality: [Good/Average/Poor] |
|
|
Main Issues: [Specific description of problems] |
|
|
Needs Regeneration: [Yes/No]""" |
|
|
|
|
|
try: |
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": image_path}, |
|
|
{"type": "text", "text": prompt} |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
text = self.processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
inputs = self.processor( |
|
|
text=[text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
inputs = inputs.to("cuda") |
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = self.model.generate(**inputs, max_new_tokens=512) |
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
|
] |
|
|
response = self.processor.batch_decode( |
|
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
|
)[0] |
|
|
|
|
|
|
|
|
assessment = self.parse_assessment_response(response) |
|
|
assessment['raw_response'] = response |
|
|
assessment['image_path'] = image_path |
|
|
print(f"Image:\n{image_path} Assessment Result:\n{response}") |
|
|
|
|
|
return assessment |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error evaluating image {image_path}: {e}") |
|
|
return { |
|
|
'score': 0, |
|
|
'face_quality': 'Unable to assess', |
|
|
'issues': f'Assessment failed: {str(e)}', |
|
|
'needs_regeneration': True, |
|
|
'raw_response': '', |
|
|
'image_path': image_path |
|
|
} |
|
|
|
|
|
def parse_assessment_response(self, response: str) -> Dict[str, Any]: |
|
|
"""Parse assessment response""" |
|
|
assessment = { |
|
|
'score': 5, |
|
|
'face_quality': 'Average', |
|
|
'issues': '', |
|
|
'needs_regeneration': False |
|
|
} |
|
|
|
|
|
try: |
|
|
lines = response.split('\n') |
|
|
for line in lines: |
|
|
line = line.strip() |
|
|
if line.startswith('Score:') or line.startswith('评分:') or line.startswith('评分:'): |
|
|
score_text = line.split(':')[-1].strip() |
|
|
if ':' in line: |
|
|
score_text = line.split(':')[-1].strip() |
|
|
score_match = score_text.split('/')[0].strip() |
|
|
try: |
|
|
assessment['score'] = int(score_match) |
|
|
except ValueError: |
|
|
pass |
|
|
elif line.startswith('Face Quality:') or line.startswith('脸部质量:') or line.startswith('脸部质量:'): |
|
|
if ':' in line: |
|
|
assessment['face_quality'] = line.split(':')[-1].strip() |
|
|
elif ':' in line: |
|
|
assessment['face_quality'] = line.split(':')[-1].strip() |
|
|
elif line.startswith('Main Issues:') or line.startswith('主要问题:') or line.startswith('主要问题:'): |
|
|
if ':' in line: |
|
|
assessment['issues'] = line.split(':')[-1].strip() |
|
|
elif ':' in line: |
|
|
assessment['issues'] = line.split(':')[-1].strip() |
|
|
elif line.startswith('Needs Regeneration:') or line.startswith('是否需要重新生成:') or line.startswith('是否需要重新生成:'): |
|
|
if ':' in line: |
|
|
regen_text = line.split(':')[-1].strip() |
|
|
assessment['needs_regeneration'] = regen_text.lower() in ['yes', 'true', '是'] |
|
|
elif ':' in line: |
|
|
regen_text = line.split(':')[-1].strip() |
|
|
assessment['needs_regeneration'] = regen_text == '是' |
|
|
except Exception as e: |
|
|
logger.warning(f"Error parsing assessment response: {e}") |
|
|
|
|
|
|
|
|
if assessment['score'] < 7: |
|
|
assessment['needs_regeneration'] = True |
|
|
|
|
|
return assessment |
|
|
|
|
|
def clear_memory(self): |
|
|
"""Clear GPU memory""" |
|
|
if self.model is not None: |
|
|
del self.model |
|
|
self.model = None |
|
|
if self.processor is not None: |
|
|
del self.processor |
|
|
self.processor = None |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
class ImageGenerator: |
|
|
"""使用Illustrious模型重新生成图像""" |
|
|
|
|
|
def __init__(self, model_path: str = "models/waiNSFWIllustrious_v140.safetensors"): |
|
|
self.model_path = model_path |
|
|
self.pipeline = None |
|
|
self.load_model() |
|
|
|
|
|
def load_model(self): |
|
|
"""加载Illustrious模型""" |
|
|
logger.info("Loading Illustrious model...") |
|
|
try: |
|
|
self.pipeline = StableDiffusionXLPipeline.from_single_file( |
|
|
self.model_path, |
|
|
torch_dtype=torch.float16, |
|
|
use_safetensors=True |
|
|
) |
|
|
self.pipeline = self.pipeline.to("cuda") |
|
|
logger.info("Illustrious model loaded successfully") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load Illustrious model: {e}") |
|
|
raise |
|
|
|
|
|
def generate_images(self, prompt_data: Dict[str, Any], num_candidates: int = 5) -> List[Image.Image]: |
|
|
""" |
|
|
基于原始提示词生成多个候选图像 |
|
|
|
|
|
Args: |
|
|
prompt_data: 包含positive_prompt和negative_prompt的字典 |
|
|
num_candidates: 生成候选图像数量 |
|
|
|
|
|
Returns: |
|
|
生成的图像列表 |
|
|
""" |
|
|
positive_prompt = prompt_data.get('positive_prompt', '') |
|
|
negative_prompt = prompt_data.get('negative_prompt', '') |
|
|
|
|
|
logger.info(f"Generating {num_candidates} candidate images") |
|
|
|
|
|
images = [] |
|
|
try: |
|
|
for i in range(num_candidates): |
|
|
logger.info(f"Generating image {i+1}/{num_candidates}") |
|
|
|
|
|
image = self.pipeline( |
|
|
prompt=positive_prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
width=1024, |
|
|
height=512, |
|
|
num_inference_steps=35, |
|
|
guidance_scale=7.5, |
|
|
num_images_per_prompt=1, |
|
|
).images[0] |
|
|
|
|
|
images.append(image) |
|
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating images: {e}") |
|
|
|
|
|
return images |
|
|
|
|
|
def clear_memory(self): |
|
|
"""Clear GPU memory""" |
|
|
if self.pipeline is not None: |
|
|
del self.pipeline |
|
|
self.pipeline = None |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
class QualityOptimizer: |
|
|
"""Image quality optimization main class""" |
|
|
|
|
|
def __init__(self, |
|
|
illustrious_generated_dir: str = "/home/ubuntu/lyl/QwenIllustrious/illustrious_generated", |
|
|
qwen_model_name: str = "Qwen/Qwen2.5-VL-7B-Instruct", |
|
|
illustrious_model_path: str = "models/waiNSFWIllustrious_v140.safetensors"): |
|
|
|
|
|
self.illustrious_generated_dir = Path(illustrious_generated_dir) |
|
|
self.metadata_dir = self.illustrious_generated_dir / "metadata" |
|
|
self.candidates_dir = self.illustrious_generated_dir / "candidates" |
|
|
self.improved_dir = self.illustrious_generated_dir / "improved" |
|
|
|
|
|
|
|
|
self.candidates_dir.mkdir(exist_ok=True) |
|
|
self.improved_dir.mkdir(exist_ok=True) |
|
|
|
|
|
self.qwen_model_name = qwen_model_name |
|
|
self.illustrious_model_path = illustrious_model_path |
|
|
|
|
|
|
|
|
self.assessor = None |
|
|
self.generator = None |
|
|
|
|
|
|
|
|
self.low_quality_images = [] |
|
|
self.optimization_results = [] |
|
|
|
|
|
def scan_and_assess_images(self, batch_size: int = 50) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Scan and assess quality of all images |
|
|
|
|
|
Args: |
|
|
batch_size: Batch size, clean memory after processing specified number |
|
|
|
|
|
Returns: |
|
|
List of low-quality images |
|
|
""" |
|
|
logger.info("Starting to scan and assess image quality...") |
|
|
|
|
|
|
|
|
if self.assessor is None: |
|
|
self.assessor = QualityAssessment(self.qwen_model_name) |
|
|
|
|
|
|
|
|
image_files = list(self.illustrious_generated_dir.glob("*.png")) |
|
|
total_images = len(image_files) |
|
|
logger.info(f"Found {total_images} image files") |
|
|
|
|
|
low_quality_images = [] |
|
|
processed_count = 0 |
|
|
|
|
|
for image_file in image_files: |
|
|
try: |
|
|
|
|
|
metadata_file = self.metadata_dir / f"{image_file.stem}.json" |
|
|
if not metadata_file.exists(): |
|
|
logger.warning(f"Missing metadata file: {metadata_file}") |
|
|
continue |
|
|
|
|
|
|
|
|
with open(metadata_file, 'r', encoding='utf-8') as f: |
|
|
metadata = json.load(f) |
|
|
|
|
|
logger.info(f"Assessing image ({processed_count + 1}/{total_images}): {image_file.name}") |
|
|
|
|
|
|
|
|
assessment = self.assessor.assess_image_quality(str(image_file)) |
|
|
|
|
|
|
|
|
if assessment['needs_regeneration']: |
|
|
low_quality_record = { |
|
|
'image_file': str(image_file), |
|
|
'metadata_file': str(metadata_file), |
|
|
'metadata': metadata, |
|
|
'assessment': assessment, |
|
|
'timestamp': time.time() |
|
|
} |
|
|
low_quality_images.append(low_quality_record) |
|
|
logger.info(f"Detected low quality image: {image_file.name} (Score: {assessment['score']}/10)") |
|
|
|
|
|
processed_count += 1 |
|
|
|
|
|
|
|
|
if processed_count % batch_size == 0: |
|
|
logger.info(f"Processed {processed_count} images, clearing memory...") |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing image {image_file}: {e}") |
|
|
continue |
|
|
|
|
|
logger.info(f"Quality assessment completed. Found {len(low_quality_images)} low quality images") |
|
|
self.low_quality_images = low_quality_images |
|
|
|
|
|
|
|
|
self.save_low_quality_records() |
|
|
|
|
|
return low_quality_images |
|
|
|
|
|
def save_low_quality_records(self): |
|
|
"""Save low quality image records to JSON file""" |
|
|
output_file = self.illustrious_generated_dir / "low_quality_images.json" |
|
|
try: |
|
|
with open(output_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(self.low_quality_images, f, ensure_ascii=False, indent=2) |
|
|
logger.info(f"Low quality image records saved to: {output_file}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to save low quality image records: {e}") |
|
|
|
|
|
def regenerate_low_quality_images(self, max_images: int = None): |
|
|
""" |
|
|
Regenerate candidate images for low quality images |
|
|
|
|
|
Args: |
|
|
max_images: Maximum number of images to process, None means process all |
|
|
""" |
|
|
if not self.low_quality_images: |
|
|
logger.info("No low quality images need regeneration") |
|
|
return |
|
|
|
|
|
|
|
|
if self.assessor is not None: |
|
|
self.assessor.clear_memory() |
|
|
self.assessor = None |
|
|
|
|
|
if self.generator is None: |
|
|
self.generator = ImageGenerator(self.illustrious_model_path) |
|
|
|
|
|
images_to_process = self.low_quality_images |
|
|
if max_images is not None: |
|
|
images_to_process = images_to_process[:max_images] |
|
|
|
|
|
logger.info(f"Starting regeneration for {len(images_to_process)} low quality images...") |
|
|
|
|
|
for idx, record in enumerate(images_to_process): |
|
|
try: |
|
|
image_hash = Path(record['image_file']).stem |
|
|
logger.info(f"Regenerating image ({idx + 1}/{len(images_to_process)}): {image_hash}") |
|
|
|
|
|
|
|
|
prompt_data = record['metadata']['original_prompt_data'] |
|
|
candidate_images = self.generator.generate_images(prompt_data, num_candidates=5) |
|
|
|
|
|
|
|
|
candidate_dir = self.candidates_dir / image_hash |
|
|
candidate_dir.mkdir(exist_ok=True) |
|
|
|
|
|
candidate_files = [] |
|
|
for i, img in enumerate(candidate_images): |
|
|
candidate_file = candidate_dir / f"candidate_{i+1}.png" |
|
|
img.save(candidate_file) |
|
|
candidate_files.append(str(candidate_file)) |
|
|
|
|
|
|
|
|
generation_result = { |
|
|
'original_image': record['image_file'], |
|
|
'candidates': candidate_files, |
|
|
'generation_timestamp': time.time(), |
|
|
'original_assessment': record['assessment'] |
|
|
} |
|
|
|
|
|
self.optimization_results.append(generation_result) |
|
|
logger.info(f"Generated {len(candidate_images)} candidate images for {image_hash}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error regenerating image: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
self.save_generation_results() |
|
|
|
|
|
def save_generation_results(self): |
|
|
"""Save regeneration results""" |
|
|
output_file = self.illustrious_generated_dir / "regeneration_results.json" |
|
|
try: |
|
|
with open(output_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(self.optimization_results, f, ensure_ascii=False, indent=2) |
|
|
logger.info(f"Regeneration results saved to: {output_file}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to save regeneration results: {e}") |
|
|
|
|
|
def select_best_candidates(self): |
|
|
"""评估候选图像并选择最佳替换""" |
|
|
if not self.optimization_results: |
|
|
logger.info("没有候选图像需要评估") |
|
|
return |
|
|
|
|
|
|
|
|
if self.generator is not None: |
|
|
self.generator.clear_memory() |
|
|
self.generator = None |
|
|
|
|
|
if self.assessor is None: |
|
|
self.assessor = QualityAssessment(self.qwen_model_name) |
|
|
|
|
|
logger.info(f"开始评估 {len(self.optimization_results)} 组候选图像...") |
|
|
|
|
|
final_results = [] |
|
|
|
|
|
for idx, result in enumerate(self.optimization_results): |
|
|
try: |
|
|
logger.info(f"评估候选组 ({idx + 1}/{len(self.optimization_results)})") |
|
|
|
|
|
candidate_assessments = [] |
|
|
for candidate_file in result['candidates']: |
|
|
assessment = self.assessor.assess_image_quality(candidate_file) |
|
|
candidate_assessments.append({ |
|
|
'file': candidate_file, |
|
|
'assessment': assessment |
|
|
}) |
|
|
|
|
|
|
|
|
best_candidate = max(candidate_assessments, key=lambda x: x['assessment']['score']) |
|
|
|
|
|
|
|
|
original_score = result['original_assessment']['score'] |
|
|
best_score = best_candidate['assessment']['score'] |
|
|
|
|
|
if best_score > original_score + 1: |
|
|
|
|
|
original_filename = Path(result['original_image']).name |
|
|
improved_file = self.improved_dir / original_filename |
|
|
|
|
|
best_image = Image.open(best_candidate['file']) |
|
|
best_image.save(improved_file) |
|
|
|
|
|
final_result = { |
|
|
'original_image': result['original_image'], |
|
|
'improved_image': str(improved_file), |
|
|
'original_score': original_score, |
|
|
'improved_score': best_score, |
|
|
'improvement': best_score - original_score, |
|
|
'best_candidate_source': best_candidate['file'], |
|
|
'all_candidates': candidate_assessments |
|
|
} |
|
|
|
|
|
logger.info(f"图像质量提升: {original_filename} " |
|
|
f"({original_score}/10 -> {best_score}/10, +{best_score - original_score})") |
|
|
else: |
|
|
final_result = { |
|
|
'original_image': result['original_image'], |
|
|
'improved_image': None, |
|
|
'original_score': original_score, |
|
|
'best_candidate_score': best_score, |
|
|
'improvement': 0, |
|
|
'reason': '候选图像质量未达到替换标准', |
|
|
'all_candidates': candidate_assessments |
|
|
} |
|
|
|
|
|
logger.info(f"保持原图: {Path(result['original_image']).name} " |
|
|
f"(原图{original_score}/10, 最佳候选{best_score}/10)") |
|
|
|
|
|
final_results.append(final_result) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"评估候选图像时出错: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
self.save_final_results(final_results) |
|
|
|
|
|
|
|
|
self.generate_summary_report(final_results) |
|
|
|
|
|
def save_final_results(self, final_results: List[Dict[str, Any]]): |
|
|
"""保存最终优化结果""" |
|
|
output_file = self.illustrious_generated_dir / "optimization_final_results.json" |
|
|
try: |
|
|
with open(output_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(final_results, f, ensure_ascii=False, indent=2) |
|
|
logger.info(f"最终优化结果已保存到: {output_file}") |
|
|
except Exception as e: |
|
|
logger.error(f"保存最终结果失败: {e}") |
|
|
|
|
|
def generate_summary_report(self, final_results: List[Dict[str, Any]]): |
|
|
"""生成优化总结报告""" |
|
|
improved_count = sum(1 for r in final_results if r.get('improved_image') is not None) |
|
|
total_processed = len(final_results) |
|
|
total_low_quality = len(self.low_quality_images) |
|
|
|
|
|
if improved_count > 0: |
|
|
avg_improvement = sum(r.get('improvement', 0) for r in final_results if r.get('improvement', 0) > 0) / improved_count |
|
|
else: |
|
|
avg_improvement = 0 |
|
|
|
|
|
report = f""" |
|
|
=== 图像质量优化总结报告 === |
|
|
|
|
|
处理统计: |
|
|
- 总图像数: {len(list(self.illustrious_generated_dir.glob('*.png')))} |
|
|
- 检测到低质量图像: {total_low_quality} |
|
|
- 重新生成处理: {total_processed} |
|
|
- 成功改善质量: {improved_count} |
|
|
- 改善成功率: {improved_count/total_processed*100:.1f}% |
|
|
|
|
|
质量提升: |
|
|
- 平均质量提升: {avg_improvement:.1f} 分 |
|
|
- 改善图像保存位置: {self.improved_dir} |
|
|
|
|
|
详细结果文件: |
|
|
- 低质量图像记录: low_quality_images.json |
|
|
- 重新生成结果: regeneration_results.json |
|
|
- 最终优化结果: optimization_final_results.json |
|
|
|
|
|
优化完成时间: {time.strftime('%Y-%m-%d %H:%M:%S')} |
|
|
""" |
|
|
|
|
|
|
|
|
report_file = self.illustrious_generated_dir / "optimization_summary_report.txt" |
|
|
with open(report_file, 'w', encoding='utf-8') as f: |
|
|
f.write(report) |
|
|
|
|
|
logger.info(report) |
|
|
logger.info(f"优化报告已保存到: {report_file}") |
|
|
|
|
|
def run_full_optimization(self, batch_size: int = 50, max_regenerate: int = None): |
|
|
""" |
|
|
运行完整的优化流程 |
|
|
|
|
|
Args: |
|
|
batch_size: 评估批处理大小 |
|
|
max_regenerate: 最大重新生成图像数量 |
|
|
""" |
|
|
logger.info("开始完整的图像质量优化流程...") |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info("=== 步骤1: 图像质量评估 ===") |
|
|
self.scan_and_assess_images(batch_size=batch_size) |
|
|
|
|
|
if not self.low_quality_images: |
|
|
logger.info("未发现需要优化的低质量图像,优化流程结束") |
|
|
return |
|
|
|
|
|
|
|
|
logger.info("=== 步骤2: 重新生成图像 ===") |
|
|
self.regenerate_low_quality_images(max_images=max_regenerate) |
|
|
|
|
|
|
|
|
logger.info("=== 步骤3: 选择最佳候选图像 ===") |
|
|
self.select_best_candidates() |
|
|
|
|
|
logger.info("=== 图像质量优化完成 ===") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"优化流程出错: {e}") |
|
|
raise |
|
|
finally: |
|
|
|
|
|
if self.assessor is not None: |
|
|
self.assessor.clear_memory() |
|
|
if self.generator is not None: |
|
|
self.generator.clear_memory() |
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
|
|
|
config = { |
|
|
'illustrious_generated_dir': '/home/ubuntu/lyl/QwenIllustrious/illustrious_generated', |
|
|
'qwen_model_name': 'models/Qwen2.5-VL-7B-Instruct', |
|
|
'illustrious_model_path': 'models/waiNSFWIllustrious_v140.safetensors', |
|
|
'batch_size': 30, |
|
|
'max_regenerate': 100 |
|
|
} |
|
|
|
|
|
logger.info("启动图像质量优化系统...") |
|
|
logger.info(f"配置参数: {config}") |
|
|
|
|
|
|
|
|
optimizer = QualityOptimizer( |
|
|
illustrious_generated_dir=config['illustrious_generated_dir'], |
|
|
qwen_model_name=config['qwen_model_name'], |
|
|
illustrious_model_path=config['illustrious_model_path'] |
|
|
) |
|
|
|
|
|
|
|
|
optimizer.run_full_optimization( |
|
|
batch_size=config['batch_size'], |
|
|
max_regenerate=config['max_regenerate'] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|