File size: 3,100 Bytes
3f9fa87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python3
"""
Test script for prompt augmentation
Generates a small number of prompts to test the functionality
"""

import sys
import os
sys.path.append(os.path.dirname(__file__))

from augment_prompts import PromptAugmenter


def test_prompt_generation():
    """Test the prompt generation with a small sample"""
    print("=== Testing Prompt Augmentation ===\n")
    
    # Initialize with a small test
    try:
        augmenter = PromptAugmenter(csv_path="../../civitai_image.csv")
        print("βœ“ Model loaded successfully")
    except Exception as e:
        print(f"βœ— Error loading model: {e}")
        return False
    
    # Test data loading
    try:
        print(f"βœ“ Loaded {len(augmenter.df)} prompt pairs from CSV")
    except Exception as e:
        print(f"βœ— Error loading data: {e}")
        return False
    
    # Test sample generation
    try:
        samples = augmenter.get_random_samples(3)
        print(f"βœ“ Generated {len(samples)} sample prompts")
        
        for i, (pos, neg) in enumerate(samples, 1):
            print(f"  Sample {i}: {pos[:50]}...")
    except Exception as e:
        print(f"βœ— Error getting samples: {e}")
        return False
    
    # Test prompt cleaning
    try:
        sample_prompt = samples[0][0]
        cleaned = augmenter.clean_prompt(sample_prompt)
        print(f"βœ“ Prompt cleaning works")
        print(f"  Original: {sample_prompt[:60]}...")
        print(f"  Cleaned:  {cleaned[:60]}...")
    except Exception as e:
        print(f"βœ— Error cleaning prompt: {e}")
        return False
    
    # Test instruction generation
    try:
        instruction = augmenter.generate_prompt_instruction(samples, multi_character_focus=True)
        print(f"βœ“ Generated instruction ({len(instruction)} characters)")
    except Exception as e:
        print(f"βœ— Error generating instruction: {e}")
        return False
    
    # Test actual generation (just 2 prompts)
    print("\n=== Testing Actual Generation ===")
    try:
        results = augmenter.generate_prompts(
            target_count=2, 
            multi_character_prob=0.5,
            save_every=1,  # Save every prompt for testing
            output_file="test_output.jsonl"
        )
        print(f"βœ“ Successfully generated {len(results)} prompts")
        
        for i, result in enumerate(results, 1):
            print(f"\nGenerated Prompt {i}:")
            print(f"  Positive: {result['positive_prompt'][:80]}...")
            print(f"  Negative: {result['negative_prompt'][:50]}...")
            print(f"  Multi-char: {result['multi_character_focus']}")
        
        return True
        
    except Exception as e:
        print(f"βœ— Error in generation: {e}")
        return False


if __name__ == "__main__":
    success = test_prompt_generation()
    if success:
        print("\nπŸŽ‰ All tests passed! The augmentation script is ready to use.")
        print("\nTo generate 10,000 prompts, run:")
        print("python augment_prompts.py")
    else:
        print("\n❌ Tests failed. Please check the errors above.")