File size: 5,401 Bytes
d575ce4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#!/usr/bin/env python3
"""

CUDA-optimized basic usage examples for Ursa Minor Smashed model

"""

import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
from inference_cuda import generate_direct, load_model_direct

def run_basic_examples():
    """Run basic CUDA-optimized usage examples"""
    
    if not torch.cuda.is_available():
        print("ERROR: CUDA is not available. Use basic_usage_cpu.py for CPU inference.")
        return
    
    print("๐Ÿš€ Ursa Minor Smashed - CUDA Basic Usage Examples")
    print("=" * 60)
    
    # Load model once for all examples
    print("Loading model on CUDA...")
    model = load_model_direct("model_optimized.pt")
    print("โœ… Model loaded!\n")
    
    examples = [
        {
            "name": "Creative Writing",
            "prompt": "In a world where artificial intelligence has",
            "params": {"max_new_tokens": 150, "temperature": 0.9, "top_k": 50}
        },
        {
            "name": "Code Generation", 
            "prompt": "def fibonacci(n):",
            "params": {"max_new_tokens": 120, "temperature": 0.4, "top_k": 40}
        },
        {
            "name": "Explanation",
            "prompt": "Explain how neural networks work:",
            "params": {"max_new_tokens": 200, "temperature": 0.7, "top_k": 50}
        },
        {
            "name": "Story Continuation",
            "prompt": "The spaceship landed on the mysterious planet, and the crew discovered",
            "params": {"max_new_tokens": 180, "temperature": 0.8, "top_k": 45}
        },
        {
            "name": "Technical Writing",
            "prompt": "The benefits of using GPU acceleration include",
            "params": {"max_new_tokens": 100, "temperature": 0.6, "top_k": 40}
        }
    ]
    
    for i, example in enumerate(examples, 1):
        print(f"๐Ÿ“ Example {i}: {example['name']}")
        print(f"๐Ÿ’ญ Prompt: {example['prompt']}")
        print("๐Ÿ”„ Generating...")
        
        try:
            result = generate_direct(
                model,
                example['prompt'],
                **example['params']
            )
            
            print("โœจ Result:")
            print("-" * 40)
            print(result)
            print("-" * 40)
            print()
            
        except Exception as e:
            print(f"โŒ Error: {e}")
            print()

def run_interactive_mode():
    """Run interactive mode for testing different parameters"""
    
    if not torch.cuda.is_available():
        print("ERROR: CUDA is not available. Use basic_usage_cpu.py for CPU inference.")
        return
    
    print("\n๐ŸŽฎ Interactive Mode")
    print("=" * 30)
    
    # Load model
    print("Loading model on CUDA...")
    model = load_model_direct("model_optimized.pt")
    print("โœ… Model loaded!")
    
    print("\nCommands:")
    print("- Enter a prompt to generate text")
    print("- Type 'params' to change generation parameters")
    print("- Type 'quit' to exit")
    print()
    
    # Default parameters optimized for CUDA
    params = {
        "max_new_tokens": 100,
        "temperature": 0.8,
        "top_k": 50,
        "top_p": 0.9,
        "repetition_penalty": 1.1
    }
    
    while True:
        user_input = input("๐ŸŽฏ Prompt (or command): ").strip()
        
        if user_input.lower() == 'quit':
            print("๐Ÿ‘‹ Goodbye!")
            break
        elif user_input.lower() == 'params':
            print("\nCurrent parameters:")
            for key, value in params.items():
                print(f"  {key}: {value}")
            
            print("\nEnter new values (press Enter to keep current):")
            for key in params:
                new_value = input(f"  {key} [{params[key]}]: ").strip()
                if new_value:
                    try:
                        if key == "max_new_tokens" or key == "top_k":
                            params[key] = int(new_value)
                        else:
                            params[key] = float(new_value)
                    except ValueError:
                        print(f"Invalid value for {key}, keeping current value")
            print()
            continue
        elif user_input == "":
            continue
        
        # Generate text
        try:
            print("๐Ÿ”„ Generating...")
            result = generate_direct(model, user_input, **params)
            print("โœจ Result:")
            print("-" * 40)
            print(result)
            print("-" * 40)
            print()
        except Exception as e:
            print(f"โŒ Error: {e}")
            print()

def main():
    """Main function"""
    print("Choose mode:")
    print("1. Run basic examples")
    print("2. Interactive mode")
    
    try:
        choice = input("Enter choice (1 or 2): ").strip()
        
        if choice == "1":
            run_basic_examples()
        elif choice == "2":
            run_interactive_mode()
        else:
            print("Invalid choice. Running basic examples...")
            run_basic_examples()
            
    except KeyboardInterrupt:
        print("\n๐Ÿ‘‹ Goodbye!")

if __name__ == "__main__":
    main()