|
|
|
|
|
""" |
|
|
Test script for prompt augmentation |
|
|
Generates a small number of prompts to test the functionality |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import os |
|
|
sys.path.append(os.path.dirname(__file__)) |
|
|
|
|
|
from augment_prompts import PromptAugmenter |
|
|
|
|
|
|
|
|
def test_prompt_generation(): |
|
|
"""Test the prompt generation with a small sample""" |
|
|
print("=== Testing Prompt Augmentation ===\n") |
|
|
|
|
|
|
|
|
try: |
|
|
augmenter = PromptAugmenter(csv_path="../../civitai_image.csv") |
|
|
print("β Model loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"β Error loading model: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
try: |
|
|
print(f"β Loaded {len(augmenter.df)} prompt pairs from CSV") |
|
|
except Exception as e: |
|
|
print(f"β Error loading data: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
try: |
|
|
samples = augmenter.get_random_samples(3) |
|
|
print(f"β Generated {len(samples)} sample prompts") |
|
|
|
|
|
for i, (pos, neg) in enumerate(samples, 1): |
|
|
print(f" Sample {i}: {pos[:50]}...") |
|
|
except Exception as e: |
|
|
print(f"β Error getting samples: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
try: |
|
|
sample_prompt = samples[0][0] |
|
|
cleaned = augmenter.clean_prompt(sample_prompt) |
|
|
print(f"β Prompt cleaning works") |
|
|
print(f" Original: {sample_prompt[:60]}...") |
|
|
print(f" Cleaned: {cleaned[:60]}...") |
|
|
except Exception as e: |
|
|
print(f"β Error cleaning prompt: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
try: |
|
|
instruction = augmenter.generate_prompt_instruction(samples, multi_character_focus=True) |
|
|
print(f"β Generated instruction ({len(instruction)} characters)") |
|
|
except Exception as e: |
|
|
print(f"β Error generating instruction: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
print("\n=== Testing Actual Generation ===") |
|
|
try: |
|
|
results = augmenter.generate_prompts( |
|
|
target_count=2, |
|
|
multi_character_prob=0.5, |
|
|
save_every=1, |
|
|
output_file="test_output.jsonl" |
|
|
) |
|
|
print(f"β Successfully generated {len(results)} prompts") |
|
|
|
|
|
for i, result in enumerate(results, 1): |
|
|
print(f"\nGenerated Prompt {i}:") |
|
|
print(f" Positive: {result['positive_prompt'][:80]}...") |
|
|
print(f" Negative: {result['negative_prompt'][:50]}...") |
|
|
print(f" Multi-char: {result['multi_character_focus']}") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error in generation: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
success = test_prompt_generation() |
|
|
if success: |
|
|
print("\nπ All tests passed! The augmentation script is ready to use.") |
|
|
print("\nTo generate 10,000 prompts, run:") |
|
|
print("python augment_prompts.py") |
|
|
else: |
|
|
print("\nβ Tests failed. Please check the errors above.") |
|
|
|