Ursa_Minor_Smashed / chat_cuda.py
Kaileh57's picture
Upload folder using huggingface_hub
8691f4b verified
#!/usr/bin/env python3
"""CUDA-optimized chat interface for Ursa Minor Smashed model"""
import torch
from inference_cuda import generate_direct, load_model_direct
def main():
print("Ursa Minor Smashed Chat (CUDA)")
print("Type 'quit' to exit, 'reset' to clear context")
print("-" * 50)
if not torch.cuda.is_available():
print("ERROR: CUDA is not available. Use chat_cpu.py for CPU inference.")
return
# Load model
print("Loading model on CUDA...")
model = load_model_direct("model_optimized.pt")
print("Model loaded! Ready to chat.\n")
context = ""
max_context_length = 800 # Leave room for generation
while True:
user_input = input("You: ").strip()
if user_input.lower() == 'quit':
print("Goodbye!")
break
elif user_input.lower() == 'reset':
context = ""
print("Context cleared!")
continue
elif user_input == "":
continue
# Add user input to context
if context:
context += f"\nHuman: {user_input}\nAssistant:"
else:
context = f"Human: {user_input}\nAssistant:"
# Truncate context if too long
if len(context.split()) > max_context_length:
# Keep recent context
words = context.split()
context = " ".join(words[-max_context_length:])
# Generate response with CUDA optimizations
try:
full_response = generate_direct(
model,
context,
max_new_tokens=100, # Match inference_cuda.py default
temperature=0.8, # Match inference_cuda.py default
top_p=0.9,
top_k=50, # Higher for better quality
repetition_penalty=1.1
)
# Extract just the new response
response = full_response[len(context):].strip()
# Stop at next "Human:" if present
if "Human:" in response:
response = response.split("Human:")[0].strip()
print(f"Assistant: {response}")
# Add response to context for next turn
context = full_response
except Exception as e:
print(f"Error generating response: {e}")
print("Try typing 'reset' to clear context and continue.")
if __name__ == "__main__":
main()