File size: 2,609 Bytes
d575ce4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8691f4b
 
d575ce4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env python3
"""CUDA-optimized chat interface for Ursa Minor Smashed model"""

import torch
from inference_cuda import generate_direct, load_model_direct

def main():
    print("Ursa Minor Smashed Chat (CUDA)")
    print("Type 'quit' to exit, 'reset' to clear context")
    print("-" * 50)
    
    if not torch.cuda.is_available():
        print("ERROR: CUDA is not available. Use chat_cpu.py for CPU inference.")
        return
    
    # Load model
    print("Loading model on CUDA...")
    model = load_model_direct("model_optimized.pt")
    print("Model loaded! Ready to chat.\n")
    
    context = ""
    max_context_length = 800  # Leave room for generation
    
    while True:
        user_input = input("You: ").strip()
        
        if user_input.lower() == 'quit':
            print("Goodbye!")
            break
        elif user_input.lower() == 'reset':
            context = ""
            print("Context cleared!")
            continue
        elif user_input == "":
            continue
        
        # Add user input to context
        if context:
            context += f"\nHuman: {user_input}\nAssistant:"
        else:
            context = f"Human: {user_input}\nAssistant:"
        
        # Truncate context if too long
        if len(context.split()) > max_context_length:
            # Keep recent context
            words = context.split()
            context = " ".join(words[-max_context_length:])
        
        # Generate response with CUDA optimizations
        try:
            full_response = generate_direct(
                model,
                context,
                max_new_tokens=100,  # Match inference_cuda.py default
                temperature=0.8,    # Match inference_cuda.py default
                top_p=0.9,
                top_k=50,  # Higher for better quality
                repetition_penalty=1.1
            )
            
            # Extract just the new response
            response = full_response[len(context):].strip()
            
            # Stop at next "Human:" if present
            if "Human:" in response:
                response = response.split("Human:")[0].strip()
            
            print(f"Assistant: {response}")
            
            # Add response to context for next turn
            context = full_response
            
        except Exception as e:
            print(f"Error generating response: {e}")
            print("Try typing 'reset' to clear context and continue.")

if __name__ == "__main__":
    main()