#!/usr/bin/env python3 """ Test script for prompt augmentation Generates a small number of prompts to test the functionality """ import sys import os sys.path.append(os.path.dirname(__file__)) from augment_prompts import PromptAugmenter def test_prompt_generation(): """Test the prompt generation with a small sample""" print("=== Testing Prompt Augmentation ===\n") # Initialize with a small test try: augmenter = PromptAugmenter(csv_path="../../civitai_image.csv") print("✓ Model loaded successfully") except Exception as e: print(f"✗ Error loading model: {e}") return False # Test data loading try: print(f"✓ Loaded {len(augmenter.df)} prompt pairs from CSV") except Exception as e: print(f"✗ Error loading data: {e}") return False # Test sample generation try: samples = augmenter.get_random_samples(3) print(f"✓ Generated {len(samples)} sample prompts") for i, (pos, neg) in enumerate(samples, 1): print(f" Sample {i}: {pos[:50]}...") except Exception as e: print(f"✗ Error getting samples: {e}") return False # Test prompt cleaning try: sample_prompt = samples[0][0] cleaned = augmenter.clean_prompt(sample_prompt) print(f"✓ Prompt cleaning works") print(f" Original: {sample_prompt[:60]}...") print(f" Cleaned: {cleaned[:60]}...") except Exception as e: print(f"✗ Error cleaning prompt: {e}") return False # Test instruction generation try: instruction = augmenter.generate_prompt_instruction(samples, multi_character_focus=True) print(f"✓ Generated instruction ({len(instruction)} characters)") except Exception as e: print(f"✗ Error generating instruction: {e}") return False # Test actual generation (just 2 prompts) print("\n=== Testing Actual Generation ===") try: results = augmenter.generate_prompts( target_count=2, multi_character_prob=0.5, save_every=1, # Save every prompt for testing output_file="test_output.jsonl" ) print(f"✓ Successfully generated {len(results)} prompts") for i, result in enumerate(results, 1): print(f"\nGenerated Prompt {i}:") print(f" Positive: {result['positive_prompt'][:80]}...") print(f" Negative: {result['negative_prompt'][:50]}...") print(f" Multi-char: {result['multi_character_focus']}") return True except Exception as e: print(f"✗ Error in generation: {e}") return False if __name__ == "__main__": success = test_prompt_generation() if success: print("\n🎉 All tests passed! The augmentation script is ready to use.") print("\nTo generate 10,000 prompts, run:") print("python augment_prompts.py") else: print("\n❌ Tests failed. Please check the errors above.")