lsmpp's picture
Add files using upload-large-folder tool
3f9fa87 verified
#!/usr/bin/env python3
"""
Test script for prompt augmentation
Generates a small number of prompts to test the functionality
"""
import sys
import os
sys.path.append(os.path.dirname(__file__))
from augment_prompts import PromptAugmenter
def test_prompt_generation():
"""Test the prompt generation with a small sample"""
print("=== Testing Prompt Augmentation ===\n")
# Initialize with a small test
try:
augmenter = PromptAugmenter(csv_path="../../civitai_image.csv")
print("βœ“ Model loaded successfully")
except Exception as e:
print(f"βœ— Error loading model: {e}")
return False
# Test data loading
try:
print(f"βœ“ Loaded {len(augmenter.df)} prompt pairs from CSV")
except Exception as e:
print(f"βœ— Error loading data: {e}")
return False
# Test sample generation
try:
samples = augmenter.get_random_samples(3)
print(f"βœ“ Generated {len(samples)} sample prompts")
for i, (pos, neg) in enumerate(samples, 1):
print(f" Sample {i}: {pos[:50]}...")
except Exception as e:
print(f"βœ— Error getting samples: {e}")
return False
# Test prompt cleaning
try:
sample_prompt = samples[0][0]
cleaned = augmenter.clean_prompt(sample_prompt)
print(f"βœ“ Prompt cleaning works")
print(f" Original: {sample_prompt[:60]}...")
print(f" Cleaned: {cleaned[:60]}...")
except Exception as e:
print(f"βœ— Error cleaning prompt: {e}")
return False
# Test instruction generation
try:
instruction = augmenter.generate_prompt_instruction(samples, multi_character_focus=True)
print(f"βœ“ Generated instruction ({len(instruction)} characters)")
except Exception as e:
print(f"βœ— Error generating instruction: {e}")
return False
# Test actual generation (just 2 prompts)
print("\n=== Testing Actual Generation ===")
try:
results = augmenter.generate_prompts(
target_count=2,
multi_character_prob=0.5,
save_every=1, # Save every prompt for testing
output_file="test_output.jsonl"
)
print(f"βœ“ Successfully generated {len(results)} prompts")
for i, result in enumerate(results, 1):
print(f"\nGenerated Prompt {i}:")
print(f" Positive: {result['positive_prompt'][:80]}...")
print(f" Negative: {result['negative_prompt'][:50]}...")
print(f" Multi-char: {result['multi_character_focus']}")
return True
except Exception as e:
print(f"βœ— Error in generation: {e}")
return False
if __name__ == "__main__":
success = test_prompt_generation()
if success:
print("\nπŸŽ‰ All tests passed! The augmentation script is ready to use.")
print("\nTo generate 10,000 prompts, run:")
print("python augment_prompts.py")
else:
print("\n❌ Tests failed. Please check the errors above.")