qwenillustrious / data_tool /clean_civitai_data /clean_weights_and_loras.py
lsmpp's picture
Add files using upload-large-folder tool
3f9fa87 verified
#!/usr/bin/env python3
"""
数据清理脚本:清理prompt和neg prompt中的权重值和LoRA标签
"""
import pandas as pd
import os
import sys
import re
from pathlib import Path
def clean_prompt_weights_and_loras(input_file, output_file=None):
"""
清理prompt和neg prompt列,移除权重值和LoRA标签
Args:
input_file (str): 输入CSV文件路径
output_file (str, optional): 输出CSV文件路径,如果不指定则覆盖原文件
"""
# 检查输入文件是否存在
if not os.path.exists(input_file):
print(f"错误: 输入文件 {input_file} 不存在")
return False
try:
# 读取CSV文件
print(f"正在读取文件: {input_file}")
df = pd.read_csv(input_file)
print(f"原始数据行数: {len(df)}")
print(f"原始数据列数: {len(df.columns)}")
print(f"列名: {list(df.columns)}")
# 检查是否存在必要的列
if 'prompt' not in df.columns or 'neg prompt' not in df.columns:
print("错误: CSV文件中缺少'prompt'或'neg prompt'列")
return False
# 定义清理函数
def clean_prompt_text(text):
"""清理prompt文本中的权重值和LoRA标签"""
if pd.isna(text) or text == '':
return text
text = str(text)
# 1. 移除LoRA标签 <lora:...>
# 匹配 <lora:任意内容> 格式
lora_pattern = r'<lora:[^>]*>'
text = re.sub(lora_pattern, '', text)
# 2. 移除权重值 (text:number)
# 匹配 (内容:数字) 格式,保留内容,移除权重
weight_pattern = r'\(([^():]+):[\d.]+\)'
text = re.sub(weight_pattern, r'\1', text)
# 3. 移除多重嵌套的括号权重 ((text:number))
nested_weight_pattern = r'\(\(([^()]+):[\d.]+\)\)'
text = re.sub(nested_weight_pattern, r'\1', text)
# 4. 移除三重嵌套的括号权重 (((text:number)))
triple_weight_pattern = r'\(\(\(([^()]+):[\d.]+\)\)\)'
text = re.sub(triple_weight_pattern, r'\1', text)
# 5. 清理多余的空格和逗号
# 移除多个连续空格
text = re.sub(r'\s+', ' ', text)
# 移除多个连续逗号
text = re.sub(r',\s*,+', ',', text)
# 移除开头和结尾的逗号和空格
text = text.strip().strip(',').strip()
# 移除孤立的逗号
text = re.sub(r',\s*,', ',', text)
return text
# 统计清理前的情况
def count_patterns(series, pattern, description):
"""统计指定模式在系列中的出现次数"""
count = 0
for text in series:
if pd.notna(text):
matches = re.findall(pattern, str(text))
count += len(matches)
return count
# LoRA标签统计
lora_pattern = r'<lora:[^>]*>'
prompt_loras_before = count_patterns(df['prompt'], lora_pattern, 'LoRA标签')
neg_prompt_loras_before = count_patterns(df['neg prompt'], lora_pattern, 'LoRA标签')
# 权重值统计
weight_pattern = r'\([^():]+:[\d.]+\)'
prompt_weights_before = count_patterns(df['prompt'], weight_pattern, '权重值')
neg_prompt_weights_before = count_patterns(df['neg prompt'], weight_pattern, '权重值')
print(f"\n清理前统计:")
print(f"Prompt列:")
print(f" - LoRA标签数量: {prompt_loras_before}")
print(f" - 权重值数量: {prompt_weights_before}")
print(f"Neg prompt列:")
print(f" - LoRA标签数量: {neg_prompt_loras_before}")
print(f" - 权重值数量: {neg_prompt_weights_before}")
# 创建清理后的数据副本
cleaned_df = df.copy()
# 清理prompt列
cleaned_df['prompt'] = cleaned_df['prompt'].apply(clean_prompt_text)
# 清理neg prompt列
cleaned_df['neg prompt'] = cleaned_df['neg prompt'].apply(clean_prompt_text)
# 统计清理后的情况
prompt_loras_after = count_patterns(cleaned_df['prompt'], lora_pattern, 'LoRA标签')
neg_prompt_loras_after = count_patterns(cleaned_df['neg prompt'], lora_pattern, 'LoRA标签')
prompt_weights_after = count_patterns(cleaned_df['prompt'], weight_pattern, '权重值')
neg_prompt_weights_after = count_patterns(cleaned_df['neg prompt'], weight_pattern, '权重值')
print(f"\n清理后统计:")
print(f"Prompt列:")
print(f" - LoRA标签数量: {prompt_loras_after}")
print(f" - 权重值数量: {prompt_weights_after}")
print(f"Neg prompt列:")
print(f" - LoRA标签数量: {neg_prompt_loras_after}")
print(f" - 权重值数量: {neg_prompt_weights_after}")
# 计算清理的总数
total_loras_cleaned = (
(prompt_loras_before - prompt_loras_after) +
(neg_prompt_loras_before - neg_prompt_loras_after)
)
total_weights_cleaned = (
(prompt_weights_before - prompt_weights_after) +
(neg_prompt_weights_before - neg_prompt_weights_after)
)
print(f"\n清理总结:")
print(f"- 清理的LoRA标签数量: {total_loras_cleaned}")
print(f"- 清理的权重值数量: {total_weights_cleaned}")
print(f"- 总共清理的元素数量: {total_loras_cleaned + total_weights_cleaned}")
# 确定输出文件路径
if output_file is None:
output_file = input_file
backup_file = input_file + '.backup'
# 创建备份
df.to_csv(backup_file, index=False)
print(f"原始文件已备份到: {backup_file}")
# 保存清理后的数据
cleaned_df.to_csv(output_file, index=False)
print(f"清理后的数据已保存到: {output_file}")
# 展示一些清理的例子
print(f"\n清理示例:")
examples_shown = 0
max_examples = 3
for idx, row in df.iterrows():
if examples_shown >= max_examples:
break
original_prompt = str(row['prompt']) if pd.notna(row['prompt']) else ''
cleaned_prompt = str(cleaned_df.loc[idx, 'prompt']) if pd.notna(cleaned_df.loc[idx, 'prompt']) else ''
original_neg_prompt = str(row['neg prompt']) if pd.notna(row['neg prompt']) else ''
cleaned_neg_prompt = str(cleaned_df.loc[idx, 'neg prompt']) if pd.notna(cleaned_df.loc[idx, 'neg prompt']) else ''
# 检查是否有变化
prompt_changed = original_prompt != cleaned_prompt
neg_prompt_changed = original_neg_prompt != cleaned_neg_prompt
if prompt_changed or neg_prompt_changed:
examples_shown += 1
print(f"\n示例 {examples_shown}:")
if prompt_changed:
# 显示具体的变化
lora_matches = re.findall(lora_pattern, original_prompt)
weight_matches = re.findall(weight_pattern, original_prompt)
if lora_matches:
print(f" Prompt中移除的LoRA: {lora_matches[:3]}{'...' if len(lora_matches) > 3 else ''}")
if weight_matches:
print(f" Prompt中移除的权重: {weight_matches[:3]}{'...' if len(weight_matches) > 3 else ''}")
# 显示前后对比(截取部分)
if len(original_prompt) > 100:
print(f" Prompt原文片段: {original_prompt[:100]}...")
print(f" Prompt清理后片段: {cleaned_prompt[:100]}...")
else:
print(f" Prompt原文: {original_prompt}")
print(f" Prompt清理后: {cleaned_prompt}")
if neg_prompt_changed:
lora_matches = re.findall(lora_pattern, original_neg_prompt)
weight_matches = re.findall(weight_pattern, original_neg_prompt)
if lora_matches:
print(f" Neg prompt中移除的LoRA: {lora_matches}")
if weight_matches:
print(f" Neg prompt中移除的权重: {weight_matches}")
return True
except Exception as e:
print(f"处理过程中发生错误: {str(e)}")
return False
def main():
"""主函数"""
# 默认输入文件路径
default_input = "/home/ubuntu/lyl/QwenIllustrious/civitai_image.csv"
# 解析命令行参数
if len(sys.argv) == 1:
input_file = default_input
output_file = None
elif len(sys.argv) == 2:
input_file = sys.argv[1]
output_file = None
elif len(sys.argv) == 3:
input_file = sys.argv[1]
output_file = sys.argv[2]
else:
print("使用方法:")
print(" python clean_weights_and_loras.py")
print(" python clean_weights_and_loras.py <input_file>")
print(" python clean_weights_and_loras.py <input_file> <output_file>")
return
print("=" * 60)
print("Civitai数据清理工具 - 移除权重值和LoRA标签")
print("=" * 60)
# 执行清理
success = clean_prompt_weights_and_loras(input_file, output_file)
if success:
print("\n✅ 数据清理完成!")
else:
print("\n❌ 数据清理失败!")
sys.exit(1)
if __name__ == "__main__":
main()