lsmpp's picture
Add files using upload-large-folder tool
3f9fa87 verified
#!/usr/bin/env python3
"""
Prompt Augmentation Script for QwenIllustrious
Generates new prompt pairs using Qwen3 model based on samples from civitai_image.csv
"""
import pandas as pd
import random
import json
import argparse
import os
from typing import List, Tuple, Dict
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class PromptAugmenter:
def __init__(self, model_name: str = "Qwen/Qwen3-8B", csv_path: str = None):
"""Initialize the prompt augmenter with Qwen3 model"""
self.model_name = model_name
self.csv_path = csv_path or "../../civitai_image.csv"
print(f"Loading model: {model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
print("Model loaded successfully!")
# Load the CSV data
self.load_data()
def load_data(self):
"""Load and process the civitai CSV data"""
print(f"Loading data from: {self.csv_path}")
try:
self.df = pd.read_csv(self.csv_path)
# Clean the data - remove rows with empty prompts
self.df = self.df.dropna(subset=['prompt', 'neg prompt'])
self.df = self.df[self.df['prompt'].str.strip() != '']
print(f"Loaded {len(self.df)} valid prompt pairs")
except Exception as e:
print(f"Error loading CSV: {e}")
raise
def get_random_samples(self, n_samples: int = 3) -> List[Tuple[str, str]]:
"""Get random prompt and negative prompt pairs from the dataset"""
samples = self.df.sample(n=min(n_samples, len(self.df)))
return [(row['prompt'], row['neg prompt']) for _, row in samples.iterrows()]
def clean_prompt(self, prompt: str) -> str:
"""Clean the prompt by removing embeddings and technical prefixes"""
# Remove embedding tags
cleaned = prompt
if 'embedding:' in cleaned:
# Remove embedding tags like "embedding:Illustrious\IllusP0s, "
parts = cleaned.split(',')
parts = [part.strip() for part in parts if not part.strip().startswith('embedding:')]
cleaned = ', '.join(parts)
# Remove common technical prefixes
# prefixes_to_remove = ['safe_pos,', 'masterpiece,', 'best quality,']
# for prefix in prefixes_to_remove:
# if cleaned.startswith(prefix):
# cleaned = cleaned[len(prefix):].strip()
return cleaned.strip()
def generate_prompt_instruction(self, samples: List[Tuple[str, str]],
multi_character_focus: bool = False) -> str:
"""Generate the instruction for Qwen3 to create new prompts"""
base_instruction = """You are an expert prompt engineer for AI image generation. I will provide you with some example prompt pairs (positive prompt and negative prompt) from a dataset. Your task is to create NEW, ORIGINAL prompt pairs that are similar in style, length, and quality to the examples.
Here are the example prompt pairs:
"""
# Add sample prompts
for i, (pos_prompt, neg_prompt) in enumerate(samples, 1):
cleaned_pos = self.clean_prompt(pos_prompt)
base_instruction += f"Example {i}:\n"
base_instruction += f"Positive prompt: {cleaned_pos}\n"
base_instruction += f"Negative prompt: {neg_prompt}\n\n"
if multi_character_focus:
specific_instruction = """Please generate 1 NEW prompt pair that emphasizes MULTIPLE CHARACTERS interacting with each other. Focus on creating scenes with 2 or more characters, their relationships, interactions, and dynamic compositions. The prompt should be detailed and similar in length to the examples."""
else:
specific_instruction = """Please generate 1 NEW prompt pair that is creative and original, maintaining similar style, detail level, and length as the examples. Focus on creating visually interesting and detailed descriptions."""
format_instruction = """
Format your response as a JSON object with this exact structure:
{
"positive_prompt": "your new positive prompt here",
"negative_prompt": "your new negative prompt here"
}
Make sure:
1. The positive prompt is detailed and descriptive (similar length to examples)
2. The negative prompt includes common quality control terms
3. Both prompts are in English
4. The content is appropriate and creative
5. Return ONLY the JSON object, no additional text"""
return base_instruction + specific_instruction + format_instruction
def generate_with_qwen3(self, instruction: str) -> Dict[str, str]:
"""Generate new prompt using Qwen3 model"""
try:
messages = [
{"role": "user", "content": instruction}
]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True
)
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
# Generate with reasonable parameters
generated_ids = self.model.generate(
**model_inputs,
max_new_tokens=2048,
temperature=0.7,
do_sample=True,
top_p=0.9,
pad_token_id=self.tokenizer.eos_token_id
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
# Parse thinking content
try:
index = len(output_ids) - output_ids[::-1].index(151668) # </think>
except ValueError:
index = 0
thinking_content = self.tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
content = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
print(f"Thinking content: {thinking_content[:200]}...")
print(f"Response: {content[:200]}...")
# Try to parse JSON from the response
try:
# Clean the content - sometimes model adds extra text
content = content.strip()
if content.startswith('```json'):
content = content[7:]
if content.endswith('```'):
content = content[:-3]
content = content.strip()
result = json.loads(content)
return result
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
print(f"Raw content: {content}")
return None
except Exception as e:
print(f"Generation error: {e}")
return None
def generate_prompts(self, target_count: int = 10000,
multi_character_prob: float = 0.4,
samples_per_batch: int = 3,
output_file: str = "augmented_prompts.jsonl",
save_every: int = 10):
"""Generate specified number of new prompts with periodic saving"""
print(f"Starting prompt generation: target={target_count}, multi_char_prob={multi_character_prob}")
print(f"Saving every {save_every} successful generations to: {output_file}")
# Check if output file already exists and count existing entries
existing_count = 0
if os.path.exists(output_file):
try:
with open(output_file, 'r', encoding='utf-8') as f:
existing_count = sum(1 for line in f if line.strip())
print(f"Found {existing_count} existing prompts in {output_file}")
if existing_count >= target_count:
print(f"Target already reached! ({existing_count} >= {target_count})")
return []
except Exception as e:
print(f"Warning: Could not read existing file: {e}")
generated_prompts = []
successful_generations = existing_count
total_attempts = 0
temp_batch = [] # Store prompts before writing to file
# Open file in append mode if it exists, write mode if new
file_mode = 'a' if existing_count > 0 else 'w'
try:
with open(output_file, file_mode, encoding='utf-8') as f:
while successful_generations < target_count:
total_attempts += 1
# Get random samples
samples = self.get_random_samples(samples_per_batch)
# Decide if this should be multi-character focused
is_multi_character = random.random() < multi_character_prob
# Generate instruction
instruction = self.generate_prompt_instruction(samples, is_multi_character)
# Generate new prompt
print(f"\nAttempt {total_attempts} (Success: {successful_generations}/{target_count})")
print(f"Multi-character focus: {is_multi_character}")
result = self.generate_with_qwen3(instruction)
if result and 'positive_prompt' in result and 'negative_prompt' in result:
# Add metadata
result['multi_character_focus'] = is_multi_character
result['generation_attempt'] = total_attempts
result['sample_sources'] = [self.clean_prompt(s[0])[:100] + "..." for s in samples]
# Add to temporary batch
temp_batch.append(result)
generated_prompts.append(result)
successful_generations += 1
print(f"โœ“ Generated prompt {successful_generations}")
print(f"Positive: {result['positive_prompt'][:100]}...")
print(f"Negative: {result['negative_prompt'][:50]}...")
# Save batch every N successful generations
if len(temp_batch) >= save_every:
for prompt in temp_batch:
f.write(json.dumps(prompt, ensure_ascii=False) + '\n')
f.flush()
print(f"๐Ÿ’พ Saved batch of {len(temp_batch)} prompts to file")
temp_batch = []
else:
print("โœ— Failed to generate valid prompt")
# Progress update every 100 attempts
if total_attempts % 100 == 0:
success_rate = successful_generations / total_attempts * 100
print(f"\n=== Progress Update ===")
print(f"Attempts: {total_attempts}")
print(f"Successful: {successful_generations}")
print(f"Success rate: {success_rate:.1f}%")
print(f"Remaining: {target_count - successful_generations}")
print(f"Current batch size: {len(temp_batch)}")
# Save any remaining prompts in the final batch
if temp_batch:
for prompt in temp_batch:
f.write(json.dumps(prompt, ensure_ascii=False) + '\n')
f.flush()
print(f"๐Ÿ’พ Saved final batch of {len(temp_batch)} prompts to file")
except KeyboardInterrupt:
print(f"\n๐Ÿ›‘ Generation interrupted by user!")
# Save any remaining prompts before exiting
if temp_batch:
with open(output_file, 'a', encoding='utf-8') as f:
for prompt in temp_batch:
f.write(json.dumps(prompt, ensure_ascii=False) + '\n')
f.flush()
print(f"๐Ÿ’พ Saved {len(temp_batch)} prompts before exit")
raise
except Exception as e:
print(f"โŒ Unexpected error: {e}")
# Save any remaining prompts before exiting
if temp_batch:
try:
with open(output_file, 'a', encoding='utf-8') as f:
for prompt in temp_batch:
f.write(json.dumps(prompt, ensure_ascii=False) + '\n')
f.flush()
print(f"๐Ÿ’พ Saved {len(temp_batch)} prompts before error exit")
except:
pass
raise
print(f"\n=== Generation Complete ===")
print(f"Total attempts: {total_attempts}")
print(f"Successful generations: {successful_generations}")
print(f"Success rate: {successful_generations/total_attempts*100:.1f}%")
print(f"Output saved to: {output_file}")
return generated_prompts
def main():
parser = argparse.ArgumentParser(description='Generate augmented prompts using Qwen3')
parser.add_argument('--model', default='models/Qwen3-8B',
help='Model name or path')
parser.add_argument('--csv_path', default='civitai_image.csv',
help='Path to civitai CSV file')
parser.add_argument('--target_count', type=int, default=10000,
help='Number of prompts to generate')
parser.add_argument('--multi_char_prob', type=float, default=0.4,
help='Probability of multi-character focused prompts')
parser.add_argument('--samples_per_batch', type=int, default=3,
help='Number of sample prompts to show to model each time')
parser.add_argument('--save_every', type=int, default=10,
help='Save to file every N successful generations')
parser.add_argument('--output', default='augmented_prompts.jsonl',
help='Output file name')
args = parser.parse_args()
# Initialize augmenter
augmenter = PromptAugmenter(
model_name=args.model,
csv_path=args.csv_path
)
# Generate prompts
augmenter.generate_prompts(
target_count=args.target_count,
multi_character_prob=args.multi_char_prob,
samples_per_batch=args.samples_per_batch,
output_file=args.output,
save_every=args.save_every
)
if __name__ == "__main__":
main()