| |
| """ |
| 数据清理脚本:清理prompt和neg prompt中的权重值和LoRA标签 |
| """ |
|
|
| import pandas as pd |
| import os |
| import sys |
| import re |
| from pathlib import Path |
|
|
| def clean_prompt_weights_and_loras(input_file, output_file=None): |
| """ |
| 清理prompt和neg prompt列,移除权重值和LoRA标签 |
| |
| Args: |
| input_file (str): 输入CSV文件路径 |
| output_file (str, optional): 输出CSV文件路径,如果不指定则覆盖原文件 |
| """ |
| |
| |
| if not os.path.exists(input_file): |
| print(f"错误: 输入文件 {input_file} 不存在") |
| return False |
| |
| try: |
| |
| print(f"正在读取文件: {input_file}") |
| df = pd.read_csv(input_file) |
| |
| print(f"原始数据行数: {len(df)}") |
| print(f"原始数据列数: {len(df.columns)}") |
| print(f"列名: {list(df.columns)}") |
| |
| |
| if 'prompt' not in df.columns or 'neg prompt' not in df.columns: |
| print("错误: CSV文件中缺少'prompt'或'neg prompt'列") |
| return False |
| |
| |
| def clean_prompt_text(text): |
| """清理prompt文本中的权重值和LoRA标签""" |
| if pd.isna(text) or text == '': |
| return text |
| |
| text = str(text) |
| |
| |
| |
| lora_pattern = r'<lora:[^>]*>' |
| text = re.sub(lora_pattern, '', text) |
| |
| |
| |
| weight_pattern = r'\(([^():]+):[\d.]+\)' |
| text = re.sub(weight_pattern, r'\1', text) |
| |
| |
| nested_weight_pattern = r'\(\(([^()]+):[\d.]+\)\)' |
| text = re.sub(nested_weight_pattern, r'\1', text) |
| |
| |
| triple_weight_pattern = r'\(\(\(([^()]+):[\d.]+\)\)\)' |
| text = re.sub(triple_weight_pattern, r'\1', text) |
| |
| |
| |
| text = re.sub(r'\s+', ' ', text) |
| |
| text = re.sub(r',\s*,+', ',', text) |
| |
| text = text.strip().strip(',').strip() |
| |
| text = re.sub(r',\s*,', ',', text) |
| |
| return text |
| |
| |
| def count_patterns(series, pattern, description): |
| """统计指定模式在系列中的出现次数""" |
| count = 0 |
| for text in series: |
| if pd.notna(text): |
| matches = re.findall(pattern, str(text)) |
| count += len(matches) |
| return count |
| |
| |
| lora_pattern = r'<lora:[^>]*>' |
| prompt_loras_before = count_patterns(df['prompt'], lora_pattern, 'LoRA标签') |
| neg_prompt_loras_before = count_patterns(df['neg prompt'], lora_pattern, 'LoRA标签') |
| |
| |
| weight_pattern = r'\([^():]+:[\d.]+\)' |
| prompt_weights_before = count_patterns(df['prompt'], weight_pattern, '权重值') |
| neg_prompt_weights_before = count_patterns(df['neg prompt'], weight_pattern, '权重值') |
| |
| print(f"\n清理前统计:") |
| print(f"Prompt列:") |
| print(f" - LoRA标签数量: {prompt_loras_before}") |
| print(f" - 权重值数量: {prompt_weights_before}") |
| print(f"Neg prompt列:") |
| print(f" - LoRA标签数量: {neg_prompt_loras_before}") |
| print(f" - 权重值数量: {neg_prompt_weights_before}") |
| |
| |
| cleaned_df = df.copy() |
| |
| |
| cleaned_df['prompt'] = cleaned_df['prompt'].apply(clean_prompt_text) |
| |
| |
| cleaned_df['neg prompt'] = cleaned_df['neg prompt'].apply(clean_prompt_text) |
| |
| |
| prompt_loras_after = count_patterns(cleaned_df['prompt'], lora_pattern, 'LoRA标签') |
| neg_prompt_loras_after = count_patterns(cleaned_df['neg prompt'], lora_pattern, 'LoRA标签') |
| prompt_weights_after = count_patterns(cleaned_df['prompt'], weight_pattern, '权重值') |
| neg_prompt_weights_after = count_patterns(cleaned_df['neg prompt'], weight_pattern, '权重值') |
| |
| print(f"\n清理后统计:") |
| print(f"Prompt列:") |
| print(f" - LoRA标签数量: {prompt_loras_after}") |
| print(f" - 权重值数量: {prompt_weights_after}") |
| print(f"Neg prompt列:") |
| print(f" - LoRA标签数量: {neg_prompt_loras_after}") |
| print(f" - 权重值数量: {neg_prompt_weights_after}") |
| |
| |
| total_loras_cleaned = ( |
| (prompt_loras_before - prompt_loras_after) + |
| (neg_prompt_loras_before - neg_prompt_loras_after) |
| ) |
| total_weights_cleaned = ( |
| (prompt_weights_before - prompt_weights_after) + |
| (neg_prompt_weights_before - neg_prompt_weights_after) |
| ) |
| |
| print(f"\n清理总结:") |
| print(f"- 清理的LoRA标签数量: {total_loras_cleaned}") |
| print(f"- 清理的权重值数量: {total_weights_cleaned}") |
| print(f"- 总共清理的元素数量: {total_loras_cleaned + total_weights_cleaned}") |
| |
| |
| if output_file is None: |
| output_file = input_file |
| backup_file = input_file + '.backup' |
| |
| df.to_csv(backup_file, index=False) |
| print(f"原始文件已备份到: {backup_file}") |
| |
| |
| cleaned_df.to_csv(output_file, index=False) |
| print(f"清理后的数据已保存到: {output_file}") |
| |
| |
| print(f"\n清理示例:") |
| |
| examples_shown = 0 |
| max_examples = 3 |
| |
| for idx, row in df.iterrows(): |
| if examples_shown >= max_examples: |
| break |
| |
| original_prompt = str(row['prompt']) if pd.notna(row['prompt']) else '' |
| cleaned_prompt = str(cleaned_df.loc[idx, 'prompt']) if pd.notna(cleaned_df.loc[idx, 'prompt']) else '' |
| |
| original_neg_prompt = str(row['neg prompt']) if pd.notna(row['neg prompt']) else '' |
| cleaned_neg_prompt = str(cleaned_df.loc[idx, 'neg prompt']) if pd.notna(cleaned_df.loc[idx, 'neg prompt']) else '' |
| |
| |
| prompt_changed = original_prompt != cleaned_prompt |
| neg_prompt_changed = original_neg_prompt != cleaned_neg_prompt |
| |
| if prompt_changed or neg_prompt_changed: |
| examples_shown += 1 |
| print(f"\n示例 {examples_shown}:") |
| |
| if prompt_changed: |
| |
| lora_matches = re.findall(lora_pattern, original_prompt) |
| weight_matches = re.findall(weight_pattern, original_prompt) |
| |
| if lora_matches: |
| print(f" Prompt中移除的LoRA: {lora_matches[:3]}{'...' if len(lora_matches) > 3 else ''}") |
| if weight_matches: |
| print(f" Prompt中移除的权重: {weight_matches[:3]}{'...' if len(weight_matches) > 3 else ''}") |
| |
| |
| if len(original_prompt) > 100: |
| print(f" Prompt原文片段: {original_prompt[:100]}...") |
| print(f" Prompt清理后片段: {cleaned_prompt[:100]}...") |
| else: |
| print(f" Prompt原文: {original_prompt}") |
| print(f" Prompt清理后: {cleaned_prompt}") |
| |
| if neg_prompt_changed: |
| lora_matches = re.findall(lora_pattern, original_neg_prompt) |
| weight_matches = re.findall(weight_pattern, original_neg_prompt) |
| |
| if lora_matches: |
| print(f" Neg prompt中移除的LoRA: {lora_matches}") |
| if weight_matches: |
| print(f" Neg prompt中移除的权重: {weight_matches}") |
| |
| return True |
| |
| except Exception as e: |
| print(f"处理过程中发生错误: {str(e)}") |
| return False |
|
|
| def main(): |
| """主函数""" |
| |
| default_input = "/home/ubuntu/lyl/QwenIllustrious/civitai_image.csv" |
| |
| |
| if len(sys.argv) == 1: |
| input_file = default_input |
| output_file = None |
| elif len(sys.argv) == 2: |
| input_file = sys.argv[1] |
| output_file = None |
| elif len(sys.argv) == 3: |
| input_file = sys.argv[1] |
| output_file = sys.argv[2] |
| else: |
| print("使用方法:") |
| print(" python clean_weights_and_loras.py") |
| print(" python clean_weights_and_loras.py <input_file>") |
| print(" python clean_weights_and_loras.py <input_file> <output_file>") |
| return |
| |
| print("=" * 60) |
| print("Civitai数据清理工具 - 移除权重值和LoRA标签") |
| print("=" * 60) |
| |
| |
| success = clean_prompt_weights_and_loras(input_file, output_file) |
| |
| if success: |
| print("\n✅ 数据清理完成!") |
| else: |
| print("\n❌ 数据清理失败!") |
| sys.exit(1) |
|
|
| if __name__ == "__main__": |
| main() |
|
|