| |
| """ |
| Prompt Augmentation Script for QwenIllustrious |
| Generates new prompt pairs using Qwen3 model based on samples from civitai_image.csv |
| """ |
|
|
| import pandas as pd |
| import random |
| import json |
| import argparse |
| import os |
| from typing import List, Tuple, Dict |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import torch |
|
|
|
|
| class PromptAugmenter: |
| def __init__(self, model_name: str = "Qwen/Qwen3-8B", csv_path: str = None): |
| """Initialize the prompt augmenter with Qwen3 model""" |
| self.model_name = model_name |
| self.csv_path = csv_path or "../../civitai_image.csv" |
| |
| print(f"Loading model: {model_name}") |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype="auto", |
| device_map="auto" |
| ) |
| print("Model loaded successfully!") |
| |
| |
| self.load_data() |
| |
| def load_data(self): |
| """Load and process the civitai CSV data""" |
| print(f"Loading data from: {self.csv_path}") |
| try: |
| self.df = pd.read_csv(self.csv_path) |
| |
| self.df = self.df.dropna(subset=['prompt', 'neg prompt']) |
| self.df = self.df[self.df['prompt'].str.strip() != ''] |
| print(f"Loaded {len(self.df)} valid prompt pairs") |
| except Exception as e: |
| print(f"Error loading CSV: {e}") |
| raise |
| |
| def get_random_samples(self, n_samples: int = 3) -> List[Tuple[str, str]]: |
| """Get random prompt and negative prompt pairs from the dataset""" |
| samples = self.df.sample(n=min(n_samples, len(self.df))) |
| return [(row['prompt'], row['neg prompt']) for _, row in samples.iterrows()] |
| |
| def clean_prompt(self, prompt: str) -> str: |
| """Clean the prompt by removing embeddings and technical prefixes""" |
| |
| cleaned = prompt |
| if 'embedding:' in cleaned: |
| |
| parts = cleaned.split(',') |
| parts = [part.strip() for part in parts if not part.strip().startswith('embedding:')] |
| cleaned = ', '.join(parts) |
| |
| |
| |
| |
| |
| |
| |
| return cleaned.strip() |
| |
| def generate_prompt_instruction(self, samples: List[Tuple[str, str]], |
| multi_character_focus: bool = False) -> str: |
| """Generate the instruction for Qwen3 to create new prompts""" |
| |
| base_instruction = """You are an expert prompt engineer for AI image generation. I will provide you with some example prompt pairs (positive prompt and negative prompt) from a dataset. Your task is to create NEW, ORIGINAL prompt pairs that are similar in style, length, and quality to the examples. |
| |
| Here are the example prompt pairs: |
| |
| """ |
| |
| |
| for i, (pos_prompt, neg_prompt) in enumerate(samples, 1): |
| cleaned_pos = self.clean_prompt(pos_prompt) |
| base_instruction += f"Example {i}:\n" |
| base_instruction += f"Positive prompt: {cleaned_pos}\n" |
| base_instruction += f"Negative prompt: {neg_prompt}\n\n" |
| |
| if multi_character_focus: |
| specific_instruction = """Please generate 1 NEW prompt pair that emphasizes MULTIPLE CHARACTERS interacting with each other. Focus on creating scenes with 2 or more characters, their relationships, interactions, and dynamic compositions. The prompt should be detailed and similar in length to the examples.""" |
| else: |
| specific_instruction = """Please generate 1 NEW prompt pair that is creative and original, maintaining similar style, detail level, and length as the examples. Focus on creating visually interesting and detailed descriptions.""" |
| |
| format_instruction = """ |
| |
| Format your response as a JSON object with this exact structure: |
| { |
| "positive_prompt": "your new positive prompt here", |
| "negative_prompt": "your new negative prompt here" |
| } |
| |
| Make sure: |
| 1. The positive prompt is detailed and descriptive (similar length to examples) |
| 2. The negative prompt includes common quality control terms |
| 3. Both prompts are in English |
| 4. The content is appropriate and creative |
| 5. Return ONLY the JSON object, no additional text""" |
| |
| return base_instruction + specific_instruction + format_instruction |
| |
| def generate_with_qwen3(self, instruction: str) -> Dict[str, str]: |
| """Generate new prompt using Qwen3 model""" |
| try: |
| messages = [ |
| {"role": "user", "content": instruction} |
| ] |
| |
| text = self.tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=True |
| ) |
| |
| model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) |
| |
| |
| generated_ids = self.model.generate( |
| **model_inputs, |
| max_new_tokens=2048, |
| temperature=0.7, |
| do_sample=True, |
| top_p=0.9, |
| pad_token_id=self.tokenizer.eos_token_id |
| ) |
| |
| output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() |
| |
| |
| try: |
| index = len(output_ids) - output_ids[::-1].index(151668) |
| except ValueError: |
| index = 0 |
| |
| thinking_content = self.tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n") |
| content = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n") |
| |
| print(f"Thinking content: {thinking_content[:200]}...") |
| print(f"Response: {content[:200]}...") |
| |
| |
| try: |
| |
| content = content.strip() |
| if content.startswith('```json'): |
| content = content[7:] |
| if content.endswith('```'): |
| content = content[:-3] |
| content = content.strip() |
| |
| result = json.loads(content) |
| return result |
| except json.JSONDecodeError as e: |
| print(f"JSON parsing error: {e}") |
| print(f"Raw content: {content}") |
| return None |
| |
| except Exception as e: |
| print(f"Generation error: {e}") |
| return None |
| |
| def generate_prompts(self, target_count: int = 10000, |
| multi_character_prob: float = 0.4, |
| samples_per_batch: int = 3, |
| output_file: str = "augmented_prompts.jsonl", |
| save_every: int = 10): |
| """Generate specified number of new prompts with periodic saving""" |
| |
| print(f"Starting prompt generation: target={target_count}, multi_char_prob={multi_character_prob}") |
| print(f"Saving every {save_every} successful generations to: {output_file}") |
| |
| |
| existing_count = 0 |
| if os.path.exists(output_file): |
| try: |
| with open(output_file, 'r', encoding='utf-8') as f: |
| existing_count = sum(1 for line in f if line.strip()) |
| print(f"Found {existing_count} existing prompts in {output_file}") |
| if existing_count >= target_count: |
| print(f"Target already reached! ({existing_count} >= {target_count})") |
| return [] |
| except Exception as e: |
| print(f"Warning: Could not read existing file: {e}") |
| |
| generated_prompts = [] |
| successful_generations = existing_count |
| total_attempts = 0 |
| temp_batch = [] |
| |
| |
| file_mode = 'a' if existing_count > 0 else 'w' |
| |
| try: |
| with open(output_file, file_mode, encoding='utf-8') as f: |
| while successful_generations < target_count: |
| total_attempts += 1 |
| |
| |
| samples = self.get_random_samples(samples_per_batch) |
| |
| |
| is_multi_character = random.random() < multi_character_prob |
| |
| |
| instruction = self.generate_prompt_instruction(samples, is_multi_character) |
| |
| |
| print(f"\nAttempt {total_attempts} (Success: {successful_generations}/{target_count})") |
| print(f"Multi-character focus: {is_multi_character}") |
| |
| result = self.generate_with_qwen3(instruction) |
| |
| if result and 'positive_prompt' in result and 'negative_prompt' in result: |
| |
| result['multi_character_focus'] = is_multi_character |
| result['generation_attempt'] = total_attempts |
| result['sample_sources'] = [self.clean_prompt(s[0])[:100] + "..." for s in samples] |
| |
| |
| temp_batch.append(result) |
| generated_prompts.append(result) |
| successful_generations += 1 |
| |
| print(f"โ Generated prompt {successful_generations}") |
| print(f"Positive: {result['positive_prompt'][:100]}...") |
| print(f"Negative: {result['negative_prompt'][:50]}...") |
| |
| |
| if len(temp_batch) >= save_every: |
| for prompt in temp_batch: |
| f.write(json.dumps(prompt, ensure_ascii=False) + '\n') |
| f.flush() |
| print(f"๐พ Saved batch of {len(temp_batch)} prompts to file") |
| temp_batch = [] |
| |
| else: |
| print("โ Failed to generate valid prompt") |
| |
| |
| if total_attempts % 100 == 0: |
| success_rate = successful_generations / total_attempts * 100 |
| print(f"\n=== Progress Update ===") |
| print(f"Attempts: {total_attempts}") |
| print(f"Successful: {successful_generations}") |
| print(f"Success rate: {success_rate:.1f}%") |
| print(f"Remaining: {target_count - successful_generations}") |
| print(f"Current batch size: {len(temp_batch)}") |
| |
| |
| if temp_batch: |
| for prompt in temp_batch: |
| f.write(json.dumps(prompt, ensure_ascii=False) + '\n') |
| f.flush() |
| print(f"๐พ Saved final batch of {len(temp_batch)} prompts to file") |
| |
| except KeyboardInterrupt: |
| print(f"\n๐ Generation interrupted by user!") |
| |
| if temp_batch: |
| with open(output_file, 'a', encoding='utf-8') as f: |
| for prompt in temp_batch: |
| f.write(json.dumps(prompt, ensure_ascii=False) + '\n') |
| f.flush() |
| print(f"๐พ Saved {len(temp_batch)} prompts before exit") |
| raise |
| except Exception as e: |
| print(f"โ Unexpected error: {e}") |
| |
| if temp_batch: |
| try: |
| with open(output_file, 'a', encoding='utf-8') as f: |
| for prompt in temp_batch: |
| f.write(json.dumps(prompt, ensure_ascii=False) + '\n') |
| f.flush() |
| print(f"๐พ Saved {len(temp_batch)} prompts before error exit") |
| except: |
| pass |
| raise |
| |
| print(f"\n=== Generation Complete ===") |
| print(f"Total attempts: {total_attempts}") |
| print(f"Successful generations: {successful_generations}") |
| print(f"Success rate: {successful_generations/total_attempts*100:.1f}%") |
| print(f"Output saved to: {output_file}") |
| |
| return generated_prompts |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Generate augmented prompts using Qwen3') |
| parser.add_argument('--model', default='models/Qwen3-8B', |
| help='Model name or path') |
| parser.add_argument('--csv_path', default='civitai_image.csv', |
| help='Path to civitai CSV file') |
| parser.add_argument('--target_count', type=int, default=10000, |
| help='Number of prompts to generate') |
| parser.add_argument('--multi_char_prob', type=float, default=0.4, |
| help='Probability of multi-character focused prompts') |
| parser.add_argument('--samples_per_batch', type=int, default=3, |
| help='Number of sample prompts to show to model each time') |
| parser.add_argument('--save_every', type=int, default=10, |
| help='Save to file every N successful generations') |
| parser.add_argument('--output', default='augmented_prompts.jsonl', |
| help='Output file name') |
| |
| args = parser.parse_args() |
| |
| |
| augmenter = PromptAugmenter( |
| model_name=args.model, |
| csv_path=args.csv_path |
| ) |
| |
| |
| augmenter.generate_prompts( |
| target_count=args.target_count, |
| multi_character_prob=args.multi_char_prob, |
| samples_per_batch=args.samples_per_batch, |
| output_file=args.output, |
| save_every=args.save_every |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|