#!/usr/bin/env python3 """CUDA-optimized chat interface for Ursa Minor Smashed model""" import torch from inference_cuda import generate_direct, load_model_direct def main(): print("Ursa Minor Smashed Chat (CUDA)") print("Type 'quit' to exit, 'reset' to clear context") print("-" * 50) if not torch.cuda.is_available(): print("ERROR: CUDA is not available. Use chat_cpu.py for CPU inference.") return # Load model print("Loading model on CUDA...") model = load_model_direct("model_optimized.pt") print("Model loaded! Ready to chat.\n") context = "" max_context_length = 800 # Leave room for generation while True: user_input = input("You: ").strip() if user_input.lower() == 'quit': print("Goodbye!") break elif user_input.lower() == 'reset': context = "" print("Context cleared!") continue elif user_input == "": continue # Add user input to context if context: context += f"\nHuman: {user_input}\nAssistant:" else: context = f"Human: {user_input}\nAssistant:" # Truncate context if too long if len(context.split()) > max_context_length: # Keep recent context words = context.split() context = " ".join(words[-max_context_length:]) # Generate response with CUDA optimizations try: full_response = generate_direct( model, context, max_new_tokens=100, # Match inference_cuda.py default temperature=0.8, # Match inference_cuda.py default top_p=0.9, top_k=50, # Higher for better quality repetition_penalty=1.1 ) # Extract just the new response response = full_response[len(context):].strip() # Stop at next "Human:" if present if "Human:" in response: response = response.split("Human:")[0].strip() print(f"Assistant: {response}") # Add response to context for next turn context = full_response except Exception as e: print(f"Error generating response: {e}") print("Try typing 'reset' to clear context and continue.") if __name__ == "__main__": main()