File size: 2,609 Bytes
d575ce4 8691f4b d575ce4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
#!/usr/bin/env python3
"""CUDA-optimized chat interface for Ursa Minor Smashed model"""
import torch
from inference_cuda import generate_direct, load_model_direct
def main():
print("Ursa Minor Smashed Chat (CUDA)")
print("Type 'quit' to exit, 'reset' to clear context")
print("-" * 50)
if not torch.cuda.is_available():
print("ERROR: CUDA is not available. Use chat_cpu.py for CPU inference.")
return
# Load model
print("Loading model on CUDA...")
model = load_model_direct("model_optimized.pt")
print("Model loaded! Ready to chat.\n")
context = ""
max_context_length = 800 # Leave room for generation
while True:
user_input = input("You: ").strip()
if user_input.lower() == 'quit':
print("Goodbye!")
break
elif user_input.lower() == 'reset':
context = ""
print("Context cleared!")
continue
elif user_input == "":
continue
# Add user input to context
if context:
context += f"\nHuman: {user_input}\nAssistant:"
else:
context = f"Human: {user_input}\nAssistant:"
# Truncate context if too long
if len(context.split()) > max_context_length:
# Keep recent context
words = context.split()
context = " ".join(words[-max_context_length:])
# Generate response with CUDA optimizations
try:
full_response = generate_direct(
model,
context,
max_new_tokens=100, # Match inference_cuda.py default
temperature=0.8, # Match inference_cuda.py default
top_p=0.9,
top_k=50, # Higher for better quality
repetition_penalty=1.1
)
# Extract just the new response
response = full_response[len(context):].strip()
# Stop at next "Human:" if present
if "Human:" in response:
response = response.split("Human:")[0].strip()
print(f"Assistant: {response}")
# Add response to context for next turn
context = full_response
except Exception as e:
print(f"Error generating response: {e}")
print("Try typing 'reset' to clear context and continue.")
if __name__ == "__main__":
main() |