File size: 3,100 Bytes
3f9fa87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
#!/usr/bin/env python3
"""
Test script for prompt augmentation
Generates a small number of prompts to test the functionality
"""
import sys
import os
sys.path.append(os.path.dirname(__file__))
from augment_prompts import PromptAugmenter
def test_prompt_generation():
"""Test the prompt generation with a small sample"""
print("=== Testing Prompt Augmentation ===\n")
# Initialize with a small test
try:
augmenter = PromptAugmenter(csv_path="../../civitai_image.csv")
print("β Model loaded successfully")
except Exception as e:
print(f"β Error loading model: {e}")
return False
# Test data loading
try:
print(f"β Loaded {len(augmenter.df)} prompt pairs from CSV")
except Exception as e:
print(f"β Error loading data: {e}")
return False
# Test sample generation
try:
samples = augmenter.get_random_samples(3)
print(f"β Generated {len(samples)} sample prompts")
for i, (pos, neg) in enumerate(samples, 1):
print(f" Sample {i}: {pos[:50]}...")
except Exception as e:
print(f"β Error getting samples: {e}")
return False
# Test prompt cleaning
try:
sample_prompt = samples[0][0]
cleaned = augmenter.clean_prompt(sample_prompt)
print(f"β Prompt cleaning works")
print(f" Original: {sample_prompt[:60]}...")
print(f" Cleaned: {cleaned[:60]}...")
except Exception as e:
print(f"β Error cleaning prompt: {e}")
return False
# Test instruction generation
try:
instruction = augmenter.generate_prompt_instruction(samples, multi_character_focus=True)
print(f"β Generated instruction ({len(instruction)} characters)")
except Exception as e:
print(f"β Error generating instruction: {e}")
return False
# Test actual generation (just 2 prompts)
print("\n=== Testing Actual Generation ===")
try:
results = augmenter.generate_prompts(
target_count=2,
multi_character_prob=0.5,
save_every=1, # Save every prompt for testing
output_file="test_output.jsonl"
)
print(f"β Successfully generated {len(results)} prompts")
for i, result in enumerate(results, 1):
print(f"\nGenerated Prompt {i}:")
print(f" Positive: {result['positive_prompt'][:80]}...")
print(f" Negative: {result['negative_prompt'][:50]}...")
print(f" Multi-char: {result['multi_character_focus']}")
return True
except Exception as e:
print(f"β Error in generation: {e}")
return False
if __name__ == "__main__":
success = test_prompt_generation()
if success:
print("\nπ All tests passed! The augmentation script is ready to use.")
print("\nTo generate 10,000 prompts, run:")
print("python augment_prompts.py")
else:
print("\nβ Tests failed. Please check the errors above.")
|